-
Notifications
You must be signed in to change notification settings - Fork 24.1k
segfault in python multithreaded setting #1868
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
also, he mentions:
|
How do you know it's because of the engine? The stack trace points to Python interpreter shutdown |
see his comment:
Of course, I'll try to repro the same. |
Yeah I know but the script doesn't really show what was guarded by the net_lock |
What I locked was the forward and backward call of the module |
Not sure it's the same issue. I am experiencing segfault with multithreading as well. import torch
import torch.functional as f
from concurrent.futures import ThreadPoolExecutor as ThreadPool
def build(cuda=False):
nn = torch.nn.Sequential(
torch.nn.Linear(1024, 1024),
torch.nn.Linear(1024, 1)
)
return nn.cuda() if cuda else nn
def train(nn, X, y, epoch=100):
X = torch.autograd.Variable(X)
y = torch.autograd.Variable(y)
optim = torch.optim.SGD(nn.parameters(), lr=0.1)
for i in range(epoch):
yhat = nn(X)
loss = ((yhat - y) ** 2).mean()
loss.backward()
optim.step()
def data(cuda=False):
X = torch.rand(10, 1024)
y = torch.rand((10, 1))
return (X.cuda(), y.cuda()) if cuda else (X, y)
def cpu_run(i=None):
nn = build(cuda=False)
d = data(cuda=False)
train(nn, *d)
def thread_cpu_run():
pool = ThreadPool()
threads = pool.map(cpu_run, list(range(5)))
return list(threads)
thread_cpu_run()
env
|
I got a similar segmentation fault problem with Louis-Tian due to random number generation.
If I add a line of code setting a manual seed just before the random function that I am using, |
Same here, I also find that torch.random generates huge number in multi-gpu setting. So I have to replace it with numpy.random. I also face the segfault when using multi-gpu. Here is the log:
|
I have the exact same issue. For my case, the calls to torch.Tensor(var.size()).normal_() and torch.Tensor(var.size()).bernoulli_() in the function of the threads were causing the problem. Once I took them out, the segfaults stopped appearing. |
My code had almost the same structure as @Louis-Tian's example, and I was able to get around it by putting a lock where I instantiate my module in each thread. Working code below (pay attention to the lock) import torch
import threading
import torch.functional as f
from concurrent.futures import ThreadPoolExecutor as ThreadPool
def build(cuda=False):
nn = torch.nn.Sequential(
torch.nn.Linear(1024, 1024),
torch.nn.Linear(1024, 1)
)
return nn.cuda() if cuda else nn
def train(nn, X, y, epoch=100):
X = torch.autograd.Variable(X)
y = torch.autograd.Variable(y)
optim = torch.optim.SGD(nn.parameters(), lr=0.1)
for i in range(epoch):
yhat = nn(X)
loss = ((yhat - y) ** 2).mean()
loss.backward()
optim.step()
def data(cuda=False):
X = torch.zeros(10, 1024)
y = torch.zeros((10, 1))
return (X.cuda(), y.cuda()) if cuda else (X, y)
def cpu_run(lock):
with lock:
nn = build(cuda=False)
d = data(cuda=False)
train(nn, *d)
def thread_cpu_run():
pool = ThreadPool()
lock = threading.Lock()
threads = pool.map(cpu_run, [lock for _ in range(5)])
return list(threads)
thread_cpu_run() |
Summary: When we added `randperm_cpu` and `THTensor_(randperm)` we forgot to lock the `THGenerator` mutex before calling `THRandom_random`, which causes segfault error mentioned in facebookresearch/maskrcnn-benchmark#93 (comment). This PR fixes the bug. Closes pytorch/pytorch#1868. Pull Request resolved: pytorch/pytorch#13832 Differential Revision: D13025453 Pulled By: yf225 fbshipit-source-id: 6e363a35c72b4862412eaea6516a154126634c9d
…86aef9 (pytorch#18248) Summary: Pull Request resolved: pytorch#18248 Previous import was 96c58ceeacf0f2b73d752e413e4fd78787a12da3 Included changes: - **[f6f80657](onnx/onnx@f6f80657)**: Skip the schema check on ops in non-standard domain (pytorch#1876) <Lu Fang> - **[8c8be722](onnx/onnx@8c8be722)**: Introduce Function Body Helper (pytorch#1868) <Sherlock> - **[b605eafb](onnx/onnx@b605eafb)**: Support down sampling for Upsample with scales < 1. (pytorch#1773) <Ke Zhang> - **[47f7aa71](onnx/onnx@47f7aa71)**: Remove scaledtanh (pytorch#1866) <Ashwini Khade> - **[4dfc56de](onnx/onnx@4dfc56de)**: Add Ceil support for Max and Average Pooling (pytorch#1860) <Lara Haidar> - **[552a8efc](onnx/onnx@552a8efc)**: Add testcase generator for functions (pytorch#1862) <Raymond Yang> - **[fdb978a5](onnx/onnx@fdb978a5)**: Promote Thresholded Relu Op (pytorch#1856) <Ashwini Khade> - **[ce332628](onnx/onnx@ce332628)**: Update Slice with dynamic input & optional input steps (pytorch#1836) <Bowen Bao> - **[3a9a8787](onnx/onnx@3a9a8787)**: Merge function into opschema (pytorch#1834) <Raymond Yang> - **[3dbf8fe9](onnx/onnx@3dbf8fe9)**: Handle string comparision represented as np.objects (pytorch#1851) <Dmitri Smirnov> - **[3b0d3bb2](onnx/onnx@3b0d3bb2)**: remove global variable in header file (pytorch#1850) <Lu Fang> - **[1cca8733](onnx/onnx@1cca8733)**: bump the version for drop out - fix the issue that the version was not bumped when changing its type constraint declaration. (pytorch#1848) <Ke Zhang> - **[1ec81bc6](onnx/onnx@1ec81bc6)**: Change TopK operator to allow dynamic 'k' (pytorch#1829) <Hariharan Seshadri> - **[a89a4a16](onnx/onnx@a89a4a16)**: Remove exp op: Affine, ImageScaler,ParametricSoftplus, Crop. (pytorch#1832) <Ke Zhang> Differential Revision: D14549289 fbshipit-source-id: 1222721e9766d30d559ad7a5fba6ba0a6afd6344
…e0ea6c (pytorch#18285) Summary: Pull Request resolved: pytorch#18285 Previous import was 96c58ceeacf0f2b73d752e413e4fd78787a12da3 Included changes: - **[c05f2ae4](onnx/onnx@c05f2ae4)**: update both core and ml docs (pytorch#1879) <Lu Fang> - **[f895279b](onnx/onnx@f895279b)**: fix the problems introduced in previous PRs in operator registration (pytorch#1878) <Lu Fang> - **[f6f80657](onnx/onnx@f6f80657)**: Skip the schema check on ops in non-standard domain (pytorch#1876) <Lu Fang> - **[8c8be722](onnx/onnx@8c8be722)**: Introduce Function Body Helper (pytorch#1868) <Sherlock> - **[b605eafb](onnx/onnx@b605eafb)**: Support down sampling for Upsample with scales < 1. (pytorch#1773) <Ke Zhang> - **[47f7aa71](onnx/onnx@47f7aa71)**: Remove scaledtanh (pytorch#1866) <Ashwini Khade> - **[4dfc56de](onnx/onnx@4dfc56de)**: Add Ceil support for Max and Average Pooling (pytorch#1860) <Lara Haidar> - **[552a8efc](onnx/onnx@552a8efc)**: Add testcase generator for functions (pytorch#1862) <Raymond Yang> - **[fdb978a5](onnx/onnx@fdb978a5)**: Promote Thresholded Relu Op (pytorch#1856) <Ashwini Khade> - **[ce332628](onnx/onnx@ce332628)**: Update Slice with dynamic input & optional input steps (pytorch#1836) <Bowen Bao> - **[3a9a8787](onnx/onnx@3a9a8787)**: Merge function into opschema (pytorch#1834) <Raymond Yang> - **[3dbf8fe9](onnx/onnx@3dbf8fe9)**: Handle string comparision represented as np.objects (pytorch#1851) <Dmitri Smirnov> - **[3b0d3bb2](onnx/onnx@3b0d3bb2)**: remove global variable in header file (pytorch#1850) <Lu Fang> - **[1cca8733](onnx/onnx@1cca8733)**: bump the version for drop out - fix the issue that the version was not bumped when changing its type constraint declaration. (pytorch#1848) <Ke Zhang> - **[1ec81bc6](onnx/onnx@1ec81bc6)**: Change TopK operator to allow dynamic 'k' (pytorch#1829) <Hariharan Seshadri> - **[a89a4a16](onnx/onnx@a89a4a16)**: Remove exp op: Affine, ImageScaler,ParametricSoftplus, Crop. (pytorch#1832) <Ke Zhang> Differential Revision: D14566202 fbshipit-source-id: 3deb51c17eb9ebd6f6efc331d9110acb3462ece5
…e0ea6c (#18285) Summary: Pull Request resolved: #18285 Previous import was 96c58ceeacf0f2b73d752e413e 6DAF 4fd78787a12da3 Included changes: - **[c05f2ae4](onnx/onnx@c05f2ae4)**: update both core and ml docs (#1879) <Lu Fang> - **[f895279b](onnx/onnx@f895279b)**: fix the problems introduced in previous PRs in operator registration (#1878) <Lu Fang> - **[f6f80657](onnx/onnx@f6f80657)**: Skip the schema check on ops in non-standard domain (#1876) <Lu Fang> - **[8c8be722](onnx/onnx@8c8be722)**: Introduce Function Body Helper (#1868) <Sherlock> - **[b605eafb](onnx/onnx@b605eafb)**: Support down sampling for Upsample with scales < 1. (#1773) <Ke Zhang> - **[47f7aa71](onnx/onnx@47f7aa71)**: Remove scaledtanh (#1866) <Ashwini Khade> - **[4dfc56de](onnx/onnx@4dfc56de)**: Add Ceil support for Max and Average Pooling (#1860) <Lara Haidar> - **[552a8efc](onnx/onnx@552a8efc)**: Add testcase generator for functions (#1862) <Raymond Yang> - **[fdb978a5](onnx/onnx@fdb978a5)**: Promote Thresholded Relu Op (#1856) <Ashwini Khade> - **[ce332628](onnx/onnx@ce332628)**: Update Slice with dynamic input & optional input steps (#1836) <Bowen Bao> - **[3a9a8787](onnx/onnx@3a9a8787)**: Merge function into opschema (#1834) <Raymond Yang> - **[3dbf8fe9](onnx/onnx@3dbf8fe9)**: Handle string comparision represented as np.objects (#1851) <Dmitri Smirnov> - **[3b0d3bb2](onnx/onnx@3b0d3bb2)**: remove global variable in header file (#1850) <Lu Fang> - **[1cca8733](onnx/onnx@1cca8733)**: bump the version for drop out - fix the issue that the version was not bumped when changing its type constraint declaration. (#1848) <Ke Zhang> - **[1ec81bc6](onnx/onnx@1ec81bc6)**: Change TopK operator to allow dynamic 'k' (#1829) <Hariharan Seshadri> - **[a89a4a16](onnx/onnx@a89a4a16)**: Remove exp op: Affine, ImageScaler,ParametricSoftplus, Crop. (#1832) <Ke Zhang> Reviewed By: yinghai Differential Revision: D14566202 fbshipit-source-id: b1e5912ae6887e2865fc628363071e2b9938dfa4
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. removes un-necessary sync from redundant thread compute analysis 2. symmetric API for BestEffortReplay 3. support merge on trivial reductions 4. Ampere async copy improvements - bug fixes: 1. vectorization bug fixes 2. type inference patch : fixes upstream #81725 3. segmenter bug fix with deterministic iteration ordering - parser update 1. added leaky_relu - scheduler 1. normalization scheduler clean up. 2. simplifies matmul scheduling with new transform propagator 3. merge all dimensions in PW scheduler 4. various gemm related improvements - debuggability 1. nsight compute support 2. debug dump for InlinePropagator 3. Add `UnaryOpType::Print` Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` dfe02f3 Merge remote-tracking branch 'csarofeen/devel' into HEAD 1617373 Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884) 7cfb779 Merge pull request #1887 from csarofeen/upstream_merge_0803 3399f6d Merge remote-tracking branch 'origin/viable/strict' into HEAD 01208f5 Add `UnaryOpType::Print` which can be helpful for debugging (#1878) 0646522 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881) 7bc76aa Fix most inlined propagator for mismatched dims (#1875) 501f4aa Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826) d863d69 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827) e0ae11a Larger sized mma instructions to support full vectorization (#1824) 9bb4cf7 fragment iteration to support fully unrolled mma ops (#1823) a48270a Merge all dims in pointwise scheduler (#1872) 172fb36 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868) a64462a Allow trivial reduction to be merged (#1871) 440102b Symmetric API for BestEffortReplay (#1870) d1caf33 Some misc cleanups/refactor split out from #1854 (#1867) 1013eda Remove some welford specific logic. (#1864) 51589d3 Some cleanups on tests and heuristics params (#1866) a6b3e70 Segmenter bug fix, and deterministic iteration ordering. (#1865) 1b665b9 Add nullptr checks to IrBuilder (#1861) 1cd9451 Simplify matmul scheduling with the new transform propagator. (#1817) bbc1fb9 Add leaky_relu operation (#1852) e842a9b Minor cleanup in pointwise scheduler (#1858) 9ee850c Fix stringstream usage (#1857) 20a36c1 Improve nsight compute support (#1855) 4059103 Remove debugging `true ||` from getPointwiseHeuristics (#1822) 01117bf Misc cleanup (#1853) 5cc6494 Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846) 92e6f02 Cleanup normalization scheduler (#1845) db89c65 Type inference patch (#1848) 102fe93 Add debug dump for InlinePropagator (#1847) b7a4d93 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687) 942be5b Upstream ci build fixes (#1842) 0b83645 Fix vectorization bug introduced in #1831 (#1840) 63630f1 Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825) 9135a96 Fix transpose benchmark dtype (#1839) 2c9a6c0 Add extra configurability to `parallelizeAllLike` (#1831) ``` RUN_TORCHBENCH: nvfuser ghstack-source-id: 3745722 Pull Request resolved: #83067
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. removes un-necessary sync from redundant thread compute analysis 2. symmetric API for BestEffortReplay 3. support merge on trivial reductions 4. Ampere async copy improvements - bug fixes: 1. vectorization bug fixes 2. type inference patch : fixes upstream #81725 3. segmenter bug fix with deterministic iteration ordering - parser update 1. added leaky_relu - scheduler 1. normalization scheduler clean up. 2. simplifies matmul scheduling with new transform propagator 3. merge all dimensions in PW scheduler 4. various gemm related improvements - debuggability 1. nsight compute support 2. debug dump for InlinePropagator 3. Add `UnaryOpType::Print` Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` dfe02f3 Merge remote-tracking branch 'csarofeen/devel' into HEAD 1617373 Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884) 7cfb779 Merge pull request #1887 from csarofeen/upstream_merge_0803 3399f6d Merge remote-tracking branch 'origin/viable/strict' into HEAD 01208f5 Add `UnaryOpType::Print` which can be helpful for debugging (#1878) 0646522 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881) 7bc76aa Fix most inlined propagator for mismatched dims (#1875) 501f4aa Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826) d863d69 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827) e0ae11a Larger sized mma instructions to support full vectorization (#1824) 9bb4cf7 fragment iteration to support fully unrolled mma ops (#1823) a48270a Merge all dims in pointwise scheduler (#1872) 172fb36 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868) a64462a Allow trivial reduction to be merged (#1871) 440102b Symmetric API for BestEffortReplay (#1870) d1caf33 Some misc cleanups/refactor split out from #1854 (#1867) 1013eda Remove some welford specific logic. (#1864) 51589d3 Some cleanups on tests and heuristics params (#1866) a6b3e70 Segmenter bug fix, and deterministic iteration ordering. (#1865) 1b665b9 Add nullptr checks to IrBuilder (#1861) 1cd9451 Simplify matmul scheduling with the new transform propagator. (#1817) bbc1fb9 Add leaky_relu operation (#1852) e842a9b Minor cleanup in pointwise scheduler (#1858) 9ee850c Fix stringstream usage (#1857) 20a36c1 Improve nsight compute support (#1855) 4059103 Remove debugging `true ||` from getPointwiseHeuristics (#1822) 01117bf Misc cleanup (#1853) 5cc6494 Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846) 92e6f02 Cleanup normalization scheduler (#1845) db89c65 Type inference patch (#1848) 102fe93 Add debug dump for InlinePropagator (#1847) b7a4d93 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687) 942be5b Upstream ci build fixes (#1842) 0b83645 Fix vectorization bug introduced in #1831 (#1840) 63630f1 Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825) 9135a96 Fix transpose benchmark dtype (#1839) 2c9a6c0 Add extra configurability to `parallelizeAllLike` (#1831) ``` RUN_TORCHBENCH: nvfuser [ghstack-poisoned]
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. removes un-necessary sync from redundant thread compute analysis 2. symmetric API for BestEffortReplay 3. support merge on trivial reductions 4. Ampere async copy improvements - bug fixes: 1. vectorization bug fixes 2. type inference patch : fixes upstream #81725 3. segmenter bug fix with deterministic iteration ordering - parser update 1. added leaky_relu - scheduler 1. normalization scheduler clean up. 2. simplifies matmul scheduling with new transform propagator 3. merge all dimensions in PW scheduler 4. various gemm related improvements - debuggability 1. nsight compute support 2. debug dump for InlinePropagator 3. Add `UnaryOpType::Print` Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` dfe02f3 Merge remote-tracking branch 'csarofeen/devel' into HEAD 1617373 Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884) 7cfb779 Merge pull request #1887 from csarofeen/upstream_merge_0803 3399f6d Merge remote-tracking branch 'origin/viable/strict' into HEAD 01208f5 Add `UnaryOpType::Print` which can be helpful for debugging (#1878) 0646522 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881) 7bc76aa Fix most inlined propagator for mismatched dims (#1875) 501f4aa Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826) d863d69 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827) e0ae11a Larger sized mma instructions to support full vectorization (#1824) 9bb4cf7 fragment iteration to support fully unrolled mma ops (#1823) a48270a Merge all dims in pointwise scheduler (#1872) 172fb36 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868) a64462a Allow trivial reduction to be merged (#1871) 440102b Symmetric API for BestEffortReplay (#1870) d1caf33 Some misc cleanups/refactor split out from #1854 (#1867) 1013eda Remove some welford specific logic. (#1864) 51589d3 Some cleanups on tests and heuristics params (#1866) a6b3e70 Segmenter bug fix, and deterministic iteration ordering. (#1865) 1b665b9 Add nullptr checks to IrBuilder (#1861) 1cd9451 Simplify matmul scheduling with the new transform propagator. (#1817) bbc1fb9 Add leaky_relu operation (#1852) e842a9b Minor cleanup in pointwise scheduler (#1858) 9ee850c Fix stringstream usage (#1857) 20a36c1 Improve nsight compute support (#1855) 4059103 Remove debugging `true ||` from getPointwiseHeuristics (#1822) 01117bf Misc cleanup (#1853) 5cc6494 Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846) 92e6f02 Cleanup normalization scheduler (#1845) db89c65 Type inference patch (#1848) 102fe93 Add debug dump for InlinePropagator (#1847) b7a4d93 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687) 942be5b Upstream ci build fixes (#1842) 0b83645 Fix vectorization bug introduced in #1831 (#1840) 63630f1 Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825) 9135a96 Fix transpose benchmark dtype (#1839) 2c9a6c0 Add extra configurability to `parallelizeAllLike` (#1831) ``` RUN_TORCHBENCH: nvfuser Differential Revision: [D38543000](https://our.internmc.facebook.com/intern/diff/D38543000) Pull Request resolved: #83067 Approved by: https://github.com/davidberard98
Summary: Pull Request resolved: #83067 Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/ Code changes includes: - codegen improvements: 1. removes un-necessary sync from redundant thread compute analysis 2. symmetric API for BestEffortReplay 3. support merge on trivial reductions 4. Ampere async copy improvements - bug fixes: 1. vectorization bug fixes 2. type inference patch : fixes upstream #81725 3. segmenter bug fix with deterministic iteration ordering - parser update 1. added leaky_relu - scheduler 1. normalization scheduler clean up. 2. simplifies matmul scheduling with new transform propagator 3. merge all dimensions in PW scheduler 4. various gemm related improvements - debuggability 1. nsight compute support 2. debug dump for InlinePropagator 3. Add `UnaryOpType::Print` Squashed commits to WAR github API Commits that's actually in this PR from the devel branch: ``` dfe02f3 Merge remote-tracking branch 'csarofeen/devel' into HEAD 1617373 Add `TensorViewBuilder::shape(std::vector<Val*> shape)` (#1884) 7cfb779 Merge pull request #1887 from csarofeen/upstream_merge_0803 3399f6d Merge remote-tracking branch 'origin/viable/strict' into HEAD 01208f5 Add `UnaryOpType::Print` which can be helpful for debugging (#1878) 0646522 Remove redundant TORCH_INTERNAL_ASSERT in lower_magic_zero.cpp (#1881) 7bc76aa Fix most inlined propagator for mismatched dims (#1875) 501f4aa Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826) d863d69 Ampere async copy ep.2: circular buffering extension to support pipelined matmul operand load (#1827) e0ae11a Larger sized mma instructions to support full vectorization (#1824) 9bb4cf7 fragment iteration to support fully unrolled mma ops (#1823) a48270a Merge all dims in pointwise scheduler (#1872) 172fb36 Make MostInlined and BestEffort inline propagation no longer assert replayed (#1868) a64462a Allow trivial reduction to be merged (#1871) 440102b Symmetric API for BestEffortReplay (#1870) d1caf33 Some misc cleanups/refactor split out from #1854 (#1867) 1013eda Remove some welford specific logic. (#1864) 51589d3 Some cleanups on tests and heuristics params (#1866) a6b3e70 Segmenter bug fix, and deterministic iteration ordering. (#1865) 1b665b9 Add nullptr checks to IrBuilder (#1861) 1cd9451 Simplify matmul scheduling with the new transform propagator. (#1817) bbc1fb9 Add leaky_relu operation (#1852) e842a9b Minor cleanup in pointwise scheduler (#1858) 9ee850c Fix stringstream usage (#1857) 20a36c1 Improve nsight compute support (#1855) 4059103 Remove debugging `true ||` from getPointwiseHeuristics (#1822) 01117bf Misc cleanup (#1853) 5cc6494 Apply the magic-zero protection to each indexed domain individually for predicate indexing (#1846) 92e6f02 Cleanup normalization scheduler (#1845) db89c65 Type inference patch (#1848) 102fe93 Add debug dump for InlinePropagator (#1847) b7a4d93 Redundant thread compute analysis to avoid un-necessary sync insertion (#1687) 942be5b Upstream ci build fixes (#1842) 0b83645 Fix vectorization bug introduced in #1831 (#1840) 63630f1 Move MaxProducerPosUpdater into InlinePropagator::tearDown (#1825) 9135a96 Fix transpose benchmark dtype (#1839) 2c9a6c0 Add extra configurability to `parallelizeAllLike` (#1831) ``` RUN_TORCHBENCH: nvfuser Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D38543000 Pulled By: davidberard98 fbshipit-source-id: 752edbfbced14fe01b84e417f23cc941b2148842
updating Apex commit id 6fc10c371d9ddae5268b2412365716c212eb51e8
Zihang Dai reports (and I've reproduced) that the autograd engine is not thread-safe.
Here's a repro script: https://gist.github.com/zihangdai/fc8f76fbb8a0f6323a6b31e6d98ceb50
Run it a few times, occassionally it segfaults.
Segfault is from a much different location, when cleaning up imports:
The text was updated successfully, but these errors were encountered: