Python深度学习框架:PyTorch、Keras、Scikit-learn、TensorFlow如何使用?学会轻松玩转AI!

devtools/2024/11/29 15:49:32/

在这里插入图片描述

前言

我们先简单了解一下PyTorch、Keras、Scikit-learn和TensorFlow都是什么。
想象一下你要盖一座大房子。你需要砖头、水泥、工具等等,对吧?机器学习也是一样,需要一些工具来帮忙。PyTorch、Keras、Scikit-learn和TensorFlow就是四种不同的“工具箱”。

  • TensorFlow: 就像一个超级大的、功能强大的工具箱,里面什么工具都有,可以盖各种各样的房子,从简单的到超级复杂的都有。它很厉害,但是也比较复杂,需要多学习才能用好。
  • PyTorch: 这个工具箱也很好用,也很强大,但是它比TensorFlow更容易上手,像积木一样,可以一块一块地搭建你的“房子”。
  • Keras: 它不是一个独立的工具箱,更像是一个方便的“说明书”,可以让你更容易地使用TensorFlow或者其他一些工具箱。它让盖房子变得简单一些。
  • Scikit-learn: 这个工具箱专门用来盖一些比较简单的“小房子”。如果你只需要盖个小棚子,它就足够用了。它比较容易学习,适合初学者。

总的来说,这四个工具箱各有各的优点,适合不同的任务和学习阶段。 你想盖什么样子的“房子”(解决什么问题),就选择合适的工具箱。
接下来让我们去了解一下他们吧

PyTorch

PyTorch是由Facebook开发的开源深度学习框架,以其灵活性和动态计算图结构著称。它非常适合研究和实验,尤其适合那些需要反复修改模型结构的场景。

PyTorch官方文档

什么是PyTorch

想象一下你有一个会学习的玩具机器人。PyTorch就像给这个机器人编程序的积木。

这些积木可以让你教机器人认猫、认狗,甚至玩游戏! 你用积木搭建一个“学习机器”,然后给它看很多猫和狗的照片,告诉它哪些是猫哪些是狗。 机器人会慢慢学习,下次看到猫或狗就能认出来了。

PyTorch就是这些“积木”的集合,有很多种积木,可以让你搭建各种各样的学习机器,让它做各种各样的事情。 它就像一个超级强大的工具箱,帮助人们创造聪明的机器。 它很厉害,但用起来需要学习一些新的“语言”和方法。

核心特点

  1. 动态计算图:PyTorch支持动态图机制,允许在运行时动态修改模型结构,非常适合实验和研究。
  2. 强大的社区支持:PyTorch拥有丰富的文档和社区资源,适合开发者快速入门和进行复杂项目开发。
  3. GPU加速:支持GPU加速,提升模型训练速度。
知识点描述
super()函数用于初始化继承自nn.Module的参数,实现子类与父类方法的关联。
模型保存与加载支持整个网络加参数和仅参数两种保存形式,可以使用.pkl或.pth文件。
卷积相关包括卷积核参数共享、局部连接、深度可分离卷积等概念。
DataLoader用于数据加载,支持批量处理、随机打乱、自定义样本处理等。
初始化方式卷积层和全连接层权重采用He-Uniform初始化,bias采用(-1,1)均匀分布。

应用场景:

研究环境中,尤其是需要反复修改模型结构的实验场景。
计算机视觉、自然语言处理等领域。
核心组件:

  1. torch:核心库,包含张量操作、数学函数等。
  2. torch.nn:神经网络模块,提供卷积层、全连接层等。
  3. torch.optim:优化器模块,提供SGD、Adam等优化算法。

PyTorch - 线性回归

使用PyTorch实现一个简单的线性回归模型,拟合一条直线。

