8000 [LTC] Memory growing up during training until OOM · Issue #65640 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[LTC] Memory growing up during training unti 8000 l OOM #65640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
leslie-fang-intel opened this issue Sep 24, 2021 · 19 comments
Open

[LTC] Memory growing up during training until OOM #65640

leslie-fang-intel opened this issue Sep 24, 2021 · 19 comments
Labels
lazy Lazy Tensor work items module: lazy triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@leslie-fang-intel
Copy link
Collaborator

Hi, we are running Resnet50 FP32 training with lazy tensor core. And we find the memory will keep grow up in each iteration until OOM. To reproduce this issue, I have written a simple test case as below:

import torch
import torch.nn as nn
import copy
import time
import lazy_tensor_core
import lazy_tensor_core.core.lazy_model as ltm
import lazy_tensor_core.debug.metrics as met

lazy_tensor_core._LAZYC._ltc_init_ts_backend()
class SimpleNet(torch.nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv = torch.nn.Conv2d(64, 128, (3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        self.conv2 = torch.nn.Conv2d(128, 128, (3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        self.conv3 = torch.nn.Conv2d(64, 128, (1, 1), stride=(2, 2), padding=(0, 0), bias=False)
    def forward(self, x):
        x1 = self.conv(x)
        y1 = self.conv2(x1)
        y2 = self.conv3(x)
        y = y1 + y2
        y = torch.flatten(y, start_dim=1)
        return y
lazy_device = ltm.lazy_device()
model = SimpleNet()
model.train().to(lazy_device)

for i in range(10000):
    x = torch.rand(64, 64, 3, 3, requires_grad=True).to(lazy_device)
    y = model(x)
    yg = torch.rand(64, 512).to(lazy_device)
    loss = nn.CrossEntropyLoss()
    output = loss(y, yg)
    output.backward()
    print(output)
    ltm.mark_step()
    print(met.metrics_report())

From this test case, I also see the memory is growing up when running. Is there any thing I misunderstand when modify the training test case with lazy_tensor? BTW: I am using lazy_tensor_staging branch(7f3d592).

@leslie-fang-intel leslie-fang-intel added lazy Lazy Tensor work items module: lazy labels Sep 24, 2021
@wconstab
Copy link
Contributor

@leslie-fang-intel I threw this into a script and ran it, and I didn't see an obvious (or fast) increase in memory using htop. Can you say what amount of memory you are seeing used near the beginning of the loop and how fast it grows?

Also, I deleted the 2 prints and added another
if i % 100 == 0: print(f"Iter {i} complete")

So far I got to iter 2600 without seeing an issue.

@leslie-fang-intel
Copy link
Collaborator Author

With RN50 training, we can quickly reproduce the OOM issue with BS112. Since in this case, the model is quite small, we only see the memory keep growing up(similar as RN50 training). But it will not be OOM(maybe after a very long time).
Did you see your system memory usage keep growing up? Or should I provide the scripts of RN50 training?

@wconstab
Copy link
Contributor

I didn't see a noticeable increase in memory with this example. For Rn50 were you using torchbenchmark or something else?

Note that our convolution operator is broken currently, so don't read too much into convnet performance. (but that should be unrelated from a memory leak issue)

@leslie-fang-intel
Copy link
Collaborator Author

I am using htop to check the system mem usage information of this simple test case. And on my system:

  • At iteration 0, the total system's mem usage is 3.05GB.
  • At iteration 1000, the total system's mem usage is 10.6GB.

The RN50 is based on https://github.com/pytorch/examples/tree/master/imagenet. Let me check if I can public it.

@leslie-fang-intel
Copy link
Collaborator Author

Here is the RN50 lazy tensor benchmark, from which I can reproduce the memory leak issue.
https://github.com/leslie-fang-intel/examples/tree/lazy_tensor_memory_leak

cd imagenet && bash run.sh resnet50 DATA_PATH lazy

@gchanan gchanan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 27, 2021
@wconstab
Copy link
Contributor
wconstab commented Oct 1, 2021

Sorry for the delay, we'll try to debug this soon, haven't had time yet. I did confirm that I can see a very slow leak on the small model, but it's probably easier to debug it on a bigger one, or see if asan provides clues.

@leslie-fang-intel
Copy link
Collaborator Author

Hi @wconstab, do we have any clues about this issue😁?

@wconstab
Copy link
Contributor

We still haven't investigated. We're focused mostly on refactoring and adding coverage to our op codegen, so we can start testing/benchmarking across a wide set of models. But @Krovatkin is planning to do a run with ASAN soon and see if we find anything that way.

@leslie-fang-intel
Copy link
Collaborator Author

Got it. Thanks for the comments.

@wenzhe-nrv
Copy link
Collaborator

@Krovatkin just check if you had a chance to run with ASAN. Any update?
I still see out of memory issue when running hf_Bert with lazy_bench on cpu recently.

@Krovatkin
Copy link
Contributor

@wenzhe-nrv We fixed a few issues related to OOM, but we believe there's something else lurking in the depth of LTC. Could you try the latest LT w/ your example and use case. Note, we will be using very small amount memory caching but after 20 iterations, the memory usage should stabilize. I believe I was able to get the stable numbers for this example.

@wenzhe-nrv
Copy link
Collaborator

@Krovatkin just to check before running with the latest - I'm using this commit d059c4f from Jan 11 2022, which shows the OOM issue when running hf_Bert on Xeon. Does this commit contain the fixes you mention? If not, I will tested with the latest.

@wenzhe-nrv
Copy link
Collaborator

@Krovatkin I tried 8000 the latest LTC (97888c4) and I was still able to observe the constant memory increase with a simple MLP model running 10k iterations. I don't think it stabilized after 20 iteration. What would you recommend to debug on my side?

How did you build both pytorch and LTC with ASAN?

@wenzhe-nrv
Copy link
Collaborator

part of this memory leak may relate to this #64412
I could reproduce the code in this issue with pytorch master branch without using LTC. And using jemalloc helped.
Though for the conv example, it still has mem leak but less comparing to without using jemalloc. I'm checking the multi-threading malloc.
cc: @wconstab

@wconstab
Copy link
Contributor
wconstab commented Feb 23, 2022

Thanks for the updates @wenzhe-nrv - just checking my understanding:

I could reproduce the code in this issue with pytorch master branch without using LTC.

You mean that the SimpleNet example above also leaks on eager pytorch?

And using jemalloc helped.

Helps on lazy tensor as well as eager?

Though for the conv example, it still has mem leak but less comparing to without using jemalloc. I'm checking the multi-threading malloc. cc: @wconstab

Ok, interesting.

@wenzhe-nrv
Copy link
Collaborator

No, not the SimpleNet. I mean this example with a GRU here: #64412 (comment)

jemalloc helped for both case. My assumption is that the memory optimization with jemalloc mainly on the python object side, based on the valgrind memcheck log.

Valgrind memcheck reported leak under the stack of the conv and linear op (MKL and MKLDNN evolved). Its upper stack points to mem allocation in thread creation. The GRU example valgrind log didn't show any mkl/mkldnn related. This could be the reason jemalloc fully fixed GRU memory leak but the conv example is not fully stable (less memory usage/jump though).

@wenzhe-nrv
Copy link
Collaborator

Forgot to mention this, I tried NUM_OMP_THREADS=1 on the conv examples with LTC, it helps that the memory usage is less, but not fully stable. so assumption is multi-threading causes partial of the mem leak.

@wconstab
Copy link
Contributor

Ok- so, sounds like 2 leaks were identified (a) somewhere in python (jemalloc helps), (b) in mkldnn-conv and mkl-linear ops.

I'm not sure if (a) has been fully explained but it sounds like both (a) and (b) are not LTC related? Are there any more leaks that seem LTC related?

@wenzhe-nrv
Copy link
Collaborator

yea, (a) and (b) are not LTC related. I'm looking into (b) for now. I'll let you know if I found other issue related to LTC.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
lazy Lazy Tensor work items module: lazy triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants
0