超参数pydantic

Posted by Shallow Dreameron April 29, 2025
参数 类型 示例值 说明
model_type str ‘resnet18’, ‘bert-base’ 模型结构名称
epochs int 10 训练轮数
batch_size int 32 每批样本数
learning_rate float 0.001 学习率
optimizer str ‘adam’ / ‘sgd’ 优化器选择
loss_function str ‘cross_entropy’ 损失函数名称
scheduler str ‘StepLR’, ‘Cosine’ 学习率调度器
dataset_version_id int 3 使用的数据集版本
train_val_split float 0.8 训练验证集拆分比例
input_columns List[str] [‘image_path’] 输入列名(从表头中选择)
label_column str ‘label’ 标签列名

要实现一个能够自动适配超参数 JSON 字段的 pydantic 模型,而不依赖于对模型名称的手动判断,我们可以利用 Python 的动态特性,通过 pydanticConfigextra = 'allow' 以及递归结构来实现。

基本的想法是:

  1. 让 pydantic 模型自动适配未知的超参数:使用递归模型验证任意层次嵌套的 JSON 字段。
  2. 动态生成模型:通过解析 JSON 字段的结构,动态生成合适的验证规则。

方案:动态适配 JSON 超参数字段

可以通过 pydanticBaseModel 结合 DictList 等类型来适配动态和嵌套的 JSON 结构,而无需显式判断模型类型。

实现步骤

  1. 设计一个递归的 pydantic 模型,它可以自动适应嵌套的 JSON 结构。
  2. 使用 extra = 'allow' 配置,使得模型能够处理未定义的字段。
  3. 通过 json.loads 动态加载和验证 JSON 字段,使得模型可以直接适应不同的参数结构。

示例代码

1. 定义一个通用的递归超参数验证模型

我们将创建一个通用的模型,能够递归地验证嵌套的超参数。

from pydantic import BaseModel, Field
from typing import Dict, Any, Optional, Union

# 定义一个支持嵌套的通用超参数模型
class Hyperparameters(BaseModel):
    # 支持任意的嵌套字段,可以是字符串、数字、布尔值或字典
    param: Optional[Union[str, int, float, bool, Dict[str, Any], list]] = None

    class Config:
        # 允许额外字段
        extra = 'allow'

# 为了支持递归验证,我们定义一个模型来递归解析嵌套的超参数
class RecursiveHyperparameters(BaseModel):
    # 使用递归的方式允许无限嵌套的结构
    param: Optional[Union[str, int, float, bool, Dict[str, 'RecursiveHyperparameters'], list]] = None

    class Config:
        extra = 'allow'

# 这类模型可以自动适配任何 JSON 结构,无论它是扁平的还是嵌套的。

2. 在 Django 模型中存储 JSON 字段并动态验证

假设您的 ModelConfig 模型中的 hyperparameters 字段是一个 JSONField,我们将通过 pydanticRecursiveHyperparameters 类来自动验证该字段。

from django.db import models
from pydantic import ValidationError
import json

class ModelConfig(models.Model):
    name = models.CharField(max_length=100)
    hyperparameters = models.JSONField()

    def validate_hyperparameters(self):
        """动态验证超参数,适应任何结构"""
        try:
            # 自动选择递归模型进行验证
            hyperparameters_model = RecursiveHyperparameters(**self.hyperparameters)
            return hyperparameters_model
        except ValidationError as e:
            print(f"Validation error: {e.json()}")
            return None

3. 使用示例

假设您从数据库获取了一个 ModelConfig 实例,并希望验证 hyperparameters 字段:

# 假设从数据库中获取了一个实例
instance = ModelConfig.objects.get(id=1)

# 获取验证后的超参数
validated_hyperparameters = instance.validate_hyperparameters()

if validated_hyperparameters:
    print("Hyperparameters are valid:", validated_hyperparameters)
else:
    print("Invalid hyperparameters!")

解释

  • 递归模型RecursiveHyperparameters 使用了 Union 和嵌套的 Dict 类型来允许超参数是一个字典,可以包含各种类型的值(字符串、整数、布尔值、嵌套字典或列表)。
  • 自动适配 JSON 结构:这个模型可以适应任何动态或嵌套的 JSON 结构,而无需手动判断每种模型类型。
  • extra = 'allow':这一配置允许 pydantic 接受 JSON 字段中的任何未在模型中明确声明的字段。
  • 嵌套参数:通过递归结构,您可以验证任何层次的嵌套 JSON 数据。
  • Django JSONField 兼容:Django 的 JSONField 存储 JSON 数据,validate_hyperparameters 方法会验证并解析该字段。

优势

  • 灵活性:该模型适应性强,可以处理任意结构的 JSON 数据,不需要事先判断模型类型。
  • 自动验证:通过 pydantic 的自动验证,您不需要手动编写大量的字段检查逻辑,能够自动根据 JSON 结构进行验证。
  • 递归支持:可以自动处理嵌套的超参数配置,无论嵌套多深,都可以递归地验证。

通过这种方式,您可以将超参数 JSON 字段的验证过程完全自动化,适应各种复杂的嵌套结构,并避免手动判断和代码重复。

理解了您的问题!您是希望在使用 pydantic 自动适配学习模型超参数时,能够验证这些超参数的值是否符合预期的要求或数据类型,而不仅仅是检查字段是否存在。

实现目标

  1. 自动适配:根据学习模型的超参数 JSON 字段动态生成模型,并进行验证。
  2. 类型验证:确保每个超参数的值符合预期的数据类型。
  3. 自定义约束:允许为某些超参数添加自定义的验证规则,比如数值范围、字符串长度、正负限制等。

方案

