训练过程记录

Posted by Shallow Dreameron July 21, 2025

你提出的这个需求非常清晰,而且很实用,属于一种分段线性采样策略,我们可以称它为:


「分段精度采样(Piecewise Linear Sampling)」策略

📌 核心目标

  • 在训练轮次较小时(如 <1000)记录全部数据
  • 当训练轮次较大时(如 >100000)线性间隔采样
  • 总记录条数保持在合理范围,如 2000~10000;
  • 无需提前知道总轮次动态地适应轮次数量级
  • 保证重要的训练阶段(如前期)记录更细致

✅ 策略设计(分段采样)

轮次数量级 采样策略
1 ~ 1000 全记录
1000 ~ 10000 每 10 步记录一次
10000 ~ 100000 每 50 步记录一次
100000 ~ 1000000 每 200 步记录一次
1000000 ~ 10000000 每 1000 步记录一次

这样你可以兼容:

  • 少量轮次的详细调试;
  • 大量训练的压缩记录。

✅ 统一适配的记录器实现(含详细注释)

import time
from collections import defaultdict

class PiecewiseMetricsRecorder:
    def __init__(self, mode='step'):
        """
        分段采样策略记录器,自动适应训练长度,兼顾精度与性能
        :param mode: 'step' 或 'epoch',决定记录依据
        """
        assert mode in ('step', 'epoch')
        self.mode = mode
        self.records = defaultdict(list)
        self._step = 0
        self._epoch = 0

    def step(self):
        """每训练一步后调用,更新计数器"""
        self._step += 1

    def epoch(self):
        """每轮结束后调用,更新计数器"""
        self._epoch += 1

    def record(self, metrics: dict):
        """
        记录指标(根据当前采样规则决定是否保留)
        :param metrics: 字典形式指标,例如 {'loss': 0.234, 'acc': 0.9}
        """
        key = self._step if self.mode == 'step' else self._epoch

        if not self._should_sample(key):
            return

        ts = time.time()
        for k, v in metrics.items():
            self.records[k].append((key, float(v), ts))

    def _should_sample(self, key: int) -> bool:
        """
        分段采样策略核心逻辑,根据 key(step 或 epoch)判断是否采样
        """
        if key < 1000:
            return True                # 小于1000,全记录
        elif key < 10000:
            return key % 10 == 0       # 每10步记录一次
        elif key < 100000:
            return key % 50 == 0       # 每50步记录一次
        elif key < 1000000:
            return key % 200 == 0      # 每200步记录一次
        else:
            return key % 1000 == 0     # 每1000步记录一次

    def get_all(self, as_dict=True):
        """
        返回所有记录数据
        :param as_dict: 是否以 dict 格式返回(否则返回 DataFrame)
        """
        if as_dict:
            return dict(self.records)
        else:
            import pandas as pd
            rows = []
            for metric, values in self.records.items():
                for key, value, ts in values:
                    rows.append({
                        'metric': metric,
                        self.mode: key,
                        'value': value,
                        'timestamp': ts
                    })
            return pd.DataFrame(rows)

✅ 使用方法(和你现在训练流程对接非常简单)

recorder = PiecewiseMetricsRecorder(mode='step')

for epoch in range(num_epochs):
    recorder.epoch()
    for step, (x, y) in enumerate(train_loader):
        recorder.step()

        loss = ...
        acc = ...
        lr = ...

        recorder.record({
            'loss': loss.item(),
            'acc': acc,
            'lr': lr
        })

# 获取记录数据
records = recorder.get_all()  # dict
# 或者 records_df = recorder.get_all(as_dict=False)

✅ 输出示例

假设你训练了 200000 步:

step loss acc
0 1.23 0.55
1 1.12 0.60
1000 0.90 0.75
1010 0.87 0.76
1050 0.85 0.78
1100
200000 0.02 0.999

✅ 优点总结

优点 描述
🔁 自动适应轮次数量级 无需预知训练总长度
📊 保留细节 & 控制规模 小训练保留全部,大训练精简
⚡ 高性能无依赖 无外部依赖,记录逻辑轻量快速
🧩 容易扩展 可添加保存、分析、图表功能

✅ 如你还有这些需求,我可继续扩展:

  • 保存为 JSON、CSV、Parquet 文件;
  • 添加分组记录(多模型/多阶段);
  • 实时绘图(matplotlib);
  • 记录某些指标变化率(如 loss delta);

是否需要我给你整合一个训练器封装类,把这个记录器直接挂进去?或者你是否想让它支持 validation loss测试指标 等统一入口?我可以帮你设计。

你说得非常对,这确实是个问题。

在训练初期,任何一个点都是目前“最优”的点,因此如果我们每次遇到“更好”的值就记录“最佳点”,那么一开始的很多轮次都会被当作“最佳点”记录一次。这不仅没有意义,还会污染数据和增加存储成本。

✅ 更合理的记录“最佳点”的逻辑:

我们应当只记录最终确定的最佳轮次及其指标值,也就是说,训练全部完成后,再将最佳记录插入采样记录中,并更新元信息。


🧠 修改方案概述:

  1. 训练中:继续采样记录采样点。
  2. 训练结束后
    • logger.best_point 获取最佳轮次及其值。
    • 将该轮次的数据插入到记录中的正确位置(按轮次排序)。
    • 同时标记这个轮次是“最佳轮次”。
    • 统一生成绘图数据,附带元信息,并将最佳轮次信息(轮次编号)保存到数据库。

✅ 改进的 TrainingLogger 方案(伪代码 + 注释):

