PyG结合MP api 实现深度学习对材料性能预测的简单案例分析

devtools/2025/3/4 0:50:45/

使用 PyTorch Geometric (PyG) 和 Material Project API (MP API),结合深度学习技术预测材料的性能。选择预测材料的带隙(band gap) 作为任务,带隙是材料的一个关键性能,影响其导电性和光学性质。通过这个案例,可以将掌握从数据获取到模型预测的完整流程,并理解深度学习在材料科学中的作用。


1. 案例背景

  • Material Project API (MP API): Materials Project 是一个开放的材料数据库,提供了丰富的晶体结构和性能数据。通过其 API,我们可以轻松获取这些数据,用于训练和测试模型。

  • PyTorch Geometric (PyG): PyG 是一个基于 PyTorch 的图神经网络(GNN)库,特别适合处理图结构数据。在材料科学中,晶体结构可以表示为图(原子为节点,化学键为边),PyG 是实现性能预测的理想工具。

  • 任务: 使用 MP API 获取材料数据,结合 PyG 构建 GNN 模型,预测材料的带隙。


2. 环境准备

在开始之前,请确保安装以下 Python 库:

bash

 

 

pip install mp_api pymatgen torch torch-geometric
  • mp_api: 用于访问 Materials Project 数据库。

  • pymatgen: 用于处理晶体结构数据。

  • torch: PyTorch 深度学习框架。

  • torch-geometric: PyTorch Geometric,处理图神经网络。

此外,你需要从 Materials Project 官网 注册并获取一个 API 密钥,将其保存以供后续使用。


3. 数据获取

我们将使用 MP API 获取一组材料的晶体结构和带隙数据。这里以 10 种材料为例。

python

 

 

from mp_api.client import MPRester# 替换为你的 API 密钥
API_KEY = "your_api_key"# 使用 MPRester 获取数据
with MPRester(API_KEY) as mpr:# 查询 10 种材料的结构和带隙materials = mpr.materials.summary.search(fields=["material_id", "structure", "band_gap"],num_elements=(1, 5),  # 限制元素数量以简化limit=10  # 获取 10 种材料)# 提取材料 ID、结构和带隙
material_ids = [mat.material_id for mat in materials]
structures = [mat.structure for mat in materials]
band_gaps = [mat.band_gap for mat in materials]# 输出验证
for mid, bg in zip(material_ids, band_gaps):print(f"Material ID: {mid}, Band Gap: {bg} eV")

说明:

  • MPRester 是 MP API 的客户端工具,用于查询数据。

  • search 方法指定了返回字段(material_id、structure、band_gap),并限制了元素数量和返回数量。

  • 我们获取了每种材料的 ID、晶体结构和带隙值。


4. 数据预处理:将晶体结构转换为图数据

晶体结构可以表示为图,原子作为节点,化学键作为边。我们需要将晶体结构转换为 PyG 可用的图数据格式。

python

 

 

from pymatgen.core import Structure
from torch_geometric.data import Data
import torchdef structure_to_graph(structure: Structure, cutoff: float = 4.0):"""将晶体结构转换为图数据。参数:structure: pymatgen 的 Structure 对象cutoff: 边距离阈值(Å)返回:Data: PyG 的图数据对象"""# 节点特征:原子序数atomic_numbers = [site.specie.number for site in structure.sites]x = torch.tensor(atomic_numbers, dtype=torch.float).view(-1, 1)# 获取所有原子对的距离(考虑周期性边界)all_neighbors = structure.get_all_neighbors(cutoff, include_index=True)edge_index = []edge_attr = []for i, neighbors in enumerate(all_neighbors):for neighbor in neighbors:j = neighbor[2]  # 邻居原子的索引dist = neighbor[1]  # 距离edge_index.append([i, j])edge_attr.append(dist)# 转换为 tensoredge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()edge_attr = torch.tensor(edge_attr, dtype=torch.float).view(-1, 1)return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)# 将所有结构转换为图数据
graph_data_list = [structure_to_graph(struct) for struct in structures]# 为每个图数据添加带隙标签
for data, bg in zip(graph_data_list, band_gaps):data.y = torch.tensor([bg], dtype=torch.float)