我们可以通过在 pydantic 模型中使用字段类型(如 int, float, str, bool, Dict, List 等)和自定义 @validator 来实现自动的值验证。

更新后的方案:自动验证超参数值的类型和范围

我们依然使用递归的 pydantic 模型来自动适配不同的超参数结构,并且通过类型和验证器来确保数据符合要求。

示例:自动匹配并验证学习模型超参数

1. 递归的 pydantic 模型

我们将为不同类型的参数添加自定义验证规则,同时保持自动适配嵌套的超参数结构。

from pydantic import BaseModel, validator, Field
from typing import Dict, Any, Union, Optional

# 自定义超参数验证模型
class Hyperparameters(BaseModel):
    learning_rate: float  # float 类型,学习率
    batch_size: int       # int 类型,批量大小
    epochs: int           # int 类型,训练轮数
    optimizer: str        # str 类型,优化器名称
    momentum: Optional[float] = None  # 可选的 momentum 参数,float 类型
    nested_param: Optional[Dict[str, Any]] = None  # 支持动态嵌套

    # 额外的验证器:检查学习率是否在合理范围内
    @validator('learning_rate')
    def check_learning_rate(cls, value):
        if value <= 0 or value > 1:
            raise ValueError('learning_rate should be between 0 and 1')
        return value

    # 检查批量大小是否为正整数
    @validator('batch_size')
    def check_batch_size(cls, value):
        if value <= 0:
            raise ValueError('batch_size must be a positive integer')
        return value

    # 检查轮数是否为正整数
    @validator('epochs')
    def check_epochs(cls, value):
        if value <= 0:
            raise ValueError('epochs must be a positive integer')
        return value

    # momentum 必须是正数
    @validator('momentum')
    def check_momentum(cls, value):
        if value is not None and (value < 0 or value > 1):
            raise ValueError('momentum should be between 0 and 1')
        return value

    class Config:
        extra = 'allow'  # 允许额外的字段(不在模型中定义的字段)

2. 递归的超参数结构(处理嵌套)

对于有嵌套超参数的情况,我们使用递归的模型:

class RecursiveHyperparameters(BaseModel):
    param: Optional[Union[str, int, float, bool, Dict[str, 'RecursiveHyperparameters'], list]] = None

    class Config:
        extra = 'allow'

3. 验证超参数字段

使用 pydantic 验证输入数据时,您可以自动进行类型验证,并且在必要时通过验证器进行额外检查。

from pydantic import ValidationError

# 示例输入数据
input_data = {
    'learning_rate': 0.01,   # valid
    'batch_size': 32,         # valid
    'epochs': 10,             # valid
    'optimizer': 'adam',      # valid
    'momentum': 0.9,          # valid
    'extra_param': 'value',   # 额外字段,pydantic 会接受
}

try:
    # 创建并验证模型
    hyperparameters = Hyperparameters(**input_data)
    print("Hyperparameters are valid:", hyperparameters)
except ValidationError as e:
    print("Validation error:", e.json())

4. 示例输入错误

# 错误的输入数据
input_data = {
    'learning_rate': 1.5,    # invalid: 超过范围
    'batch_size': -32,        # invalid: 负数
    'epochs': 0,              # invalid: 0 轮数
    'optimizer': 'adam',
}

try:
    # 创建并验证模型
    hyperparameters = Hyperparameters(**input_data)
    print("Hyperparameters are valid:", hyperparameters)
except ValidationError as e:
    print("Validation error:", e.json())

5. 输出示例

合法输入数据的输出

Hyperparameters are valid: learning_rate=0.01 batch_size=32 epochs=10 optimizer='adam' momentum=0.9 nested_param=None

非法输入数据的输出

Validation error: [
  {
    'loc': ['learning_rate'],
    'msg': 'learning_rate should be between 0 and 1',
    'type': 'value_error'
  },
  {
    'loc': ['batch_size'],
    'msg': 'batch_size must be a positive integer',
    'type': 'value_error'
  },
  {
    'loc': ['epochs'],
    'msg': 'epochs must be a positive integer',
    'type': 'value_error'
  }
]

解释

  1. 类型验证
    • learning_rate: floatbatch_size: intepochs: int 等字段类型会自动检查字段值的类型是否符合要求。
    • 如果字段值不符合类型,pydantic 会抛出 ValidationError
  2. 自定义验证器
    • @validator 用于确保 learning_rate 在 0 到 1 之间,batch_sizeepochs 为正整数。
    • momentum 是可选的,且值必须在 0 到 1 之间,如果提供了该字段。
  3. 嵌套参数
    • recursiveparam 字段可以适应任意嵌套的结构。如果您的超参数中有复杂的嵌套结构,它会被自动验证。
  4. 额外字段
    • extra = 'allow' 允许输入数据中存在模型未定义的字段(如 extra_param),这不会导致验证失败。

总结

  • 自动类型匹配pydantic 会根据字段类型(如 float, int, str 等)自动验证字段的类型。
  • 自定义验证:您可以通过 @validator 来为特定字段添加自定义验证逻辑,确保值符合业务规则(如数值范围、正负值限制等)。
  • 递归适应嵌套结构:使用递归模型来适配任意嵌套的超参数结构。
  • 灵活性:通过 extra = 'allow',允许额外的字段出现在输入数据中。

这样,您可以实现一个通用的超参数验证机制,不仅能够适配不同模型的超参数,还能确保每个超参数的值符合预期的类型和规则。

https://gitee.com/yanglinxin0611/books/blob/master/PyTorch%E5%AE%98%E6%96%B9%E6%95%99%E7%A8%8B%E4%B8%AD%E6%96%87%E7%89%88.pdf

https://blog.csdn.net/qq_19334535/article/details/120815623

https://blog.csdn.net/xuguoliang757/article/details/118269944