python">import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 1. 准备数据
# 生成一些线性数据
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)
y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]], dtype=torch.float32)# 2. 定义模型
class LinearRegressionModel(nn.Module):def __init__(self):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(1, 1)  # 输入和输出都是1维def forward(self, x):return self.linear(x)model = LinearRegressionModel()# 3. 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 4. 训练模型
num_epochs = 1000
for epoch in range(num_epochs):model.train()outputs = model(x_train)loss = criterion(outputs, y_train)optimizer.zero_grad()loss.backward()optimizer.step()if (epoch+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 5. 预测并可视化
model.eval()
predicted = model(x_train).detach().numpy()plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original data')
plt.plot(x_train.numpy(), predicted, label='Fitted line')
plt.legend()
plt.show()

经过1000次迭代后,模型拟合的直线应该接近 y = 2x。控制台会输出损失值,最终损失值应该非常小(接近0)。图像上会显示原始数据点和拟合的直线。

TensorFlow

TensorFlow是由谷歌开发的深度学习框架,特别适用于生产环境,尤其是在大规模分布式系统中。它的设计初衷是服务于大规模计算任务,在速度和效率上有显著优势。

TensorFlow官方文档

什么是TensorFlow

想象一下你有一只很聪明的狗狗,你教它认猫和狗的图片。一开始它什么都不懂,但你每次给它看猫的图片就说“猫”,狗的图片就说“狗”。狗狗慢慢地就会学会区分猫和狗了,对吧?

TensorFlow就像一个超级厉害的训练狗狗的工具!它能让电脑像这只狗狗一样,通过看大量的图片(或者其他东西,比如文字、声音)来学习,然后自己学会区分不同的东西,甚至能预测一些事情。

比如,你可以用TensorFlow教电脑识别手写数字,或者翻译不同的语言,甚至能预测明天的天气!它就像一个神奇的工具箱,里面有很多方法能让电脑变得越来越聪明。 它需要很多很多的数据来学习,学习的过程就像教狗狗一样,需要反复练习。

简单来说,TensorFlow就是一个帮助电脑学习的超级工具,让电脑变得越来越聪明!

核心特点

  1. 静态计算图:TensorFlow的静态计算图使得模型在执行前就可以进行优化,提升效率。
  2. 广泛的部署工具:提供了从移动设备到服务器的全方位支持,具备强大的生产环境部署能力。
  3. 生态系统丰富:配套工具如TensorBoard、TensorFlow Lite和TensorFlow Serving,使得其生态系统非常完整。
知识点描述
静态计算图模型在执行前进行优化,提升效率。
TensorBoard可视化工具,用于查看模型结构、训练指标等。
TensorFlow Lite用于在移动设备和嵌入式设备上部署 TensorFlow 模型。
TensorFlow Serving用于生产环境中的模型部署和推理服务。

应用场景:

需要在生产环境中运行的大规模深度学习模型,如推荐系统、语音识别和自动驾驶等。

核心组件:

  1. tf.Tensor:张量对象,表示多维数组。
  2. tf.keras:高层API,简化模型构建。
  3. tf.data:数据输入管道,提供高效的数据加载和预处理。

TensorFlow - 二分类问题

使用TensorFlow实现一个简单的二分类模型,使用逻辑回归。

python">import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import numpy as np# 1. 准备数据
x_train = np.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]])
y_train = np.array([[0], [0], [1], [1]])  # 标签为0或1# 2. 定义模型
model = Sequential([Dense(1, activation='sigmoid', input_shape=(2,))
])# 3. 编译模型
model.compile(optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy'])# 4. 训练模型
model.fit(x_train, y_train, epochs=100, verbose=1)# 5. 预测
predictions = model.predict(x_train)
print(predictions)

模型会输出每个样本属于类别1的概率。随着训练次数增加,模型对训练数据的预测准确率会不断提高。在控制台上,你会看到损失值和准确率随着每个epoch的变化。

Keras

Keras是一个基于TensorFlow的高级神经网络API,设计初衷是为了简化深度学习的开发流程。它提供了简洁的接口,帮助用户快速构建复杂的深度学习模型。

Keras官方文档

什么是Keras

想象一下,你想要教一只小狗做一些事情,比如坐下、握手。 你不会一下子教它所有动作,而是先教它坐下,然后奖励它,再教它握手,再奖励它。

机器学习的 Keras 就像是一个教小狗的工具。 它有很多种“指令”,可以告诉电脑“如果看到这样的东西,就应该做出这样的反应”。 就像你教小狗“看到球就坐下”,Keras 可以教电脑“看到图片里是猫,就判断是猫”。

Keras 帮我们把这些“指令”组织起来,让电脑更容易学习。 它就像一个好老师,让电脑学习得更快、更有效率。 所以,Keras 帮助电脑学习各种事情,比如识别图片、预测天气等等。

核心特点

  1. 简洁易用:提供了非常直观的API,用户可以快速上手,适合新手和中小型项目。
  2. 高度模块化:允许用户自由组合层、优化器、损失函数等,模型的可读性和可维护性较高。
  3. 与TensorFlow完美结合:在TensorFlow 2.x之后,Keras成为TensorFlow的官方高级API,集成更为紧密。
知识点描述
Sequential模型一种按顺序堆叠网络层的模型。
函数式模型用于构建更复杂的模型,支持分支和合并等操作。
编译模型使用.compile方法指定损失函数、优化器和评估指标。
训练模型使用.fit()方法在训练数据上进行迭代训练。

应用场景:

快速原型开发和中小型项目,特别是在自然语言处理和图像处理任务中。

核心组件:

  1. Sequential:顺序模型,用于搭建简单的神经网络。
  2. Model:函数式模型,用于搭建复杂的神经网络。
  3. layers:网络层模块,提供卷积层、全连接层等。

Keras - 图像分类(使用MNIST数据集)

使用Keras实现一个简单的图像分类模型,对MNIST数据集进行手写数字识别。

python">from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.utils import to_categorical# 1. 准备数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28 * 28).astype('float32') / 255
x_test = x_test.reshape(-1, 28 * 28).astype('float32') / 255
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)# 2. 定义模型
model = Sequential([Flatten(input_shape=(28 * 28,)),Dense(128, activation='relu'),Dense(10, activation='softmax')
])# 3. 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# 4. 训练模型
model.fit(x_train, y_train, epochs=5, batch_size=32, verbose=1)# 5. 评估模型
loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
print(f'Test accuracy: {accuracy:.4f}')

