-
Notifications
You must be signed in to change notification settings - Fork 24.4k
Python Bindings for SymInts #78135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Python Bindings for SymInts #78135
Conversation
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 50810d1 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
e43d362
to
8bdcc93
Compare
c10/core/SymbolicIntNode.h
Outdated
virtual SymbolicIntNode* wrap(int64_t num) { TORCH_CHECK(false, "NYI"); }; | ||
virtual bool bool_() { TORCH_CHECK(false, "NYI"); }; | ||
virtual int64_t int_() { TORCH_CHECK(false, "NYI"); } | ||
virtual std::string str() { TORCH_CHECK(false, "NYI"); }; | ||
virtual std::ostream& operator<<(std::ostream& os) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this still need to be virtual
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it does the LTC impl won't have a python object to redispatch or fallback to
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but they'll implement str instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OH, just realized what you meant ... operator<< does NOT to be virtual me thinks :)
c10/core/SymbolicIntNode.h
Outdated
virtual SymbolicIntNode* gt(SymbolicIntNode* other) { TORCH_CHECK(false, "NYI"); }; | ||
virtual SymbolicIntNode* lt(SymbolicIntNode* other) { TORCH_CHECK(false, "NYI"); }; | ||
virtual SymbolicIntNode* wrap(int64_t num) { TORCH_CHECK(false, "NYI"); }; | ||
virtual bool bool_() { TORCH_CHECK(false, "NYI"); }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, not sure why we need this on top of the int conversion. Is the problem that you also want a SymbolicBoolNode as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@horace overrode bool
in his PoC, so I added it as well. I can provide a default implementation static_cast<boo>(this->int_())
or just let implementers implement via int_
.
c10/core/TensorImpl.h
Outdated
@@ -1306,6 +1306,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |||
return numel() == 0; | |||
} | |||
|
|||
// if we are going to use sym sizes, we should be setting sym strides at the same time, | |||
// otherwise it's very easy to misuse this API | |||
virtual void set_sym_sizes_and_strides(c10::SymIntArrayRef sizes, c10::SymIntArrayRef strides); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not clear to me why this needs to be virtual
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, @suo mentioned he would help to formalize this API. I'll remove virtual
in the meantime.
c10/core/TensorImpl.h
Outdated
@@ -2330,6 +2335,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |||
void set_custom_device(bool custom_device) { | |||
custom_device_ = custom_device; | |||
} | |||
protected: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol thanks
c10/core/TensorImpl.h
Outdated
@@ -2321,6 +2325,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |||
// | |||
// Can override: strides(), is_contiguous(), sizes(), dim(), numel() | |||
CustomSizes = 2, | |||
CustomSymSizes = 3, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skeptical about this. I'll read through the uses first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
requirements.txt
Outdated
@@ -10,3 +10,5 @@ setuptools | |||
six | |||
types-dataclasses | |||
typing_extensions | |||
dataclasses; python_version<"3.7" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we're py3.7 and up only, so this really shouldn't be needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah this is a bad merge.. will fix..
torch/csrc/Size.cpp
Outdated
@@ -40,6 +45,29 @@ PyObject * THPSize_NewFromSizes(int dim, const int64_t *sizes) | |||
return self.release(); | |||
} | |||
|
|||
PyObject * THPSize_NewFromSymSizes(const at::Tensor& self_) | |||
{ | |||
HANDLE_TH_ERRORS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you shouldn't need this macro here; this function is not directly bound to python
torch/csrc/utils/python_arg_parser.h
Outdated
@@ -389,9 +390,52 @@ inline std::vector<int64_t> PythonArgs::intlist(int i) { | |||
return intlistWithDefault(i, signature.params[i].default_intlist); | |||
} | |||
|
|||
TORCH_API bool is_symint_node(py::handle obj); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just include the header?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah let me inline it back.
const auto size1 = signature.params[i].size; | ||
if (size1 > 0 && THPUtils_checkLong(args[i])) { | ||
return std::vector<c10::SymInt>(size1, c10::SymInt(THPUtils_unpackIndex(args[i]))); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to replicate this logic for a solitary symint arg as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, a good catch, thank you!
torch/csrc/jit/python/init.cpp
Outdated
// we need to clear SymIntTable until we have python | ||
// otherwise python classes are already deregistered | ||
|
||
//c10::getSymIntTable().clear(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is dead now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as a dinosaur
c10/core/SymIntTable.cpp
Outdated
@@ -14,6 +14,11 @@ std::shared_ptr<SymbolicIntNode> SymIntTable::getNode(size_t index) { | |||
return nodes_[index]; | |||
} | |||
|
|||
void SymIntTable::clear() { | |||
std::lock_guard<std::mutex> lock(mutex_); | |||
nodes_.clear(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is dead now
c10/core/TensorImpl.cpp
Outdated
@@ -792,6 +792,13 @@ void TensorImpl::ShareExternalPointer( | |||
} | |||
} | |||
|
|||
void TensorImpl::set_sym_sizes_and_strides(c10::SymIntArrayRef sizes, c10::SymIntArrayRef strides) { | |||
has_symbolic_sizes_strides_ = true; | |||
sizes_strides_policy_ = static_cast<uint8_t>(SizesStridesPolicy::CustomSizes); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to me like CustomSymSizes isn't actually being used!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are setting the CustomSizes
policy for python tensors (i.e. made via make_wrapper_class
) so calls to sizes()
would throw for those. Unfortunately, it means that sym_sizes()
also throws. We actually would like to just run the default implementation in this case hence CustomSymSizes
which is indeed overridden by LTC. I'm open to how we can make this cleaner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@suo any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the easiest thing is to just call into python if python key is set and you have a custom sizes policy.
@@ -622,9 +622,6 @@ static PyObject* THPVariable_make_wrapper_subclass(PyObject*, PyObject* args, Py | |||
if (r.toBool(10)) { | |||
data.unsafeGetTensorImpl()->set_sizes_strides_policy(c10::TensorImpl::SizesStridesPolicy::CustomStrides); | |||
} | |||
if (r.toBool(11)) { | |||
data.unsafeGetTensorImpl()->set_custom_device(true); | |||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh crap, a bad merge :(
// NB: pin_memory doesn't actually do anything | ||
// TODO: strides variant? | ||
static PythonArgParser parser({ | ||
"_make_wrapper_subclass(PyObject* cls, SymIntArrayRef size, SymIntArrayRef strides, int64_t? storage_offset=None, *, MemoryFormat? memory_format=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about keeping only one _make_wrapper_subclass
and just having a second overload for PythonArgParser? Having it as an overload should also help reduce duplication in this variant of the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mkkk... it's already a pretty branch and long function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It being long is a good reason not to copy paste right ;)
torch/csrc/jit/python/init.cpp
Outdated
pyobj_ = std::make_shared<c10::SafePyObject>(pyobj.release().ptr(), getPyInterpreter()); | ||
}; | ||
|
||
virtual SymbolicIntNode* wrap(int64_t num) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remind me again why we are doing raw pointer memory management here
torch/csrc/jit/python/init.cpp
Outdated
|
||
virtual bool bool_() { | ||
py::gil_scoped_acquire acquire; | ||
return py::str(getPyObj().attr("__bool__")()).is(py::str(Py_True)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are you doing a string comparison to test what the bool result is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lemme try simplifying it a bit. This what SO recommends to do, but I do agree it's convoluted. There's no py::cast to bool
but maybe we don't have to do py::str
on both sides.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When taking questionable advice from stack overflow I highly recommend leaving a link to the URL of the question
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleaned up this part. Now it should make more sense.
torch/csrc/jit/python/init.cpp
Outdated
virtual SymbolicIntNode* dispatch_common_(const char* fname, SymbolicIntNode* other) { | ||
auto pother = dynamic_cast<PythonSymbolicIntNode*>(other); | ||
TORCH_CHECK(pother); | ||
auto magic_fname = std::string("__") + fname + std::string("__"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd much rather you had taken the magic_fname as argument lol. With a macro you could paste the __ together with a string constant without having to do a string concat every function call (which is wasteful)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha sorry, I was going to do but didn't do it before your review.
torch/csrc/jit/python/init.cpp
Outdated
.def_static("isinstance", [](py::object obj, bool convert) -> bool { | ||
return pybind11::detail::type_caster<std::shared_ptr<c10::SymbolicIntNode>>().load(obj, convert); | ||
//return false; | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method is pretty weird
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dead code. sorry..
torch/csrc/jit/python/init.cpp
Outdated
if (torch::is_symint_node(b)) { | ||
return std::shared_ptr<c10::SymbolicIntNode>(a->add(b.cast<c10::SymbolicIntNode*>())); | ||
} else { | ||
return std::shared_ptr<c10::SymbolicIntNode> (a->add(a->wrap(b.cast<int64_t>()))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This here feels like a helper function would help a bit here. But it's also not entirely clear you want to wrap integers into symbolic int nodes (that denote plain integers); it seems like it would be more user friendly if these showed up at dispatch site as plain integers. It might be a bit easier here to make the add method accept an IValue instead of a SymbolicIntNode, so you can pass in either an int or symbolic int without needing to unconditionally accept a SymbolicIntNode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks way nicer rn.
I don't want to introduce a dependency on IValue
into SymbolicIntNode
:( . It seems more complex architecturally and possibly less user friendly since both LTC and AOTAutograd will need to parse IValue
s explicitly. Both LTC and AOTAutograd already wrap ints into sympy.Integer
or prim::Constant
.
b539c7d
to
4e3fa42
Compare
@pytorchbot merge this |
❌ 🤖 pytorchbot command failed:
Try |
@pytorchbot merge |
@pytorchbot successfully started a merge job. Check the current status here |
Hey @Krovatkin. |
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) { | ||
return sym_sizes_custom(); | ||
} | ||
virtual c10::SymIntArrayRef sym_sizes() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we want this to be virtual now, instead of doing the policy thing that we do with all of our other de-virtualized methods?
(I need to fix it up for functionalization, which will be pretty easy - just curious on the reasoning)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And do we even need a sym_sizes_custom()
anymore if sym_sizes()
is virtual?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh python tensor subclasses and LTC want to do different things when policy is set to CustomSizes
and there's no easy way to implement both via _custom
so we had to make sym_sizes()
virtual for now
We suspect that this PR broke TorchVision's tests. Information available here: pytorch/vision#6166 (comment) |
@pytorchbot revert -m "broke torchvision tests" -c weird |
@pytorchbot successfully started a revert job. Check the current status here |
This reverts commit d332724. Reverted #78135 on behalf of https://github.com/ezyang due to broke torchvision tests
This reverts commit b8db0a0. [ghstack-poisoned]
This reverts commit b8db0a0. [ghstack-poisoned]
This reverts commit b8db0a0. ghstack-source-id: 602ffd6 Pull Request resolved: pytorch#79608
This PR adds support for
SymInt
s in python. Namely,THPVariable_size
now returnssym_sizes()
SymbolicIntNode
sSymbolicIntNode
are added, so size expressions can be traced