Products
GG网络技术分享 2026-04-17 05:56 0
说实话, 玩PyTorch的朋友们大多都有过那种“哎呀,我的训练刚跑到第1999步,服务器突然宕机,我的心都凉了半截”的经历。特别是用DFGAN这种结构, 模型庞大、 捡漏。 参数多,一不小心就会掉进断点续训的深渊。下面这篇文章, 我将把自己摸爬滚打的血泪史搬上来顺便抛出几招“妙招”,希望能帮你少点儿崩溃,多点儿笑声。
在PyTorch里checkpoint最核心的就是两大块:

model.state_dict——网络权重;optimizer.state_dict——优化器状态。如果你只保存了前者, 而忽略了后者,那恢复训练时学习率往往会回到初始值,导致后面几百轮都像在原地打转。 所以“断点续训主要保存的是网络模型的参数以及优化器optimizer的状态”。记住这句话,它比任何文档都管用。
def save_models:
if and != 0):
None
else:
state = {
'model': {
'netG': _dict,
'netD': _dict,
'netC': _dict
},
'optimizers': {
'optimizer_G': _dict,
'optimizer_D': _dict
},
'epoch': epoch
}
)
对应读取:
def load_model_opt(netG, netD, netC, optim_G, optim_D, path,
multi_gpus=False):
checkpoint = )
netG = load_model_weights(netG,
checkpoint,
multi_gpus)
netD = load_model_weights(netD,
checkpoint,
multi_gpus)
netC = load_model_weights(netC,
checkpoint,
multi_gpus)
optim_G = load_opt_weights(optim_G,
checkpoint)
optim_D = load_opt_weights(optim_D,
checkpoint)
return netG, netD, netC, optim_G, optim_D
Pitfall #1:size mismatch
报错类似:
size mismatch for : copying a param with shape () from checkpoint,
shape in current model is ().size mismatch for : copying a param with shape
from checkpoint,
shape in current model is .
改进一下。 原因大多是「模型结构改动」或「超参数不统一」导致。解决办法:
del state_dict。{k: v for k,v in state_dict.items if k in model.state_dict}过滤。牛逼。 Pitfall #2:忘记保存lr_scheduler!
到位。 LRScheduler本质也是一个优化器状态, 如果不恢复,它会从头开始衰减,训练曲线瞬间掉坑。下面给出完整保存方式:
state = scheduler.state_dict
torch.save(state,
os.path.join(save_path,
f'state_epoch_{epoch}.pth'))
If you are lazy like me and already use Lightning:
This will auto‑handle both model & optimizer states.,总的来说...
| #方案 | 磁盘IO | EFA | 恢复时间 |
|---|---|---|---|
| A. 单文件 | 120 | 3500+ | 12.4 |
| B. 分文件 | 95 | 2800 | 15.8 |
| C. 使用torch.save + zip压缩 | 70 | 2100 | 22.1 |
| D. SSD NVMe直写 | 340 | 8000 | 4.7 |
| E. 网络挂载 | 45 |
ckpt_dir = './saved_models/bird/pretrained/'
resume_epoch = 300 # 想从哪儿接着跑就改这个数字
checkpoint_path = f"{ckpt_dir}state_epoch_{resume_epoch}.pth"
checkpoint = torch.load(checkpoint_path,
map_location='cuda' if torch.cuda.is_available else 'cpu')
netG.load_state_dict
netD.load_state_dict
optG.load_state_dict
optD.load_state_dict
if 'scheduler' in checkpoint:
scheduler.load_state_dict
start_epoch = checkpoint + 1
print
for epoch in range(start_epoch,
start_epoch + 100): # 再跑100轮
train_one_epoch
if epoch % 10 == 0:
save_models
print
, 防止显存碎片化导致OOM。CNN里的卷积核可以丢失, 梯度可以爆炸,但只要你把`torch.save`+`torch.load`+`state_dict`三件套搞定,就没有不可跨越的断点。祝各位勇士们在DFGAN的大海里划桨不止, 造起来。 有时候还能踩到漂浮的小岛——那就是我们精心准备好的Checkpoint啦!如果你看完还没笑,那一定是太累了需要先去喝杯奶茶再继续撸代码。
© 2026 DFGAN爱好者社群 保留所有权利,仅供学习交流使用。Demand feedback