论文辅助笔记:TEMPO 之 dataset.py

embedded/2024/10/18 19:24:04/

0 导入库

import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from .utils import StandardScaler, decompose
from .features import time_features

1 Dataset_ETT_hour

1.1 构造函数

class Dataset_ETT_hour(Dataset):def __init__(self,root_path,flag="train",size=None,features="S",data_path="ETTh1.csv",target="OT",scale=True,inverse=False,timeenc=0,freq="h",cols=None,period=24,):if size == None:self.seq_len = 24 * 4 * 4self.pred_len = 24 * 4else:self.seq_len = size[0]self.pred_len = size[1]#输入sequence和输出sequence的长度assert flag in ["train", "test", "val"]type_map = {"train": 0, "val": 1, "test": 2}self.set_type = type_map[flag]'''指定数据集的用途,可以是 "train"、"test" 或 "val",分别对应训练集、测试集和验证集'''self.features = features#指定数据集包含的特征类型,默认为 "S",表示单一特征self.target = target#指定预测的目标特征self.scale = scale#一个布尔值,用于确定数据是否需要归一化处理self.inverse = inverse#一个布尔值,用于决定是否进行逆变换self.timeenc = timeenc#用于确定是否对时间进行编码【原始模样 or -0.5~0.5区间】self.freq = freq#定义时间序列的频率,如 "h" 表示小时级别的频率self.period = period#定义时间序列的周期,默认为 24self.root_path = root_pathself.data_path = data_pathself.__read_data__()#用于读取并初始化数据集

1.2 __read_data__

def __read_data__(self):self.scaler = StandardScaler()#初始化一个 StandardScaler 对象,用于数据的标准化处理df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))#读取数据集文件,将其存储为 DataFrame 对象 df_rawborder1s = [0,12 * 30 * 24 - self.seq_len,12 * 30 * 24 + 4 * 30 * 24 - self.seq_len,]#定义了三个区间的起始位置,分别对应训练集、验证集和测试集border2s = [12 * 30 * 24,12 * 30 * 24 + 4 * 30 * 24,12 * 30 * 24 + 8 * 30 * 24,]#定义了每个区间的结束位置border1 = border1s[self.set_type]border2 = border2s[self.set_type]'''通过 self.set_type 确定当前数据集类型并从 border1s 和 border2s 中获取对应的起始和结束位置 border1 和 border2'''if self.features == "M" or self.features == "MS":cols_data = df_raw.columns[1:]df_data = df_raw[cols_data]elif self.features == "S":df_data = df_raw[[self.target]]'''选择特征数据:多特征 "M" 或 "MS":选择所有数据列,除去日期列。单一特征 "S":只选择目标特征列(由 self.target 指定)。'''if self.scale:train_data = df_data[border1s[0] : border2s[0]]self.scaler.fit(train_data.values)data = self.scaler.transform(df_data.values)else:data = df_data.values'''如果 self.scale 为 True,则执行数据归一化:train_data:选择训练集的数据,用于拟合 self.scaler。data:对整个 df_data 进行转换。'''df_stamp = df_raw[["date"]][border1:border2]df_stamp["date"] = pd.to_datetime(df_stamp.date)data_stamp = time_features(df_stamp, timeenc=self.timeenc, freq=self.freq)'''时间特征处理:提取日期列 df_stamp,并将其转换为时间特征:pd.to_datetime:将日期转换为 datetime 对象。time_features:用于生成时间特征。'''self.data_x = data[border1:border2]if self.inverse:self.data_y = df_data.values[border1:border2]else:self.data_y = data[border1:border2]self.data_stamp = data_stamp'''将转换后的数据和时间特征赋值给 self.data_x、self.data_y 和 self.data_stamp:self.data_x 取 data 中的对应区间数据。self.data_y 根据 self.inverse 决定是从 data 还是 df_data 中获取。self.data_stamp 取生成的时间特征。'''

1.3 __getitem__