说明:

  • structure_to_graph 函数将晶体结构转换为图:

    • 节点特征 (x): 使用原子的原子序数作为特征。

    • 边索引 (edge_index): 表示原子之间的连接。

    • 边特征 (edge_attr): 使用原子间距离作为特征。

  • get_all_neighbors 获取指定距离(cutoff)内的邻居原子,考虑周期性边界。

  • 每个图数据对象 Data 包含节点特征、边索引、边特征和带隙标签 y。


5. 构建图神经网络(GNN)模型

我们使用 PyG 构建一个简单的 GNN 模型来预测带隙。

python

 

 

import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_poolclass BandGapPredictor(nn.Module):def __init__(self, num_features: int, hidden_channels: int):super().__init__()self.conv1 = GCNConv(num_features, hidden_channels)self.conv2 = GCNConv(hidden_channels, hidden_channels)self.lin = nn.Linear(hidden_channels, 1)def forward(self, data):x, edge_index, batch = data.x, data.edge_index, data.batchx = self.conv1(x, edge_index).relu()x = self.conv2(x, edge_index).relu()x = global_mean_pool(x, batch)  # 全局平均池化return self.lin(x)

说明:

  • GCNConv: 图卷积层,用于学习节点间的特征交互。

  • global_mean_pool: 将所有节点的嵌入聚合为图级嵌入。

  • 模型包含两层图卷积和一个线性层,最终输出带隙预测值。


6. 数据集划分和训练准备

将数据划分为训练集和测试集,并创建数据加载器。

python

 

 

from torch_geometric.data import DataLoader# 划分训练集和测试集(80% 训练,20% 测试)
train_data = graph_data_list[:8]
test_data = graph_data_list[8:]# 创建数据加载器
train_loader = DataLoader(train_data, batch_size=2, shuffle=True)
test_loader = DataLoader(test_data, batch_size=2, shuffle=False)

说明:

  • DataLoader 用于批量加载图数据。

  • batch_size=2 表示每个批次包含 2 个图。


7. 模型训练

训练模型以最小化预测带隙和实际带隙之间的均方误差。

python

 

 

