SDXL总结

news/2024/10/16 2:22:35/

SDXL base部分的权重:https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main

diffusers库中的SDXL代码pipelines:

https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion_xl

参考:深入浅出完整解析Stable Diffusion XL(SDXL)核心基础知识 - 知乎 (zhihu.com)


 Stable Diffusion XL是一个二阶段的级联扩散模型(Latent Diffusion Model),包括Base模型和Refiner模型。其中Base模型的主要工作和Stable Diffusion 1.x-2.x一致,具备文生图(txt2img)、图生图(img2img)、图像inpainting等能力。在Base模型之后,级联了Refiner模型,对Base模型生成的图像Latent特征进行精细化提升,其本质上是在做图生图的工作

SDXL Base模型由U-Net、VAE以及CLIP Text Encoder(两个)三个模块组成

SDXL Refiner模型同样由U-Net、VAE和CLIP Text Encoder(一个)三个模块

1.VAE

VAE Encoder与VAE Decoder结构图 

VAE官方开源权重:https://huggingface.co/stabilityai/sdxl-vae

Stable Diffusion XL VAE模型与之前的Stable Diffusion系列并不兼容。如果在SDXL上使用之前系列的VAE,会生成充满噪声的图片。

Stable Diffusion XL VAE采用FP16精度时会出现数值溢出成NaNs的情况,导致重建的图像是一个黑图,所以必须使用FP32精度进行推理重建。

import cv2
import torch
import numpy as np
from diffusers import AutoencoderKL# 加载SDXL VAE模型: SDXL VAE模型可以通过指定subfolder文件来单独加载。
# SDXL VAE模型权重百度云网盘:关注Rocky的公众号WeThinkIn,后台回复:SDXL模型,即可获得资源链接
VAE = AutoencoderKL.from_pretrained("/本地路径/sdxl-vae")
VAE.to("cuda") # 用OpenCV读取和调整图像大小
raw_image = cv2.imread("test_vae.png")
raw_image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB)
raw_image = cv2.resize(raw_image, (1024, 1024))# 将图像数据转换为浮点数并归一化
image = raw_image.astype(np.float32) / 127.5 - 1.0# 调整数组维度以匹配PyTorch的格式 (N, C, H, W)
image = image.transpose(2, 0, 1)
image = image[None, :, :, :]# 转换为PyTorch张量
image = torch.from_numpy(image).to("cuda")# 压缩图像为Latent特征并重建
with torch.inference_mode():# 使用SDXL VAE进行压缩和重建latent = VAE.encode(image).latent_dist.sample()rec_image = VAE.decode(latent).sample# 后处理rec_image = (rec_image / 2 + 0.5).clamp(0, 1)rec_image = rec_image.cpu().permute(0, 2, 3, 1).numpy()# 反归一化rec_image = (rec_image * 255).round().astype("uint8")rec_image = rec_image[0]# 保存重建后图像cv2.imwrite("reconstructed_sdxl.png", cv2.cvtColor(rec_image, cv2.COLOR_RGB2BGR))

 2.Unet

SDXL Base部分的 U-Net的完整结构图

 Stable Diffusion XL中的Text Condition信息由两个Text Encoder提供(OpenCLIP ViT-bigG和OpenAI CLIP ViT-L),将两个Text Encoder提取的Token Embedding进行Contact,通过Cross Attention组件嵌入,作为K Matrix和V Matrix。与此同时,图片的Latent Feature作为Q Matrix

3.Text Encoder模型

Stable Diffusion XL分别提取两个Text Encoder的倒数第二层特征,并进行concat操作作为文本条件(Text Conditioning)。其中OpenCLIP ViT-bigG的特征维度为77x1280,而OpenAI CLIP ViT-L/14的特征维度是77x768,所以输入总的特征维度是77x2048(77是最大的token数,2048是SDXL的context dim),再通过Cross Attention模块将文本信息传入Stable Diffusion XL的训练过程与推理过程中。

Stable Diffusion XL与之前的系列相比使用了两个CLIP Text Encoder,分别是OpenCLIP ViT-bigG(694M)和OpenAI CLIP ViT-L/14(123.65M),从而大大增强了Stable Diffusion XL对文本的提取和理解能力,同时提高了输入文本和生成图片的一致性

