模型保存测试

Posted by Shallow Dreameron July 16, 2025

✅ 好的! 这句话非常关键:

“我希望在加载时能完全还原训练时保存的模型,包括训练出的参数等。”


这里的重点:要“完全还原”不仅仅是结构和参数,还包括优化器状态、超参数、epoch 等训练上下文。


🟢 你需要什么?

需要什么 解释
1️⃣ 结构 必须保留(或用 TorchScript/ONNX 冻结)
2️⃣ 参数权重 state_dict
3️⃣ 优化器状态 optimizer.state_dict()
4️⃣ 超参数 你自己用 config 保存
5️⃣ 训练到第几轮等上下文 一并保存到 checkpoint
6️⃣ Python、PyTorch 版本 建议写进去防止不兼容

最常见的可重复训练/恢复方案

在 PyTorch,完整保存训练状态最通用的是:

# === 保存 ===
torch.save({
    'model_name': 'resnet50',
    'model_version': 'v1',
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': current_epoch,
    'train_loss': train_loss,
    'hyperparams': hyperparams,
    'pytorch_version': torch.__version__,
}, 'checkpoint.pt')

加载时:完全恢复

# === 1) 加载 ckpt ===
checkpoint = torch.load('checkpoint.pt')

# === 2) 从注册表找结构(或者直接 import)===
ModelClass = MODEL_REGISTRY[(checkpoint['model_name'], checkpoint['model_version'])]
model = ModelClass(**checkpoint['hyperparams'])
model.load_state_dict(checkpoint['model_state_dict'])

# === 3) 恢复优化器 ===
optimizer = torch.optim.Adam(model.parameters(), lr=checkpoint['hyperparams']['lr'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# === 4) 恢复 epoch ===
start_epoch = checkpoint['epoch'] + 1

这样你得到的好处

✅ 模型结构还原(通过 version + 注册表) ✅ 参数还原(state_dict) ✅ 优化器状态还原(momentum、lr schedulers) ✅ 可以从上次训练中断点接着跑 ✅ 不会有 pickle 跨版本坑 ✅ 同时也能额外导出 TorchScript/ONNX(用于推理)


完整示例:一份可执行的可重复训练保存加载脚本

下面给你直接一份最小版:


📂 my_model.py

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

📂 train_and_save.py

import torch
from my_model import MyModel

# === 模型配置 ===
hyperparams = {
    'input_dim': 4,
    'hidden_dim': 8,
    'output_dim': 2,
    'lr': 1e-3
}

model = MyModel(
    input_dim=hyperparams['input_dim'],
    hidden_dim=hyperparams['hidden_dim'],
    output_dim=hyperparams['output_dim']
)

optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams['lr'])

# === 假装训练 ===
x = torch.randn(10, 4)
y = torch.randint(0, 2, (10,))

criterion = torch.nn.CrossEntropyLoss()

for epoch in range(3):
    optimizer.zero_grad()
    logits = model(x)
    loss = criterion(logits, y)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch} Loss: {loss.item()}")

# === 保存完整 checkpoint ===
torch.save({
    'model_name': 'mymodel',
    'model_version': 'v1',
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epoch,
    'hyperparams': hyperparams,
    'pytorch_version': torch.__version__
}, 'checkpoint.pt')

print("✅ Saved full training checkpoint.")

📂 resume_training.py

import torch
from my_model import MyModel

# === 加载 ===
checkpoint = torch.load('checkpoint.pt')

hyperparams = checkpoint['hyperparams']

model = MyModel(
    input_dim=hyperparams['input_dim'],
    hidden_dim=hyperparams['hidden_dim'],
    output_dim=hyperparams['output_dim']
)

model.load_state_dict(checkpoint['model_state_dict'])

