-
Notifications
You must be signed in to change notification settings - Fork 24.5k
[jit] support tracing tensor __setitem__ with dynamic shape #45828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: fix pytorch#43548 Test Plan: buck test mode/dev-nosan //caffe2/test:jit -- 'test_trace_slice' --jobs 1 Differential Revision: D24106641 fbshipit-source-id: 240a6a69f44bccb277e21fd76406e31e48c8968b
This pull request was exported from Phabricator. Differential Revision: D24106641 |
@eellison could you help take a look? I am not familiar with this part of code. |
💊 CI failures summary and remediationsAs of commit 2719ba8 (more details on the Dr. CI page):
🕵️ 4 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
I can believe that this "helps", but the fix is still subtly wrong. You are still baking in static sizes, by fastpathing to I don't have enough information to assess if the hack is worth it. |
I currently have |
Yeah this is not perfect and I think a perfect solution for tracing to support dynamic sizes everywhere is very hard. However, the edge case that's left in this PR can be easily workaround: if one needs to do |
// A shortcut to avoid generating hard-coded constant sizes during tracing. | ||
dst.copy_(src); | ||
return; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might need to test whether this would hurt indexing perf for the fallback path. It would be awesome to collect numbers like the following (copy-pasted from #32841 (comment)):
Test | Old (us) | New (us) | % difference (negative: speedup) |
---|---|---|---|
v[0] | 1.68 | 1.69 | 0.60% |
v[0] = 1 | 7.33 | 7.04 | -3.96% |
v[:] | 1.67 | 1.69 | 1.20% |
v[:] = 1 | 7.58 | 6.98 | -7.92% |
v[...] | 1.29 | 1.29 | 0.00% |
v[...] = 1 | 5.6 | 5.33 | -4.82% |
v[None] | 1.67 | 1.68 | 0.60% |
v[None] = 1 | 7.39 | 7.32 | -0.95% |
v[False] | 9.37 | 9.03 | -3.63% |
v[False] = 1 | 1.15 | 1 | -13.04% |
v[True] | 12.3 | 11.7 | -4.88% |
v[True] = 1 | 7.52 | 7.22 | -3.99% |
v[0, 0] | 3.32 | 3.38 | 1.81% |
v[0, 0] = 1 | 9 | 8.83 | -1.89% |
a[b, None, ..., :, True] | 37.8 | 37.4 | -1.06% |
a[b, None, ..., :, True] = 1 | 39.4 | 39.2 | -0.51% |
to better understand the impact on performance (note: we might need to change the test cases to make sure we are executing the fallback path).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense. Which code produces this table?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is what I used:
#!/bin/bash
# Declare an array of string with type
declare -a StmtArray=(
"python -m timeit -s 'import torch; v = torch.randn(1,1,1);' 'v[0]'"
"python -m timeit -s 'import torch; v = torch.randn(1,1,1);' 'v[0] = 1'"
"python -m timeit -s 'import torch; v = torch.randn(1,1,1);' 'v[0, 0]'"
"python -m timeit -s 'import torch; v = torch.randn(1,1,1);' 'v[0, 0] = 1'"
"python -m timeit -s 'import torch; v = torch.randn(1,1,1);' 'v[...]'"
"python -m timeit -s 'import torch; v = torch.randn(1,1,1);' 'v[...] = 1'"
"python -m timeit -s 'import torch; v = torch.randn(1,1,1);' 'v[:]'"
"python -m timeit -s 'import torch; v = torch.randn(1,1,1);' 'v[:] = 1'"
"python -m timeit -s 'import torch; v = torch.randn(1,1,1);' 'v[None]'"
"python -m timeit -s 'import torch; v = torch.randn(1,1,1);' 'v[None] = 1'"
"python -m timeit -s 'import torch; v = torch.randn(1,1,1);' 'v[False]'"
"python -m timeit -s 'import torch; v = torch.randn(1,1,1);' 'v[False] = 1'"
"python -m timeit -s 'import torch; v = torch.randn(1,1,1);' 'v[True]'"
"python -m timeit -s 'import torch; v = torch.randn(1,1,1);' 'v[True] = 1'"
"python -m timeit -s 'import torch; a = torch.zeros(100, 100, 1, 1, 1); b = torch.arange(99, -1, -1).long(); ' 'a[b, None, ..., :, True]'"
"python -m timeit -s 'import torch; a = torch.zeros(100, 100, 1, 1, 1); b = torch.arange(99, -1, -1).long(); ' 'a[b, None, ..., :, True] = 1'"
)
for (( i = 0; i < ${#StmtArray[@]} ; i++ )); do
printf "\n**** Running: ${StmtArray[$i]} *****\n\n"
for (( j = 0 ; j < 20 ; j++ )); do
eval "${StmtArray[$i]}"
done
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way to do it nowadays is something like:
from torch.utils.benchmark import Timer, Compare
SCALAR_SETUP = "v = torch.randn(1,1,1)"
VECTOR_SETUP = "a = torch.zeros(100, 100, 1, 1, 1); b = torch.arange(99, -1, -1).long()"
TASKS = (
(SCALAR_SETUP, 'v[0]'),
(SCALAR_SETUP, 'v[0] = 1'),
(SCALAR_SETUP, 'v[0, 0]'),
(SCALAR_SETUP, 'v[0, 0] = 1'),
(SCALAR_SETUP, 'v[...]'),
(SCALAR_SETUP, 'v[...] = 1'),
(SCALAR_SETUP, 'v[:]'),
(SCALAR_SETUP, 'v[:] = 1'),
(SCALAR_SETUP, 'v[None]'),
(SCALAR_SETUP, 'v[None] = 1'),
(SCALAR_SETUP, 'v[False]'),
(SCALAR_SETUP, 'v[False] = 1'),
(SCALAR_SETUP, 'v[True]'),
(SCALAR_SETUP, 'v[True] = 1'),
(VECTOR_SETUP, 'a[b, None, ..., :, True]'),
(VECTOR_SETUP, 'a[b, None, ..., :, True] = 1'),
)
results = []
for setup, stmt in TASKS:
timer = Timer(
stmt,
setup=setup,
description="HEAD",
)
results.append(timer.blocked_autorange(min_run_time=1))
cmp = Compare(results)
# cmp.trim_significant_figures()
cmp.print()
[-------------------- -------------------]
| HEAD
1 threads: --------------------------------
v[0] | 1.7
v[0] = 1 | 7.2
v[0, 0] | 3.3
v[0, 0] = 1 | 9.5
v[...] | 1.3
v[...] = 1 | 6.0
v[:] | 1.6
v[:] = 1 | 6.7
v[None] | 1.6
v[None] = 1 | 7.0
v[False] | 8.3
v[False] = 1 | 1.1
v[True] | 11.3
v[True] = 1 | 7.0
a[b, None, ..., :, True] | 33.2
a[b, None, ..., :, True] = 1 | 35.1
Times are in microseconds (us).
trim_significant_figures
is currently a bit conservative with really fast ops so I didn't use it here, but in the future it will be recommended.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I added one more setindex testcase with same shape. These are the results after running both old and new code twice.
$cat old*.txt
[--------------------- ---------------------]
| HEAD
1 threads: -----------------------------------
v[0] = 1 | 5830.6
v[0, 0] = 1 | 6796.2
v[0, 0, 0] = 1 | 6797.3
v[...] = 1 | 4329.7
v[:] = 1 | 5482.9
v[None] = 1 | 5664.0
v[False] = 1 | 964.3
v[True] = 1 | 5707.1
a[b, None, ..., :, True] = 1 | 30175.3
Times are in nanoseconds (ns).
[--------------------- ---------------------]
| HEAD
1 threads: -----------------------------------
v[0] = 1 | 5890.9
v[0, 0] = 1 | 7040.5
v[0, 0, 0] = 1 | 7034.4
v[...] = 1 | 4518.2
v[:] = 1 | 5635.5
v[None] = 1 | 5664.1
v[False] = 1 | 937.6
v[True] = 1 | 5705.7
a[b, None, ..., :, True] = 1 | 29280.1
Times are in nanoseconds (ns).
$cat new*.txt
[--------------------- ---------------------]
| HEAD
1 threads: -----------------------------------
v[0] = 1 | 5597.9
v[0, 0] = 1 | 6769.5
v[0, 0, 0] = 1 | 5938.4
v[...] = 1 | 4368.9
v[:] = 1 | 5516.4
v[None] = 1 | 5608.9
v[False] = 1 | 978.5
v[True] = 1 | 5643.1
a[b, None, ..., :, True] = 1 | 32076.3
Times are in nanoseconds (ns).
[--------------------- ---------------------]
| HEAD
1 threads: -----------------------------------
v[0] = 1 | 5842.8
v[0, 0] = 1 | 6934.9
v[0, 0, 0] = 1 | 5923.4
v[...] = 1 | 4381.3
v[:] = 1 | 5548.5
v[None] = 1 | 5705.6
v[False] = 1 | 972.3
v[True] = 1 | 5658.0
a[b, None, ..., :, True] = 1 | 30968.8
Times are in nanoseconds (ns).
The new test case v[0, 0, 0] = 1
is considerably faster. Hard to obtain signals from the others as they seem to be within noise range. Maybe the last one gets slower? I can test more if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@robieta To answer your question further above, this benchmark is the one that I'd expect to see added.
I propose some verbose mode for tracing that lets uncover some potentially incorrect traces (or a warning) wrt indexing. It's better than a silent recording of static shapes and then failure at runtime and implicitly divergent behavior. It would also be cool to allow adding some "shape hints" to help shape inference, that would e.g. guarantee dimensionality, some static shapes or equality of a dim of one tensor to another (or to some known "dim symbol"). |
OK, this is sufficient justification to me. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you bulk up the comment on the shortcut to include the discussion that we had here, including the suggested workaround?
@vadimkantorov's comment seems relevant in the long term, maybe file a follow up issue about this too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ppwwyyxx has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@vadimkantorov do you want to open a follow up issue and elaborate more about the ideas? |
Are we still developing tracing functionality? I thought we would be moving away from tracing to instead rely on scripting the model. |
Currently scripting works really badly for export to ONNX (I filed a few
issues about this)...--
Vadim Kantorov
+33 6 03 29 27 69
|
@ppwwyyxx @ezyang I had two proposals:
|
@vadimkantorov wanna stick these in some fresh issues? |
#40373 covers it, doesn't it? There's no substantial development, but motivation is well-discussed there, no? |
Maybe? That issue as a whole is not so likely to get resolved in the near future, but more focused fixes for tracing might be easier to tackle. |
I added an UPD there and will rename the issue as well! |
Okay, I'll make a new issue as well |
Summary: fix #43548
Test Plan: buck test mode/dev-nosan //caffe2/test:jit -- 'test_trace_slice' --jobs 1
Differential Revision: D24106641