区分stable diffusion中的通道数与张量维度

server/2024/10/18 8:30:46/

区分stable diffusion中的通道数与张量形状

  • 1.通道数:
    • 1.1 channel = 3
    • 1.2 channel = 4
  • 2.张量形状
  • 3.应用
    • 3.1 问题
    • 3.2 举例
    • 3.3 张量可以理解为多维可变数组

前言:通道数与张量形状都在数值3和4之间变换,容易混淆。

1.通道数:

1.1 channel = 3

RGB 图像具有 3 个通道(红色、绿色和蓝色)。

1.2 channel = 4

Stable Diffusion has 4 latent channels。
如何理解卷积神经网络中的通道(channel)

2.张量形状

2.1 3D 张量

形状为 (C, H, W),其中 C 是通道数,H 是高度,W 是宽度。这适用于单个图像。

2.2 4D 张量

2.2.1 通常

形状为 (B, C, H, W),其中 B 是批次大小,C 是通道数,H 是高度,W 是宽度。这适用于多个图像(例如,批量处理)。

2.2.2 stable diffusion

在img2img中,将image用vae编码并按照timestep加噪:

		# This code copyed from diffusers.pipline_controlnet_img2img.py# 6. Prepare latent variableslatents = self.prepare_latents(image,latent_timestep,batch_size,num_images_per_prompt,prompt_embeds.dtype,device,generator,)

image的dim(维度)是3,而latents的dim为4。
让我们先看text2img的prepare_latents函数:

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latentsdef prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)if isinstance(generator, list) and len(generator) != batch_size:raise ValueError(f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"f" size of {batch_size}. Make sure the batch size matches the length of the generators.")if latents is None:latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)else:latents = latents.to(device)# scale the initial noise by the standard deviation required by the schedulerlatents = latents * self.scheduler.init_noise_sigmareturn latents

显然,shape已经规定了latents的dim(4)和排列顺序。
在img2img中:

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latentsdef prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):raise ValueError(f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}")image = image.to(device=device, dtype=dtype)batch_size = batch_size * num_images_per_promptif image.shape[1] == 4:init_latents = imageelse:if isinstance(generator, list) and len(generator) != batch_size:raise ValueError(f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"f" size of {batch_size}. Make sure the batch size matches the length of the generators.")elif isinstance(generator, list):init_latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)]init_latents = torch.cat(init_latents, dim=0)else:init_latents = self.vae.encode(image).latent_dist.sample(generator)init_latents = self.vae.config.scaling_factor * init_latentsif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:# expand init_latents for batch_sizedeprecation_message = (f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"" your script to pass as many initial images as text prompts to suppress this warning.")deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)additional_image_per_prompt = batch_size // init_latents.shape[0]init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:raise ValueError(f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.")else:init_latents = torch.cat([init_latents], dim=0)shape = init_latents.shapenoise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)# get latentsinit_latents = self.scheduler.add_noise(init_latents, noise, timestep)latents = init_latentsreturn latents

3.应用

3.1 问题

new_map = texture.permute(1, 2, 0)
RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 4 is not equal to len(dims) = 3

该问题是张量形状的问题,跟通道数毫无关系。

3.2 举例

问:4D 张量:形状为 (B, C, H, W),其中C可以为3吗?
答:4D 张量的形状为 (B,C,H,W),其中 C 表示通道数。通常情况下,C 可以为 3,这对应于 RGB 图像的三个颜色通道(红色、绿色和蓝色)。

3.3 张量可以理解为多维可变数组

print("sample:", sample.shape)
print("sample:", sample[0].shape)
print("sample:", sample[0][0].shape)
>>
sample: torch.Size([10, 4, 96, 96])
sample: torch.Size([4, 96, 96])
sample: torch.Size([96, 96])

由此可见,可以将张量形状为torch.size([10, 4, 96, 96])理解为一个4维可变数组。


http://www.ppmy.cn/server/14476.html

相关文章

设计模式-状态模式在Java中的使用示例-信用卡业务系统

场景 在软件系统中,有些对象也像水一样具有多种状态,这些状态在某些情况下能够相互转换,而且对象在不同的状态下也将具有不同的行为。 为了更好地对这些具有多种状态的对象进行设计,我们可以使用一种被称之为状态模式的设计模式…

【Python-编程模式】

Python-编程模式 ■ 单例模式■ 工厂模式■■ ■ 单例模式 新建文件 str_tools.py 如下代码。 class StrTools:passstr_tool StrTools()在其他文件使用时导入该变量。 from str_tools_py import str_tool s1 str_tool s2 str_tool print(id(s1)) print(id(s2))■ 工厂模式…

20240330-1-词嵌入模型w2v+tf-idf

Word2Vector 1.什么是词嵌入模型? 把词映射为实数域向量的技术也叫词嵌⼊ 2.介绍一下Word2Vec 谷歌2013年提出的Word2Vec是目前最常用的词嵌入模型之一。Word2Vec实际是一种浅层的神经网络模型,它有两种网络结构,分别是连续词袋&#xff…

Hive 数据倾斜

1.什么是数据倾斜 数据倾斜:数据分布不均匀,造成数据大量的集中到一点,造成数据热点。主要表现为任务进度长时间维持在 99%或者 100%的附近,查看任务监控页面,发现只有少量 reduce 子任务未完成,因为其处理…

MATLAB命令

MATLAB是一个用于数值计算和数据可视化的交互式程序。您可以通过在命令窗口的MATLAB提示符 ‘>>’ 处键入命令来输入命令。 在本节中,我们将提供常用的通用MATLAB命令列表。 用于管理会话的命令 MATLAB提供了用于管理会话的各种命令。下表提供了所有此类命令…

【Linux】NFS网络文件系统搭建

一、服务端配置 #软件包安装 [roothadoop01 ~]# yum install rpcbind nfs-utils.x86_64 -y [roothadoop01 ~]# mkdir /share#配置文件修改 #格式为 共享资源路径 [主机地址] [选项] # [roothadoop01 ~]# vi /etc/exports /share 192.168.10.0/24(rw,sync,no_root_squash) #…

什么是Java中的Web服务?

Java中的Web服务是一种应用程序,它使用网络和基于Web的标准通信协议,如HTTP和XML,为客户端提供服务。Web服务允许不同的机器在不同的操作系统和编程语言之间进行交互,而无需考虑底层的技术细节。这种交互是通过交换简单的、标准化…

Day16-Java进阶-线程通信线程生命周期线程池单例设计模式

1. 线程通信 1.1 线程通信介绍 1.2 两条线程通信 package com.itheima.correspondence;public class CorrespondenceDemo1 {/*两条线程通信*/public static void main(String[] args) {Printer1 p new Printer1();new Thread(new Runnable() {Overridepublic void run() {syn…