【从0开始】使用Flax NNX API 构建简单神经网络并训练

news/2025/2/11 4:40:23/

与 Linen API 不同,NNX 使用起来对初学者更加简单,跟 PyTorch 的体验更加接近。

任务

使用MLP拟合简单函数:
y = 2 x 2 + 1 y=2x^2+1 y=2x2+1

代码

import jax.numpy as jnp
import jax.random as jrm
import optax as ox
from jax import Array
from flax import nnx
from typing import Generatorclass Network(nnx.Module):"""def a simple MLP"""def __init__(self, in_dim: int, out_dim: int, rng: nnx.Rngs, hidden_dim: int):super().__init__()self.linear1 = nnx.Linear(in_dim, hidden_dim, rngs=rng)self.linear2 = nnx.Linear(hidden_dim, hidden_dim, rngs=rng)self.linear3 = nnx.Linear(hidden_dim, out_dim, rngs=rng)def __call__(self, x) -> Array:x = self.linear1(x)x = nnx.relu(x)x = self.linear2(x)x = nnx.relu(x)x = self.linear3(x)return xdef make_dataset(X: Array, Y: Array, batch: int, seed: int = 0
) -> Generator[tuple[jnp.ndarray, jnp.ndarray], None, None]:"dataset sample function"combined = jnp.stack((X, Y), axis=1)[..., None]key = jrm.key(seed)while True:selected = jrm.choice(key, combined, shape=(batch,))yield selected[:, 0], selected[:, 1]def loss_fn(model: Network, batch):x, y = batchpredicted = model(x)return ox.l2_loss(predicted, y).mean()# hyper parameter
seed = 0
batch = 16# make dataset
X = jnp.arange(0, 10, 0.005)
Y = 2 * X**2 + 1.0# build model & optimizer
model = Network(1, 1, hidden_dim=20, rng=nnx.Rngs(seed))
optimizer = nnx.Optimizer(model, ox.adamw(0.001, 0.90))# train
for i, (x, y) in enumerate(make_dataset(X, Y, batch)):loss, grads = nnx.value_and_grad(loss_fn)(model, (x, y))optimizer.update(grads)print(i, loss)if i >= 6000:break

依赖如下

absl-py==2.1.0
chex==0.1.88
etils==1.11.0
flax==0.10.2
fsspec==2025.2.0
humanize==4.11.0
importlib-resources==6.5.2
jax==0.5.0
jaxlib==0.5.0
markdown-it-py==3.0.0
mdurl==0.1.2
ml-dtypes==0.5.1
msgpack==1.1.0
nest-asyncio==1.6.0
numpy==2.2.2
opt-einsum==3.4.0
optax==0.2.4
orbax-checkpoint==0.11.2
protobuf==3.20.3
pygments==2.19.1
pyyaml==6.0.2
rich==13.9.4
scipy==1.15.1
simplejson==3.19.3
tensorstore==0.1.71
toolz==1.0.0
typing-extensions==4.12.2
zipp==3.21.0

http://www.ppmy.cn/news/1571061.html

相关文章

Java面试题-计算机网络

文章目录 1.介绍一下TCP/IP五层模型?2.**什么是TCP三次握手、四次挥手?**1.三次握手建立连接2.四次握手断开连接 **3.HTTPS和HTTP的区别是什么?**4.**浏览器输入www.taobao.com回车之后发生了什么**?1.URL解析,对URL进…

Windows系统中常用的命令

随着Windows系统的不断改进,维护系统时有时候会因为新系统的更新而找不到对应的模块或者相关的信息入口,这个时候,记住一些命令就可以起到很好的帮助作用。 比如,windows11中的网络属性的修改,可能习惯了windows10或者…

一个00后的自述:不好好学习的我后悔了

普通人家的孩子不读书,以后你能做什么? 以下是一个00后的自述: 我是2000年出生的,父亲是建筑工人,母亲是农民,我就是一个普通人家的孩子。 小时候,其实我的学习成绩也是不错的,但…

机器学习:定义、原理、应用与未来(万字总结)

机器学习:定义、原理、应用与未来 一、机器学习是什么 机器学习作为人工智能领域的核心技术,正以前所未有的速度改变着我们的生活和工作方式。从智能语音助手到自动驾驶汽车,从个性化推荐系统到医疗诊断辅助,机器学习的应用无处…

(一)DeepSeek大模型安装部署-Ollama安装

大模型deepseek安装部署 (一)、安装ollama curl -fsSL https://ollama.com/install.sh | sh sudo systemctl start ollama sudo systemctl enable ollama sudo systemctl status ollama(二)、安装ollama遇到网络问题,请手动下载 ollama-linux-amd64.tgz curl -L …

【AI编程助手系列】国产AI编程工具 DeepSeek+Cline+VSCode 快速集成

文章目录 前言一、deepseek 介绍二、deepseek 优势三、什么是 Cline?3.1 安装与配置3.1.1 安装 Cline 插件3.1.2 获取 DeepSeek API Key3.1.3 配置 Cline 四、总结 前言 🤖 DeepSeek 是一个强大的 API 平台,提供了丰富的功能和数据&#xff…

计算机网络-SSH实验-密码验证

前面我们学习了SSH连接的几个阶段,这次来实际配置,在用户认证阶段支持口令认证和密钥认证等方式,今天先来学习简单的口令认证。 一、SSH基础使用-口令认证 拓扑: 拓扑描述:使用网络桥接到本地终端:本地主…

【windows系统】02-windows server 2022系统安装

一、环境准备 1、下载VMware workstation 17 具体步骤参考 我的另一篇文章:https://blog.csdn.net/adminabcd/article/details/145480529?spm1001.2014.3001.5501 2、下载windows server 2022的镜像文件 访问网址 https://next.itellyou.cn/Original/#cbpProdu…