在亚马逊云科技Amazon SageMaker上进行Stable Diffusion模型训练和推理

news/2025/2/13 6:36:32/

 Stable Diffusion Quick Kit是一个基于亚马逊云科技Amazon SageMaker进行Stable Diffusion模型快速部署的工具包,包括了一组示例代码、服务部署脚本、前端UI,可以帮助可以快速部署一套Stable Diffusion的原型服务。

 本文将介绍如何在SageMaker Training Job中加载进行Stable Diffusion XL(以下简称SDXL)的Dreambooth微调训练,及训练完成后使用Stable Diffusion WebUI开源框架进行模型部署和即时推理,实现训推一体的整体pipeline及业务流程。

 背景介绍

 Dreambooth微调训练

 Dreambooth是Stable Diffusion模型训练的一种方式,通过输入instance_prompt定义实体主体(e.g.人物或者实体物品)和instance images的fine-tuning图像,抽取原SD中UNet,VAE网络,将instance prompt和instance images图像与之绑定,以便后续生成的图片中只要有instance的prompt中的关键词,即可保持输入instance图片里面的主体实体,实现人物和物品生图时的高保真效果。

 Stable Diffusion WebUI

 Stable Diffusion WebUI是基于Stable Diffusion开发的一个开源的可视化软件,WebUI在Stable Diffusion txt2img,img2img生图基础上拓展了很多插件来增强Stable Diffusion的生图能力,比如Ultimate Upscale、Inpain等,使得开发者可以方便地通过界面拖拽或者API调用进行Stable Diffusion模型的加载和调用。

 相对于Diffuser SDK的模型推理,WebUI有更丰富的调用参数及更多的插件支持,因此同样模型的出图效果某些场景下会比Diffuser更好,这也是目前业界不少客户使用WebUI API方式进行推理生图的原因。

 训练+推理业务场景

 在遇到使用Stable Diffusion模型微调和推理的业务场景中,针对ToB端客户,通常会上传需要训练的图像,使用Dreambooth训练人物(如模特或者数字人)和商品(如箱包,衣服),然后针对训练好的模型批量生成海报/广告/logo等文案素材的图像,该过程并不需要像app应用一样实时交互的出图,而是一个离线异步的过程。

 这种情况下,可以在训练任务的算力机上,同时安装部署模型微调和模型推理的框架,利用SageMaker Training Job方式,将微调和推理放到一个job中,微调训练完成,即加载model进行推理出图,从而一次性完成模型微调(Dreambooth)+模型推理(WebUI API)整个完整pipeline工作,将推理的模型改造到训练任务中,而不用再单独部署模型的服务端点。

 同时,SageMaker Training Job支持Spot竞价实例,训练任务完成则推理出图也完成,机器资源释放,进一步帮助用户节约整体的成本。

 SageMaker Training Job中进行SDXL Dreambooth Fine-tune

 Dreambooth训练框架

 Stable Diffusion 1.x版本时,Dreambooth fine-tune有多种开源版本的微调框架,SDXL版本后,Diffuser官方在HuggingFace社区发布了基于LoRA的Dreambooth fine-tune框架,代码相对于原1.x版本更加简洁,且使用了更新的xformers加速框架,支持Flash Attension v2,其Pytorch版本也升级到了2.0以上。

 其中train_dreambooth_lora_sdxl.py就是微调训练Dreambooth的代码。

 SageMaker Training Job脚本

 在SageMaker Training Job中,可以clone上一章节的diffuser官方repo训练代码作为source训练脚本目录,并将其依赖的xformers,deepspeed等依赖打包在Docker训练镜像中,通过shell entrypoint方式在算力机上拉起其训练脚本。

 详细如下:

  • 准备source源代码目录并clone官方代码

  • 打包训练任务的docker镜像(使用Amazon预置的0.0+cuda118 HuggingFace DLC容器作为基础镜像,与diffuser官方pytorch/cuda版本保持一致)

  • dockerfile编写

  • build镜像并推送到Amazon ECR镜像仓库

  • 准备训练图像,这里我们使用官方示例dataset图像

  • 图像数据上传到$images_s3uri的S3路径,以便SageMaker Training Job拉取。

  • SageMaker Estimator拉起Training Job

  • 训练任务脚本编写,这里采用shell entrypoint方式,方便调用diffuser官方脚本,且传递环境变量。

 我们通过SageMaker提供的Pytorch的Estimator训练器SDK,拉起Training Job训练任务。

 Dreambooth训练调参

 SDXL Dreambooth Fine-tune的训练参数与之前1.x版本调参类似,这里把Diffuser框架及SageMaker新加的主要配置参数说明如下:

  • ‘images’:f”s3://{bucket}/dreambooth-xl/images/”:上一步骤中准备好的dreambooth微调图像数据,通过inputs参数指定S3路径,SageMaker会自动将该路径下训练图像上传到训练算力机的/opt/ml/data/input/images目录下

  • keep_alive_period_in_seconds:该参数是SageMaker Training Job的warmpool,设置后可以把下一次训练机器保持在该用户的一个资源池中,这样方便多个SDXL Dreambooth训练时的镜像拉起,节省耗时的开销

  • enable_xformers_memory_efficient_attention:启用xformers的flash attention关注度计算优化,加速训练过程

  • train_use_spot_instance:是否使用spot竞价实例进行训练,进一步节省成本

  • max_run:训练任务的最大运行时间

  • max_wait:等待竞价实例的最长时间,如果使用spot竞价实例该参数是必须的

