联邦大模型微调

news/2024/12/29 3:45:14/

微调(Fine-tuning)是一种迁移学习的技术,用于在一个已经预训练好的模型基础上,通过进一步训练来适应特定的任务或数据集。微调可以在具有相似特征的任务之间共享知识,从而加快训练速度并提高模型性能。

微调步骤:

  1. 选择预训练模型:选择一个在大规模数据集上预训练好的模型,如ImageNet上的预训练的卷积神经网络(如ResNet、VGG等)。这些模型通常具有良好的特征提取能力。
  2. 冻结底层权重:将预训练模型的底层权重(通常是卷积层)固定住,不进行训练。这是因为底层权重通常学习到了通用的特征,可以被用于许多不同的任务。
  3. 替换顶层分类器:将预训练模型的顶层分类器(通常是全连接层)替换为适合特定任务的新的分类器。新的分类器的输出节点数量应该与任务的类别数相匹配。
  4. 解冻部分权重(可选):根据任务的复杂性和可用的训练数据量,可以选择解冻一些底层权重,以便更好地适应新的任务。这样可以允许底层权重进行微小的调整,以更好地适应新任务的特征。
  5. 进行训练:使用特定任务的训练数据集对新的分类器进行训练。可以使用较小的学习率进行训练,以避免对预训练模型的权重进行过大的更新。
  6. 评估和调整:在训练完成后,使用验证集或测试集评估模型的性能。根据评估结果,可以进行调整,如调整学习率、调整模型结构等。

PEFT:

PEFT(Performance Estimation and Modeling for Fine-Tuning)是一种用于微调任务的性能估计和建模方法。它的主要目的是帮助研究人员和从业者在微调过程中更好地理解和预测模型的性能,并进行更有效的模型选择和调优。

使用场景:

  1. 模型选择:在微调之前,通常需要选择一个合适的预训练模型。PEFT可以帮助评估和比较不同预训练模型在特定任务上的性能,从而选择最适合的模型。
  2. 超参数调优:微调过程中可能涉及到一些超参数的选择,如学习率、批量大小等。PEFT可以帮助预估不同超参数设置下模型的性能,并指导超参数的调优。
  3. 计算资源规划:微调通常需要大量的计算资源,如显存、GPU时间等。PEFT可以帮助估计不同模型和数据集规模下的计算资源需求,以便更好地规划和分配资源。
  4. 模型压缩和加速:在一些场景下,需要将模型压缩或加速,以便在资源受限的设备上进行推理。PEFT可以帮助评估不同压缩和加速技术对模型性能的影响,并指导模型优化的方向。

PEFT的关键步骤:

  1. 数据采样:从原始数据集中采样一小部分数据用于性能估计。这样可以减少计算开销,同时保持采样数据与原始数据集的分布一致性。
  2. 特征提取:使用预训练模型提取采样数据的特征表示。这些特征通常具有很好的表达能力,可以用于性能估计。
  3. 性能估计模型:基于采样数据的特征表示,建立一个性能估计模型。这个模型可以是简单的线性回归模型,也可以是更复杂的神经网络模型。
  4. 性能预测:使用性能估计模型对未知数据的性能进行预测。通过输入微调任务的特征表示,模型可以输出预测的性能指标,如准确率、F1分数等。

代码:

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
import flwr as fl
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
from datasets import load_dataset
from flwr.client.mod import fixedclipping_mod
from flwr.server.strategy import (DifferentialPrivacyClientSideFixedClipping
)
from utils.utils import * 
from utils.LLM import LLM_fl
from utils.LLM import get_fireworks_api_key,load_env
cfg = get_config("federated")print_config(cfg)

partitioner = IidPartitioner(num_partitions=cfg.flower.num_clients)
fds = FederatedDataset(dataset=cfg.dataset.name,partitioners={"train": partitioner}
)partition_zero = fds.load_partition(0) format_dataset(partition_zero)

(20个客户端均分数据集)

visualize_partitions(fds)

获取分词器(tokenizer)、数据整理器(data collator)和格式化提示函数(formatting prompts function):

(
tokenizer,
data_collator,
formatting_prompts_func,
) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name, cfg.model.use_fast_tokenizer,cfg.train.padding_side,
)

预训练的分词器:EleutherAI/pythia-70m

save_path = "./my_fl_model"
client = fl.client.ClientApp(client_fn=gen_client_fn(fds,tokenizer,formatting_prompts_func,data_collator,cfg.model, cfg.train, save_path,),mods=[fixedclipping_mod] 
)

fds:

<flwr_datasets.federated_dataset.FederatedDataset object at 0x7fa271c2f790>

tokenizer:

为联邦学习服务器配置了一个 FedAvg 策略,并添加了差分隐私机制,同时指定了服务器运行的轮数和其他相关参数:

def server_fn(context: Context):# Define the Strategystrategy = fl.server.strategy.FedAvg(min_available_clients=cfg.flower.num_clients, # total clientsfraction_fit=cfg.flower.fraction_fit, # ratio of clients to samplefraction_evaluate=0.0, # No federated evaluation# A (optional) function used to configure a "fit()" roundon_fit_config_fn=get_on_fit_config(),# A (optional) function to aggregate metrics sent by clientsfit_metrics_aggregation_fn=fit_weighted_average,# A (optional) function to execute on the server after each round. # In this example the function only saves the global model.evaluate_fn=get_evaluate_fn( cfg.model,cfg.train.save_every_round,cfg.flower.num_rounds,save_path),)# Add Differential Privacysampled_clients = cfg.flower.num_clients*strategy.fraction_fitstrategy = DifferentialPrivacyClientSideFixedClipping(strategy, noise_multiplier=cfg.flower.dp.noise_mult,clipping_norm=cfg.flower.dp.clip_norm, num_sampled_clients=sampled_clients)# Number of rounds to run the simulationnum_rounds = cfg.flower.num_roundsconfig = fl.server.ServerConfig(num_rounds=num_rounds)return fl.server.ServerAppComponents(strategy=strategy, config=config) 

