8000 Fix persistent worker exits before pin_memory thread by ejguan · Pull Request #71579 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Fix persistent worker exits before pin_memory thread #71579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from

Conversation

ejguan
Copy link
Contributor
@ejguan ejguan commented Jan 20, 2022

Fixes #1551

Stack from ghstack:

As the comment in the code, register a function to terminate persistent workers.
By adding a reference of these workers in atexit, it would prevent Python interpreter kills these persistent worker processes before pin_memorh_thread exits.
And, if users explicitly kills DataLoader iterator, such function in atexit would be a no-op.

Differential Revision: D33896537

@facebook-github-bot
Copy link
Contributor
facebook-github-bot commented Jan 20, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 0802991 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Fixes #1551


As the comment in the code, adding a function to manually delete the iterator with persistent worker. It would invoke `__del__` inside Iterator object and make sure pin_memory_thread exits before worker process.

I choose using `atexit` rather than adding `__del__` to DataLoader because it's not safe as the destructor function may not be invoked when Python interpreter exits.

[ghstack-poisoned]
ejguan added a commit that referenced this pull request Jan 20, 2022
ghstack-source-id: 57a8e9e
Pull Request resolved: #71579
@ejguan ejguan changed the title Fix persistent worker exists before pin_memory thread Fix persistent worker exits before pin_memory thread Jan 20, 2022
Fixes #1551


As the comment in the code, adding a function to manually delete the iterator with persistent worker. It would invoke `__del__` inside Iterator object and make sure pin_memory_thread exits before worker process.

I choose using `atexit` rather than adding `__del__` to DataLoader because it's not safe as the destructor function may not be invoked when Python interpreter exits.

[ghstack-poisoned]
ejguan added a commit that referenced this pull request Jan 20, 2022
ghstack-source-id: 05cc8f7
Pull Request resolved: #71579
Comment on lines 949 to 955
# In some rare cases, persistent workers (daemonic processes)
# would be terminated before `__del__` of iterator is invoked
# when main process exits
# It would cause failure when pin_memory_thread tries to read
# corrupted data from worker_result_queue
# atexit is used to prevent persistent workers exiting before
# pin_memory_thread
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the comment.

Comment on lines 2305 to 2339
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_early_exit(self):
import subprocess
proc = subprocess.run([sys.executable, '-c', """\
import torch
from torch.utils.data import DataLoader, IterableDataset

class RandomDataset(IterableDataset):
def __init__(self, len, size):
super(RandomDataset).__init__()
self.len = len
self.size = size

def __iter__(self):
return self

def __next__(self):
if self.len <= 0:
raise StopIteration
self.len -= 1
return torch.randn(self.size)

if __name__ == '__main__':
dl = DataLoader(
RandomDataset(64, (28, 28)),
batch_size=16,
num_workers=2,
pin_memory=True,
persistent_workers=True,
)

for _ in dl:
break
"""], stderr=subprocess.PIPE)
self.assertTrue(proc.stderr == b'')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test won't pass before this PR.

@ejguan
Copy link
Contributor Author
ejguan commented Jan 20, 2022

@ejguan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ejguan ejguan requested a review from VitalyFedyunin January 20, 2022 23:23
# corrupted data from worker_result_queue
# atexit is used to prevent persistent workers exiting before
# pin_memory_thread
if self._persistent_workers and self._pin_memory:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to potentially fix part of #71187 for the case of persistent workers, we may want to remove this condition on self._pin_memory

facebook-github-bot pushed a commit that referenced this pull request Jan 21, 2022
Summary:
Pull Request resolved: #71579

Fixes #1551

As the comment in the code, register a function to terminate persistent workers. Using `atexit` to make sure termination of persistent workers always happens at the end (after pin_memory_thread exits).
We need such mechanism because Python interpreter would clean up worker process before DataLoader iterator in some rare cases.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33694867

Pulled By: ejguan

fbshipit-source-id: 0847f4d424a0cd6b3c0be8235d505415970254e8
pytorchmergebot pushed a commit that referenced this pull request Jan 21, 2022
Summary:
Pull Request resolved: #71579

Fixes #1551

As the comment in the code, register a function to terminate persistent workers. Using `atexit` to make sure termination of persistent workers always happens at the end (after pin_memory_thread exits).
We need such mechanism because Python interpreter would clean up worker process before DataLoader iterator in some rare cases.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33694867

Pulled By: ejguan

