基于PyTorch Lightning的图像增强框架,集成超分辨率/去噪/低光增强等核心功能,支持快速实验部署。
-
代码简洁性
消除重复训练循环代码,聚焦核心模型设计(相比原生PyTorch减少70%样板代码)# 原生PyTorch需手动编写训练循环 # Lightning只需定义LightningModule各阶段逻辑
-
工程化训练支持
• 自动混合精度训练(precision=16
开启FP16加速)• 多GPU/TPU支持(修改
devices
参数即可)• 梯度累积(
accumulate_grad_batches=4
优化显存) -
智能实验管理
# 自动记录超参数/指标(TensorBoard/W&B集成) # 模型检查点自动保存(val_loss监控与早停) # 100%复现性保障(seed_everything全局种子控制)
-
灵活扩展机制
通过Callback
系统实现:# 自定义学习率策略 # 梯度裁剪/可视化回调 # 自定义分布式策略
git clone https://github.com/cuncunsama/IRLit.git
conda create --name IRLit python=3.10
cd IRLit
pip install -e .
pip install -U 'jsonargparse[signatures]>=4.27.7'
pip install -r requirements.txt
命令 | 说明 | 参数示例 |
---|---|---|
fit |
完整训练流程 | python main.py fit -c config/NAFNet.yaml |
validate |
验证集评估 | python main.validate fit -c config/NAFNet.yaml |
test |
测试集评估 | python main.py test -c config/NAFNet.yaml |
predict |
推理 | python main.py predict -c config/NAFNet.yaml |
# config/NAFNet.yaml 核心参数
seed_everything: 10
trainer:
accelerator: auto
strategy: auto
devices: [1]
num_nodes: 1
ckpt_path: null
model:
class_path: IRLitModule
init_args:
net:
class_path: NAFNet
init_args:
img_channel: 3
width: 64
enc_blk_nums: [2, 2, 4, 8]
middle_blk_num: 12
dec_blk_nums: [2, 2, 2, 2]
data:
class_path: IRLitDataModule
init_args:
train:
dataset:
flag: 1 # 0: grayscale, 1: color, -1: unchanged
lq_dir: /home/yxq/project/datasets/SIDD/train/lq512.lmdb
gt_dir: /home/yxq/project/datasets/SIDD/train/gt512.lmdb
patch_size: 256
dataloader:
batch_size: 8
num_workers: 8
-
调试模式
# 快速验证数据流 python main.py fit --config.debug_mode=True --trainer.fast_dev_run=3
-
实验追踪
# 启动TensorBoard tensorboard --logdir=logs/
-
生产部署
# ONNX/TensorRT导出支持 model = IRLit.load_from_checkpoint("model.ckpt") torch.onnx.export(model, input_sample, "model.onnx")