optimizer = torch.optim.Adam(model.parameters(), lr=hyperparams['lr'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

start_epoch = checkpoint['epoch'] + 1

print(f"✅ Resumed training from epoch {start_epoch}")

# === 继续训练 ===
x = torch.randn(10, 4)
y = torch.randint(0, 2, (10,))

criterion = torch.nn.CrossEntropyLoss()

for epoch in range(start_epoch, start_epoch + 2):
    optimizer.zero_grad()
    logits = model(x)
    loss = criterion(logits, y)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch} Loss: {loss.item()}")

这样保证了什么?

你要的 方案怎么做到
完全一致性 state_dict + optimizer.state_dict()
可继续训练 保存了 epoch,优化器动量
不依赖 pickle 只保存张量字典
可长期使用 结构用 version 管理,不随项目变动丢失
可推理 可以额外保存 TorchScript 或 ONNX

小结

训练时

  • 保留结构文件(或导出 TorchScript)
  • state_dict + optimizer.state_dict + 版本信息

加载时

  • 根据 version 找到正确结构
  • 用一致的超参数初始化
  • 完全还原

✅ 这样你无论是继续训练、后续微调还是推理,都不会踩坑!


要不要我给你打包一个:

  • state_dict + TorchScript + ONNX
  • 完整的一致性验证
  • 训练、继续训练、推理都可复用的全套示例?

要的话我马上帮你出一份! 🚀

好的,我来帮你整理一个“跨服务器、跨环境、可长期使用”的 PyTorch 官方推荐保存方案,一份可直接运行的最小示例,不踩 pickle 坑,还能在不同环境、不同 Python 路径下稳定使用!


目标

  • 只用 torch.save,但避开 pickle 的全局对象依赖
  • 使用 state_dict,可继续训练和复现
  • 同时导出 TorchScript,可部署和推理,不依赖 Python 源文件
  • 验证两个文件都能正常推理、结果一致

📦 推荐结构

1️⃣ 保存 state_dict ➜ 用于训练/微调 2️⃣ 保存 TorchScript ➜ 用于跨平台推理,不需要模型代码


完整可运行示例

保存为 safe_model_save.py,直接运行即可。


import torch
import torch.nn as nn

# =====================================================
# ✅ 定义模型
# =====================================================
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(4, 2)

    def forward(self, x):
        return self.linear(x)

# =====================================================
# ✅ 训练阶段(演示)
# =====================================================
# 初始化模型
model = MyModel()

# 伪造输入
x = torch.randn(1, 4)

# 正向推理
output_before = model(x)
print(f"✅ Output BEFORE saving: {output_before}")

# =====================================================
# ✅ 保存 state_dict
# =====================================================
torch.save(model.state_dict(), 'model_state.pt')
print("✅ Saved state_dict to model_state.pt")

# =====================================================
# ✅ 保存 TorchScript 模型
# =====================================================
scripted_model = torch.jit.script(model)
scripted_model.save('model_scripted.pt')
print("✅ Saved TorchScript model to model_scripted.pt")

# =====================================================
# ✅ === 重新加载 state_dict ===
# 需要有相同结构定义
# =====================================================
loaded_model = MyModel()
loaded_model.load_state_dict(torch.load('model_state.pt'))
loaded_model.eval()

output_after_state = loaded_model(x)
print(f"✅ Output AFTER loading state_dict: {output_after_state}")

# =====================================================
# ✅ === 重新加载 TorchScript ===
# 不需要任何结构代码
# =====================================================
loaded_scripted = torch.jit.load('model_scripted.pt')
loaded_scripted.eval()

output_after_scripted = loaded_scripted(x)
print(f"✅ Output AFTER loading TorchScript: {output_after_scripted}")

# =====================================================
# ✅ 验证三者一致
# =====================================================
print("\n🟢 SAME? state_dict:", torch.allclose(output_before, output_after_state))
print("🟢 SAME? TorchScript:", torch.allclose(output_before, output_after_scripted))

🟢 怎么跑?

1️⃣ 第一次跑:

python safe_model_save.py

你会看到:

  • state_dict 保存成功 ✅
  • TorchScript 保存成功 ✅
  • 加载后输出一致 ✅

2️⃣ 验证可跨平台

  • model_scripted.pt 拷贝到别的服务器
  • 在新的 Python 环境里只要 torch 可用 ➜ torch.jit.load 就能推理

3️⃣ 模拟结构丢失:

  • class MyModel 注释掉,重新跑:
    • 加载 state_dict 会挂(找不到结构)
    • 加载 TorchScript 正常(不依赖结构)

优缺点总结

保存方式 可继续训练 可跨平台 依赖 Python 源文件 推荐用途
state_dict ⚠️(需结构) ✅ 需要 训练、微调
TorchScript ❌ 不需要 推理部署

🟢 生产最佳实践

1️⃣ 训练 ➜ 保存 state_dict + TorchScript 2️⃣ 部署 ➜ 加载 TorchScript,不依赖 Python 源码 3️⃣ 版本管理 ➜ 在文件名里带版本号或哈希:model_v1_scripted.pt


如果需要,我可以帮你再写:

  • 跨机器导入 ONNX 的完整对比示例
  • 或者带更复杂控制流、torch.jit.scripttorch.jit.trace 的示例

要不要?我可以直接帮你打包成【可部署模板】! 🚀

好!我给你做一份【动态生成结构 + 嵌入参数 + register_buffer + 验证一致性】的可直接运行的完整示例, 这样你就能一眼看懂:

  • 哪些东西必须注册成 ParameterBuffer
  • 保存 state_dict 时,哪些会保存,哪些不会
  • 怎么保证加载后结果和训练时一致

保存成一个文件:dynamic_model_safe_save.py

直接复制保存,然后:

python dynamic_model_safe_save.py

就能看效果。


import torch
import torch.nn as nn
import random

# ======================================================
# ✅ 动态生成结构 + 嵌入权重的模型示例
# ======================================================
class DynamicModel(nn.Module):
    def __init__(self, input_size, hidden_size, use_buffer=True):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # 这里是动态生成的 Linear 层
        self.linear = nn.Linear(input_size, hidden_size)

        # 这里是动态生成的嵌入权重:必须注册
        rand_weight = torch.randn(hidden_size)
        if use_buffer:
            self.register_buffer('my_embed', rand_weight)
        else:
            self.my_embed = rand_weight  # ⚠️ 这样就不会进 state_dict!

    def forward(self, x):
        out = self.linear(x)
        # 使用动态生成的嵌入参数
        out = out + self.my_embed
        return out

# ======================================================
# ✅ 保证可复现
# ======================================================
torch.manual_seed(42)
random.seed(42)

# ======================================================
# ✅ 训练时
# ======================================================
model = DynamicModel(input_size=4, hidden_size=2, use_buffer=True)

# 做个前向
x = torch.randn(1, 4)
output_before = model(x)
print(f"✅ Output BEFORE saving: {output_before}")

# 保存结构参数
torch.save({
    'model_state_dict': model.state_dict(),
    'model_args': {
        'input_size': 4,
        'hidden_size': 2,
        'use_buffer': True
    }
}, "dynamic_model_state.pt")
print("✅ Saved state_dict and config!")

# ======================================================
# ✅ === 加载 ===
# ======================================================
checkpoint = torch.load("dynamic_model_state.pt")

# 用保存的结构参数重新生成模型
model2 = DynamicModel(**checkpoint['model_args'])
model2.load_state_dict(checkpoint['model_state_dict'])
model2.eval()

# 用相同输入验证输出
output_after = model2(x)
print(f"✅ Output AFTER loading: {output_after}")

print("\n🟢 SAME? ", torch.allclose(output_before, output_after))

# ======================================================
# ✅ === 演示如果不注册 Buffer 会怎样 ===
# ======================================================
print("\n⚠️ 演示: 如果不注册 Buffer,my_embed 不会被保存!")
model_bad = DynamicModel(input_size=4, hidden_size=2, use_buffer=False)
torch.manual_seed(42)
output_bad = model_bad(x)
torch.save(model_bad.state_dict(), "bad_state.pt")

# 重新加载: 注意我们重新生成了结构,但 my_embed 没被保存
model_bad2 = DynamicModel(input_size=4, hidden_size=2, use_buffer=False)
model_bad2.load_state_dict(torch.load("bad_state.pt"))
model_bad2.eval()

output_bad_after = model_bad2(x)

print(f"⚠️ Output BEFORE (bad): {output_bad}")
print(f"⚠️ Output AFTER (bad): {output_bad_after}")
print("🟡 SAME? (bad)", torch.allclose(output_bad, output_bad_after))

这份示例里做了啥

✅ 正确做法:

  • my_embedregister_buffer 注册 ➜ 会进 state_dict
  • 保存结构参数 ➜ 保证动态生成可重现
  • 加载时 .eval() + 用同样输入验证 ➜ 输出一致

⚠️ 错误做法:

  • use_buffer=Falsemy_embed 没注册 ➜ 没保存进 state_dict
  • 加载后 my_embed 会重新随机生成 ➜ 输出变了

🟢 跑完会看到:

✅ Output BEFORE saving: ...
✅ Saved state_dict and config!
✅ Output AFTER loading: ...
🟢 SAME?  True

⚠️ 演示: 如果不注册 Buffer,my_embed 不会被保存!
⚠️ Output BEFORE (bad): ...
⚠️ Output AFTER (bad): ...
🟡 SAME? (bad) False

核心 takeaway

做法 会不会丢状态? 结果一致吗?
使用 register_buffer 不丢 ✅ 一致
不注册 ❌ 不一致

🔑 你的动态生成结构想要“只保存 state_dict 也能跑”必须做到:

1️⃣ 所有需要保存的 tensor 都是 nn.Parameterregister_buffer 2️⃣ 动态生成的超参数一并保存 3️⃣ 推理时 .eval() 4️⃣ 输入一致 ➜ 结果一致


要不要我帮你把这个示例再加上:

  • TorchScript 的版本 ➜ 没有结构也能跑
  • ONNX 的版本 ➜ 跨框架部署

需要的话我给你整理到一份可跑的跨环境模板!🚀