-
Notifications
You must be signed in to change notification settings - Fork 18
feat: reserve interface for other torch devices #27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Reviewer's Guide by SourceryThis PR refactors device handling in the Final2x_core library by introducing a dedicated device utility function and updating the supported device list. The changes centralize device initialization logic and add support for DirectML devices. Class diagram for updated device handling in Final2x_coreclassDiagram
class CCRestoration {
-SRBaseModel _SR_class
+CCRestoration(SRConfig config)
+process(np.ndarray img) np.ndarray
}
class SRConfig {
+String device
+String pretrained_model_name
+String gh_proxy
}
class AutoModel {
+from_pretrained(String pretrained_model_name, Boolean fp16, Union<torch.device, String> device, String gh_proxy)
}
class device {
+get_device(String device) Union<torch.device, String>
}
CCRestoration --> SRConfig
CCRestoration --> AutoModel
AutoModel --> device
note for device "New utility function for device handling"
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @Tohrusky - I've reviewed your changes and they look great!
Here's what I looked at during the review
- 🟡 General issues: 2 issues found
- 🟢 Security: all looks good
- 🟢 Testing: all looks good
- 🟢 Complexity: all looks good
- 🟢 Documentation: all looks good
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
elif device.startswith("xpu"): | ||
return torch.device("xpu") | ||
else: | ||
print(f"Unknown device: {device}, use auto instead.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Consider using logger.warning instead of print for unknown device fallback
Using the logger would be more consistent with the rest of the codebase and provide better error tracking. Also consider raising a ValueError instead of falling back silently.
logger.warning(f"Unknown device: {device}, use auto instead.")
elif device.startswith("mps"): | ||
return torch.device("mps") | ||
elif device.startswith("directml"): | ||
import torch_directml |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Add explicit error handling for optional directml import
Consider wrapping this in a try-except block to provide a more informative error message if torch_directml is not installed.
import torch_directml | |
try: | |
import torch_directml | |
except ImportError: | |
raise ImportError("torch_directml is not installed. Please install it to use DirectML device.") |
Summary by Sourcery
Add support for 'directml' device and refactor device initialization to improve flexibility in selecting torch devices. Update the project version to 3.0.2.
New Features:
Enhancements:
Chores: