import torch. nn. functional as F
def crop_tensor_by_height_width ( tensor, height_crop, width_crop) : assert len ( tensor. shape) == 4 , '输入的tensor应为4维' assert height_crop > 0 and width_crop > 0 , 'crop应该大于0' height_extra = 0 width_extra = 0 if height_crop % 2 != 0 : height_extra = 1 if width_crop % 2 != 0 : width_extra = 1 lower_bound_height_crop = height_crop // 2 lower_bound_width_crop = width_crop // 2 original_height, original_width = tensor. shape[ 2 ] , tensor. shape[ 3 ] upper_width_height_crop = original_height - height_crop // 2 - height_extraupper_width_width_crop = original_width - width_crop // 2 - width_extrareturn tensor[ : , : , lower_bound_height_crop: upper_width_height_crop, lower_bound_width_crop: upper_width_width_crop] def crop_or_pad_tensor_by_height_width ( tensor, height_crop, width_crop, pad_value= 0 ) : '''裁剪或扩展Tensor在高度(仅底部)和宽度(仅右侧)维度上的最后一个像素。正数表示扩展(用0填充),负数表示裁剪。参数:tensor (torch.Tensor): 输入的4维张量,形状为 (batch_size, channels, height, width)height_crop (int): 高度方向上底部要裁剪或扩展的像素数量,默认为1width_crop (int): 宽度方向上右侧要裁剪或扩展的像素数量,默认为1pad_value (float or int): 填充时使用的值,默认为0返回:cropped_or_padded_tensor (torch.Tensor): 裁剪或扩展后的张量''' assert len ( tensor. shape) == 4 , '输入的tensor应为4维' original_height, original_width = tensor. shape[ 2 ] , tensor. shape[ 3 ] height_to_remove_from_bottom = min ( original_height, - height_crop) if height_crop < 0 else 0 width_to_remove_from_right = min ( original_width, - width_crop) if width_crop < 0 else 0 pad_bottom = abs ( height_crop) if height_crop > 0 else 0 pad_right = abs ( width_crop) if width_crop > 0 else 0 padded_tensor = F. pad( tensor, pad= ( 0 , pad_right, 0 , pad_bottom) , mode= 'constant' , value= pad_value) if height_to_remove_from_bottom > 0 and width_to_remove_from_right > 0 : cropped_or_padded_tensor = padded_tensor[ : , : , : - height_to_remove_from_bottom, : - width_to_remove_from_right] elif height_to_remove_from_bottom > 0 : cropped_or_padded_tensor = padded_tensor[ : , : , : - height_to_remove_from_bottom, : ] elif width_to_remove_from_right > 0 : cropped_or_padded_tensor = padded_tensor[ : , : , : , : - width_to_remove_from_right] else : cropped_or_padded_tensor = padded_tensorreturn cropped_or_padded_tensordef crop_or_pad_tensor_by_depth_height_width ( tensor, depth_crop, height_crop, width_crop, pad_value= 0 ) : '''裁剪或扩展Tensor在深度(仅最后一个)、高度(仅底部)和宽度(仅右侧)维度上的最后一个像素。正数表示扩展(用0填充),负数表示裁剪。参数:tensor (torch.Tensor): 输入的5维张量,形状为 (batch_size, channels, depth, height, width)depth_crop (int): 深度方向上最后一个要裁剪或扩展的数量,默认为1height_crop (int): 高度方向上底部要裁剪或扩展的像素数量,默认为1width_crop (int): 宽度方向上右侧要裁剪或扩展的像素数量,默认为1pad_value (float or int): 填充时使用的值,默认为0返回:cropped_or_padded_tensor (torch.Tensor): 裁剪或扩展后的张量''' assert len ( tensor. shape) == 5 , '输入的tensor应为5维' original_depth, original_height, original_width = tensor. shape[ 2 ] , tensor. shape[ 3 ] , tensor. shape[ 4 ] depth_to_remove_from_end = min ( original_depth, - depth_crop) if depth_crop < 0 else 0 height_to_remove_from_bottom = min ( original_height, - height_crop) if height_crop < 0 else 0 width_to_remove_from_right = min ( original_width, - width_crop) if width_crop < 0 else 0 pad_depth = abs ( depth_crop) if depth_crop > 0 else 0 pad_bottom = abs ( height_crop) if height_crop > 0 else 0 pad_right = abs ( width_crop) if width_crop > 0 else 0 padded_tensor = F. pad( tensor, pad= ( 0 , pad_right, 0 , pad_bottom, 0 , pad_depth) , mode= 'constant' , value= pad_value) if depth_to_remove_from_end > 0 and height_to_remove_from_bottom > 0 and width_to_remove_from_right > 0 : cropped_or_padded_tensor = padded_tensor[ : , : , : - depth_to_remove_from_end, : - height_to_remove_from_bottom, : - width_to_remove_from_right] elif depth_to_remove_from_end > 0 and height_to_remove_from_bottom > 0 : cropped_or_padded_tensor = padded_tensor[ : , : , : - depth_to_remove_from_end, : - height_to_remove_from_bottom, : ] elif depth_to_remove_from_end > 0 and width_to_remove_from_right > 0 : cropped_or_padded_tensor = padded_tensor[ : , : , : - depth_to_remove_from_end, : , : - width_to_remove_from_right] elif height_to_remove_from_bottom > 0 and width_to_remove_from_right > 0 : cropped_or_padded_tensor = padded_tensor[ : , : , : , : - height_to_remove_from_bottom, : - width_to_remove_from_right] elif depth_to_remove_from_end > 0 : cropped_or_padded_tensor = padded_tensor[ : , : , : - depth_to_remove_from_end, : , : ] elif height_to_remove_from_bottom > 0 : cropped_or_padded_tensor = padded_tensor[ : , : , : , : - height_to_remove_from_bottom, : ] elif width_to_remove_from_right > 0 : cropped_or_padded_tensor = padded_tensor[ : , : , : , : , : - width_to_remove_from_right] else : cropped_or_padded_tensor = padded_tensorreturn cropped_or_padded_tensor