class TrainingLogger:
    def __init__(self, title: str, max_steps: int, max_points: int = 500):
        self.title = title
        self.max_steps = max_steps
        self.max_points = max_points
        self.data = []  # [(step, value)]
        self.best_point = None  # (step, value)
        self._best_idx = None

    def log(self, step: int, value: float):
        if not self.best_point or value < self.best_point[1]:
            self.best_point = (step, value)

        # 采样点数策略:线性步进
        if self._should_record(step):
            self.data.append((step, value))

    def _should_record(self, step):
        interval = max(1, self.max_steps // self.max_points)
        return step % interval == 0 or step == self.max_steps

    def finalize(self):
        # 将 best_point 插入正确位置
        if self.best_point:
            best_step, best_value = self.best_point
            inserted = False
            for i, (s, _) in enumerate(self.data):
                if s > best_step:
                    self.data.insert(i, (best_step, best_value))
                    self._best_idx = i
                    inserted = True
                    break
            if not inserted:
                self.data.append((best_step, best_value))
                self._best_idx = len(self.data) - 1

    def export(self):
        # 导出用于绘图的结构
        steps, values = zip(*self.data)
        return {
            "title": self.title,
            "x": list(steps),
            "y": list(values),
            "best_step": self.best_point[0] if self.best_point else None,
            "best_value": self.best_point[1] if self.best_point else None,
            "best_index": self._best_idx,
        }

✅ 数据库存储建议:

你可以将 logger.export() 结构中的所有信息统一存储为一个 JSON 字段,例如:

{
  "title": "loss_curve",
  "x": [0, 10, 20, ..., 990],
  "y": [0.98, 0.89, ..., 0.12],
  "best_step": 430,
  "best_value": 0.1201,
  "best_index": 21
}

这个 JSON 字段可以直接映射为 Django/PostgreSQL 的 JSONField,后续取出来直接绘图,无需额外处理。


✅ 多条数据曲线 + 元信息示意:

如果你有多个指标,如 loss、accuracy、val_loss,可以多个 logger:

loggers = {
    'train_loss': TrainingLogger("Train Loss", max_steps=100000),
    'val_loss': TrainingLogger("Validation Loss", max_steps=100000),
    ...
}

统一导出时生成结构:

{
  "charts": {
    "train_loss": {
      "title": "Train Loss",
      "x": [...],
      "y": [...],
      "best_step": ...,
      ...
    },
    "val_loss": {
      "title": "Validation Loss",
      ...
    }
  }
}

如果你希望我帮你实现完整的可复用类 + 多指标管理器 + JSON 打包 + 数据库存储格式,我可以直接写出来。是否需要?

为满足以下需求:

  • 统一适配多种训练模型的记录器
  • 支持线性采样和记录采样(适应轮次范围)
  • 记录多个指标图线的数据
  • 记录图表元信息(如标题、轴标签等)
  • 保留最佳轮次信息,并将其插入对应顺序中
  • 最终将记录信息和最佳轮次存储到数据库中某个字段(如 JSON 字段)

下面是完整解决方案:


✅ 1. TrainingRecorder:通用训练过程记录器

import json
from collections import defaultdict
from bisect import bisect_left

class TrainingRecorder:
    def __init__(self, max_records=1000):
        """
        :param max_records: 每条曲线最多保留的数据点数,自动采样
        """
        self.max_records = max_records
        self.records = defaultdict(list)  # 每条线的数据 {metric_name: [(step, value), ...]}
        self.meta = {}  # 每条线的元信息 {metric_name: {'title': str, 'ylabel': str, ...}}
        self.best_point = None  # {'metric': 'loss', 'step': 42, 'value': 0.123}
        self.best_step = None

    def add_meta(self, metric_name, title=None, ylabel=None):
        """
        添加绘图元信息
        """
        self.meta[metric_name] = {
            'title': title or metric_name,
            'ylabel': ylabel or metric_name
        }

    def record(self, step, metrics: dict):
        """
        添加当前轮次的多个指标
        :param step: 当前轮次
        :param metrics: {'loss': 0.123, 'accuracy': 0.987}
        """
        for metric, value in metrics.items():
            points = self.records[metric]

            if len(points) < self.max_records:
                points.append((step, value))
            else:
                # 动态采样
                interval = max(1, step // self.max_records)
                if step % interval == 0:
                    points.append((step, value))

    def update_best(self, step, metric_name, value, mode='min'):
        """
        更新最佳点。如果是第一次或比已有的更优,则替换。
        """
        is_better = False
        if self.best_point is None:
            is_better = True
        elif mode == 'min' and value < self.best_point['value']:
            is_better = True
        elif mode == 'max' and value > self.best_point['value']:
            is_better = True

        if is_better:
            self.best_point = {'metric': metric_name, 'step': step, 'value': value}
            self.best_step = step
            # 插入到对应记录中
            self._insert_best_point(metric_name, step, value)

    def _insert_best_point(self, metric_name, step, value):
        """
        保证最佳点插入正确位置,不重复记录
        """
        points = self.records[metric_name]
        steps = [s for s, _ in points]
        if step in steps:
            return  # 已存在
        idx = bisect_left(steps, step)
        points.insert(idx, (step, value))

    def get_all_data(self):
        """
        返回所有记录(用于绘图或存储)
        """
        return {
            'records': dict(self.records),
            'meta': self.meta,
            'best': self.best_point
        }

    def to_json(self):
        """
        转换为 JSON 字符串,便于存入数据库字段
        """
        return json.dumps(self.get_all_data(), ensure_ascii=False)

✅ 2. 使用方式

recorder = TrainingRecorder(max_records=1000)

recorder.add_meta('loss', title='Loss Curve', ylabel='Loss')
recorder.add_meta('accuracy', title='Accuracy Curve', ylabel='Accuracy')

for epoch in range(1, 50001):
    # 假设你训练过程如下
    train_loss = ...
    val_acc = ...
    
    recorder.record(epoch, {
        'loss': train_loss,
        'accuracy': val_acc
    })

    recorder.update_best(epoch, 'loss', train_loss, mode='min')  # 记录最优损失点

✅ 3. 存入数据库(Django 示例)

from django.db import models

class TrainingLog(models.Model):
    model_name = models.CharField(max_length=128)
    record_json = models.JSONField()  # 存所有数据记录和元信息(包含最佳点)

# 使用记录器后保存:
record_data = recorder.to_json()

TrainingLog.objects.create(
    model_name='MyModelV1',
    record_json=json.loads(record_data)
)

✅ 支持的功能总结

功能 说明
自动采样记录 自动根据 max_records 限制数量
多条数据线记录(多指标) record(epoch, {'loss': ..., 'acc': ...})
图线元信息(标题等) add_meta(...)
最佳点插入原始顺序 _insert_best_point(...)
单字段 JSON 存储 to_json()
最佳轮次信息 best_point 包含 stepvaluemetric

如你有更高级的需求,例如:

  • 绘图直接导出
  • 多个阶段(如 warmup、finetune)分开记录
  • 记录的字段带单位或注释

都可以在现有结构上扩展,有需要我也可以帮你继续完善。

为了实现以下目标:

  1. 记录训练过程中的指标数据,适配各种训练轮次(100 ~ 10,000,000),带自动采样策略
  2. 记录每条数据线的绘图元信息(标题、x/y轴标签等)
  3. 记录“最佳点”,并将其插入到合适的位置
  4. 记录“最佳点”是哪一轮次,并将所有记录统一存入数据库的一个字段中(如 JSON)

我们设计如下方案:


✅ 一体化训练记录器 TrainingRecorder

import json
from typing import Any, Dict, List, Optional


class TrainingRecorder:
    def __init__(self, max_points: int = 1000):
        """
        :param max_points: 每条线最多记录的点数,超过将采样
        """
        self.max_points = max_points
        self.lines: Dict[str, Dict[str, Any]] = {}
        self.best_point: Optional[Dict[str, Any]] = None
        self.best_epoch: Optional[int] = None

    def add_line(self, key: str, title: str, xlabel: str = "Epoch", ylabel: str = ""):
        """添加一条曲线记录线(如 loss、accuracy)"""
        self.lines[key] = {
            "title": title,
            "xlabel": xlabel,
            "ylabel": ylabel,
            "points": [],
        }

    def record(self, epoch: int, metrics: Dict[str, float]):
        """记录当前轮次的多个指标"""
        for key, value in metrics.items():
            if key not in self.lines:
                raise ValueError(f"Metric '{key}' not found. Please use add_line first.")

            line = self.lines[key]
            points = line["points"]

            # 采样逻辑:只保留 max_points 个点
            if len(points) < self.max_points:
                points.append({"x": epoch, "y": value})
            else:
                # 采样替换逻辑(线性均匀替换)
                idx = int(epoch / self.total_epochs * self.max_points)
                if 0 <= idx < self.max_points:
                    points[idx] = {"x": epoch, "y": value}

    def update_best(self, epoch: int, metrics: Dict[str, float]):
        """
        更新最优点,会记录完整的点信息(epoch + metrics)
        """
        if self.best_point is None or self._is_better(metrics):
            self.best_point = {"epoch": epoch, **metrics}
            self.best_epoch = epoch

    def _is_better(self, metrics: Dict[str, float]) -> bool:
        """定义最优判断逻辑(可自定义)"""
        if self.best_point is None:
            return True
        # 默认按 'loss' 最小值判断
        return metrics.get("loss", float("inf")) < self.best_point.get("loss", float("inf"))

    def finalize(self, total_epochs: int):
        """
        在训练结束时调用以设置总轮次供采样策略使用
        """
        self.total_epochs = total_epochs
        # 插入 best_point 到对应位置
        if self.best_point and "loss" in self.lines:
            line = self.lines["loss"]
            points = line["points"]
            best_epoch = self.best_point["epoch"]
            if not any(p["x"] == best_epoch for p in points):
                points.append({"x": best_epoch, "y": self.best_point["loss"]})
                points.sort(key=lambda p: p["x"])

    def export_json(self) -> str:
        """
        导出为 JSON 格式,可直接存入数据库字段
        包含:
            - 各条曲线
            - 最佳点(带完整信息)
            - 最佳轮次
        """
        return json.dumps({
            "lines": self.lines,
            "best_epoch": self.best_epoch,
            "best_point": self.best_point
        }, ensure_ascii=False, indent=2)

✅ 使用示例

recorder = TrainingRecorder(max_points=1000)
recorder.add_line("loss", title="训练损失", ylabel="Loss")
recorder.add_line("accuracy", title="准确率", ylabel="Accuracy")

total_epochs = 1000000
for epoch in range(total_epochs):
    # 模拟指标
    loss = 1 / (epoch + 1)
    acc = epoch / total_epochs

    recorder.record(epoch, {"loss": loss, "accuracy": acc})
    recorder.update_best(epoch, {"loss": loss, "accuracy": acc})

# 最终整理,确保总轮次生效 + 插入最佳点
recorder.finalize(total_epochs=total_epochs)

# 导出结果(可存入数据库的单个字段)
json_data = recorder.export_json()

✅ 存入数据库字段(Django示例)

from django.db import models

class TrainingResult(models.Model):
    model_name = models.CharField(max_length=100)
    training_record = models.JSONField()  # 存储上面 export_json 的结果

✅ 优势

功能 说明
线性采样 自动控制记录点数,支持百万轮次训练
插入最佳点 可确保最佳轮次仍保留在曲线上
多条数据线支持 每个指标都有独立配置和元信息
数据库友好 全部信息封装为一个 JSON 字段
可直接用于前端绘图 输出结构支持 x/y 坐标、标题、标签

如果你还需要将其拆分为模块、支持持久化、增加前端图表格式化支持(如 chart.js / echarts),我也可以继续帮你扩展。是否需要?

好的,我们来梳理你提出的“最佳点记录逻辑”的需求,并解释当前实现背后的逻辑以及可能出现的“误判最佳点”的原因。你需要的是:


需求总结

  1. 训练过程采样记录(用于绘图):
    • 自动适应任意轮次数量(100 ~ 10000000),控制采样量。
    • 支持多条记录线(如 loss、val_loss、accuracy 等)。
    • 每条线带有 name, title, xlabel, ylabel 等元信息。
    • 所有记录都是有序的(按轮次递增)。
  2. 记录“最佳点”:
    • 仅记录一次最佳点。
    • 该最佳点数据也需要落入对应的位置(轮次顺序中)。
    • 最佳点应包含和普通点相同的全部信息。
    • 应能指定哪一项是“评价指标”(如 val_loss 最小、accuracy 最大)。
    • 最佳轮次信息要写入数据库字段(如 train_record_json)。

🧠 当前最佳点逻辑的解释

在目前的实现中,“最佳点”的识别逻辑可能是这样的(伪逻辑):

if best_value is None or current_value < best_value:
    best_value = current_value
    best_step = current_step
    best_data = current_data  # 记录该点

因此在训练早期,例如第 1、2、3 步,只要 current_valuebest_value 小,就会被认为是“最佳点”,即便它只是暂时好一点。


❗️潜在问题

  • 多次更新最佳点:每次发现更优的指标就更新,导致前期频繁记录“最佳点”。
  • 结果“漂移”或不稳定:如果早期数据波动大,可能误记录很多“次优点”。

更稳健的“最佳点记录逻辑”建议

你可以选择:

✅ 方法一:只记录最终最佳点

适合训练结束后记录,例如:

# 在训练完成后进行比较与记录
final_best_step = np.argmin(val_loss_list)  # or argmax for accuracy
final_best_data = full_record[final_best_step]

优点:

  • 避免前期波动影响。
  • 简洁、只写入一次。

✅ 方法二:实时记录但只保存一次

适合希望实时追踪最优模型的同时,仅记录一条最佳记录数据

# 初始化
best_value = float('inf')
best_record = None
best_step = None

# 每轮训练中
if metric_value < best_value:
    best_value = metric_value
    best_step = current_step
    best_record = current_metrics.copy()

注意:这个 best_record 最终只保存一次,加入完整记录列表时插入到对应轮次顺序中。


✅ 如何“插入最佳点到记录列表中正确位置”?

例如,训练记录列表如下:

records = [
    {"step": 1, "val_loss": 0.95},
    {"step": 2, "val_loss": 0.93},
    {"step": 5, "val_loss": 0.90},
]
best = {"step": 2, "val_loss": 0.93, "is_best": True}

你可以在保存前做如下插入:

inserted = False
for i, r in enumerate(records):
    if r["step"] > best["step"]:
        records.insert(i, best)
        inserted = True
        break

if not inserted:
    records.append(best)

✅ 示例记录结构(支持多条线、绘图信息)

{
  "title": "训练过程记录",
  "xlabel": "轮次",
  "ylabel": "指标值",
  "best_step": 42,
  "lines": [
    {
      "name": "loss",
      "label": "训练损失",
      "color": "#ff0000",
      "points": [[1, 0.9], [2, 0.85], ..., [42, 0.12]]
    },
    {
      "name": "val_loss",
      "label": "验证损失",
      "color": "#0000ff",
      "points": [[1, 1.0], [2, 0.95], ..., [42, 0.11]]
    }
  ]
}

你可以将这个 JSON 存入数据库中的一个 TextFieldJSONField 字段,字段名如:train_record_json

下面是一个支持多图(多个子图)+ 每图多线(多指标)的训练过程记录器完整示例。该记录器支持:

  • 自动采样(基于变化剧烈密集记录、变化平缓稀疏记录);
  • 多图:比如一个图展示 Loss,另一个展示 Accuracy;
  • 每图多线:比如 Loss 图中包含 train_loss 和 val_loss;
  • 记录最佳点,并将其插入到正确的位置;
  • 限制最大记录点数;
  • 提供标题、x/y 标签等绘图元信息;
  • 可将所有记录结构化保存为 JSON,用于存数据库字段等;
  • 最后绘图展示。

✅ 安装依赖(如果未安装)

pip install matplotlib numpy

🧠 示例代码(可直接运行)

import json
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Any


class TrainingRecorder:
    def __init__(self, max_points: int = 500, delta_threshold: float = 0.01):
        """
        :param max_points: 每条线最多保留多少个点
        :param delta_threshold: 值变化超过该比例才记录(用于自动采样)
        """
        self.plots: Dict[str, Dict[str, List[Dict[str, Any]]]] = {}  # 每个图表包含多条线的数据
        self.meta: Dict[str, Dict[str, str]] = {}  # 每个图的元信息
        self.best_points: Dict[str, Dict[str, Dict[str, Any]]] = {}  # 每图每线的最佳点记录
        self.max_points = max_points
        self.delta_threshold = delta_threshold

    def add_plot(self, plot_name: str, title: str, xlabel: str, ylabel: str):
        """注册一个图"""
        self.plots[plot_name] = {}
        self.meta[plot_name] = {'title': title, 'xlabel': xlabel, 'ylabel': ylabel}

    def add_series(self, plot_name: str, series_name: str):
        """注册一条线"""
        self.plots[plot_name][series_name] = []
        self.best_points.setdefault(plot_name, {})[series_name] = None

    def record(self, plot_name: str, series_name: str, x: float, y: float,
               is_best: bool = False, extra: Dict[str, Any] = None):
        """记录一条数据点,如果波动小则可能跳过"""
        series = self.plots[plot_name][series_name]
        point = {'x': x, 'y': y}
        if extra:
            point.update(extra)

        # 如果是最佳点则记录并插入正确位置
        best = self.best_points[plot_name][series_name]
        if is_best or best is None or y < best['y']:  # 例如 y 越小越好
            self.best_points[plot_name][series_name] = point
            # 插入排序
            series.append(point)
            series.sort(key=lambda d: d['x'])
            return

        # 自动采样逻辑
        if len(series) == 0 or abs(y - series[-1]['y']) / max(abs(series[-1]['y']), 1e-8) > self.delta_threshold:
            series.append(point)

        # 限制最大记录数
        if len(series) > self.max_points:
            idx = np.linspace(0, len(series) - 1, self.max_points, dtype=int)
            self.plots[plot_name][series_name] = [series[i] for i in idx]

    def get_all_data(self) -> Dict[str, Any]:
        """导出所有记录数据,包括元信息和最佳点,可存数据库"""
        return {
            'meta': self.meta,
            'data': self.plots,
            'best': self.best_points
        }

    def plot_all(self):
        """绘制所有图"""
        for plot_name, series_dict in self.plots.items():
            plt.figure()
            meta = self.meta.get(plot_name, {})
            plt.title(meta.get('title', plot_name))
            plt.xlabel(meta.get('xlabel', 'X'))
            plt.ylabel(meta.get('ylabel', 'Y'))

            for series_name, points in series_dict.items():
                x_vals = [p['x'] for p in points]
                y_vals = [p['y'] for p in points]
                plt.plot(x_vals, y_vals, label=series_name)
                # 标注最佳点
                best = self.best_points[plot_name].get(series_name)
                if best:
                    plt.scatter([best['x']], [best['y']], marker='x', color='red')
                    plt.text(best['x'], best['y'], f"Best {series_name}", fontsize=8)

            plt.legend()
            plt.grid(True)
            plt.show()


# ✅ 示例使用
recorder = TrainingRecorder(max_points=200, delta_threshold=0.01)

# 添加两个图,分别是 Loss 和 Accuracy
recorder.add_plot("loss_plot", "Loss Curve", "Epoch", "Loss")
recorder.add_plot("acc_plot", "Accuracy Curve", "Epoch", "Accuracy")

# 每个图添加两条线:训练和验证
for plot in ["loss_plot", "acc_plot"]:
    for line in ["train", "val"]:
        recorder.add_series(plot, line)

# 模拟训练记录过程
for epoch in range(1, 300):
    train_loss = np.exp(-epoch / 50) + np.random.rand() * 0.02
    val_loss = np.exp(-epoch / 45) + np.random.rand() * 0.02

    train_acc = 1 - train_loss + np.random.rand() * 0.01
    val_acc = 1 - val_loss + np.random.rand() * 0.01

    recorder.record("loss_plot", "train", epoch, train_loss)
    recorder.record("loss_plot", "val", epoch, val_loss, is_best=(epoch == 150))
    recorder.record("acc_plot", "train", epoch, train_acc)
    recorder.record("acc_plot", "val", epoch, val_acc, is_best=(epoch == 150))

# 导出记录数据,可用于存数据库字段(如 JSONField)
record_data = recorder.get_all_data()
record_json = json.dumps(record_data)  # 存储到数据库字段中

# 打印部分结果
print(f"记录 JSON 总长度: {len(record_json)}")
print("最佳点示例:", json.dumps(record_data['best'], indent=2))

# 绘图
recorder.plot_all()

📌 特性说明

功能 支持情况
自动采样
多图支持
每图多线
最佳点插入
限制最大点数
元信息支持
JSON 导出
支持后续绘图使用

如你还需要接入数据库模型(如 Django 的 JSONField),或者希望将记录器拆分为模块支持不同训练器复用,也可以继续拓展。是否还需要加入「训练中实时输出」或「Web 前端实时展示」等?我也可以继续扩展。


如果你想要我帮你完善代码,自动采样、记录、处理最佳点并生成这种结构,请告诉我你用的是哪种框架(PyTorch / Keras / 自定义训练逻辑),我可以提供整套封装。

下面是一个完整的、可扩展的 多图多线训练记录器 示例,它:

  • 支持多指标记录(如 loss、accuracy)
  • 每个指标可配置:
    • 图表信息(x/y轴名、标题)
    • 是否记录最佳点
    • 是否波动自适应采样
  • 提供:
    • 波动计算(波动剧烈密集记录,平稳时稀疏记录)
    • 最佳点自动插入合适位置
    • 最终可导出数据(用于存数据库)

✅ 一、模块结构

我们将实现三个核心类:

from typing import List, Dict, Any, Optional, Callable
import matplotlib.pyplot as plt


class MetricCurve:
    """
    表示单条图线:记录某一指标(如 loss、accuracy)的训练值。
    支持稀疏采样、最佳点自动插入、绘图。
    """

    def __init__(
        self,
        name: str,
        title: str,
        xlabel: str = "Epoch",
        ylabel: str = "",
        track_best: bool = False,
        better_fn: Optional[Callable[[float, float], bool]] = None,
        max_points: int = 1000,
        use_adaptive_sampling: bool = True,
        window_size: int = 5,
        min_step: int = 10,
    ):
        self.name = name
        self.title = title
        self.xlabel = xlabel
        self.ylabel = ylabel or name

        self.track_best = track_best
        self.better_fn = better_fn or (lambda a, b: a < b)  # 默认越小越好
        self.best_point = None  # (x, y)

        self.data = []  # [(x, y)]
        self.all_data = []  # 完整记录
        self.max_points = max_points
        self.use_adaptive_sampling = use_adaptive_sampling
        self.window_size = window_size
        self.min_step = min_step

    def _compute_volatility(self) -> float:
        """计算最近窗口内数据的波动幅度"""
        if len(self.all_data) < self.window_size + 1:
            return float('inf')
        diffs = [
            abs(self.all_data[-i][1] - self.all_data[-i - 1][1])
            for i in range(1, self.window_size + 1)
        ]
        return sum(diffs) / len(diffs)

    def add_point(self, x: int, y: float):
        """添加一条记录,自动判断是否保留"""
        self.all_data.append((x, y))

        # 检查是否是最佳点
        if self.track_best and (self.best_point is None or self.better_fn(y, self.best_point[1])):
            self.best_point = (x, y)

        # 自适应采样逻辑
        if len(self.data) == 0:
            self.data.append((x, y))
            return

        if self.use_adaptive_sampling:
            volatility = self._compute_volatility()
            step = max(self.min_step, int(self.max_points * (1 / (volatility + 1e-6))))
            if x % step == 0:
                self.data.append((x, y))
        else:
            if len(self.data) < self.max_points:
                self.data.append((x, y))

    def finalize(self):
        """将最佳点插入合适位置"""
        if self.track_best and self.best_point and self.best_point not in self.data:
            inserted = False
            for i in range(len(self.data)):
                if self.data[i][0] > self.best_point[0]:
                    self.data.insert(i, self.best_point)
                    inserted = True
                    break
            if not inserted:
                self.data.append(self.best_point)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "name": self.name,
            "title": self.title,
            "xlabel": self.xlabel,
            "ylabel": self.ylabel,
            "data": self.data,
            "best_point": self.best_point,
        }

    def plot(self, ax):
        xs, ys = zip(*self.data)
        ax.plot(xs, ys, label=self.name)
        if self.track_best and self.best_point:
            ax.scatter(*self.best_point, color="red", label=f"Best {self.name}")
        ax.set_title(self.title)
        ax.set_xlabel(self.xlabel)
        ax.set_ylabel(self.ylabel)
        ax.legend()
        ax.grid(True)


class TrainingRecorder:
    """
    管理多个指标图线,并统一导出绘图或数据库存储格式。
    """

    def __init__(self):
        self.curves: Dict[str, MetricCurve] = {}

    def add_curve(self, curve: MetricCurve):
        self.curves[curve.name] = curve

    def record(self, x: int, values: Dict[str, float]):
        for name, y in values.items():
            if name in self.curves:
                self.curves[name].add_point(x, y)

    def finalize(self):
        for curve in self.curves.values():
            curve.finalize()

    def export(self) -> Dict[str, Any]:
        """用于存入数据库的单字段记录格式"""
        return {
            "plots": {name: curve.to_dict() for name, curve in self.curves.items()}
        }

    def plot_all(self):
        n = len(self.curves)
        fig, axes = plt.subplots(nrows=n, ncols=1, figsize=(8, 4 * n))
        if n == 1:
            axes = [axes]
        for ax, curve in zip(axes, self.curves.values()):
            curve.plot(ax)
        plt.tight_layout()
        plt.show()

✅ 二、示例用法(可直接运行)

import random

recorder = TrainingRecorder()

# 添加两个指标的图线
recorder.add_curve(MetricCurve("loss", "训练损失", ylabel="Loss", track_best=True))
recorder.add_curve(MetricCurve("acc", "准确率", ylabel="Accuracy", track_best=True, better_fn=lambda a, b: a > b))

# 模拟训练
for epoch in range(1, 5001):
    loss = 1.0 / (epoch ** 0.5) + random.uniform(-0.02, 0.02)
    acc = 1 - loss + random.uniform(-0.01, 0.01)
    recorder.record(epoch, {"loss": loss, "acc": acc})

# 整理记录
recorder.finalize()

# 导出并打印
data_for_db = recorder.export()
print(data_for_db)

# 绘图展示
recorder.plot_all()

✅ 三、优势总结

  • ✔ 支持任意多图、多线、各自配置
  • ✔ 自动判断采样密度(波动计算)
  • ✔ 自动记录并插入最佳点
  • ✔ 可转存为数据库字段(结构清晰)
  • ✔ 可视化展示支持 matplotlib

如需进一步扩展,例如加入标签、批次记录等功能,也可以继续在 MetricCurve 内扩展字段(如 metadata),让每条线拥有更多上下文信息。需要我可以继续添加数据库模型示例或改成 JSON 存储方式。是否需要?

你说得对,我们来一步一步修复和优化这个系统,以满足:

  1. 支持多图多线记录
  2. 支持波动性判断实现密集/稀疏记录
  3. 可配置记录限制
  4. 记录最佳点并插入合适位置
  5. 最终以 JSON 数据结构(可直接保存数据库一个字段)
  6. x/y 轴信息、图名、曲线名、注释信息完整
  7. 并能直接绘图展示验证。

✅ 目标功能说明

你最终需要的是:

  • 多个图(例如:train_loss、val_accuracy)
  • 每个图有多条线(如:不同模型或指标)
  • 每条线记录 (step, value)
  • 支持 波动剧烈时密集记录平稳时稀疏记录
  • 自动记录 最佳点
  • 最后生成结构如下:
{
  "charts": {
    "train_loss": {
      "title": "Training Loss",
      "x_label": "step",
      "y_label": "loss",
      "lines": {
        "model_a": [[0, 1.0], [10, 0.85], ...],
        "model_b": ...
      },
      "best_point": {
        "model_a": {"step": 234, "value": 0.02}
      }
    },
    ...
  }
}

✅ 示例代码

下面是一个完整的、可直接运行的 Python 示例,包含模拟训练、动态记录、多图多线、最佳点记录、波动控制、绘图。

import random
import matplotlib.pyplot as plt
import json
from typing import Dict, List, Tuple

class LineDataRecorder:
    def __init__(self, max_points=500, fluctuation_threshold=0.02):
        self.points: List[Tuple[int, float]] = []
        self.best_point: Tuple[int, float] = None
        self.max_points = max_points
        self.fluctuation_threshold = fluctuation_threshold

    def should_record(self, step: int, value: float) -> bool:
        """根据波动性和数量限制决定是否记录"""
        if len(self.points) == 0:
            return True
        if len(self.points) >= self.max_points:
            return False

        last_step, last_value = self.points[-1]
        fluctuation = abs(value - last_value)
        return fluctuation >= self.fluctuation_threshold

    def record(self, step: int, value: float):
        if self.best_point is None or value < self.best_point[1]:
            self.best_point = (step, value)
            self.points.append((step, value))  # 直接插入,后面排序
        elif self.should_record(step, value):
            self.points.append((step, value))

        self.points = sorted(self.points, key=lambda x: x[0])

class ChartRecorder:
    def __init__(self, title: str, x_label: str, y_label: str):
        self.title = title
        self.x_label = x_label
        self.y_label = y_label
        self.lines: Dict[str, LineDataRecorder] = {}

    def record(self, line_name: str, step: int, value: float):
        if line_name not in self.lines:
            self.lines[line_name] = LineDataRecorder()
        self.lines[line_name].record(step, value)

    def to_dict(self):
        return {
            "title": self.title,
            "x_label": self.x_label,
            "y_label": self.y_label,
            "lines": {
                name: recorder.points for name, recorder in self.lines.items()
            },
            "best_point": {
                name: {"step": recorder.best_point[0], "value": recorder.best_point[1]}
                for name, recorder in self.lines.items() if recorder.best_point
            }
        }

class TrainingRecorder:
    def __init__(self):
        self.charts: Dict[str, ChartRecorder] = {}

    def record(self, chart_name: str, line_name: str, step: int, value: float,
               x_label="step", y_label="value", title=None):
        if chart_name not in self.charts:
            self.charts[chart_name] = ChartRecorder(
                title=title or chart_name,
                x_label=x_label,
                y_label=y_label
            )
        self.charts[chart_name].record(line_name, step, value)

    def to_json(self):
        return json.dumps({
            "charts": {name: chart.to_dict() for name, chart in self.charts.items()}
        }, indent=2)

    def plot(self):
        for chart_name, chart in self.charts.items():
            plt.figure(figsize=(10, 5))
            for line_name, recorder in chart.lines.items():
                steps, values = zip(*recorder.points) if recorder.points else ([], [])
                plt.plot(steps, values, label=line_name)
                # 绘制最佳点
                if recorder.best_point:
                    step, val = recorder.best_point
                    plt.scatter(step, val, color='red')
                    plt.annotate(f"Best {val:.4f}", (step, val), textcoords="offset points", xytext=(0,10), ha='center')
            plt.title(chart.title)
            plt.xlabel(chart.x_label)
            plt.ylabel(chart.y_label)
            plt.legend()
            plt.grid(True)
            plt.show()

# ===============================
# 模拟训练过程
# ===============================
recorder = TrainingRecorder()

for step in range(1000):
    # 模拟两个模型在两个图的表现(train_loss 和 val_accuracy)
    val1 = 1.0 / (1 + step * 0.01) + random.uniform(-0.01, 0.01)
    val2 = 0.5 / (1 + step * 0.005) + random.uniform(-0.02, 0.02)
    acc1 = 0.5 + 0.4 * step / 1000 + random.uniform(-0.02, 0.02)
    acc2 = 0.4 + 0.5 * step / 1000 + random.uniform(-0.03, 0.01)

    recorder.record("train_loss", "model_1", step, val1, y_label="loss", title="Training Loss")
    recorder.record("train_loss", "model_2", step, val2, y_label="loss")

    recorder.record("val_accuracy", "model_1", step, acc1, y_label="accuracy", title="Validation Accuracy")
    recorder.record("val_accuracy", "model_2", step, acc2, y_label="accuracy")

# 输出 JSON(可存数据库)
json_data = recorder.to_json()
print(json_data[:800] + "\n...")

# 可视化
recorder.plot()

✅ 说明总结

功能 实现说明
多图 charts["train_loss"], charts["val_accuracy"]
多线 每个 chart 的 lines["model_1"], lines["model_2"]
最佳点 best_point 自动记录并标红显示
波动记录 should_record() 控制波动记录
插入排序 所有记录按 step 排序
存数据库 to_json() 得到一个完整字段
限制点数 max_points 限制每条曲线的点数
可扩展性 所有模块独立可替换

如果你还需要:

  • 自定义最优指标(最大化/最小化、不同指标);
  • 支持保存到数据库模型;
  • 最佳点标签记录备注(如“验证损失最低”);
  • 添加绘图主题风格;

我可以继续为你扩展。

是否现在这版已经符合你当前需求?需要我帮你接入 Django 模型或数据库字段结构吗?

当前这个示例中的“波动计算”逻辑核心是通过比较前后值的变化(也就是梯度或差值)来决定当前轮次是否值得记录:

fluctuation = abs(current_value - prev_value)
if fluctuation > threshold:
    should_record = True

当前波动计算逻辑说明:

  1. 计算的是单点变化幅度abs(current_value - prev_value));
  2. 如果变化大于设定阈值(threshold),就说明当前数据“有波动”,于是记录;
  3. 否则跳过记录;
  4. 初始点和最佳点也总会被记录。

