Keras中model.predict()与model()的区别

server/2024/9/25 8:23:51/

文章目录

  • 一、函数详解
  • 二、加速测试
    • 2.1、model.predict(x=input_data) —— 时耗:0.09967 秒
    • 2.2、model.predict(x=input_data, batch_size=8) —— 时耗:0.12919 秒
    • 2.3、model.predict(tf.data.Dataset.from_tensors(input_data)) —— 时耗:0.08310 秒
    • 2.4、model(x=input_data, training=False) —— 时耗:0.01395 秒

一、函数详解

在 Keras 中,获取模型的预测结果的两种方式:

  • keras_model() 直接调用模型对象将 Keras 模型对象当作函数一样调用,并将输入数据作为参数传递给它,从而直接获取预测结果。

    • 优缺点:(1)支持动态图计算;(2)只支持单样本预测;(3)只支持Tensor类型的输入数据;(4)输出数据为Tensor类型;
    • 适用范围:大规模数据;实时处理;预测速度快
  • keras_model.predict() 方法:predict() 方法是 Keras 模型对象的一个函数,用于进行推理并获取预测结果。

    • 优缺点:(1)不支持动态图计算;(2)支持批量样本预测;(3)支持Tensor和NumPy类型的输入数据;(4)输出数据为NumPy类型;(5)需要一次性将所有数据加载到内存中,因此对于大型数据集,可能会导致内存不足。
    • 适用范围:小规模数据;对内存占用不敏感;预测速度慢

在 PyTorch 中,只有一种方法获取模型的预测结果:pytorch_model()

keras_modelpredictx_19">1.1、keras_model.predict(x)

python">"""###################################################################
函数说明: keras_model.predict(x, batch_size=None, verbose=0, steps=None, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False)
输入参数: (1)x:                       输入数据。可以是 NumPy 数组、tf.data.Dataset 对象或字典等。具体取决于模型的输入层的期望。(2)batch_size:              批处理大小。如果未指定,将使用默认的批处理大小。批处理大小表示模型一次性处理的样本数量,可以影响内存使用和预测速度。(3)verbose:                 是否显示进度信息(0表示不显示任何信息、1表示显示进度条、2表示每个epoch显示一行)(4)steps:                   指定预测结束的步数(批次数)。如果未指定,将一直进行预测,直到输入数据用尽。(5)callbacks:               可选的回调函数列表。在预测过程中的不同阶段触发不同的回调函数,用于自定义行为。(6)max_queue_size:          指定生成器队列的最大大小。对于生成器提供的数据,此参数可以控制内存使用。(7)workers:                 指定生成器的工作进程数量。仅在使用生成器提供数据时才相关。(8)use_multiprocessing:     布尔值,表示是否使用多进程进行数据生成。默认为 False。如果设置为 True,则会使用 workers 个进程进行数据生成。返回参数:numpy数组 
###################################################################"""

keras_modelx_38">1.2、keras_model(x)

python">"""###################################################################
函数说明: keras_model(x, training=None)
输入参数: (1)x:      				输入数据。可以是 Numpy 数组、Tensor 对象或者其他可以被模型接受的输入数据。这里输入数据的形状和数据类型需要和模型的输入层相匹配。(2)training:  			布尔值,表示模型是否处于训练模式。True(训练模式): 		如:启用Dropout、	启用Batch NormalizationFalse(推理或预测模式): 	如:不启用Dropout、	不启用Batch Normalization返回参数:tensor张量
###################################################################"""

二、加速测试

keras里predict函数,预测速度慢的优化方法
keras里predict很慢,300倍减少predict运行时间的优化方法

2.1、model.predict(x=input_data) —— 时耗:0.09967 秒

2.2、model.predict(x=input_data, batch_size=8) —— 时耗:0.12919 秒

2.3、model.predict(tf.data.Dataset.from_tensors(input_data)) —— 时耗:0.08310 秒

2.4、model(x=input_data, training=False) —— 时耗:0.01395 秒