模型会在MNIST数据集上进行训练,并在测试集上进行评估。控制台会输出训练过程中的损失值和准确率,最终会输出测试集上的准确率。

Scikit-learn

Scikit-learn是Python生态系统中最受欢迎的传统机器学习库,适用于数据预处理、分类、回归、聚类、降维等任务。它封装了经典的机器学习算法,具有简单易用的API和丰富的算法支持。

Scikit-learn官方文档

什么是Scikit-learn

想象一下,你有一堆积木,各种形状、颜色。你想用这些积木搭出不同的房子。

机器学习就像一个聪明的建筑师,它可以从这些积木(数据)中学习,找出规律,然后自己搭出房子(预测结果)。

Scikit-learn 是一个工具箱,里面有很多不同的积木(算法),可以帮助你搭建各种各样的房子。比如,你想知道哪块积木搭出来的房子最高,Scikit-learn 可以帮你找到。

简单来说,Scikit-learn 帮助你用数据训练机器学习模型,让机器学会如何预测或分类。 它有很多不同的工具,可以根据你想要搭的房子(预测的目标)选择合适的积木(算法)。 就像你用不同的积木搭出高楼、小房子一样,Scikit-learn 可以帮你用数据搭出各种各样的结果。

主要特点:

  1. 经典机器学习算法:提供监督学习、无监督学习的经典算法。
  2. 数据处理工具丰富:提供从数据预处理、特征选择到模型评估的全套工具。
  3. 与其他库兼容:与NumPy、Pandas等数据科学库无缝集成。
知识点描述
估计器(Estimator)包括fit()和predict()方法,用于训练模型和预测。
转换器(Transformer)用于数据预处理和数据转换,包括fit()、transform()和fit_transform()方法。
流水线(Pipeline)将多个数据处理步骤和模型训练封装在一起,方便重现实验结果。
特征抽取包括文本、图像等数据的特征抽取技术。
特征选择删除不重要的特征,降低模型复杂度。
降维使用PCA等方法降低数据维度,提取主要特征。