def __getitem__(self, index):s_begin = index#设置序列的起始点s_end = s_begin + self.seq_len#计算序列的结束点r_begin = s_end#设置预测序列的起始点r_end = r_begin + self.pred_len#计算预测序列的结束点seq_x = self.data_x[s_begin:s_end]#从 data_x 中提取序列部分seq_y = self.data_y[r_begin:r_end]# 从 data_y 中提取预测部分[ground-truth]x = torch.tensor(seq_x, dtype=torch.float).transpose(1, 0)  # [1, seq_len]y = torch.tensor(seq_y, dtype=torch.float).transpose(1, 0)  # [1, pred_len](trend, seasonal, residual) = decompose(x, period=self.period)#对序列 x 进行时间序列分解,返回趋势、季节性和残差三部分components = torch.cat((trend, seasonal, residual), dim=0)  # [3, seq_len]#将分解后的三部分按 0 维(纵向)拼接,形成一个包含三种特征的张量return components, y

1.3__len__

    def __len__(self):return len(self.data_x) - self.seq_len - self.pred_len + 1

1.4  inverse_transform

将数据进行逆转换,还原到原始尺度

    def inverse_transform(self, data):return self.scaler.inverse_transform(data)

2 Dataset_ETT_minute

基本上和hour 的一样,几个地方不一样:

  • __init__
    • data_path="ETTm1.csv",
    • freq="t",
    • period: int = 60,
  • __read_data__
    • border1s = [0,12 * 30 * 24 * 4 - self.seq_len,12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len,]
      border2s = [12 * 30 * 24 * 4,12 * 30 * 24 * 4 + 4 * 30 * 24 * 4,12 * 30 * 24 * 4 + 8 * 30 * 24 * 4,]


http://www.ppmy.cn/embedded/29094.html

相关文章

10_Scala控制抽象*了解

Scala控制抽象 2.抽象函数 抽象函数-> 函数没有输入只有返回 ->抽象方法的调用时候,不能有小括号 def test(f: > Unit): Unit {f //调用}3.控制抽象 控制抽象是一系列语句的聚集,是一种特殊的函数。 控制抽象也是函数的一种,它…

Docker容器:搭建LNMP架构

目录 前言 1、任务要求 2、Nginx 镜像创建 2.1 建立工作目录并上传相关安装包 2.2 编写 Nginx Dockerfile 脚本 2.3 准备 nginx.conf 配置文件 2.4 生成镜像 2.5 创建 Nginx 镜像的容器 2.6 验证nginx 3、Mysql 镜像创建 3.1 建立工作目录并上传相关安装包 3.2 编写…

微信小程序常用的api

基础API: wx.request:用于发起网络请求,支持GET、POST等方式,是获取网络数据的主要手段。wx.showToast:显示消息提示框,通常用于向用户展示操作成功、失败或加载中等状态。wx.showModal:显示模态…

4.Docker本地镜像发布至阿里云仓库、私有仓库、DockerHub

文章目录 0、镜像的生成方法1、本地镜像发布到阿里云仓库2、本地镜像发布到私有仓库3、本地镜像发布到Docker Hub仓库 Docker仓库是集中存放镜像的地方,分为公共仓库和私有仓库。 注册服务器是存放仓库的具体服务器,一个注册服务器上可以有多个仓库&…

AnolisOS8.8基于yum安装mariadb并进行授权管理

1 安装并启动MariaDB # 安装 dnf -y install mariadb-server # 设置开机启动并立即启动 systemctl enable --now mariadb2 配置root用户允许远程访问 注意&#xff1a;本机ip地址 一定要替换成自己mariadb服务的ip mysql<<eof grant all privileges on *.* to root本机…

数字化技术可以促进中国企业创新吗?

数字化技术可以显著促进中国企业的创新。数字化技术&#xff0c;包括人工智能&#xff08;AI&#xff09;、区块链&#xff08;Blockchain&#xff09;、云计算&#xff08;Cloud computing&#xff09;、大数据&#xff08;big Data&#xff09;等&#xff0c;被称为ABCD技术&…

Spring Cloud——Circuit Breaker上篇

Spring Cloud——Circuit Breaker上篇 一、分布式系统面临的问题1.服务雪崩2.禁止服务雪崩故障 二、Circuit Breaker三、resilience4j——服务熔断和降级1.理论知识2.常用配置3.案例实战&#xff08;1&#xff09;COUNT_BASED&#xff08;计数的滑动窗口&#xff09;&#xff0…

Python中的类(Class)详解——新手指南

在Python编程中&#xff0c;类&#xff08;Class&#xff09;是一个非常重要的概念&#xff0c;它允许程序员创建自己的对象类型。这些对象类型可以包含数据&#xff08;称为属性&#xff09;和函数&#xff08;称为方法&#xff09;&#xff0c;它们定义了这些对象的行为。本文…