Kaggle中使用
python">!pip install git+https://github.com/sberbank-ai/Real-ESRGAN.git
python">import os
from huggingface_hub import hf_hub_download
import torch
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
model_path = hf_hub_download(repo_id="Shandypur/ESRGAN-4x-UltraSharp", filename="4x-UltraSharp.pth", repo_type="model")
weights = torch.load(model_path) # pth -> pth转换
map_key = [["conv_body", "sub.23"],["body", "sub"],["rdb", "RDB"],["", "model.1."],["conv_first", "model.0"],["conv_up1", "model.3"],["conv_up2", "model.6"],["conv_hr", "model.8"],["conv_last", "model.10"],[".w", ".0.w"],[".b", ".0.b"]
]
state_dict = {}
for k in list(weights.keys()):v = weights[k]for m_k in map_key:k = k.replace(m_k[1], m_k[0]) state_dict[k] = v# 权重保存路径
model_path_pt = "./state_dict.pth"
torch.save(state_dict, model_path_pt)import torch, time
from PIL import Image
import numpy as np
from RealESRGAN import RealESRGANdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载模型和权重
model = RealESRGAN(device, scale=4)
model.load_weights(model_path_pt ) # 推理
path_to_images = "/kaggle/input/esr-low-img/1Z412103326-1-1200.jpg"
image = Image.open(path_to_image).convert('RGB')
sr_image = model.predict(image)# 保存图片
sr_image.save(str(time.time()) + 'sr_image.png')
转换为onnx模型
python">!pip install onnxruntime
python"># to onnx 这里只是举个例子
import torchdummy_input = torch.randn(1, 3, 224, 224).to("cuda") # 输入示例
rdb = model.model.eval() # 原作者自定义的RealESRGAN 下的继承了 torch.nn.Module 的 model 组件才可以被转换为onnx# onnx模型转换以及保存
onnx_model_path = "./RDB.onnx"
torch.onnx.export(rdb, dummy_input, onnx_model_path
)# 使用onnx模型推理
import numpy as np
import onnxruntime as ort # 加载 ONNX 模型
ort_session = ort.InferenceSession(onnx_model_path) # 获取输入和输出的名称
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name # 准备输入数据
input_data = np.random.rand(1, 3, 224, 224).astype(np.float32) # 运行推理
results = ort_session.run([output_name], {input_name: input_data}) # 处理推理结果
output = results[0]
# ...
转换为paddle模型
paddle中超分体验链接
python">!pip install paddlepaddle
python"># to paddle
import paddle
import os
import torch
from huggingface_hub import hf_hub_downloados.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
model_path = hf_hub_download(repo_id="Shandypur/ESRGAN-4x-UltraSharp", filename="4x-UltraSharp.pth", repo_type="model")
weights = torch.load(model_path)# 键映射
map_key_pd = [["trunk_conv", "model.1.sub.23"],["RRDB_trunk", "model.1.sub"],["conv_first", "model.0"],["upconv1", "model.3"],["upconv2", "model.6"],["HRconv", "model.8"],["conv_last", "model.10"],[".w", ".0.w"],[".b", ".0.b"]
]state_dict = {}
for k in list(weights.keys()):v = weights[k].numpy()for m_k in map_key_pd:k = k.replace(m_k[1], m_k[0]) state_dict[k] = paddle.to_tensor(v)# state_dict.keys()state_dict_pd = {}
state_dict_pd["generator"] = state_dict# state_dict_pd["generator"].keys()
# 保存
model_path_pd = "./ERRGAN_UltralSharp_X4.pdparams"
paddle.save(state_dict_pd, model_path_pd)