SageMaker Training Job中安装部署Stable Diffusion WebUI

 如上文所述,训练完成后可以直接使用fine-tuned模型进行推理出图,这里采用Stable Diffusion WebUI进行推理,需要在training job训练算力机上安装部署开源的WebUI组件,将模型目录同步到WebUI的model location下,然后调用WebUI API text2img/img2img出图,详细如下:

 docker镜像脚本

 由于是在training job中进行推理,扩充训练任务的dockerfile镜像文件,将Stable Diffusion WebUI组件及依赖同样的方式和上文中training的dockerfile打包到一起:

 WebUI启动脚本

 使用上述章节同样的build & push脚本,将docker镜像打包推送,然后在统一训练和推理的entry point脚本中启动训练任务,任务完成后启动WebUI。

 SageMaker Training Job中对Fine-tuned Dreambooth Model进行推理

 在start_sd_webui.py脚本启动WebUI服务器之后,即可使用WebUI API进行txt2img/img2img的推理调用,其推理API与官方参数一致。

 由于在同一台训练算力机上,其URI为localhost(0.0.0.0)对应端口及API路径前缀。

 总结

 本文介绍了在Quick Kit中使用SageMaker Training Job对SDXL模型进行Dreambooth微调,并且可以在训练完成后对fine-tuned后的模型使用Stable Diffusion WebUI进行推理,实现从训练到推理的一体化操作,满足客户对于快速训练人物或商品实体并批量推理出图的需求。


http://www.ppmy.cn/news/1165392.html

相关文章

spring cloud Eureka集群模式搭建(IDEA中运行)

spring cloud Eureka集群模式搭建(IDEA中运行) 新建springboot 工程工程整体目录配置文件IDEA中部署以jar包形式启动总结 新建springboot 工程 新建一个springboot 工程,命名为:eureka_server。 其中pom.xml文件为: …

Spring Securit OAuth 2.0整合—核心的接口和类

目录 一、ClientRegistration 二、ClientRegistrationRepository 三、OAuth2AuthorizedClient 四、OAuth2AuthorizedClientRepository 和 OAuth2AuthorizedClientService 五、OAuth2AuthorizedClientManager 和 OAuth2AuthorizedClientProvider 一、ClientRegistration C…

TS使用echarts柱状图鼠标放上去并弹出

效果 代码 <template><div><Chart style"width: 100%; height: 344px" :option"chartOption" /></div> </template><script lang"ts" setup>import { ref } from vue;import { ToolTipFormatterParams } f…

el-input无法输入的问题和表单验证失败问题(亲测有效)-开发bug总结4

大部分无法输入的问题&#xff1a;基本都是没有进行v-model双向数据绑定&#xff0c;这个很好解决。 本人项目中遇到的bug问题如下&#xff1a; 点击添加&#xff0c;表单内可输入用户名 和 用户姓名&#xff0c;但有时会偶发出现无法这两个input框里面无法输入内容。 原因&a…

【小米】Linux 实习生

下午不准备去图书馆自习来着&#xff0c;中午就狠狠地多睡了一个小时&#xff0c;三点起床靠在椅子上剥柚子&#xff0c;太爽了&#xff0c;这秋天的下午。“邮件&#xff1a;小米公司邀请你预约面试时间”.......... 我擦&#xff0c;投了一个月了&#xff0c;认真准备的时候…

eNSP-OSPF协议其他区域不与骨干区域相连解决方法2

隧道技术 AR1 [ar1]int g0/0/0 [ar1-GigabitEthernet0/0/0]ip add 192.168.1.1 24 [ar1-GigabitEthernet0/0/0]quit [ar1]ospf [ar1-ospf-1]area 0 [ar1-ospf-1-area-0.0.0.0]net 192.168.1.0 0.0.0.255 [ar1-ospf-1-area-0.0.0.0]quit AR2 [ar2]int g0/0/0 [ar2-GigabitEthe…

Windows 事件日志监控

Windows 事件日志是记录 Microsoft 系统上发生的所有活动的文件&#xff0c;在 Windows 环境中&#xff0c;将记录系统上托管的系统、安全性和应用程序的事件&#xff0c;事件日志提供包含有关事件的详细信息&#xff0c;包括日期、时间、事件 ID、源、事件类型和发起它的用户。…

万宾科技智能井盖传感器特点介绍

当谈论城市基础设施的管理和安全时&#xff0c;井盖通常不是第一项引人注目的话题。然而&#xff0c;传统井盖和智能井盖传感器之间的差异已经引起了城市规划者和工程师的广泛关注。这两种技术在功能、管理、安全和成本等多个方面存在着显著的差异。 WITBEE万宾智能井盖传感器E…