配置和启动一个联邦学习模拟,其中包含了服务器和客户端的应用,以及它们运行所需的资源。通过调用 run_simulation 函数,模拟开始执行,客户端和服务器之间将进行多轮通信,以协作训练一个共享的模型,同时保护数据的隐私:

client_resources = dict(cfg.flower.client_resources)
fl.simulation.run_simulation(server_app=server,client_app=client,num_supernodes=cfg.flower.num_clients,backend_config={"client_resources": client_resources,"init_args": backend_setup}
)

运行微调后的模型:

# Load the checkpoint
llm_eval = LLM_fl()# Load dataset
train_dataset = load_dataset(cfg.dataset.name, split='train')
train_dataset = format_dataset(train_dataset)# Select training example
example_index = 6data_point = train_dataset[example_index]# Print the prompt
llm_eval.eval(data_point['instruction'], verbose=True)

# Print the fine-tuned LLM response
llm_eval.print_response()

# Print the expected output from the medAlpaca dataset
ex_response = format_string(data_point['response'])
print(f"Expected output:\n\t{ex_response}")

结果可视化:

visualize_results(results=['7b/pretrained', '7b/cen_10', '7b/fl'])

向集中式精细调整模型提供相同数量的数据:

visualize_results(results=['7b/pretrained', '7b/cen_10','7b/cen_full', '7b/fl'],compact=True)

计算花销:

cfg = get_config("federated")compute_communication_costs(cfg, comm_bw_mbps=20)

参考:

【1】LLMs_interview_notes/大模型(LLMs)参数高效微调(PEFT)面/大模型(LLMs)参数高效微调(PEFT)面.md at main · naginoa/LLMs_interview_notes · GitHub

【2】


http://www.ppmy.cn/news/1558961.html

相关文章

Linux(Centos 7.6)基本信息查看

1.服务器硬件信息查看 1.1.服务器厂商、产品名称查看 dmidecode -s system-manufacturer&#xff1a;查看服务器厂商信息 dmidecode -s system-product-name&#xff1a;查看服务器产品名称信息 1.Windows使用VMware安装的Linux(Centos 7.6)后&#xff0c;服务器厂商、产品名…

Windows和Linux安全配置和加固

一.A模块基础设施设置/安全加固 A-1.登录加固 1.密码策略 a.最小密码长度不少于8个字符&#xff0c;将密码长度最小值的属性配置界面截图。 练习用的WindowsServer2008,系统左下角开始 > 管理工具 > 本地安全策略 > 账户策略 > 密码策略 > 密码最小长度&#…

如何在 Spring Boot 微服务中设置和管理多个数据库

在现代微服务架构中&#xff0c;通常需要与多个数据库交互的服务。这可能是由于各种原因&#xff0c;例如遗留系统集成、不同类型的数据存储需求&#xff0c;或者仅仅是为了优化性能。Spring Boot 具有灵活的配置和强大的数据访问库&#xff0c;可以轻松配置多个数据库。在本综…

为什么要在PHY芯片和RJ45网口中间加网络变压器

在PHY芯片和RJ45网口之间加入网络变压器是出于以下几个重要的考虑&#xff1a; 1. 电气隔离&#xff1a;网络变压器提供了电气隔离功能&#xff0c;有效阻断了PHY芯片与RJ45之间直流分量的直接连接。这样可以防止可能的电源冲突&#xff0c;降低系统故障的风险&#xff0c;并保…

如何阻止盗版软件在互联网上传播

阻止公司软件的盗版传播是一项复杂但重要的任务&#xff0c;可以通过技术、法律和管理手段相结合来实现。以下是一些有效的措施&#xff1a; 1. 技术措施 1.1 软件保护 使用软件加密&#xff1a;采用强大的代码混淆、加密技术和反篡改机制。硬件绑定&#xff1a;将软件激活与…

Selenium实践总结

1.使用显示等待而不是隐式等待 隐式等待可能会导致不可预测的测试行为&#xff0c;尤其是在动态 Web 应用程序中。显式等待&#xff0c;它允许您 等待特定条件发生后再继续测试&#xff0c;这种方法提供了更多的控制和可靠性。 WebDriverWait wait new WebDriverWait(drive…

网站服务器被攻击了怎么办?

当网站服务器被攻击时&#xff0c;可能会出现各种问题&#xff0c;如服务中断、数据泄露、恶意软件感染等。如果不及时采取措施&#xff0c;可能会给企业带来严重的损失。因此&#xff0c;当网站服务器被攻击时&#xff0c;企业需要采取以下措施来应对&#xff1a; 一、快速定…

springboot整合Elasticsearch介绍

上一篇博客介绍了elasticsearch及其安装部署&#xff08;https://chengpei.top/archives/elasticsearch-jieshao&#xff09;&#xff0c;这次就介绍了一下如何将ES和我们的springboot项目整合使用 连接工具 整合之前我们先介绍一款工具用于连接elasticsearch查询工具&#x…