存在的问题和可优化点:

✅ 问题 1:仅靠相邻点差值不够鲁棒

  • 单个跳跃可能是噪声,可能会导致“误判”为波动。
  • 某些指标可能变化范围不大,但趋势重要,不能被省略。

✅ 优化方案 1:使用滑动窗口标准差滑动平均差

用过去若干轮的标准差或均值差来判断当前波动更合理:

import numpy as np
from collections import deque

class FluctuationTracker:
    def __init__(self, window_size=5, threshold=0.01):
        self.window = deque(maxlen=window_size)
        self.threshold = threshold

    def update(self, value):
        self.window.append(value)

    def should_record(self):
        if len(self.window) < 2:
            return True  # 初始化阶段记录
        std_dev = np.std(self.window)
        return std_dev > self.threshold

使用方式:

tracker = FluctuationTracker(window_size=10, threshold=0.05)

for step in range(total_steps):
    value = compute_metric(step)
    tracker.update(value)
    if tracker.should_record():
        recorder.record(step, value)

✅ 优化方案 2:考虑“波动斜率”和“拟合误差”

使用一段数据的拟合曲线(线性或样条)来评估趋势是否稳定,类似折线的弯曲程度。这个可以更精准地定义“是否稳定”,但复杂度更高。


✅ 优化方案 3:增加最大记录数限制