fbshipit-source-id: 0847f4d424a0cd6b3c0be8235d505415970254e8
(cherry picked from commit 18ad462)
@malfet
Copy link
Contributor
malfet commented Jan 22, 2022

Reverting as it broke Windows test sanity(from https://github.com/pytorch/pytorch/runs/4903195822?check_suite_focus=true):

Traceback (most recent call last):
  File "test_dataloader.py", line 2339, in test_early_exit
    self.assertTrue(proc.stderr == b'')
AssertionError: False is not true```

@facebook-github-bot
Copy link
Contributor

This pull request has been reverted by 401ac7b52a96377814c37dfeaa3f0bdea231d058. To re-land this change, follow these steps.

@facebook-github-bot
Copy link
Contributor

This pull request has been reverted by 86aefdc. To re-land this change, follow these steps.

@ejguan ejguan reopened this Jan 24, 2022
@atalman atalman closed this Jan 26, 2022
@ejguan ejguan reopened this Jan 26, 2022
Fixes #1551


As the comment in the code, register a function to terminate persistent workers. 
By adding a reference of these workers in `atexit`, it would prevent Python interpreter kills these persistent worker processes before `pin_memorh_thread` exits.
And, if users explicitly kills DataLoader iterator, such function in `atexit` would be a no-op.


[ghstack-poisoned]
Fixes #1551


As the comment in the code, register a function to terminate persistent workers. 
By adding a reference of these workers in `atexit`, it would prevent Python interpreter kills these persistent worker processes before `pin_memorh_thread` exits.
And, if users explicitly kills DataLoader iterator, such function in `atexit` would be a no-op.


[ghstack-poisoned]
@ejguan
Copy link
Contributor Author
ejguan commented Jan 31, 2022

@ejguan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Fixes #1551


As the comment in the code, register a function to terminate persistent workers. 
By adding a reference of these workers in `atexit`, it would prevent Python interpreter kills these persistent worker processes before `pin_memorh_thread` exits.
And, if users explicitly kills DataLoader iterator, such function in `atexit` would be a no-op.

Differential Revision: [D33896537](https://our.internmc.facebook.com/intern/diff/D33896537)

[ghstack-poisoned]
ejguan added a commit that referenced this pull request Jan 31, 2022
ghstack-source-id: 2d15b14
Pull Request resolved: #71579
@ejguan
Copy link
Contributor Author
ejguan commented Jan 31, 2022

@ejguan has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Feb 1, 2022
Summary:
Pull Request resolved: #71579

Fixes #1551

As the comment in the code, register a function to terminate persistent workers.
By adding a reference of these workers in `atexit`, it would prevent Python interpreter kills these persistent worker processes before `pin_memorh_thread` exits.
And, if users explicitly kills DataLoader iterator, such function in `atexit` would be a no-op.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33896537

Pulled By: ejguan

fbshipit-source-id: 36b57eac7523d8aa180180c2b61fc693ea4638ae
pytorchmergebot pushed a commit that referenced this pull request Feb 1, 2022
Summary:
Pull Request resolved: #71579

Fixes #1551

As the comment in the code, register a function to terminate persistent workers.
By adding a reference of these workers in `atexit`, it would prevent Python interpreter kills these persistent worker processes before `pin_memorh_thread` exits.
And, if users explicitly kills DataLoader iterator, such function in `atexit` would be a no-op.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33896537

Pulled By: ejguan

fbshipit-source-id: 36b57eac7523d8aa180180c2b61fc693ea4638ae
(cherry picked from commit 05add2a)
@github-actions
Copy link
Contributor
github-actions bot commented Feb 1, 2022

Hey ejguan.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@ejguan ejguan added the release notes: dataloader release notes category label Feb 2, 2022
@anjali411
Copy link
Contributor

@ejguan can you also put a topic label to it? 😊

@ejguan ejguan added the topic: bug fixes topic category label Feb 2, 2022
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 3, 2022
Summary:
Pull Request resolved: pytorch/pytorch#71579

Fixes #1551

As the comment in the code, register a function to terminate persistent workers. Using `atexit` to make sure termination of persistent workers always happens at the end (after pin_memory_thread exits).
We need such mechanism because Python interpreter would clean up worker process before DataLoader iterator in some rare cases.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33694867

Pulled By: ejguan

fbshipit-source-id: 0847f4d424a0cd6b3c0be8235d505415970254e8
(cherry picked from commit 18ad462)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 3, 2022
Summary:
Pull Request resolved: pytorch/pytorch#71579

Fixes #1551

As the comment in the code, register a function to terminate persistent workers.
By adding a reference of these workers in `atexit`, it would prevent Python interpreter kills these persistent worker processes before `pin_memorh_thread` exits.
And, if users explicitly kills DataLoader iterator, such function in `atexit` would be a no-op.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33896537

Pulled By: ejguan

fbshipit-source-id: 36b57eac7523d8aa180180c2b61fc693ea4638ae
(cherry picked from commit 05add2a)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 3, 2022 F438
Summary:
Pull Request resolved: pytorch/pytorch#71579

Fixes #1551

As the comment in the code, register a function to terminate persistent workers. Using `atexit` to make sure termination of persistent workers always happens at the end (after pin_memory_thread exits).
We need such mechanism because Python interpreter would clean up worker process before DataLoader iterator in some rare cases.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33694867

Pulled By: ejguan

fbshipit-source-id: 0847f4d424a0cd6b3c0be8235d505415970254e8
(cherry picked from commit 18ad462)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 3, 2022
Summary:
Pull Request resolved: pytorch/pytorch#71579

Fixes #1551

As the comment in the code, register a function to terminate persistent workers.
By adding a reference of these workers in `atexit`, it would prevent Python interpreter kills these persistent worker processes before `pin_memorh_thread` exits.
And, if users explicitly kills DataLoader iterator, such function in `atexit` would be a no-op.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33896537

Pulled By: ejguan

fbshipit-source-id: 36b57eac7523d8aa180180c2b61fc693ea4638ae
(cherry picked from commit 05add2a)
ejguan added a commit to ejguan/pytorch that referenced this pull request Feb 3, 2022
malfet pushed a commit that referenced this pull request Feb 4, 2022
* release 1.11 Install torch from test channel, Pin builder and xla repo (#72217)

* Fix persistent worker exits before pin_memory thread

ghstack-source-id: 2d15b14
Pull Request resolved: #71579

Co-authored-by: Andrey Talman <atalman@fb.com>
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 9, 2022
Summary:
Pull Request resolved: pytorch/pytorch#71579

Fixes #1551

As the comment in the code, register a function to terminate persistent workers. Using `atexit` to make sure termination of persistent workers always happens at the end (after pin_memory_thread exits).
We need such mechanism because Python interpreter would clean up worker process before DataLoader iterator in some rare cases.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33694867

Pulled By: ejguan

fbshipit-source-id: 0847f4d424a0cd6b3c0be8235d505415970254e8
(cherry picked from commit 18ad462)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 9, 2022
Summary:
Pull Request resolved: pytorch/pytorch#71579

Fixes #1551

As the comment in the code, register a function to terminate persistent workers.
By adding a reference of these workers in `atexit`, it would prevent Python interpreter kills these persistent worker processes before `pin_memorh_thread` exits.
And, if users explicitly kills DataLoader iterator, such function in `atexit` would be a no-op.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33896537

Pulled By: ejguan

fbshipit-source-id: 36b57eac7523d8aa180180c2b61fc693ea4638ae
(cherry picked from commit 05add2a)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 9, 2022
Summary:
Pull Request resolved: pytorch/pytorch#71579

Fixes #1551

As the comment in the code, register a function to terminate persistent workers. Using `atexit` to make sure termination of persistent workers always happens at the end (after pin_memory_thread exits).
We need such mechanism because Python interpreter would clean up worker process before DataLoader iterator in some rare cases.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33694867

Pulled By: ejguan

fbshipit-source-id: 0847f4d424a0cd6b3c0be8235d505415970254e8
(cherry picked from commit 18ad462)
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Feb 9, 2022
Summary:
Pull Request resolved: pytorch/pytorch#71579

Fixes #1551

As the comment in the code, register a function to terminate persistent workers.
By adding a reference of these workers in `atexit`, it would prevent Python interpreter kills these persistent worker processes before `pin_memorh_thread` exits.
And, if users explicitly kills DataLoader iterator, such function in `atexit` would be a no-op.

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D33896537

Pulled By: ejguan

fbshipit-source-id: 36b57eac7523d8aa180180c2b61fc693ea4638ae
(cherry picked from commit 05add2a)
@facebook-github-bot facebook-github-bot deleted the gh/ejguan/106/head branch March 4, 2022 15:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0