注意1d和2d的函数区别。注意默认参数问题。最终三个版本结果能够对齐。
load('wave_in.mat')% res: image of 1536 x 1536
th=1;
dlevel=7;
wavename='db6';[m,n] = wavedec2(res, dlevel, wavename);vec = zeros(size(m));
vec(1:n(1)*n(1)*1) = m(1:n(1)*n(1)*1);background = waverec2(vec, n, wavename);
background(background<0.001)=0;
python version1
python">import time
import torch
from mat_utils import load_mat
import numpy as np
import pywt
import copydata_dict = load_mat("test_data/wave_in.mat")
res = data_dict["res"]
# res: image of 1536 x 1536data_torch = torch.from_numpy(res)wavelet = pywt.Wavelet('db6')wavename = 'db6'
dlevel = 7coeffs_py = pywt.wavedec2(res, wavename, level=dlevel)# keep fist one and zero all others
for i in range(1, len(coeffs_py)):sub_coefs = list(coeffs_py[i])sub_coefs = [np.zeros_like(tensor) for tensor in sub_coefs]coeffs_py[i] = sub_coefsbackground_py = pywt.waverec2(coeffs_py, wavename)
python version2 with pytorch
使用库
https://github.com/v0lta/PyTorch-Wavelet-Toolbox
python">import time
import torch
from mat_utils import load_mat
import numpy as np
import pywt
import ptwt
# pip install ptwtdata_dict = load_mat("test_data/wave_in.mat")
res = data_dict["res"]
# res: image of 1536 x 1536data_torch = torch.from_numpy(res)wavename = 'db6'
dlevel = 7wavelet = pywt.Wavelet(wavename)coeffs_pt = ptwt.wavedec2(data_torch, wavelet, level=dlevel, mode='symmetric')
coeffs_pt = list(coeffs_pt)# keep fist one and zero all others
for i in range(1, len(coeffs_pt)):sub_coefs = list(coeffs_pt[i])sub_coefs = [torch.zeros_like(tensor) for tensor in sub_coefs]coeffs_pt[i] = tuple(sub_coefs)coeffs_pt = tuple(coeffs_pt)
background_pt = ptwt.waverec2(coeffs_pt, wavelet).squeeze()background_pt_np = background_pt.numpy()