✅ 好的! 这句话非常关键:
“我希望在加载时能完全还原训练时保存的模型,包括训练出的参数等。”
⚡ 这里的重点:要“完全还原”不仅仅是结构和参数,还包括优化器状态、超参数、epoch 等训练上下文。
🟢 你需要什么?
需要什么 | 解释 |
---|---|
1️⃣ 结构 | 必须保留(或用 TorchScript/ONNX 冻结) |
2️⃣ 参数权重 | state_dict |
3️⃣ 优化器状态 | optimizer.state_dict() |
4️⃣ 超参数 | 你自己用 config 保存 |
5️⃣ 训练到第几轮等上下文 | 一并保存到 checkpoint |
6️⃣ Python、PyTorch 版本 | 建议写进去防止不兼容 |
✅ 最常见的可重复训练/恢复方案
在 PyTorch,完整保存训练状态最通用的是:
# === 保存 ===
torch.save({
'model_name': 'resnet50',
'model_version': 'v1',
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': current_epoch,
'train_loss': train_loss,
'hyperparams': hyperparams,
'pytorch_version': torch.__version__,
}, 'checkpoint.pt')
✅ 加载时:完全恢复
# === 1) 加载 ckpt ===
checkpoint = torch.load('checkpoint.pt')
# === 2) 从注册表找结构(或者直接 import)===
ModelClass = MODEL_REGISTRY[(checkpoint['model_name'], checkpoint['model_version'])]
model = ModelClass(**checkpoint['hyperparams'])
model.load_state_dict(checkpoint['model_state_dict'])
# === 3) 恢复优化器 ===
optimizer = torch.optim.Adam(model.parameters(), lr=checkpoint['hyperparams']['lr'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# === 4) 恢复 epoch ===
start_epoch = checkpoint['epoch'] + 1
✅ 这样你得到的好处
✅ 模型结构还原(通过 version + 注册表)
✅ 参数还原(state_dict
)
✅ 优化器状态还原(momentum、lr schedulers)
✅ 可以从上次训练中断点接着跑
✅ 不会有 pickle 跨版本坑
✅ 同时也能额外导出 TorchScript/ONNX(用于推理)
⚡ 完整示例:一份可执行的可重复训练保存加载脚本
下面给你直接一份最小版:
📂 my_model.py
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.net(x)
📂 train_and_save.py
import torch
from my_model import MyModel
# === 模型配置 ===
hyperparams = {
'input_dim': 4,
'hidden_dim': 8,
'output_dim': 2,
'lr': 1e-3
}
model = MyModel(
input_dim=hyperparams['input_dim'],
hidden_dim=hyperparams['hidden_dim'],
output_dim=hyperparams['output_dim']
)
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams['lr'])
# === 假装训练 ===
x = torch.randn(10, 4)
y = torch.randint(0, 2, (10,))
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(3):
optimizer.zero_grad()
logits = model(x)
loss = criterion(logits, y)
loss.backward()
optimizer.step()
print(f"Epoch {epoch} Loss: {loss.item()}")
# === 保存完整 checkpoint ===
torch.save({
'model_name': 'mymodel',
'model_version': 'v1',
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'hyperparams': hyperparams,
'pytorch_version': torch.__version__
}, 'checkpoint.pt')
print("✅ Saved full training checkpoint.")
📂 resume_training.py
import torch
from my_model import MyModel
# === 加载 ===
checkpoint = torch.load('checkpoint.pt')
hyperparams = checkpoint['hyperparams']
model = MyModel(
input_dim=hyperparams['input_dim'],
hidden_dim=hyperparams['hidden_dim'],
output_dim=hyperparams['output_dim']
)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams['lr'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
print(f"✅ Resumed training from epoch {start_epoch}")
# === 继续训练 ===
x = torch.randn(10, 4)
y = torch.randint(0, 2, (10,))
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(start_epoch, start_epoch + 2):
optimizer.zero_grad()
logits = model(x)
loss = criterion(logits, y)
loss.backward()
optimizer.step()
print(f"Epoch {epoch} Loss: {loss.item()}")
✅ 这样保证了什么?
你要的 | 方案怎么做到 |
---|---|
完全一致性 | state_dict + optimizer.state_dict() |
可继续训练 | 保存了 epoch,优化器动量 |
不依赖 pickle | 只保存张量字典 |
可长期使用 | 结构用 version 管理,不随项目变动丢失 |
可推理 | 可以额外保存 TorchScript 或 ONNX |
✅ 小结
训练时
- 保留结构文件(或导出 TorchScript)
state_dict
+optimizer.state_dict
+ 版本信息
加载时
- 根据 version 找到正确结构
- 用一致的超参数初始化
- 完全还原
✅ 这样你无论是继续训练、后续微调还是推理,都不会踩坑!
要不要我给你打包一个:
state_dict
+TorchScript
+ONNX
- 完整的一致性验证
- 训练、继续训练、推理都可复用的全套示例?
要的话我马上帮你出一份! 🚀
好的,我来帮你整理一个“跨服务器、跨环境、可长期使用”的 PyTorch 官方推荐保存方案,一份可直接运行的最小示例,不踩 pickle 坑,还能在不同环境、不同 Python 路径下稳定使用!
✅ 目标
- 只用
torch.save
,但避开 pickle 的全局对象依赖 - 使用
state_dict
,可继续训练和复现 - 同时导出
TorchScript
,可部署和推理,不依赖 Python 源文件 - 验证两个文件都能正常推理、结果一致
📦 推荐结构
1️⃣ 保存 state_dict
➜ 用于训练/微调
2️⃣ 保存 TorchScript
➜ 用于跨平台推理,不需要模型代码
✅ 完整可运行示例
保存为 safe_model_save.py
,直接运行即可。
import torch
import torch.nn as nn
# =====================================================
# ✅ 定义模型
# =====================================================
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 2)
def forward(self, x):
return self.linear(x)
# =====================================================
# ✅ 训练阶段(演示)
# =====================================================
# 初始化模型
model = MyModel()
# 伪造输入
x = torch.randn(1, 4)
# 正向推理
output_before = model(x)
print(f"✅ Output BEFORE saving: {output_before}")
# =====================================================
# ✅ 保存 state_dict
# =====================================================
torch.save(model.state_dict(), 'model_state.pt')
print("✅ Saved state_dict to model_state.pt")
# =====================================================
# ✅ 保存 TorchScript 模型
# =====================================================
scripted_model = torch.jit.script(model)
scripted_model.save('model_scripted.pt')
print("✅ Saved TorchScript model to model_scripted.pt")
# =====================================================
# ✅ === 重新加载 state_dict ===
# 需要有相同结构定义
# =====================================================
loaded_model = MyModel()
loaded_model.load_state_dict(torch.load('model_state.pt'))
loaded_model.eval()
output_after_state = loaded_model(x)
print(f"✅ Output AFTER loading state_dict: {output_after_state}")
# =====================================================
# ✅ === 重新加载 TorchScript ===
# 不需要任何结构代码
# =====================================================
loaded_scripted = torch.jit.load('model_scripted.pt')
loaded_scripted.eval()
output_after_scripted = loaded_scripted(x)
print(f"✅ Output AFTER loading TorchScript: {output_after_scripted}")
# =====================================================
# ✅ 验证三者一致
# =====================================================
print("\n🟢 SAME? state_dict:", torch.allclose(output_before, output_after_state))
print("🟢 SAME? TorchScript:", torch.allclose(output_before, output_after_scripted))
🟢 怎么跑?
1️⃣ 第一次跑:
python safe_model_save.py
你会看到:
state_dict
保存成功 ✅TorchScript
保存成功 ✅- 加载后输出一致 ✅
2️⃣ 验证可跨平台:
- 把
model_scripted.pt
拷贝到别的服务器 - 在新的 Python 环境里只要
torch
可用 ➜torch.jit.load
就能推理
3️⃣ 模拟结构丢失:
- 把
class MyModel
注释掉,重新跑:- 加载
state_dict
会挂(找不到结构) - 加载
TorchScript
正常(不依赖结构)
- 加载
✅ 优缺点总结
保存方式 | 可继续训练 | 可跨平台 | 依赖 Python 源文件 | 推荐用途 |
---|---|---|---|---|
state_dict |
✅ | ⚠️(需结构) | ✅ 需要 | 训练、微调 |
TorchScript |
❌ | ✅ | ❌ 不需要 | 推理部署 |
🟢 生产最佳实践
1️⃣ 训练 ➜ 保存 state_dict
+ TorchScript
2️⃣ 部署 ➜ 加载 TorchScript
,不依赖 Python 源码
3️⃣ 版本管理 ➜ 在文件名里带版本号或哈希:model_v1_scripted.pt
如果需要,我可以帮你再写:
- 跨机器导入
ONNX
的完整对比示例 - 或者带更复杂控制流、
torch.jit.script
和torch.jit.trace
的示例
要不要?我可以直接帮你打包成【可部署模板】! 🚀
好!我给你做一份【动态生成结构 + 嵌入参数 + register_buffer + 验证一致性】的可直接运行的完整示例, 这样你就能一眼看懂:
- 哪些东西必须注册成
Parameter
或Buffer
- 保存
state_dict
时,哪些会保存,哪些不会 - 怎么保证加载后结果和训练时一致
✅ 保存成一个文件:dynamic_model_safe_save.py
直接复制保存,然后:
python dynamic_model_safe_save.py
就能看效果。
import torch
import torch.nn as nn
import random
# ======================================================
# ✅ 动态生成结构 + 嵌入权重的模型示例
# ======================================================
class DynamicModel(nn.Module):
def __init__(self, input_size, hidden_size, use_buffer=True):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
# 这里是动态生成的 Linear 层
self.linear = nn.Linear(input_size, hidden_size)
# 这里是动态生成的嵌入权重:必须注册
rand_weight = torch.randn(hidden_size)
if use_buffer:
self.register_buffer('my_embed', rand_weight)
else:
self.my_embed = rand_weight # ⚠️ 这样就不会进 state_dict!
def forward(self, x):
out = self.linear(x)
# 使用动态生成的嵌入参数
out = out + self.my_embed
return out
# ======================================================
# ✅ 保证可复现
# ======================================================
torch.manual_seed(42)
random.seed(42)
# ======================================================
# ✅ 训练时
# ======================================================
model = DynamicModel(input_size=4, hidden_size=2, use_buffer=True)
# 做个前向
x = torch.randn(1, 4)
output_before = model(x)
print(f"✅ Output BEFORE saving: {output_before}")
# 保存结构参数
torch.save({
'model_state_dict': model.state_dict(),
'model_args': {
'input_size': 4,
'hidden_size': 2,
'use_buffer': True
}
}, "dynamic_model_state.pt")
print("✅ Saved state_dict and config!")
# ======================================================
# ✅ === 加载 ===
# ======================================================
checkpoint = torch.load("dynamic_model_state.pt")
# 用保存的结构参数重新生成模型
model2 = DynamicModel(**checkpoint['model_args'])
model2.load_state_dict(checkpoint['model_state_dict'])
model2.eval()
# 用相同输入验证输出
output_after = model2(x)
print(f"✅ Output AFTER loading: {output_after}")
print("\n🟢 SAME? ", torch.allclose(output_before, output_after))
# ======================================================
# ✅ === 演示如果不注册 Buffer 会怎样 ===
# ======================================================
print("\n⚠️ 演示: 如果不注册 Buffer,my_embed 不会被保存!")
model_bad = DynamicModel(input_size=4, hidden_size=2, use_buffer=False)
torch.manual_seed(42)
output_bad = model_bad(x)
torch.save(model_bad.state_dict(), "bad_state.pt")
# 重新加载: 注意我们重新生成了结构,但 my_embed 没被保存
model_bad2 = DynamicModel(input_size=4, hidden_size=2, use_buffer=False)
model_bad2.load_state_dict(torch.load("bad_state.pt"))
model_bad2.eval()
output_bad_after = model_bad2(x)
print(f"⚠️ Output BEFORE (bad): {output_bad}")
print(f"⚠️ Output AFTER (bad): {output_bad_after}")
print("🟡 SAME? (bad)", torch.allclose(output_bad, output_bad_after))
✅ 这份示例里做了啥
✅ 正确做法:
my_embed
用register_buffer
注册 ➜ 会进state_dict
- 保存结构参数 ➜ 保证动态生成可重现
- 加载时
.eval()
+ 用同样输入验证 ➜ 输出一致
⚠️ 错误做法:
use_buffer=False
时my_embed
没注册 ➜ 没保存进state_dict
- 加载后
my_embed
会重新随机生成 ➜ 输出变了
🟢 跑完会看到:
✅ Output BEFORE saving: ...
✅ Saved state_dict and config!
✅ Output AFTER loading: ...
🟢 SAME? True
⚠️ 演示: 如果不注册 Buffer,my_embed 不会被保存!
⚠️ Output BEFORE (bad): ...
⚠️ Output AFTER (bad): ...
🟡 SAME? (bad) False
✅ 核心 takeaway
做法 | 会不会丢状态? | 结果一致吗? |
---|---|---|
使用 register_buffer |
不丢 | ✅ 一致 |
不注册 | 丢 | ❌ 不一致 |
🔑 你的动态生成结构想要“只保存 state_dict 也能跑”必须做到:
1️⃣ 所有需要保存的 tensor 都是 nn.Parameter
或 register_buffer
2️⃣ 动态生成的超参数一并保存
3️⃣ 推理时 .eval()
4️⃣ 输入一致 ➜ 结果一致
要不要我帮你把这个示例再加上:
TorchScript
的版本 ➜ 没有结构也能跑- ONNX 的版本 ➜ 跨框架部署
需要的话我给你整理到一份可跑的跨环境模板!🚀