SDXL OpenCLIP ViT-bigG的完整结构图

SDXL OpenCLIP ViT-bigG的文本编码过程:

from transformers import CLIPTextModel, CLIPTokenizer# 加载 OpenCLIP ViT-bigG Text Encoder模型和Tokenizer
# SDXL模型权重百度云网盘:关注Rocky的公众号WeThinkIn,后台回复:SDXL模型,即可获得资源链接
text_encoder = CLIPTextModel.from_pretrained("/本地路径/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2").to("cuda")
text_tokenizer = CLIPTokenizer.from_pretrained("/本地路径/stable-diffusion-xl-base-1.0", subfolder="tokenizer_2")# 将输入SDXL模型的prompt进行tokenize,得到对应的token ids特征
prompt = "1girl,beautiful"
text_token_ids = text_tokenizer(prompt,padding="max_length",max_length=text_tokenizer.model_max_length,truncation=True,return_tensors="pt"
).input_idsprint("text_token_ids' shape:",text_token_ids.shape)
print("text_token_ids:",text_token_ids)# 将token ids特征输入OpenCLIP ViT-bigG Text Encoder模型中输出77x1280的Text Embeddings特征
text_embeddings = text_encoder(text_token_ids.to("cuda"))[0] # 由于Text Encoder模型输出的是一个元组,所以需要[0]对77x1280的Text Embeddings特征进行提取
print("text_embeddings' shape:",text_embeddings.shape)
print(text_embeddings)---------------- 运行结果 ----------------
text_token_ids' shape: torch.Size([1, 77])
text_token_ids: tensor([[49406,   272,  1611,   267,  1215, 49407,     0,     0,     0,     0,0,     0,     0,     0,     0,     0,     0,     0,     0,     0,0,     0,     0,     0,     0,     0,     0,     0,     0,     0,0,     0,     0,     0,     0,     0,     0,     0,     0,     0,0,     0,     0,     0,     0,     0,     0,     0,     0,     0,0,     0,     0,     0,     0,     0,     0,     0,     0,     0,0,     0,     0,     0,     0,     0,     0,     0,     0,     0,0,     0,     0,     0,     0,     0,     0]])
text_embeddings' shape: torch.Size([1, 77, 1280])
tensor([[[-0.1025, -0.3104,  0.1660,  ..., -0.1596, -0.0680, -0.0180],[ 0.7724,  0.3004,  0.5225,  ...,  0.4482,  0.8743, -1.0429],[-0.3963,  0.0041, -0.3626,  ...,  0.1841,  0.2224, -1.9317],...,[-0.8887, -0.2579,  1.3508,  ..., -0.4421,  0.2193,  1.2736],[-0.9659, -0.0447,  1.4424,  ..., -0.4350, -0.1186,  1.2042],[-0.5213, -0.0255,  1.8161,  ..., -0.7231, -0.3752,  1.0876]]],device='cuda:0', grad_fn=<NativeLayerNormBackward0>)

SDXL OpenAI CLIP ViT-L/14的完整结构图 

SDXL OpenAI CLIP ViT-L/14的文本编码过程: 

