8000 Python Bindings for SymInts by Krovatkin · Pull Request #78135 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Python Bindings for SymInts #78135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed

Python Bindings for SymInts #78135

wants to merge 15 commits into from

Conversation

Krovatkin
Copy link
Contributor
@Krovatkin Krovatkin commented May 23, 2022

This PR adds support for SymInts in python. Namely,

  • THPVariable_size now returns sym_sizes()
  • python arg parser is modified to parse PyObjects into ints and SymbolicIntNodes
  • pybind11 bindings for SymbolicIntNode are added, so size expressions can be traced
  • a large number of tests added to demonstrate how to implement python symints.

@facebook-github-bot
Copy link
Contributor
facebook-github-bot commented May 23, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 50810d1 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label May 23, 2022
@Chillee Chillee mentioned this pull request May 31, 2022
7 tasks
@Krovatkin Krovatkin changed the title Python SymInts collab [WIP] Python Bindings for SymInts Jun 1, 2022
@Krovatkin Krovatkin force-pushed the krovatkin/pybind_symint branch from e43d362 to 8bdcc93 Compare June 1, 2022 17:59
virtual SymbolicIntNode* wrap(int64_t num) { TORCH_CHECK(false, "NYI"); };
virtual bool bool_() { TORCH_CHECK(false, "NYI"); };
virtual int64_t int_() { TORCH_CHECK(false, "NYI"); }
virtual std::string str() { TORCH_CHECK(false, "NYI"); };
virtual std::ostream& operator<<(std::ostream& os) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this still need to be virtual

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does the LTC impl won't have a python object to redispatch or fallback to

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but they'll implement str instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OH, just realized what you meant ... operator<< does NOT to be virtual me thinks :)

virtual SymbolicIntNode* gt(SymbolicIntNode* other) { TORCH_CHECK(false, "NYI"); };
virtual SymbolicIntNode* lt(SymbolicIntNode* other) { TORCH_CHECK(false, "NYI"); };
virtual SymbolicIntNode* wrap(int64_t num) { TORCH_CHECK(false, "NYI"); };
virtual bool bool_() { TORCH_CHECK(false, "NYI"); };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not sure why we need this on top of the int conversion. Is the problem that you also want a SymbolicBoolNode as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@horace overrode bool in his PoC, so I added it as well. I can provide a default implementation static_cast<boo>(this->int_()) or just let implementers implement via int_.

@@ -1306,6 +1306,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return numel() == 0;
}

// if we are going to use sym sizes, we should be setting sym strides at the same time,
// otherwise it's very easy to misuse this API
virtual void set_sym_sizes_and_strides(c10::SymIntArrayRef sizes, c10::SymIntArrayRef strides);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not clear to me why this needs to be virtual

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, @suo mentioned he would help to formalize this API. I'll remove virtual in the meantime.

@@ -2330,6 +2335,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
void set_custom_device(bool custom_device) {
custom_device_ = custom_device;
}
protected:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol thanks

@@ -2321,6 +2325,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
//
// Can override: strides(), is_contiguous(), sizes(), dim(), numel()
CustomSizes = 2,
CustomSymSizes = 3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skeptical about this. I'll read through the uses first

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

requirements.txt Outdated
@@ -10,3 +10,5 @@ setuptools
six
types-dataclasses
typing_extensions
dataclasses; python_version<"3.7"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're py3.7 and up only, so this really shouldn't be needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah this is a bad merge.. will fix..

@@ -40,6 +45,29 @@ PyObject * THPSize_NewFromSizes(int dim, const int64_t *sizes)
return self.release();
}