同时加入最大记录量的限制,以防大轮次模型下数据爆炸:

if len(recorded_data) >= max_record_limit:
    continue

✅ 总结推荐波动逻辑优化点:

优化项 说明 是否推荐
相邻点差值 简单有效,容易实现 ✅ 初步使用
滑动窗口标准差 更稳定识别波动趋势 ✅ 推荐
滑动平均/线性回归残差 可平滑趋势判断,防止噪声误判 ✅ 如需精细控制
全局最大点/最小点记录 保留代表点 ✅ 推荐
数据点数限制 防止记录量爆炸 ✅ 推荐

如果你希望,我可以整理成完整的波动感知记录系统示例,支持:

  • 多图、多线;
  • 指标元信息配置;
  • 最佳点记录;
  • 波动感知稀疏采样;
  • 最大记录限制;
  • 最终打包为字典结构可直接存入数据库。

是否需要我现在为你整理出这个完整示例?

在这个改进后的版本中,波动计算与记录控制更加合理,以下是详细说明:


✅ 当前波动记录机制说明:

我们基于 相对变化率(波动率) 来控制是否记录一个训练点。具体逻辑如下:

volatility = abs(current - prev) / (abs(prev) + epsilon)

其中:

  • current 是当前值;
  • prev 是前一个已记录的点的值;
  • epsilon 是一个很小的数,防止除零;
  • volatility >= volatility_threshold 时,判定为波动较大,需要记录当前点
  • 否则,跳过这次记录(除非是最佳点或指定采样点间隔)。