from transformers import CLIPTextModel, CLIPTokenizer# 加载 OpenAI CLIP ViT-L/14 Text Encoder模型和Tokenizer
# SDXL模型权重百度云网盘:关注Rocky的公众号WeThinkIn,后台回复:SDXL模型,即可获得资源链接
text_encoder = CLIPTextModel.from_pretrained("/本地路径/stable-diffusion-xl-base-1.0", subfolder="text_encoder").to("cuda")
text_tokenizer = CLIPTokenizer.from_pretrained("/本地路径/stable-diffusion-xl-base-1.0", subfolder="tokenizer")# 将输入SDXL模型的prompt进行tokenize,得到对应的token ids特征
prompt = "1girl,beautiful"
text_token_ids = text_tokenizer(prompt,padding="max_length",max_length=text_tokenizer.model_max_length,truncation=True,return_tensors="pt"
).input_idsprint("text_token_ids' shape:",text_token_ids.shape)
print("text_token_ids:",text_token_ids)# 将token ids特征输入OpenAI CLIP ViT-L/14 Text Encoder模型中输出77x768的Text Embeddings特征
text_embeddings = text_encoder(text_token_ids.to("cuda"))[0] # 由于Text Encoder模型输出的是一个元组,所以需要[0]对77x768的Text Embeddings特征进行提取
print("text_embeddings' shape:",text_embeddings.shape)
print(text_embeddings)---------------- 运行结果 ----------------
text_token_ids' shape: torch.Size([1, 77])
text_token_ids: tensor([[49406,   272,  1611,   267,  1215, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407]])
text_embeddings' shape: torch.Size([1, 77, 768])
tensor([[[-0.3885,  0.0230, -0.0521,  ..., -0.4901, -0.3065,  0.0674],[-0.8424, -1.1387,  1.2767,  ..., -0.2598,  1.6289, -0.7855],[ 0.1751, -0.9847,  0.1881,  ...,  0.0657, -1.4940, -1.2612],...,[ 0.2039, -0.7298, -0.3206,  ...,  0.6751, -0.5814, -0.7320],[ 0.1921, -0.7345, -0.3039,  ...,  0.6806, -0.5852, -0.7228],[ 0.2112, -0.6438, -0.3042,  ...,  0.6628, -0.5576, -0.7583]]],device='cuda:0', grad_fn=<NativeLayerNormBackward0>)

 以上都为SDXL的base模型


 4.Refiner模型

由于已经有U-Net(Base)模型生成了图像的Latent特征,所以Refiner模型的主要工作是在Latent特征进行小噪声去除和细节质量提升

Refiner模型和Base模型一样是基于Latent的扩散模型,也采用了Encoder-Decoder结构,和U-Net兼容同一个VAE模型。不过在Text Encoder部分,Refiner模型只使用了OpenCLIP ViT-bigG的Text Encoder,同样提取了倒数第二层特征以及进行了pooled text embedding的嵌入。

refine模型中的Unet结构:

 单独使用Stable Diffusion XL中的Base模型来生成图像: 

# 加载diffusers和torch依赖库
from diffusers import DiffusionPipeline
import torch# 加载Stable Diffusion XL Base模型(stable-diffusion-xl-base-1.0或stable-diffusion-xl-base-0.9)
pipe = DiffusionPipeline.from_pretrained("/本地路径/stable-diffusion-xl-base-1.0",torch_dtype=torch.float16, variant="fp16")
# "/本地路径/stable-diffusion-xl-base-1.0"表示我们需要加载的Stable Diffusion XL Base模型路径
# 大家可以关注Rocky的公众号WeThinkIn,后台回复:SDXL模型,即可获得SDXL模型权重资源链接
# "fp16"代表启动fp16精度。比起fp32,fp16可以使模型显存占用减半# 使用GPU进行Pipeline的推理
pipe.to("cuda")# 输入提示词
prompt = "Watercolor painting of a desert landscape, with sand dunes, mountains, and a blazing sun, soft and delicate brushstrokes, warm and vibrant colors"# 输入负向提示词,表示我们不想要生成的特征
negative_prompt = "(EasyNegative),(watermark), (signature), (sketch by bad-artist), (signature), (worst quality), (low quality), (bad anatomy), NSFW, nude, (normal quality)"# 设置seed,可以固定生成图像中的构图
seed = torch.Generator("cuda").manual_seed(42)# SDXL Base Pipeline进行推理
image = pipe(prompt, negative_prompt=negative_prompt,generator=seed).images[0]
# Pipeline生成的images包含在一个list中:[<PIL.Image.Image image mode=RGB size=1024x1024>]
#所以需要使用images[0]来获取list中的PIL图像# 保存生成图像
image.save("SDXL-Base.png")

将SDXL Base模型和SDXL Refiner模型级联来生成图像: 