应用场景:

传统机器学习任务,如小型数据集上的分类、回归分析、聚类分析等。

核心组件:

  1. datasets:内置数据集模块,提供玩具数据集和真实世界数据集。
  2. preprocessing:数据预处理模块,提供归一化、标准化等功能。
  3. model_selection:模型选择模块,提供交叉验证、网格搜索等功能。

Scikit-learn内置数据集

数据集名称类型描述
Iris分类问题包含三种鸢尾花的四个特征,目标是根据这些特征预测鸢尾花的种类
Digits多分类问题包含手写数字的8x8像素图像,目标是识别这些图像对应的数字
Boston House Prices回归问题包含波士顿各个区域的房价和其他13个特征,目标是预测房价
Breast Cancer二分类问题包含乳腺肿瘤的30个特征,目标是预测肿瘤是良性还是恶性

Scikit-learn - 鸢尾花分类

使用Scikit-learn实现一个简单的分类模型,对鸢尾花数据集进行分类。

python">from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score# 1. 准备数据
iris = load_iris()
X = iris.data
y = iris.target# 2. 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 3. 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 4. 定义并训练模型
model = LogisticRegression(max_iter=200)
model.fit(X_train, y_train)# 5. 预测并评估模型
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Test accuracy: {accuracy:.4f}')

模型会对鸢尾花数据集进行分类,并在测试集上进行评估。控制台会输出测试集上的准确率。
在这里插入图片描述在这里插入图片描述在这里插入图片描述


http://www.ppmy.cn/devtools/137958.html

相关文章

线程的生命周期

线程的生命周期描述了线程从创建到消亡的整个过程,以及在这个过程中线程所经历的不同状态。以下是线程生命周期的详细解释: 一、新建(NEW) 当使用new关键字创建一个线程对象时,线程进入新建状态。此时,线…

Flink 之 Window 机制详解(上):基础概念与分类

《Flink 之 Window 机制详解(上):基础概念与分类》 一、引言 在当今大数据蓬勃发展的时代,Flink 作为一款卓越的分布式流处理和批处理框架,以其独特的架构和强大的功能在数据处理领域占据着重要地位。其底层基于流式…

SQL Server 中的游标:介绍、效率、使用场景及替代方法对比

在 SQL Server 中,游标(Cursor)是一种数据库对象,用于逐行处理查询结果集。虽然游标在某些场景下非常有用,但它们的性能往往不如集合操作(set-based operations)。本文将详细介绍游标的概念、使…

设计模式---单例模式

单例模式:确保一个类只有一个实例,并提供该实例的全局访问点, 本文介绍6中常用的实现方式 懒汉式-线程不安全 以下实现中,私有静态变量 uniqueInstance 被延迟实例化,这样做的好处是,如果没有用到该类,那么…

怀念食家巷平凉面点,重拾美好

在美食的长河中,总有一些味道能勾起我们内心深处最温暖的回忆。食家巷平凉面点,便是这样一种带着浓郁乡愁与美好记忆的传统美食。平凉,这座历史悠久的城市,孕育出了独具特色的面点文化。食家巷的平凉面点白饼、烤馍,传…

faiss VS ChromaDB

faiss faiss 是一个开源的机器学习库,由Facebook AI Research(FAIR)开发,主要用于高效的大规模向量搜索和聚类。 faiss 的核心优势在于它为高维向量空间中的数据提供了快速的近似最近邻搜索(ANNS)算法&am…

DevOps工程技术价值流:Jenkins驱动的持续集成与交付实践

一、Jenkins系统概述 Jenkins:开源CI/CD引擎的佼佼者 Jenkins,作为一款基于Java的开源持续集成(CI)与持续交付(CD)系统,凭借其强大的插件生态系统,成为DevOps实践中不可或缺的核心…

Linux入门系列--用户与权限

一、前言 1.注意: 【】用户是Linux系统工作中重要的一环,用户管理包括 用户 与 组账号 的管理 【】在Linux系统中,不论是由本机或是远程登录(SSH)系统,每个系统都必须拥有一个账号,并且对于不同的系统资源拥有不同的使…