python">import time
import numpy as np
import tensorflow as tfif __name__ == "__main__":print("Tensorflow版本 =", tf.__version__)############################################################################### (1)新建模型flag = 2if flag == 1:"""1.序列模型"""from tensorflow import kerasmodel = keras.Sequential([keras.layers.Dense(128, activation='relu', input_shape=(10,)),  # 定义第一个全连接层,包括 128 个神经元,使用 ReLU 激活函数,输入形状为 (10,)。keras.layers.Dense(64, activation='relu'),  # 定义第二个全连接层,包括 64 个神经元,使用 ReLU 激活函数。keras.layers.Dense(1, activation='sigmoid')  # 定义输出层,包括 1 个神经元,使用 Sigmoid 激活函数,适用于二分类问题。])input_data = np.random.rand(1000, 10)  # 随机生成输入数据,其中包含 1000 个样本,每个样本有 10 个特征。elif flag == 2:"""2.卷积模型"""from tensorflow.keras import layers, modelsmodel = models.Sequential([# 卷积层,包含 32 个卷积核(filters),每个卷积核大小为 (3, 3),使用 ReLU 激活函数。input_shape=(100, 100, 3) 表示输入图像的形状为 (100, 100, 3)layers.Conv2D(32, (3, 3), activation='relu', input_shape=(100, 50, 3)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),# 将卷积层输出的多维数据展平为一维,为全连接层做准备。layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(1, activation='sigmoid')])input_data = np.random.rand(1000, 100, 50, 3)  # 随机生成输入数据,其中包含 1000 个样本,每个图像为100x100x3(RGB)。model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])  # 编译模型: 指在使用模型进行训练之前,配置模型参数。model.summary()  # 打印模型概要############################################################################### (2)多种预测方式的时耗for ii in range(3):clock = time.time()res1 = model.predict(x=input_data)elapsed_time = time.time() - clockprint(f"总共耗时1: {elapsed_time:.5f} 秒")clock = time.time()res2 = model.predict(x=input_data, batch_size=8)elapsed_time = time.time() - clockprint(f"总共耗时2: {elapsed_time:.5f} 秒")clock = time.time()test3 = model.predict(tf.data.Dataset.from_tensors(input_data))elapsed_time = time.time() - clockprint(f"总共耗时3: {elapsed_time:.5f} 秒")clock = time.time()res4 = model(x=input_data, training=False)elapsed_time = time.time() - clockprint(f"总共耗时4: {elapsed_time:.5f} 秒")print(" ")"""总共耗时1: 0.09967 秒总共耗时2: 0.12919 秒总共耗时3: 0.08310 秒总共耗时4: 0.01395 秒"""

http://www.ppmy.cn/server/20249.html

相关文章

手撕netty源码(一)- NioEventLoopGroup

文章目录 前言一、NIO 与 netty二、NioEventLoopGroup 对象的创建过程2.1 创建流程图2.2 EventExecutorChooser 的创建 前言 processOn文档跳转 本文是手撕netty源码系列的开篇文章,会先介绍一下netty对NIO关键代码的封装位置,主要介绍 NioEventLoopGro…

深入剖析图像平滑与噪声滤波

噪声 在数字图像处理中,噪声是指在图像中引入的不希望的随机或无意义的信号。它是由于图像采集、传输、存储或处理过程中的各种因素引起的。 噪声会导致图像质量下降,使图像失真或降低细节的清晰度。它通常表现为图像中随机分布的亮度或颜色变化&#…

TypeScript入门第一天,所有类型+基础用法+接口使用

表示逻辑值&#xff1a;true 和 false。在JavaScript和TypeScript里叫做boolean | | 数组类型 | 无 | 声明变量为数组。 // 在元素类型后面加上[] let arr: number[] [1, 2]; // 或者使用数组泛型&#xff0c;Array<元素类型> let arr: Array [1, 2]; | | 元组…

MySQL__索引

文章目录 &#x1f60a; 作者&#xff1a;Lion J &#x1f496; 主页&#xff1a; https://blog.csdn.net/weixin_69252724 &#x1f389; 主题&#xff1a; MySQL__索引&#xff09; ⏱️ 创作时间&#xff1a;2024年04月23日 ———————————————— 这里写目…

SpringBoot学习之SpringBoot3集成OpenApi(三十八)

Springboot升级到Springboot3以后,就彻底放弃了对之前swagger的支持,转而重新支持最新的OpenApi,今天我们通过一个实例初步看看OpenApi和Swagger之间的区别. 一、POM依赖 我的POM文件如下,仅作参考: <?xml version="1.0" encoding="UTF-8"?>…

盲人咖啡厅导航:科技之光点亮独立生活新里程

在这个繁华的世界中&#xff0c;咖啡厅不仅是人们社交聚会、休闲阅读的场所&#xff0c;更是无数人心灵栖息的一方天地。然而&#xff0c;对于视障群体而言&#xff0c;独自前往这样的公共场所往往面临重重挑战。幸运的是&#xff0c;一款名为蝙蝠避障专为盲人设计的辅助应用&a…

把私有数据接入 LLMs:应用程序轻松集成 | 开源日报 No.236

run-llama/llama_index Stars: 29.9k License: MIT llama_index 是用于 LLM 应用程序的数据框架。 该项目解决了如何最佳地利用私有数据增强 LLMs&#xff0c;并提供以下工具&#xff1a; 提供数据连接器&#xff0c;以摄取现有的数据源和各种格式&#xff08;API、PDF、文档…

Spring boot + Redis + Spring Cache 实现缓存

学习 Redis 的 value 有 5 种常用的数据结构 Redis 存储的是 key-value 结构的数据。key 是字符串类型&#xff0c;value 有 5 种常用的数据结构&#xff1a; Redis 的图形化工具 Another Redis Desktop Manager Spring Data Redis Redis 的 Java 客户端。 Spring Cache Spr…