深度学习:利用随机数据更快地测试一个新的模型在自己数据格式很复杂的时候

server/2024/11/17 2:38:48/

技巧:

比如下面一个新的模型deeponet,我自己的数据很复杂,这里在代码最后用用随机生成的数据,两分钟就完成了代码的测试成功。 

import torch
import torch.nn as nn
import torch.optim as optim# 带偏置项的 DeepONet 结构,包括 Branch 和 Trunk 网络
class DeepONet(nn.Module):def __init__(self, branch_input_dim, trunk_input_dim, hidden_dim):super(DeepONet, self).__init__()# Branch 网络,用于处理输入点云的特征(例如位移量、压强)self.branch_net = nn.Sequential(nn.Linear(branch_input_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim))# Trunk 网络,用于处理时间和空间坐标 [x, y, z, t]self.trunk_net = nn.Sequential(nn.Linear(trunk_input_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim))# 偏置项 biasself.bias = nn.Parameter(torch.zeros(1))  # 可训练的偏置项# 最终的输出层,预测位移或压强等物理状态self.fc_output = nn.Linear(hidden_dim, 3)def forward(self, point_features, coord_time):# Branch网络的输出branch_output = self.branch_net(point_features)# Trunk网络的输出trunk_output = self.trunk_net(coord_time)# 将 Branch 和 Trunk 的输出结合,计算最终的输出combined = branch_output * trunk_outputoutput = self.fc_output(combined) + self.bias  # 加上偏置项return output# 数据准备
# 输入的数据格式:
# point_features:3D点云的物理特征(例如位移量 pointDisplacement、压强 p)
# coord_time:空间位置和时间 [x, y, z, t]# 示例数据的维度设置
branch_input_dim = 3  # 例如 [pointDisplacement, p, ...] 
trunk_input_dim = 4   # [x, y, z, t]
hidden_dim = 64       # 隐藏层维度,可根据需求调整# 模型初始化
model = DeepONet(branch_input_dim, trunk_input_dim, hidden_dim)# 损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练流程
def train(model, point_features, coord_time, target, epochs=1000):for epoch in range(epochs):optimizer.zero_grad()# 前向传播output = model(point_features, coord_time)# 计算损失loss = criterion(output, target)# 反向传播和优化loss.backward()optimizer.step()if epoch % 100 == 0:print(f"Epoch {epoch}, Loss: {loss.item()}")# 示例数据,实际应用时需要替换为真实数据
N = 1000  # 样本数量
point_features = torch.randn(N, branch_input_dim)  # 3D点云的物理特征
coord_time = torch.randn(N, trunk_input_dim)       # [x, y, z, t]
target = torch.randn(N, 3)                         # 目标物理状态# 训练模型
train(model, point_features, coord_time, target, epochs=1000)# 推理:给定新的时空点,预测物理状态
def predict(model, point_features, coord_time):model.eval()with torch.no_grad():prediction = model(point_features, coord_time)return prediction# 示例推理
new_point_features = torch.randn(1, branch_input_dim)
new_coord_time = torch.tensor([[0.5, 0.5, 0.5, 0.1]])  # 在 t=0.1 的 (0.5, 0.5, 0.5) 空间点
prediction = predict(model, new_point_features, new_coord_time)
print("Predicted state:", prediction)

输出如下:

Epoch 0, Loss: 1.0260347127914429
Epoch 100, Loss: 0.7669863104820251
Epoch 200, Loss: 0.5786211490631104
Epoch 300, Loss: 0.4749055504798889
Epoch 400, Loss: 0.41076529026031494
Epoch 500, Loss: 0.36538082361221313
Epoch 600, Loss: 0.39494913816452026
Epoch 700, Loss: 0.30206459760665894
Epoch 800, Loss: 0.2839098572731018
Epoch 900, Loss: 0.2648167908191681
Predicted state: tensor([[-0.2604,  0.2214,  0.5066]])Process finished with exit code 0


http://www.ppmy.cn/server/142533.html

相关文章

python os.path.basename(获取路径中的文件名部分) 详解

os.path.basename 是 Python 的 os 模块中的一个函数,用于获取路径中的文件名部分。它会去掉路径中的目录部分,只返回最后的文件名或目录名。 以下是 os.path.basename 的详细解释和使用示例: 语法 os.path.basename(path) 参数 path&…

算法——二分查找(leetcode704)

对于二分查找而言,首先我们得到的查找数组必须是一个有序数组,接着通过数组的两端得到左指针和右指针继而得到中间指针指向数组中间元素,将中间元素与目标值比较如果大于目标值舍弃数组中间元素右边的一半将右指针重置为中间指针下标-1中间指针重置为左右指针下标之和除以2&…

K8S 查看pod节点的磁盘和内存使用情况

查看某个节点的磁盘使用率: kubectl exec -it pod名称 -n 命名空间 – df -h 查询所有节点的已使用内存: kubectl top pods --all-namespaces | grep itsm 查询某个节点的总内存, kubectl describe pod itsr-domain-59f4ff5854-hzb68 --nam…

面试时问到软件开发原则,我emo了

今天去一个小公司面试,面试官是公司的软件总监,眼镜老花到看笔记本电脑困难,用win7的IE打开leetcode网页半天打不开,公司的wifi连接不上,用自己手机热点,却在笔记本电脑上找不到。还是我用自己的手机做热点…

Springboot RabbitMq 集成分布式事务问题

话不多说&#xff0c;直接上代码 先整体结构 pom依赖&#xff1a; <parent><artifactId>spring-boot-starter-parent</artifactId><groupId>org.springframework.boot</groupId><version>2.7.18</version></parent><depe…

【Python爬虫实战】轻量级爬虫利器:DrissionPage之SessionPage与WebPage模块详解

&#x1f308;个人主页&#xff1a;易辰君-CSDN博客 &#x1f525; 系列专栏&#xff1a;https://blog.csdn.net/2401_86688088/category_12797772.html ​ 目录 前言 一、SessionPage &#xff08;一&#xff09;SessionPage 模块的基本功能 &#xff08;二&#xff09;基本使…

1.两数之和-力扣(LeetCode)

题目&#xff1a; 解题思路&#xff1a; 在解决这个问题之前&#xff0c;首先要明确两个点&#xff1a; 1、参数returnSize的含义是返回答案的大小&#xff08;数目&#xff09;&#xff0c;由于这里的需求是寻找数组中符合条件的两个数&#xff0c;那么当找到这两个数时&#…

Python 正则表达式基础教程:简单匹配

Python 正则表达式基础教程&#xff1a;简单匹配 正则表达式&#xff08;Regular Expression&#xff09;是一种用于匹配字符串模式的强大工具。在 Python 中&#xff0c;正则表达式广泛用于数据处理、文本分析等任务&#xff0c;能够帮助我们快速找到或替换特定的字符或字符串…