论文辅助笔记:TEMPO 之 utils.py

ops/2024/9/24 21:23:56/

0 导入库

from typing import Tuple
import random
import numpy as np
import torch
from statsmodels.tsa.seasonal import STL

1 EarlyStopping

  • 提供了一个早停机制,用于在模型训练过程中监控验证集上的损失
  • 如果损失停止改进,则停止训练

1.1 __init__

class EarlyStopping:def __init__(self, patience=7, verbose=False, delta=0):self.patience = patience#早停的容忍度,如果连续 patience 次验证损失没有改善,则停止训练。self.verbose = verbose#决定是否输出详细信息self.counter = 0#记录连续未改善验证损失的次数self.best_score = None#用于存储目前为止最佳的验证损失分数self.early_stop = False#一个布尔值,指示是否应该停止训练self.val_loss_min = np.Inf#存储目前为止最小的验证损失self.delta = delta#一个阈值,用于决定损失的改善幅度

1.2 __call__ 在训练过程中监控验证损失

def __call__(self, val_loss, model, path):score = -val_lossif self.best_score is None:self.best_score = scoreself.save_checkpoint(val_loss, model, path)#如果这是第一次调用 __call__,初始化 best_score 为 score 并保存模型。elif score < self.best_score + self.delta:self.counter += 1print(f"EarlyStopping counter: {self.counter} out of {self.patience}")if self.counter >= self.patience:self.early_stop = True'''如果 score < self.best_score + self.delta,则说明损失没有显著改善增加 counter 并检查是否超过 patience,如果超过则停止训练'''else:self.best_score = scoreself.save_checkpoint(val_loss, model, path)self.counter = 0'''如果 score > self.best_score + self.delta,更新 best_score 并保存模型然后将 counter 重置为零'''

1.3 save_checkpoint 在验证损失降低时保存模型

