代码功能概述
- 导入相关包与设置环境变量:
- 首先导入了如
os
、numpy
、pandas
等常用的 Python 库,同时设置了一些与特定库(如XLA_PYTHON_CLIENT_PREALLOCATE
和JAX_PM AP_USE_TENSORSTORE
)相关的环境变量,用于优化计算等操作。
- 首先导入了如
- 加载预训练的 TimesFM 模型:
- 通过指定相关超参数(如后端为
gpu
、每核心批处理大小等)以及预训练模型在 Hugging Face 上的仓库id
,实例化了TimesFm
模型对象,用于后续的评估和微调等操作。
- 通过指定相关超参数(如后端为
- 准备数据集相关信息并加载数据:
- 定义了一个数据集字典
DATA_DICT
,包含多个数据集(如ettm1
等)的详细信息,包括数据文件路径、时间频率、划分边界等。 - 根据选定的数据集(示例中初始化为
ettm1
),读取对应的数据文件为DataFrame
,然后配置TimeSeriesdata
类的实例来进行数据加载、划分训练集、验证集和测试集,同时对数据进行了一些规范化等预处理操作,并生成对应的批次数据(train_batches
、val_batches
、test_batches
)。
- 定义了一个数据集字典
- 评估预训练模型在测试集上的 MAE(平均绝对误差):
- 通过迭代测试集批次数据,利用预训练模型进行预测,计算预测值和实际值之间的平均绝对误差,以此来评估模型在当前数据集上的性能表现。
- 微调模型:
- 导入了一系列用于构建和训练模型的
praxis
、paxml
相关的模块和类,进行了诸如定义学习器(包括优化器、学习率调度等配置)、构建任务、初始化模型状态等操作。 - 将预训练模型的参数设置为微调模型的初始权重,然后通过定义训练步和评估步函数,在多个
epoch
内循环进行训练和定期评估(利用早停机制,根据验证集损失决定是否提前停止训练),在每个训练步中对模型参数进行更新,每个评估步计算验证集上的损失,保存最优模型状态的检查点。
- 导入了一系列用于构建和训练模型的
- 加载并评估微调后的模型:
- 从保存的检查点中恢复最优的模型状态,将其参数更新到原
TimesFM
模型中,然后再次在测试集上计算平均绝对误差,以对比微调前后模型性能的变化情况。
- 从保存的检查点中恢复最优的模型状态,将其参数更新到原
针对车辆销售数据的改写步骤
-
数据准备与加载部分(适配车辆销售数据):
- 修改数据集字典
DATA_DICT
:- 创建一个新的字典项来对应你的车辆销售数据集,例如取名为
vehicle_sales
。 - 填写对应的数据文件路径(假设你的车辆销售数据存储在
../datasets/vehicle_sales.csv
,则data_path
设置为此路径)。 - 根据你的数据时间粒度来设置
freq
,比如如果是按天记录的,就可以设置为"D"
(代表 Daily),如果是按月记录的,可设置为"M"
(代表 Monthly)等。 - 按照你的训练、验证、测试集划分的时间范围来设置
boundaries
列表中的值,例如如果前 3 年数据作为训练集,第 4 年作为验证集,第 5 年作为测试集,你需要根据数据点数量等信息确定对应的时间点边界值填入该列表。
- 创建一个新的字典项来对应你的车辆销售数据集,例如取名为
- 调整数据读取和数据加载配置部分:
- 在
data_df = pd.read_csv(open(data_path, "r"))
这行代码中,确认数据文件格式正确能被read_csv
方法读取,如果数据有特定的分隔符、编码等情况,按需调整参数(比如添加sep
参数指定分隔符、encoding
参数指定编码格式等)。 - 根据你的车辆销售数据列名,修改
ts_cols
、num_cov_cols
、cat_cov_cols
的定义。例如,销售量和销售价格等数值型的时间序列列可添加到ts_cols
,车型、经销商这些分类列可以根据需求分配到num_cov_cols
(如果要进行数值编码等处理)或者cat_cov_cols
(作为分类特征)中。同时修改TimeSeriesdata
实例化时传入的参数,确保数据能正确划分和预处理,例如datetime_col
设置为数据中代表日期的列名。
- 在
- 修改数据集字典
-
模型微调部分(可能无需大改,但检查配置合理性):
- 确认微调时定义的学习器配置(如优化器、学习率调度等参数)是否适合车辆销售数据预测任务。你可能需要根据实际情况调整学习率、训练总步数等参数,例如车辆销售数据如果比较复杂,可能需要适当调小学习率、增加训练总步数等,以保证模型能更好地收敛和学习到数据中的模式。
- 检查
build_learner
函数中设置的bprop_variable_exclusion
参数是否合理,对于车辆销售数据微调场景下想要固定或者放开训练的模型层,根据模型结构和需求进行调整,确保只训练希望更新参数的那些部分。
-
模型评估部分(保持逻辑基本一致):
- 在计算微调前后模型在测试集上的平均绝对误差(MAE)部分,确保数据维度等处理符合车辆销售数据的特点。例如,在预测结果和实际结果对比计算
MAE
时,确认预测的销售量、销售价格等和实际值的对应关系和维度对齐正确,特别是如果有多个时间序列维度或者特征维度时,保证forecasts
和actuals
的形状匹配能正确计算误差。
- 在计算微调前后模型在测试集上的平均绝对误差(MAE)部分,确保数据维度等处理符合车辆销售数据的特点。例如,在预测结果和实际结果对比计算
以下是一个python代码(假设你的车辆销售数据 vehicle_sales.csv
有 date
(日期)、car_model
(车型)、dealer
(经销商)、sales_volume
(销售量)、sales_price
(销售价格)这几列,并且想将车型和经销商作为分类特征,销售量和销售价格作为时间序列特征,数据按年划分训练、验证、测试集,这里简化假设前 3 年训练、第 4 年验证、第 5 年测试,并且时间频率是按年 "Y"
):
# 以下代码主要用于基于TimesFM模型对车辆销售数据进行预训练模型评估、模型微调以及微调后模型的评估操作
# 导入相关包用于微调操作,同时设置一些环境变量来优化计算等相关配置
## Importing relevant packages for finetuning
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['JAX_PMAP_USE_TENSORSTORE'] = 'false'
import timesfm
import gc
import numpy as np
import pandas as pd
from timesfm import patched_decoder
from timesfm import data_loader
from tqdm import tqdm
import dataclasses
import IPython
import IPython.display
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.figsize'] = (8, 6)
mpl.rcParams['axes.grid'] = False# 加载预训练的TimesFM模型,通过指定相关超参数(如后端使用的设备、每核心批处理大小、预测长度等)以及从Hugging Face获取预训练模型的仓库id
# 实例化TimesFm模型对象,后续将利用该模型进行数据评估和微调等操作
## Loading TimesFM pretrained checkpoint
tfm = timesfm.TimesFm(hparams=timesfm.TimesFmHparams(backend="gpu",per_core_batch_size=32,horizon_len=128,),checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id="google/timesfm-1.0-200m"),)# 配置车辆销售数据集相关信息,包括数据集划分边界、数据文件路径、时间频率等,用于后续的数据加载和处理
# 此处简化假设按年划分数据集,前3年作为训练集,第4年作为验证集,第5年作为测试集,时间频率设置为按年("Y")
# 根据实际数据情况和需求,这些设置都可以进行相应调整
## Evaluating pretrained checkpoint on vehicle sales dataset
DATA_DICT = {"vehicle_sales": {"boundaries": [3, 4, 5], # 简化按年划分,前3年训练,第4年验证,第5年测试"data_path": "../datasets/vehicle_sales.csv","freq": "Y", # 按年的时间频率}
}
dataset = "vehicle_sales"
data_path = DATA_DICT[dataset]["data_path"]
freq = DATA_DICT[dataset]["freq"]
int_freq = timesfm.freq_map(freq)
boundaries = DATA_DICT[dataset]["boundaries"]# 读取车辆销售数据文件为DataFrame格式,后续将基于此数据进行进一步处理,例如划分数据集、提取特征列等操作
# 需要确保数据文件路径正确以及数据格式能被read_csv方法正常读取,如有特殊格式可按需调整参数(如分隔符、编码等)
data_df = pd.read_csv(open(data_path, "r"))# 定义时间序列特征列,这里选取销售量和销售价格作为时间序列特征,将用于模型的输入和预测等相关操作
# 根据实际业务需求和数据特点,可调整此列表包含的列名
ts_cols = ["sales_volume", "sales_price"]
# 暂未定义数值型协变量列,可根据后续是否需要添加额外数值型特征进行设置
num_cov_cols = None
# 定义分类特征列,这里选取车型和经销商作为分类特征,模型可以根据这些特征学习不同分类下的销售模式等信息
cat_cov_cols = ["car_model", "dealer"] context_len = 512
pred_len = 96num_ts = len(ts_cols)
batch_size = 16# 实例化TimeSeriesdata类,用于加载、划分和预处理车辆销售数据,配置训练集、验证集、测试集的范围,以及设置数据归一化等参数
# 该类内部会根据设置对数据进行相应处理,生成对应的批次数据,便于后续模型训练和评估使用
dtl = data_loader.TimeSeriesdata(data_path=data_path,datetime_col="date",num_cov_cols=num_cov_cols,cat_cov_cols=cat_cov_cols,ts_cols=np.array(ts_cols),train_range=[0, boundaries[0]],val_range=[boundaries[0], boundaries[1]],test_range=[boundaries[1], boundaries[2]],hist_len=context_len,pred_len=pred_len,batch_size=num_ts,freq=freq,normalize=True,epoch_len=None,holiday=False,permute=True,)
# 获取训练集批次数据,每个批次的数据将按照设置的batch_size进行划分,便于在训练循环中迭代使用
train_batches = dtl.tf_dataset(mode="train", shift=1).batch(batch_size)
# 获取验证集批次数据,同样按照设置进行划分,用于在模型训练过程中的定期验证,以监控模型性能和防止过拟合等
val_batches = dtl.tf_dataset(mode="val", shift=pred_len)
# 获取测试集批次数据,用于最终评估模型在未见过的数据上的性能表现
test_batches = dtl.tf_dataset(mode="test", shift=pred_len)
# 简单遍历训练集批次数据的迭代器,此处主要是为了触发数据加载等相关操作,确保数据可以正常获取,暂未对数据做具体处理
for tbatch in tqdm(train_batches.as_numpy_iterator()):pass
# 打印训练集批次数据中第一个元素(通常是输入数据部分)的形状,用于检查数据维度是否符合预期
print(tbatch[0].shape)# 以下代码块用于计算预训练模型在测试集上的平均绝对误差(MAE),通过迭代测试集批次数据
# 利用预训练模型进行预测,然后对比预测值和实际值计算平均绝对误差,以此评估模型初始性能
### MAE on the test split for the pretrained TimesFM model
mae_losses = []
for batch in tqdm(test_batches.as_numpy_iterator()):past = batch[0]actuals = batch[3]forecasts, _ = tfm.forecast(list(past), [0] * past.shape[0], normalize=True)forecasts = forecasts[:, 0 : actuals.shape[1]]mae_losses.append(np.abs(forecasts - actuals).mean())print(f"MAE: {np.mean(mae_losses)}")# 导入一系列用于构建和训练模型的praxis、paxml相关的模块和类,这些模块提供了配置模型、定义学习器、优化训练过程等功能
# 后续将利用这些工具来对模型进行微调操作,使其能更好地适应车辆销售数据特点和预测任务
## Finetuning the model on the vehicle sales dataset
import jax
from jax import numpy as jnp
from praxis import pax_fiddle
from praxis import py_utils
from praxis import pytypes
from praxis import base_model
from praxis import optimizers
from praxis import schedules
from praxis import base_hyperparams
from praxis import base_layer
from paxml import tasks_lib
from paxml import trainer_lib
from paxml import checkpoints
from paxml import learners
from paxml import partitioning
from paxml import checkpoint_types
# PAX shortcuts,定义一些便捷使用的类型和函数别名,方便后续代码中调用相关功能时书写简洁
NestedMap = py_utils.NestedMap
WeightInit = base_layer.WeightInit
WeightHParams = base_layer.WeightHParams
InstantiableParams = py_utils.InstantiableParams
JTensor = pytypes.JTensor
NpTensor = pytypes.NpTensor
WeightedScalars = pytypes.WeightedScalars
instantiate = base_hyperparams.instantiate
LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
AuxLossStruct = base_layer.AuxLossStructAUX_LOSS = base_layer.AUX_LOSS
template_field = base_layer.template_field# 定义标准的伪随机数生成器(PRNG)键名称,用于在模型训练等过程中需要随机数的地方保持一致性和可复现性
# PARAMS和RANDOM是常见的用于区分不同用途随机数的标识
PARAMS = base_layer.PARAMS
RANDOM = base_layer.RANDOM# 生成一个初始的随机数生成器的键,设置种子为1234,以便后续在需要随机初始化等操作时能复现结果
key = jax.random.PRNGKey(seed=1234)
# 配置微调模型的结构,使用PatchedDecoderFinetuneModel作为基础结构,并将之前加载的预训练模型的核心层配置传入
# 以此构建微调模型的初始结构,后续将在此基础上进行参数更新等微调操作
model = pax_fiddle.Config(patched_decoder.PatchedDecoderFinetuneModel,name='patched_decoder_finetune',core_layer_tpl=tfm.model_p,
)# 定义构建学习器的函数,配置学习器相关参数,如损失函数名称、优化器(这里使用Adam优化器)及其参数(学习率、学习率调度策略、梯度裁剪阈值、指数移动平均衰减等)
# 同时设置在微调过程中要固定的模型层(通过bprop_variable_exclusion参数指定,这里示例中固定了变压器层,可根据实际情况调整)
### We will hold the transformer layers fixed while finetuning, while training all other components.
@pax_fiddle.auto_config
def build_learner() -> learners.Learner():return pax_fiddle.Config(learners.Learner,name='learner',loss_name='avg_qloss',optimizer=optimizers.Adam(epsilon=1e-7,clip_threshold=1e2,learning_rate=1e-3, # 示例中适当调整学习率,可根据实际情况进一步优化lr_schedule=pax_fiddle.Config(schedules.Cosine,initial_value=1e-4,final_value=1e-5,total_steps=40000,),ema_decay=0.9999,),# 线性探测,固定变压器层(可根据实际情况调整要固定的层)bprop_variable_exclusion=['.*/stacked_transformer_layer/.*'],)# 构建训练任务配置,将之前定义的模型和学习器配置组合起来,同时设置模型相关的分布式训练的一些参数(如mesh形状和轴名称等)
# 用于后续在多设备等分布式环境下进行模型训练时的配置管理
task_p = tasks_lib.SingleTask(name='vehicle_sales_learn',model=model,train=tasks_lib.SingleTask.Train(learner=build_learner(),),
)
task_p.model.ici_mesh_shape = [1, 1, 1]
task_p.model.mesh_axis_names = ['replica', 'data', 'mdl']# 获取可用的设备(如GPU或CPU等)信息,并将其整理为特定的形状,用于构建分布式训练的Mesh对象,以支持模型在多设备上并行训练
DEVICES = np.array(jax.devices()).reshape([1, 1, 1])
MESH = jax.sharding.Mesh(DEVICES, ['replica', 'data', 'mdl'])# 获取本地设备数量,用于后续在数据划分、并行操作等方面根据设备数量进行相应处理,并打印设备相关信息方便查看
num_devices = jax.local_device_count()
print(f'num_devices: {num_devices}')
print(f'device kind: {jax.local_devices()[0].device_kind}')
jax_task = task_p
key, init_key = jax.random.split(key)# 以下两个函数用于处理训练批次数据和评估批次数据,主要功能是对数据的形状进行调整和整理
# 确保数据格式符合模型输入和后续处理的要求,例如将数据按照设备数量和批次大小等进行合理重塑
# 方便在分布式训练和评估过程中正确使用数据
# To correctly prepare a batch of data for model initialization (now that shape
# inference is merged), we take one devices*batch_size tensor tuple of data,
# slice out just one batch, then run the prepare_input_batch function over it.
def process_train_batch(batch):past_ts = batch[0].reshape(batch_size * num_ts, -1)actual_ts = batch[3].reshape(batch_size * num_ts, -1)return NestedMap(input_ts=past_ts, actual_ts=actual_ts)def process_eval_batch(batch):past_ts = batch[0]actual_ts = batch[3]return NestedMap(input_ts=past_ts, actual_ts=actual_ts)# 初始化模型状态,传入训练任务配置、初始化随机数键以及处理后的训练批次数据等信息
# 根据指定的检查点类型(这里是GDA类型)进行模型状态的初始化操作,得到初始的模型状态信息
jax_model_states, _ = trainer_lib.initialize_model_state(jax_task,init_key,process_train_batch(tbatch),checkpoint_type=checkpoint_types.CheckpointType.GDA,
)# 将预训练模型的参数设置为微调模型的初始权重,具体是将预训练模型的参数赋值给微调模型状态中对应核心层的参数部分
# 这样微调模型就可以在预训练的基础上进行进一步优化,加快收敛并利用预训练学到的通用特征表示
### Setting the initial model weights to the pretrained TimesFM parameters.
jax_model_states.mdl_vars['params']['core_layer'] = tfm._train_state.mdl_vars['params']
jax_vars = jax_model_states.mdl_vars
gc.collect()# 以下是模型微调的训练循环部分,定义了训练步和评估步的函数,在多个训练轮次(epoch)内循环进行训练和定期评估
# 通过早停机制(根据验证集损失决定是否提前停止训练)来避免过拟合,在每个训练步中更新模型参数,每个评估步计算验证集上的损失并保存最优模型状态的检查点
### Training loop# 将之前配置好的训练任务(task_p)赋值给jax_task变量,后续在训练和评估步骤中会用到这个任务配置信息
jax_task = task_p# 定义训练步函数train_step,该函数内部调用了trainer_lib.train_step_single_learner函数,用于执行单个学习器的一次训练步骤
# 它接收当前的模型状态(states)、伪随机数生成器的键(prng_key)以及输入数据(inputs)作为参数,返回训练后的模型状态等相关信息
def train_step(states, prng_key, inputs):return trainer_lib.train_step_single_learner(jax_task, states, prng_key, inputs)# 定义评估步函数eval_step,首先将模型状态转换为评估状态(通过to_eval_state方法),然后调用trainer_lib.eval_step_single_learner函数执行评估步骤
# 同样接收模型状态、伪随机数生成器的键以及输入数据作为参数,返回评估相关的结果信息(例如损失值等)
def eval_step(states, prng_key, inputs):states = states.to_eval_state()return trainer_lib.eval_step_single_learner(jax_task, states, prng_key, inputs)# 对初始的随机数生成器的键(key)进行分割,生成三个新的随机数生成器的键,分别用于后续的训练、评估以及其他可能的操作
# 这样可以保证在不同的步骤中使用不同的随机数流,便于控制随机性和复现实验结果
key, train_key, eval_key = jax.random.split(key, 3)# 根据本地设备数量(jax.local_device_count()),将训练用的随机数生成器的键(train_key)分割成多个子键,每个子键对应一个设备
# 用于在分布式训练中为每个设备提供独立的随机数种子,确保随机性在不同设备上的正确应用
train_prng_seed = jax.random.split(train_key, num=jax.local_device_count())# 同理,将评估用的随机数生成器的键(eval_key)也按照本地设备数量进行分割,为每个设备在评估过程中提供独立的随机数种子
eval_prng_seed = jax.random.split(eval_key, num=jax.local_device_count())# 使用jax.pmap对训练步函数(train_step)进行并行化处理,指定按照'batch'轴进行并行,使得训练可以在多个设备上并行执行,提高训练效率
p_train_step = jax.pmap(train_step, axis_name='batch')# 同样地,对评估步函数(eval_step)进行并行化处理,使其能在多个设备上并行执行评估操作,也是按照'batch'轴进行并行
p_eval_step = jax.pmap(eval_step, axis_name='batch')# 对初始的模型状态(jax_model_states)进行复制操作,以适配分布式训练环境,使得每个设备都有一份相同的初始模型状态副本
# 后续每个设备可以基于这份副本进行独立的参数更新等操作,最终再进行汇总等处理
replicated_jax_states = trainer_lib.replicate_model_state(jax_model_states)# 获取复制后的模型状态中的变量部分(mdl_vars),方便后续在训练和评估过程中对模型参数等变量进行操作和访问
replicated_jax_vars = replicated_jax_states.mdl_vars# 初始化最优验证集损失值为一个较大的数(1e7),在训练过程中,一旦发现更小的验证集损失值,就会更新这个最优值,并保存对应的模型状态
best_eval_loss = 1e7# 记录当前已经执行的训练步数,初始化为0,随着训练循环的进行,每执行一次训练步就会加1,用于判断是否达到定期评估的步数等条件
step_count = 0# 早停机制相关的耐心值(patience),初始化为0,代表目前还没有出现验证集损失不再下降的情况
# 当验证集损失连续多次(由PATIENCE变量定义)没有下降时,就会触发早停机制,提前结束训练
patience = 0# 设定总的训练轮次(epoch)数量,这里设置为100,表示模型将对整个训练数据集完整遍历100次,可根据实际情况调整该值
NUM_EPOCHS = 100# 设定早停机制中的耐心值,即验证集损失连续多少次没有下降就触发早停,这里设置为5,意味着如果连续5次评估验证集损失都没有变小,就停止训练
PATIENCE = 5# 定义每经过多少个训练步就进行一次模型在验证集上的评估操作,这里设置为1000步评估一次,用于定期监控模型在验证集上的性能表现
TRAIN_STEPS_PER_EVAL = 1000# 指定保存模型检查点的目录路径,训练过程中,当发现当前模型在验证集上的性能更好(验证集损失更小)时,会将模型状态保存到这个目录下
CHECKPOINT_DIR = '/home/senrajat_google_com/vehicle_sales_finetune'# 定义一个函数reshape_batch_for_pmap,用于根据设备数量对批次数据进行形状重塑,使其能正确地在分布式训练环境下分配到各个设备上
# 具体操作是将输入张量的第一个维度(通常是批次大小维度)按照设备数量进行划分,重新调整张量的形状
def reshape_batch_for_pmap(batch, num_devices):def _reshape(input_tensor):bsize = input_tensor.shape[0]residual_shape = list(input_tensor.shape[1:])nbsize = bsize // num_devicesreturn jnp.reshape(input_tensor, [num_devices, nbsize] + residual_shape)return jax.tree.map(_reshape, batch)# 外层循环,按照设定的训练轮次(NUM_EPOCHS)进行循环训练,每个epoch代表对整个训练数据集的一次完整遍历
for epoch in range(NUM_EPOCHS):# 打印当前所处的训练轮次信息,方便在训练过程中查看训练进度,flush=True用于立即输出信息,不缓冲print(f"__________________Epoch: {epoch}__________________", flush=True)# 获取训练集批次数据的迭代器,用于在当前epoch内逐个批次地遍历训练数据train_its = train_batches.as_numpy_iterator()# 判断如果早停的耐心值(patience)达到设定的阈值(PATIENCE),则触发早停机制,结束训练if patience >= PATIENCE:print("Early stopping.", flush=True)break# 内层循环,遍历当前epoch的每个训练批次数据for batch in tqdm(train_its):train_losses = []# 再次检查早停条件,若满足则提前停止当前批次的训练if patience >= PATIENCE:print("Early stopping.", flush=True)break# 调用函数处理训练批次数据,主要是对数据形状进行调整,使其符合模型训练输入要求tbatch = process_train_batch(batch)# 根据设备数量对处理后的批次数据进行重塑,以便在分布式训练环境下能正确分配到各个设备上进行并行计算tbatch = reshape_batch_for_pmap(tbatch, num_devices)# 执行分布式训练的一个训练步,传入当前模型状态、训练随机数种子以及处理好的批次数据# 返回更新后的模型状态以及包含训练损失等信息的输出结果(step_fun_out)replicated_jax_states, step_fun_out = p_train_step(replicated_jax_states, train_prng_seed, tbatch)# 将当前训练步的损失值添加到训练损失列表(train_losses)中,后续可以用于计算平均训练损失等操作train_losses.append(step_fun_out.loss[0])# 判断当前训练步数是否达到了设定的定期评估步数(TRAIN_STEPS_PER_EVAL),如果达到则进行模型在验证集上的评估操作if step_count % TRAIN_STEPS_PER_EVAL == 0:# 打印当前训练步数下的平均训练损失值,方便查看训练过程中模型在训练集上的损失变化情况,flush=True用于立即输出信息print(f"Train loss at step {step_count}: {np.mean(train_losses)}",flush=True,)# 清空训练损失列表,为下一个评估周期准备,避免累计之前的损失值影响下一次平均损失的计算train_losses = []# 打印提示信息,表示开始进行模型在验证集上的评估操作print("Starting eval.", flush=True)# 获取验证集批次数据的迭代器,用于在验证过程中逐个批次地遍历验证数据val_its = val_batches.as_numpy_iterator()# 初始化用于存储每个验证批次损失值的列表,用于后续计算平均验证集损失eval_losses = []# 遍历验证集的每个批次数据for ev_batch in tqdm(val_its):# 调用函数处理验证批次数据,对数据形状进行调整,使其符合模型评估输入要求ebatch = process_eval_batch(ev_batch)# 根据设备数量对处理后的验证批次数据进行重塑,适配分布式评估环境ebatch = reshape_batch_for_pmap(ebatch, num_devices)# 执行分布式评估的一个评估步,传入当前模型状态、评估随机数种子以及处理好的验证批次数据# 返回包含评估损失等信息的输出结果(这里只关心损失值,所以用下划线忽略其他返回信息)_, step_fun_out = p_eval_step(replicated_jax_states, eval_prng_seed, ebatch)# 将当前验证批次的损失值添加到验证损失列表(eval_losses)中eval_losses.append(step_fun_out.loss[0])# 计算平均验证集损失值,通过对验证损失列表中的所有损失值求平均得到mean_loss = np.mean(eval_losses)# 打印当前训练步数下的平均验证集损失值,方便查看模型在验证集上的性能表现,flush=True用于立即输出信息print(f"Eval loss at step {step_count}: {mean_loss}", flush=True)# 判断当前平均验证集损失值是否小于之前记录的最优验证集损失值(best_eval_loss),或者是否为NaN(表示出现异常情况)# 如果满足条件,说明当前模型在验证集上的性能更好,需要保存当前的模型状态作为最优模型状态if mean_loss < best_eval_loss or np.isnan(mean_loss):# 更新最优验证集损失值为当前的平均验证集损失值best_eval_loss = mean_loss# 打印提示信息,表示正在保存模型检查点print("Saving checkpoint.")# 对复制后的模型状态进行处理,将其转换为适合保存的格式(可能涉及去除一些分布式相关的冗余信息等操作)jax_state_for_saving = py_utils.maybe_unreplicate_for_fully_replicated(replicated_jax_states)# 调用函数保存模型检查点,将处理后的模型状态保存到指定的目录(CHECKPOINT_DIR)下,并且设置覆盖已存在的同名检查点checkpoints.save_checkpoint(jax_state_for_saving, CHECKPOINT_DIR, overwrite=True)# 将早停机制的耐心值重置为0,因为当前模型性能有提升,重新开始计算耐心值patience = 0# 删除已经保存的模型状态变量,释放内存空间,避免内存占用过多del jax_state_for_saving# 手动触发垃圾回收,及时回收不再使用的内存,优化内存使用情况gc.collect()# 如果当前平均验证集损失值没有小于最优验证集损失值,说明模型在验证集上的性能没有提升,则增加早停机制的耐心值else:patience += 1# 打印当前的耐心值,方便查看早停机制的触发进度情况print(f"patience: {patience}")# 每执行完一个训练步,训练步数加1,用于跟踪训练的进度以及判断是否达到定期评估等条件step_count += 1# 以下代码用于加载根据验证集损失选出的最优微调后的模型检查点,并在测试集上对其进行评估,计算平均绝对误差(MAE)来衡量模型性能
## Loading and evaluating the best (according to validation loss) finetuned checkpoint# 调用函数从指定的目录(CHECKPOINT_DIR)中恢复之前保存的最优模型状态,将其赋值给train_state变量,用于后续的模型参数更新和评估操作
train_state = checkpoints.restore_checkpoint(jax_model_states, CHECKPOINT_DIR)# 打印恢复的模型状态对应的训练步数信息,可用于查看加载的是哪个阶段保存的模型状态
print(train_state.step)# 将微调后模型的参数更新到原TimesFM模型中,具体是将恢复的模型状态中的核心层参数赋值给原TimesFM模型的对应参数部分
# 使得原模型可以使用微调后的参数进行预测等操作,用于在测试集上评估微调后的模型性能
tfm._train_state.mdl_vars['params'] = train_state.mdl_vars['params']['core_layer']# 对TimesFM模型执行即时编译(jit)相关的解码操作,可能是为了优化模型在后续预测过程中的性能,加快预测速度
tfm.jit_decode()# 初始化用于存储测试集上每个批次预测结果与实际结果的平均绝对误差(MAE)的列表,后续将通过循环计算并填充该列表
mae_losses = []# 遍历测试集的每个批次数据,用于计算在整个测试集上的平均绝对误差(MAE)
for batch in tqdm(test_batches.as_numpy_iterator()):# 获取当前批次的输入数据(通常是历史时间序列数据等),作为模型预测的输入past = batch[0]# 获取当前批次的实际值(真实的目标数据,例如实际销售量、销售价格等),用于与模型预测结果进行对比计算误差actuals = batch[3]# 使用更新参数后的TimesFM模型进行预测,传入当前批次的输入数据以及一些相关的辅助参数(这里辅助参数都设置为0,具体含义可能取决于模型的定义)# 返回预测结果(forecasts)以及其他可能的相关信息(这里用下划线忽略)_, forecasts = tfm.forecast(list(past), [0] * past.shape[0])# 对预测结果进行维度处理,选取与实际值维度对应的部分,确保两者可以正确地进行误差计算(这里假设实际值和预测值在维度上需要进行一定的对齐操作)forecasts = forecasts[:, 0 : actuals.shape[1], 5]# 计算当前批次预测结果与实际结果的平均绝对误差(MAE),通过计算预测值与实际值差值的绝对值的平均值得到# 将每个批次的MAE值添加到mae_losses列表中,后续可以通过求平均得到整个测试集上的平均MAE值mae_losses.append(np.abs(forecasts - actuals).mean())print(f"MAE: {np.mean(mae_losses)}")
请注意:
- 上述代码中的路径等相关设置(如
CHECKPOINT_DIR
、数据文件路径等)需要根据你的实际运行环境进行调整,确保可以正确读写文件以及保存和加载模型检查点。 - 代码中关于模型的一些超参数(如学习率、训练轮数、批处理大小等)都是示例值,你可能需要根据车辆销售数据的特点、模型训练情况等进行多次试验和调整,以获得更好的预测性能。
- 假设数据文件
vehicle_sales.csv
的格式是比较规范的,能被pandas
的read_csv
方法正常读取,如果实际数据有特殊格式(例如包含标题行、特定的日期格式、缺失值表示等情况),可能需要进一步对数据读取部分进行修改完善。