✅ 优化点建议:

你可以从以下几个方面进一步改进或拓展:

1. 波动窗口平滑

使用滑动窗口的标准差来更稳定地判断波动:

import numpy as np
window = [y1, y2, ..., yn]  # 最新的 n 个 y 值
volatility = np.std(window) / (np.mean(window) + epsilon)

适用于波动非单点突变的场景,能更准确地反映整体趋势。


2. 自适应阈值

动态调整波动阈值,比如随着 epoch 增长逐渐调高:

volatility_threshold = base_threshold * (1 + epoch / max_epoch)

初期记录密集,后期渐疏,适合模型训练初期变化剧烈、后期趋稳的规律。


3. 最小采样间隔控制

避免频繁记录带来性能开销:

if step - last_recorded_step < min_interval:
    continue

4. 多图多线指标管理

当前代码支持同时记录多个图(图名)与多条线(y 名称)。你可以为每张图分别设置:

  • x 轴字段(如 epoch、step);
  • y 轴多个字段(loss、accuracy 等);
  • 图的元信息(title、颜色、单位)等。

5. 数据库记录整合

如果需要将所有记录数据统一序列化存入数据库的一个字段,建议格式:

{
  "best_step": 129,
  "records": {
    "loss_plot": {
      "meta": {"title": "Loss Curve", "xlabel": "Epoch", "ylabel": "Loss"},
      "data": {
        "train": [[x1, y1], [x2, y2], ...],
        "val": [[x1, y1], [x2, y2], ...]
      }
    },
    "acc_plot": {
      "meta": {"title": "Accuracy Curve", "xlabel": "Epoch", "ylabel": "Accuracy"},
      "data": {
        "train": [[x1, y1], ...],
        "val": [[x1, y1], ...]
      }
    }
  }
}

可使用 json.dumps(record_object) 存入 JSONField 类型字段。


是否希望我为你加入“滑动窗口波动计算”和“自适应采样控制”的完整版本代码?