def save_checkpoint(self, val_loss, model, path):if self.verbose:print(f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...")torch.save(model.state_dict(), path + "/" + "checkpoint.pth")#使用 torch.save() 保存模型的状态字典self.val_loss_min = val_loss

2 StandardScaler

实现数据标准化

2.1 __init__

class StandardScaler:def __init__(self):self.mean = 0.0self.std = 1.0

2.2  fit

计算并更新 self.meanself.std

def fit(self, data):self.mean = data.mean(0)self.std = data.std(0)

 2.3  transform

   将数据转换为标准化形式

def transform(self, data):mean = (torch.from_numpy(self.mean).type_as(data).to(data.device)if torch.is_tensor(data)else self.mean)std = (torch.from_numpy(self.std).type_as(data).to(data.device)if torch.is_tensor(data)else self.std)'''mean 和 std 的类型转换:根据 data 是 torch.Tensor 还是 numpy 数组将 self.mean 和 self.std 转换为相应类型,以确保类型匹配'''return (data - mean) / std

 2.4 inverse_transform

将标准化后的数据还原

    def inverse_transform(self, data):mean = (torch.from_numpy(self.mean).type_as(data).to(data.device)if torch.is_tensor(data)else self.mean)std = (torch.from_numpy(self.std).type_as(data).to(data.device)if torch.is_tensor(data)else self.std)'''mean 和 std 的类型转换:根据 data 是 torch.Tensor 还是 numpy 数组将 self.mean 和 self.std 转换为相应类型,以确保类型匹配'''if data.shape[-1] != mean.shape[-1]:mean = mean[-1:]std = std[-1:]return (data * std) + mean'''通过 (data * std) + mean 将标准化后的数据还原为原始形式'''

3 decompose

使用STL,将时间序列分解为趋势、季节性和残差成分

def decompose(x: torch.Tensor, period: int = 7
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:#x:输入的一维时间序列,类型为 torch.Tensor,形状为 (1, seq_len)x = x.squeeze(0).cpu().numpy()'''首先调用 squeeze(0) 将 x 的第一个维度去掉然后通过 cpu().numpy() 将 x 转换为 numpy 数组,以便 STL 分解函数使用'''decomposed = STL(x, period=period).fit()'''调用 STL(x, period=period).fit() 对 x 进行分解,并返回分解结果 decomposed其中包含了 trend(趋势)、seasonal(季节性)和 resid(残差)成分'''trend = decomposed.trend.astype(np.float32)seasonal = decomposed.seasonal.astype(np.float32)residual = decomposed.resid.astype(np.float32)'''将 decomposed 中的各个成分转换为 numpy 数组,并转为 float32 类型'''return (torch.from_numpy(trend).unsqueeze(0),torch.from_numpy(seasonal).unsqueeze(0),torch.from_numpy(residual).unsqueeze(0),)'''将它们转换为 torch.Tensor并使用 unsqueeze(0) 将其包装为 (1, seq_len) 的张量,以匹配输入张量的形状'''

4 set_seed

为 Python 中的各种随机生成器设置种子

def set_seed(seed):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)


http://www.ppmy.cn/ops/33427.html

相关文章

VScode添加c/c++头文件路径

1.设置工作区include path方法&#xff1a; 命令面板 -> 输入c/c 修改配置文件&#xff0c;添加路径&#xff1a; 2.全局路径&#xff1a; 设置 - > 搜索include path

怎么把jpg图片变成gif?参考这个方法一键制作

Jpg图片如何变成gif图片&#xff1f;Jpg、gif都是最常用的图片格式&#xff0c;想要将这两种格式的图片互相转化的时候要怎么操作呢&#xff1f;想要将jpg图片变成gif方法很简单&#xff0c;只需要使用gif图片制作&#xff08;https://www.gif5.net/&#xff09;工具-GIF5工具网…

浏览器中不能使用ES6的扩展语法...报错

浏览器大多数已经支持ES6&#xff08;ECMAScript 2015&#xff09;的扩展语法&#xff08;...&#xff09;&#xff0c;包括Chrome、Firefox、Safari和Edge等。然而&#xff0c;如果你在某些浏览器中遇到无法使用扩展语法的问题&#xff0c;可能是由以下原因导致的&#xff1a;…

局域网唤醒平台:UpSnap

简介&#xff1a;UpSnap是一个简单的唤醒局域网网络应用程序。UpSnap为每个用户、每个设备提供了唯一的访问权限。虽然管理员拥有所有权限&#xff0c;但他们可以为用户分配特定的权限&#xff0c;如显示/隐藏设备、访问设备编辑、删除和打开/关闭设备电源。 历史攻略&#xf…

JavaEE >> Spring MVC(2)

接上文 本文介绍如何使用 Spring Boot/MVC 项目将程序执行业务逻辑之后的结果返回给用户&#xff0c;以及一些相关内容进行分析解释。 返回静态页面 要返回一个静态页面&#xff0c;首先需要在 resource 中的 static 目录下面创建一个静态页面&#xff0c;下面将创建一个静态…

Docker私有镜像仓库搭建 带图形化界面的

搭建镜像仓库可以基于Docker官方提供的DockerRegistry来实现。 官网地址&#xff1a;https://hub.docker.com/_/registry 先配置私服的信任地址: # 打开要修改的文件 vi /etc/docker/daemon.json # 添加内容&#xff1a; "insecure-registries":["http://192.…

gitee本地项目上传

1.先生成SSH密钥 ssh-keygen -t rsa -C "trueelegance163.com" 2.gitee配置公钥&#xff0c;设置对应的公钥名称比如:localtest; Git 全局设置:若刚安装完Git&#xff0c;需要进行Git的配置【若已配置完成&#xff0c;此步骤可以跳过】 git config --global user…

一篇文章带你深入了解“指针”

一篇文章带你深入了解“指针” 内存和地址了解指针指针类型const修饰指针指针的运算指针与整数之间的运算指针与指针之间的运算指针的关系运算 void* 指针传值调用和传址调用数组和指针的关系野指针野指针的形成原因规避野指针 二级指针字符指针指针数组数组指针数组传参一维数…