PyObject * THPSize_NewFromSymSizes(const at::Tensor& self_)
{
HANDLE_TH_ERRORS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you shouldn't need this macro here; this function is not directly bound to python

@@ -389,9 +390,52 @@ inline std::vector<int64_t> PythonArgs::intlist(int i) {
return intlistWithDefault(i, signature.params[i].default_intlist);
}

TORCH_API bool is_symint_node(py::handle obj);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just include the header?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah let me inline it back.

const auto size1 = signature.params[i].size;
if (size1 > 0 && THPUtils_checkLong(args[i])) {
return std::vector<c10::SymInt>(size1, c10::SymInt(THPUtils_unpackIndex(args[i])));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to replicate this logic for a solitary symint arg as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, a good catch, thank you!

// we need to clear SymIntTable until we have python
// otherwise python classes are already deregistered

//c10::getSymIntTable().clear();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is dead now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a dinosaur

@@ -14,6 +14,11 @@ std::shared_ptr<SymbolicIntNode> SymIntTable::getNode(size_t index) {
return nodes_[index];
}

void SymIntTable::clear() {
std::lock_guard<std::mutex> lock(mutex_);
nodes_.clear();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is dead now

@@ -792,6 +792,13 @@ void TensorImpl::ShareExternalPointer(
}
}

void TensorImpl::set_sym_sizes_and_strides(c10::SymIntArrayRef sizes, c10::SymIntArrayRef strides) {
has_symbolic_sizes_strides_ = true;
sizes_strides_policy_ = static_cast<uint8_t>(SizesStridesPolicy::CustomSizes);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to me like CustomSymSizes isn't actually being used!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are setting the CustomSizes policy for python tensors (i.e. made via make_wrapper_class) so calls to sizes() would throw for those. Unfortunately, it means that sym_sizes() also throws. We actually would like to just run the default implementation in this case hence CustomSymSizes which is indeed overridden by LTC. I'm open to how we can make this cleaner.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@suo any suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the easiest thing is to just call into python if python key is set and you have a custom sizes policy.

@@ -622,9 +622,6 @@ static PyObject* THPVariable_make_wrapper_subclass(PyObject*, PyObject* args, Py
if (r.toBool(10)) {
data.unsafeGetTensorImpl()->set_sizes_strides_policy(c10::TensorImpl::SizesStridesPolicy::CustomStrides);
}
if (r.toBool(11)) {
data.unsafeGetTensorImpl()->set_custom_device(true);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh crap, a bad merge :(

// NB: pin_memory doesn't actually do anything
// TODO: strides variant?
static PythonArgParser parser({
"_make_wrapper_subclass(PyObject* cls, SymIntArrayRef size, SymIntArrayRef strides, int64_t? storage_offset=None, *, MemoryFormat? memory_format=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about keeping only one _make_wrapper_subclass and just having a second overload for PythonArgParser? Having it as an overload should also help reduce duplication in this variant of the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mkkk... it's already a pretty branch and long function

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It being long is a good reason not to copy paste right ;)

pyobj_ = std::make_shared<c10::SafePyObject>(pyobj.release().ptr(), getPyInterpreter());
};

virtual SymbolicIntNode* wrap(int64_t num) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remind me again why we are doing raw pointer memory management here


virtual bool bool_() {
py::gil_scoped_acquire acquire;
return py::str(getPyObj().attr("__bool__")()).is(py::str(Py_True));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you doing a string comparison to test what the bool result is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lemme try simplifying it a bit. This what SO recommends to do, but I do agree it's convoluted. There's no py::cast to bool but maybe we don't have to do py::str on both sides.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When taking questionable advice from stack overflow I highly recommend leaving a link to the URL of the question

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleaned up this part. Now it should make more sense.

virtual SymbolicIntNode* dispatch_common_(const char* fname, SymbolicIntNode* other) {
auto pother = dynamic_cast<PythonSymbolicIntNode*>(other);
TORCH_CHECK(pother);
auto magic_fname = std::string("__") + fname + std::string("__");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd much rather you had taken the magic_fname as argument lol. With a macro you could paste the __ together with a string constant without having to do a string concat every function call (which is wasteful)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha sorry, I was going to do but didn't do it before your review.

.def_static("isinstance", [](py::object obj, bool convert) -> bool {
return pybind11::detail::type_caster<std::shared_ptr<c10::SymbolicIntNode>>().load(obj, convert);
//return false;
})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is pretty weird

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dead code. sorry..

if (torch::is_symint_node(b)) {
return std::shared_ptr<c10::SymbolicIntNode>(a->add(b.cast<c10::SymbolicIntNode*>()));
} else {
return std::shared_ptr<c10::SymbolicIntNode> (a->add(a->wrap(b.cast<int64_t>())));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This here feels like a helper function would help a bit here. But it's also not entirely clear you want to wrap integers into symbolic int nodes (that denote plain integers); it seems like it would be more user friendly if these showed up at dispatch site as plain integers. It might be a bit easier here to make the add method accept an IValue instead of a SymbolicIntNode, so you can pass in either an int or symbolic int without needing to unconditionally accept a SymbolicIntNode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks way nicer rn.
I don't want to introduce a dependency on IValue into SymbolicIntNode :( . It seems more complex architecturally and possibly less user friendly since both LTC and AOTAutograd will need to parse IValues explicitly. Both LTC and AOTAutograd already wrap ints into sympy.Integer or prim::Constant.

@Krovatkin Krovatkin force-pushed the krovatkin/pybind_symint branch from b539c7d to 4e3fa42 Compare June 13, 2022 19:06
@Krovatkin
Copy link
Contributor Author

@pytorchbot merge this

@pytorch-bot
Copy link
pytorch-bot bot commented Jun 14, 2022

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: unrecognized arguments: this

usage: @pytorchbot [-h] {merge,revert,rebase} ...

Try @pytorchbot help for more info.

@ezyang
Copy link
Contributor
ezyang commented Jun 14, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Contributor

Hey @Krovatkin.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) {
return sym_sizes_custom();
}
virtual c10::SymIntArrayRef sym_sizes() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want this to be virtual now, instead of doing the policy thing that we do with all of our other de-virtualized methods?

(I need to fix it up for functionalization, which will be pretty easy - just curious on the reasoning)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And do we even need a sym_sizes_custom() anymore if sym_sizes() is virtual?

Copy link
Contributor Author
@Krovatkin Krovatkin Jun 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh python tensor subclasses and LTC want to do different things when policy is set to CustomSizes and there's no easy way to implement both via _custom so we had to make sym_sizes() virtual for now

@datumbox
Copy link
Contributor

We suspect that this PR broke TorchVision's tests. Information available here: pytorch/vision#6166 (comment)

@ezyang
Copy link
Contributor
ezyang commented Jun 15, 2022

@pytorchbot revert -m "broke torchvision tests" -c weird

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here

pytorchmergebot added a commit that referenced this pull request Jun 15, 2022
This reverts commit d332724.

Reverted #78135 on behalf of https://github.com/ezyang due to broke torchvision tests
ezyang added a commit that referenced this pull request Jun 15, 2022
This reverts commit b8db0a0.

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Jun 15, 2022
ezyang added a commit that referenced this pull request Jun 15, 2022
This reverts commit b8db0a0.

ghstack-source-id: 602ffd6
Pull Request resolved: #79608
Krovatkin pushed a commit to Krovatkin/pytorch that referenced this pull request Jun 15, 2022
This reverts commit b8db0a0.

ghstack-source-id: 602ffd6
Pull Request resolved: pytorch#79608
wconstab pushed a commit that referenced this pull request Jun 17, 2022
This reverts commit b8db0a0.

ghstack-source-id: 602ffd6
Pull Request resolved: #79608
@github-actions github-actions bot deleted the krovatkin/pybind_symint branch February 17, 2024 01:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged oncall: jit Add this issue/PR to JIT oncall triage queue Reverted
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants
0