# 初始化模型、优化器和损失函数
model = BandGapPredictor(num_features=1, hidden_channels=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()# 训练函数
def train():model.train()total_loss = 0for data in train_loader:optimizer.zero_grad()out = model(data)loss = criterion(out, data.y.view(-1, 1))loss.backward()optimizer.step()total_loss += loss.item()return total_loss / len(train_loader)# 训练模型
for epoch in range(50):loss = train()if (epoch + 1) % 10 == 0:print(f"Epoch {epoch+1}, Loss: {loss:.4f}")

说明:

  • 使用 Adam 优化器和均方误差(MSE)损失函数。

  • 训练 50 个周期,每 10 个周期输出一次损失。


8. 模型评估

在测试集上评估模型的性能。

python

 

 

def test(loader):model.eval()total_error = 0with torch.no_grad():for data in loader:out = model(data)error = torch.abs(out - data.y.view(-1, 1)).sum().item()total_error += errorreturn total_error / len(loader.dataset)# 测试模型
test_error = test(test_loader)
print(f"Test MAE: {test_error:.4f} eV")

说明:

  • test 函数计算平均绝对误差(MAE)。

  • model.eval() 切换到评估模式,torch.no_grad() 关闭梯度计算。


9. 预测新材料的带隙

使用训练好的模型预测一个新材料的带隙。

python

 

 

# 获取一个新材料(例如 mp-150)
with MPRester(API_KEY) as mpr:new_material = mpr.materials.summary.search(material_ids=["mp-150"], fields=["structure"])
new_structure = new_material[0].structure# 转换为图数据
new_data = structure_to_graph(new_structure)# 预测
model.eval()
with torch.no_grad():prediction = model(new_data)
print(f"Predicted band gap for mp-150: {prediction.item():.4f} eV")

说明:

  • 获取新材料的晶体结构并转换为图数据。

  • 使用训练好的模型预测带隙并输出结果。


10. 总结与深度学习的意义

通过这个案例,我们完成了从数据获取到模型预测的完整流程:

  1. 使用 MP API 获取材料数据。

  2. 将晶体结构转换为图数据。

  3. 构建并训练 GNN 模型。

  4. 评估模型并预测新材料的带隙。

深度学习在材料性能预测中的作用:

  • 自动化特征提取: GNN 可以从晶体结构中自动提取特征,避免手动设计复杂的描述符。

  • 高效预测: 训练好的模型可以快速预测新材料的性能,加速材料筛选和设计。

  • 广泛适用性: 该方法可扩展到其他性能预测任务(如弹性模量、热导率等),适用于大规模数据集。


http://www.ppmy.cn/devtools/164324.html

相关文章

leetcode第39题组合总和

原题出于leetcode第39题https://leetcode.cn/problems/combination-sum/description/题目如下: 给你一个 无重复元素 的整数数组 candidates 和一个目标整数 target ,找出 candidates 中可以使数字和为目标数 target 的 所有 不同组合 ,并以…

基于人工智能/机器学习的SPICE建模与参数提取基准

来源 Benchmarks for SPICE Modeling and Parameter Extraction Based on AI/ML(TED) 摘要 在过去的几十年里,使用数值方法进行SPICE建模或对现有SPICE模型参数进行表征(提取)的提交论文数量显著增加。许多此类文章…

Centos7服务器防火墙设置教程

Centos7服务器防火墙设置教程 系统环境:Centos7 首先,确保你的系统上安装了 firewalld。通常,在 CentOS 7 上,firewalld 已经预装。如果没有安装,可以通过以下命令安装: sudo yum install firewalld 启动…

【R语言】PCA主成分分析

使用R语言手动实现PCA主成分分析计算&#xff0c;通过计算协方差矩阵计算出数据的主成分得分&#xff0c;根据的分最高的特征进行得分图的绘制 # 读取数据raw_data <- read.csv("R可视化/data.csv", header TRUE, fileEncoding "GBK")new_data <-…

Yolo11实战:基于YOLOv11的半自动化数据标注技术实践

摘要 在人工智能项目开发中,数据标注的耗时性与高成本已成为制约模型迭代效率的核心瓶颈。本文以YOLOv11的COCO预训练模型为技术基础,系统阐述半自动化标注流程的设计与实现,旨在通过**“模型推理-人工校验-迭代优化”**的闭环机制,显著提升标注效率与数据质量。 一、技术…

Spring Boot 与 MyBatis 数据库操作

一、核心原理 Spring Boot 的自动配置 通过 mybatis-spring-boot-starter 自动配置 DataSource&#xff08;连接池&#xff09;、SqlSessionFactory 和 SqlSessionTemplate。 扫描 Mapper 接口或指定包路径&#xff0c;生成动态代理实现类。 MyBatis 的核心组件 SqlSessionF…

3.【基于深度学习YOLOV11的车辆类型检测系统】

文章目录 研究背景主要工作内容一、系统核心功能介绍及效果演示演示&#xff1a;软件主要功能&#xff1a;检测界面各大板块说明&#xff1a;检测区域&#xff1a;结果显示&#xff1a;主要功能说明:&#xff08;1&#xff09;图片检测说明&#xff08;2&#xff09;图片批量检…

变电站蓄电池在线监测系统(论文+源码)

1系统方案设计 本次课题为变电站蓄电池在线监测系统的设计&#xff0c;其系统架构如图3.1所示&#xff0c;包括了主控制器STC89C52单片机&#xff0c;液晶显示器LCD1602,模数转换器ADC0832&#xff0c;电流传感器ACS712&#xff0c;分压电阻&#xff0c;蜂鸣器以及温度传感器。…