from diffusers import DiffusionPipeline
import torch# 下面的五行代码不变
pipe = DiffusionPipeline.from_pretrained("/本地路径/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16")pipe.to("cuda")prompt = "Watercolor painting of a desert landscape, with sand dunes, mountains, and a blazing sun, soft and delicate brushstrokes, warm and vibrant colors"negative_prompt = "(EasyNegative),(watermark), (signature), (sketch by bad-artist), (signature), (worst quality), (low quality), (bad anatomy), NSFW, nude, (normal quality)"seed = torch.Generator("cuda").manual_seed(42)# 运行SDXL Base模型的Pipeline,设置输出格式为output_type="latent"
image = pipe(prompt=prompt, negative_prompt=negative_prompt, generator=seed, output_type="latent").images# 加载Stable Diffusion XL Refiner模型(stable-diffusion-xl-refiner-1.0或stable-diffusion-xl-refiner-0.9)
pipe = DiffusionPipeline.from_pretrained("/本地路径/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16")
# "本地路径/stable-diffusion-xl-refiner-1.0"表示我们需要加载的Stable Diffusion XL Refiner模型,
# 大家可以关注Rocky的公众号WeThinkIn,后台回复:SDXL模型,即可获得SDXL模型权重资源链接pipe.to("cuda")# SDXL Refiner Pipeline进行推理
images = pipe(prompt=prompt, negative_prompt=negative_prompt, generator=seed, image=image).images# 保存生成图像
images[0].save("SDXL-Base-Refiner.png")


http://www.ppmy.cn/news/1504718.html

相关文章

【Wireshark 抓 CAN 总线】Wireshark 抓取 CAN 总线数据的实现思路

最近看到一个帖子 Wireshark 对接 Windows 系统命名管道&#xff0c;抓取数据 我突然想到一个很有意思的方式 你没看错 用 Wireshark 来抓取 CAN 总线数据 【其实 Wireshark 上有 CAN 总线的的解码器&#xff0c;不信你可以在表达式栏打 can 试下&#xff0c;是有这个解码器的】…

Python爬虫基础:爬取网页内容解析标题

当你需要从网页上获取数据并进行处理时&#xff0c;Python的BeautifulSoup和requests库是非常强大的工具。这些库可以帮助你发送HTTP请求&#xff0c;获取网页内容&#xff0c;并解析HTML以提取所需的信息。在这篇博客文章中&#xff0c;我们将详细介绍如何使用这些库从网页上获…

崆峒酥饼—旅游与访友的绝佳选择

当您踏上旅途&#xff0c;或是准备拜访亲朋好友&#xff0c;总在寻觅一份既能代表心意&#xff0c;又独具特色的礼物。而食家巷崆峒酥饼&#xff0c;无疑是您的不二之选。 崆峒酥饼&#xff0c;源自平凉的美食瑰宝&#xff0c;每一口都承载着浓厚的地方风情。它的外表金黄…

【Android】ContentProvider基本概念

ContentProvider Android权限机制详解 <manifest xmlns:android"http://schemas.android.com/apk/res/android"package"com.example.broadcasttest"> <uses-permission android:name"android.permission.RECEIVE_BOOT_COMPLETED" />…

亲子游戏 - 华为OD统一考试(D卷)

OD统一考试(D卷) 分值: 200分 题解: Java / Python / C++ 题目描述 宝宝和妈妈参加亲子游戏,在一个二维矩阵(N*N)的格子地图上,宝宝和妈妈抽签决定各自的位置,地图上每个格子有不同的糖果数量,部分格子有障碍物。 游戏规则是妈妈必须在最短的时间(每个单位时间只能走…

“常温”前端网站框架(四)-- 音乐播放器【附源码】

开篇&#xff08;请大家看完&#xff09;&#xff1a;此网站写给挚爱&#xff0c;后续页面还会慢慢更新&#xff0c;大家敬请期待~ ~ ~ 此前端框架&#xff0c;主要侧重于前端页面的视觉效果和交互体验。通过运用各种前端技术和创意&#xff0c;精心打造了一系列引人入胜的页面…

MySQL:表级锁

表级锁 Table Lock&#xff08;表锁&#xff09;是一种数据库锁&#xff08;Lock&#xff09;机制&#xff0c;用于控制并发访问数据库表的操作。当一个会话对表进行操作时&#xff0c;会自动获取相应的锁&#xff0c;以确保其他会话无法同时修改该表的数据&#xff0c;从而维…

拉刀基础知识——拉刀的种类

如前面所说&#xff1a;近期要围绕拉削和拉刀这个话题&#xff0c;分享一些相关的内容&#xff0c;从最基础的知识开始&#xff0c;为此还专门买了本旧书——《拉刀设计》入门学习。废话不多说&#xff0c;直接开始。 拉刀最早由冲头演变而来&#xff0c;用于加工方孔&#xf…