torch.full
是 PyTorch 中用于创建一个具有指定形状、填充值和数据类型的张量的函数。它非常适用于需要初始化特定数值的张量的情况,比如将所有元素填充为一个常量值。
函数定义
torch.full(size, fill_value, *, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
参数说明
- size:
tuple
,指定要创建张量的形状(如(2, 3)
表示 2 行 3 列的矩阵)。 - fill_value:
float
或int
,用来填充张量的数值。 - dtype:
torch.dtype
,张量的数据类型(如torch.float
、torch.int
等)。默认根据fill_value
推断。 - layout:
torch.layout
,张量的布局,默认是torch.strided
,表示标准的内存布局。 - device:
torch.device
,指定创建张量的设备(如"cpu"
或"cuda"
)。 - requires_grad:
bool
,是否需要计算梯度,通常在训练模型时使用,默认为False
。
示例代码
1. 创建一个填充为常数值的张量
import torch# 创建一个 2x3 的张量,所有元素填充为 7
tensor = torch.full((2, 3), 7)
print(tensor)
2. 指定数据类型
# 创建一个 2x2 的张量,所有元素为 3.14,指定数据类型为 float32
tensor = torch.full((2, 2), 3.14, dtype=torch.float32)
print(tensor)
3. 在 GPU 上创建张量
# 创建一个 3x3 的张量,填充值为 -1,在 CUDA 设备上
tensor = torch.full((3, 3), -1, device="cuda")
print(tensor)
4. 创建一个需要梯度的张量
# 创建一个 4x4 的张量,填充值为 0.5,并启用 requires_grad
tensor = torch.full((4, 4), 0.5, requires_grad=True)
print(tensor)
应用场景
- 用于初始化权重矩阵或其他模型参数。
- 创建掩码张量(mask tensor),例如填充为 0 或 1 的张量,适合在自然语言处理或图像处理任务中定义有效区域。
- 生成常量张量来辅助计算和操作。
torch.full
是一个非常灵活的初始化方法,在需要定义一个特定常量值的张量时非常实用。