-
Notifications
You must be signed in to change notification settings - Fork 24.4k
Python Bindings for SymInts #78135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Python Bindings for SymInts #78135
Changes from all commits
df516fc
e3b0b07
cee36ef
b4f0bdb
177cb91
57293f2
9acb95e
42df407
ef2bb04
8d61f11
8c59c65
b6f75a7
fac7500
4e3fa42
50810d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,4 +25,5 @@ SymIntTable& getSymIntTable() { | |
static SymIntTable sit; | ||
return sit; | ||
} | ||
|
||
} // namespace c10 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -552,12 +552,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |
return sizes_default(); | ||
} | ||
|
||
c10::SymIntArrayRef sym_sizes() const { | ||
if (C10_UNLIKELY( | ||
sizes_strides_policy_ >= | ||
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) { | ||
return sym_sizes_custom(); | ||
} | ||
virtual c10::SymIntArrayRef sym_sizes() const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we want this to be virtual now, instead of doing the policy thing that we do with all of our other de-virtualized methods? (I need to fix it up for functionalization, which will be pretty easy - just curious on the reasoning) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And do we even need a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bdhirsh python tensor subclasses and LTC want to do different things when policy is set to |
||
return sym_sizes_default(); | ||
} | ||
|
||
|
@@ -1311,6 +1306,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |
return numel() == 0; | ||
} | ||
|
||
// if we are going to use sym sizes, we should be setting sym strides at the | ||
// same time, otherwise it's very easy to misuse this API | ||
void set_sym_sizes_and_strides( | ||
c10::SymIntArrayRef sizes, | ||
c10::SymIntArrayRef strides); | ||
|
||
/** | ||
* Change the size at some dimension. This DOES NOT update strides; | ||
* thus, most changes to size will not preserve contiguity. You probably | ||
|
@@ -2325,7 +2326,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |
// Customizable sizes behavior, e.g., nested tensor | ||
// | ||
// Can override: strides(), is_contiguous(), sizes(), dim(), numel() | ||
CustomSizes = 2, | ||
CustomSizes = 2 | ||
}; | ||
|
||
void set_sizes_strides_policy(SizesStridesPolicy policy) { | ||
|
@@ -2336,6 +2337,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |
custom_device_ = custom_device; | ||
} | ||
|
||
protected: | ||
Storage storage_; | ||
|
||
private: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,4 @@ setuptools | |
six | ||
types-dataclasses | ||
typing_extensions | ||
sympy |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,15 +104,20 @@ def testAddSubFallback(self): | |
def testBatchNorm(self): | ||
device = get_test_device() | ||
x = torch.randn(16, 3, 224, 224, device=device) | ||
bn = torch.nn.BatchNorm2d(3).to(device=device) | ||
weight = torch.randn(3, device=device) | ||
bias = torch.randn(3, device=device) | ||
|
||
for i in range(10): | ||
z = bn(x) | ||
# BatchNorm2d does extra checks on dimensions which SymInts don't support yet | ||
# so we call `torch.ops.aten.native_batch_norm` to bypass the checks. | ||
z, _, _ = torch.ops.aten.native_batch_norm(x, weight, bias, None, None, True, 0.1, 1e-5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see any use of symbolic ints here, how come we need to make this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BatchNorm python module does multiplication and assertion and LTC doesn't have this part implemented yet :( |
||
|
||
device = "lazy" | ||
x_lazy = x.detach().clone().to(device=device) | ||
bn = bn.to(device=device) | ||
weight_lazy = weight.detach().clone().to(device=device) | ||
bias_lazy = bias.detach().clone().to(device=device) | ||
for i in range(10): | ||
z_lazy = bn(x_lazy) | ||
z_lazy, _, _ = torch.ops.aten.native_batch_norm(x_lazy, weight_lazy, bias_lazy, None, None, True, 0.1, 1e-5) | ||
torch._lazy.mark_step() | ||
|
||
torch.testing.assert_close(z.cpu(), z_lazy.cpu()) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another possibility is we could just assert that the policy is sufficiently high for symbolic sizes. This is OK because we will never store symbolic sizes in a stock tensor; we wouldn't be able to represent the data in question. It would have to be some sort of python sub class or some special subclass type like lazy tensor. Assuming this code works however, I wouldn't bother changing it for now.