可使用的 ESRGAN 超分模型

server/2024/9/23 9:01:00/

Kaggle中使用

python">!pip install git+https://github.com/sberbank-ai/Real-ESRGAN.git
python">import os
from huggingface_hub import hf_hub_download
import torch
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
model_path = hf_hub_download(repo_id="Shandypur/ESRGAN-4x-UltraSharp", filename="4x-UltraSharp.pth", repo_type="model")
weights = torch.load(model_path)  # pth -> pth转换
map_key = [["conv_body", "sub.23"],["body", "sub"],["rdb", "RDB"],["", "model.1."],["conv_first", "model.0"],["conv_up1", "model.3"],["conv_up2", "model.6"],["conv_hr", "model.8"],["conv_last", "model.10"],[".w", ".0.w"],[".b", ".0.b"]
]
state_dict = {}
for k in list(weights.keys()):v = weights[k]for m_k in map_key:k = k.replace(m_k[1], m_k[0]) state_dict[k] = v# 权重保存路径
model_path_pt = "./state_dict.pth"   
torch.save(state_dict, model_path_pt)import torch, time
from PIL import Image
import numpy as np
from RealESRGAN import RealESRGANdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载模型和权重
model = RealESRGAN(device, scale=4)
model.load_weights(model_path_pt ) # 推理
path_to_images = "/kaggle/input/esr-low-img/1Z412103326-1-1200.jpg"
image = Image.open(path_to_image).convert('RGB')
sr_image = model.predict(image)# 保存图片
sr_image.save(str(time.time()) + 'sr_image.png')

转换为onnx模型

python">!pip install onnxruntime
python"># to onnx 这里只是举个例子
import torchdummy_input = torch.randn(1, 3, 224, 224).to("cuda") # 输入示例
rdb = model.model.eval()  # 原作者自定义的RealESRGAN 下的继承了 torch.nn.Module 的 model 组件才可以被转换为onnx# onnx模型转换以及保存
onnx_model_path = "./RDB.onnx" 
torch.onnx.export(rdb, dummy_input, onnx_model_path
)# 使用onnx模型推理
import numpy as np  
import onnxruntime as ort  # 加载 ONNX 模型  
ort_session = ort.InferenceSession(onnx_model_path)  # 获取输入和输出的名称  
input_name = ort_session.get_inputs()[0].name  
output_name = ort_session.get_outputs()[0].name  # 准备输入数据  
input_data = np.random.rand(1, 3, 224, 224).astype(np.float32)  # 运行推理  
results = ort_session.run([output_name], {input_name: input_data})  # 处理推理结果  
output = results[0]  
# ...

转换为paddle模型

paddle中超分体验链接

python">!pip install paddlepaddle 
python"># to paddle
import paddle
import os
import torch
from huggingface_hub import hf_hub_downloados.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
model_path = hf_hub_download(repo_id="Shandypur/ESRGAN-4x-UltraSharp", filename="4x-UltraSharp.pth", repo_type="model")
weights = torch.load(model_path)# 键映射
map_key_pd = [["trunk_conv", "model.1.sub.23"],["RRDB_trunk", "model.1.sub"],["conv_first", "model.0"],["upconv1", "model.3"],["upconv2", "model.6"],["HRconv", "model.8"],["conv_last", "model.10"],[".w", ".0.w"],[".b", ".0.b"]
]state_dict = {}
for k in list(weights.keys()):v = weights[k].numpy()for m_k in map_key_pd:k = k.replace(m_k[1], m_k[0]) state_dict[k] = paddle.to_tensor(v)# state_dict.keys()state_dict_pd = {}
state_dict_pd["generator"] = state_dict# state_dict_pd["generator"].keys()
# 保存
model_path_pd = "./ERRGAN_UltralSharp_X4.pdparams"
paddle.save(state_dict_pd, model_path_pd)

http://www.ppmy.cn/server/17169.html

相关文章

Android音视频开发-AudioTrack

Android音视频开发-AudioTrack 本篇文章我们主要介绍下AudioTrack. 1: 简介 AudioTrack是Android平台上的一个类,用于播放音频数据. 它允许PCM音频缓冲区流式传输到音频接收器进行播放. 创建AudioTrack对象:可以通过构造函数创建AudioTrack对象&…

java后端项目:视积分抽奖平台

一、项目背景: 本次抽奖系统实现是在视频中内置一个线上活动抽奖系统,奖品是在一个时间段区间内均匀发布,用户可以在这个时间段内参与抽奖。 二、项目架构 活动抽奖平台采用微服务架构来完成,在功能上实现拆分为用户、网关、以及抽奖微服务,其中用户、网关是后台项目通…

AjaxAxios

Ajax 注:AJAX很少使用,现在都使用更简单的Axios所以只需要了解Ajax即可 概念 AJAX,全称“Asynchronous JavaScript and XML”(异步JavaScript和XML) 作用: 与服务器进行数据交换,通过Ajax可…

Open CASCADE学习|一个点的坐标变换

gp_Trsf 类是 Open CASCADE Technology (OCCT) 软件库中的一个核心类,用于表示和操作三维空间中的变换。以下是该类的一些关键成员和方法的介绍: 成员变量: scale: Standard_Real 类型,表示变换的缩放因子。 shape: gp_TrsfFor…

pytorch的mask-rcnn的模型参数解释

输入图像1920x1080,batch_size8为例. 训练阶段 loss_dict model(images,targets) 入参 images: List(Tensor(3,1920,1080))[8]targets: List(dict()[3])[8] dict详情见下表: keytypedtypesizeremarkboxesTensorfloat32(n,4)1the ground-truth boxes in [x1, y1, x2, y2] …

HarmonyOS NEXT应用开发之swiper指示器导航点位于swiper下方

介绍 本示例介绍通过分割swiper区域,实现指示器导航点位于swiper下方的效果。 效果预览图 使用说明 加载完成后swiper指示器导航点,位于显示内容下方。 实现思路 将swiper区域分割为两块区域,上方为内容区域,下方为空白区域。…

简单了解Ajax

什么是Ajax Ajax,全称 Asynchronous JavaScript and XML(异步的 JavaScript 和 XML),是一种用于创建更好更快以及交互性更强的网页应用的技术。它允许网页在不重新加载整个页面的情况下,与服务器交换数据并更新部分网…

c++ 模板和对象多态的结合情况举例

1.概要 模板和对象多态的结合情况举例 模板和多态的碰撞会擦出什么样的火花呢? 模板是忽略类型,啥意思,就是一个函数,来一个模板的对象,无论啥对象,我都按照一个套路处理。 多态呢,多态是根据…