Chapter 6.8-Using the LLM as a spam classifier

ops/2025/2/5 2:23:00/

Chapter 6 -Fine-tuning for classification

6.8-Using the LLM as a spam classifier

  • 对模型进行微调和评估后,我们现在可以对垃圾短信进行分类了(下图)。让我们使用基于 GPT 的微调垃圾邮件分类模型。

    followingclassify_review 函数遵循数据预处理步骤,类似于我们在 SpamDataset 实现器中使用的步骤。然后,在将文本处理成token ID 后,函数使用该模型预测一个整数类标签,类似于我们在第 6.6 节中实现的标签,然后返回相应的类名

    python">def classify_review(text, model, tokenizer, device, max_length=None, pad_token_id=50256):model.eval()# Prepare inputs to the modelinput_ids = tokenizer.encode(text)supported_context_length = model.pos_emb.weight.shape[0]# Note: In the book, this was originally written as pos_emb.weight.shape[1] by mistake# It didn't break the code but would have caused unnecessary truncation (to 768 instead of 1024)# Truncate sequences if they too longinput_ids = input_ids[:min(max_length, supported_context_length)]# Pad sequences to the longest sequenceinput_ids += [pad_token_id] * (max_length - len(input_ids))input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) # add batch dimension# Model inferencewith torch.no_grad():logits = model(input_tensor)[:, -1, :]  # Logits of the last output tokenpredicted_label = torch.argmax(logits, dim=-1).item()# Return the classified resultreturn "spam" if predicted_label == 1 else "not spam"
    
  • 让我们在下面的几个例子中尝试一下

    python">text_1 = ("You are a winner you have been specially"" selected to receive $1000 cash or a $2000 award."
    )print(classify_review(text_1, model, tokenizer, device, max_length=train_dataset.max_length
    ))"""输出"""
    spam
    
    python">text_2 = ("Hey, just wanted to check if we're still on"" for dinner tonight? Let me know!"
    )print(classify_review(text_2, model, tokenizer, device, max_length=train_dataset.max_length
    ))"""输出"""
    not spam
    
  • 最后,让我们保存模型,以防以后我们想重用模型而不必再次训练它

    python">torch.save(model.state_dict(), "review_classifier.pth")
    

    然后,在新会话中,我们可以按如下方式加载模型

    python">model_state_dict = torch.load("review_classifier.pth", map_location=device, weights_only=True)
    model.load_state_dict(model_state_dict)
    

6.9-Summary and takeaways

  • summary

    1. There are different strategies for fine-tuning LLMs, including classification fine-tuning and instruction fine-tuning.
    2. Classification fine-tuning involves replacing the output layer of an LLM via a small classification layer.
    3. In the case of classifying text messages as “spam” or “not spam,” the new classification layer consists of only two output nodes. Previously, we used the number of output nodes equal to the number of unique tokens in the vocabulary (i.e., 50,256).
    4. Instead of predicting the next token in the text as in pretraining, classification fine-tuning trains the model to output a correct class label—for example, “spam” or “not spam.”
    5. The model input for fine-tuning is text converted into token IDs, similar to pretraining.
    6. Before fine-tuning an LLM, we load the pretrained model as a base model.
    7. Evaluating a classification model involves calculating the classification accuracy (the fraction or percentage of correct predictions).
    8. Fine-tuning a classification model uses the same cross entropy loss function as when pretraining the LLM
  • takeaways

    1. ./gpt_class_finetune.py为用于分类微调的脚本
    2. appendix E 中有lora参数高效训练的介绍
  • gpt_class_finetune.py

    python"># Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
    # Source for "Build a Large Language Model From Scratch"
    #   - https://www.manning.com/books/build-a-large-language-model-from-scratch
    # Code: https://github.com/rasbt/LLMs-from-scratch# This is a summary file containing the main takeaways from chapter 6.import urllib.request
    import zipfile
    import os
    from pathlib import Path
    import timeimport matplotlib.pyplot as plt
    import pandas as pd
    import tiktoken
    import torch
    from torch.utils.data import Dataset, DataLoaderfrom gpt_download import download_and_load_gpt2
    from previous_chapters import GPTModel, load_weights_into_gptdef download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=False):if data_file_path.exists():print(f"{data_file_path} already exists. Skipping download and extraction.")returnif test_mode:  # Try multiple times since CI sometimes has connectivity issuesmax_retries = 5delay = 5  # delay between retries in secondsfor attempt in range(max_retries):try:# Downloading the filewith urllib.request.urlopen(url, timeout=10) as response:with open(zip_path, "wb") as out_file:out_file.write(response.read())break  # if download is successful, break out of the loopexcept urllib.error.URLError as e:print(f"Attempt {attempt + 1} failed: {e}")if attempt < max_retries - 1:time.sleep(delay)  # wait before retryingelse:print("Failed to download file after several attempts.")return  # exit if all retries failelse:  # Code as it appears in the chapter# Downloading the filewith urllib.request.urlopen(url) as response:with open(zip_path, "wb") as out_file:out_file.write(response.read())# Unzipping the filewith zipfile.ZipFile(zip_path, "r") as zip_ref:zip_ref.extractall(extracted_path)# Add .tsv file extensionoriginal_file_path = Path(extracted_path) / "SMSSpamCollection"os.rename(original_file_path, data_file_path)print(f"File downloaded and saved as {data_file_path}")def create_balanced_dataset(df):# Count the instances of "spam"num_spam = df[df["Label"] == "spam"].shape[0]# Randomly sample "ham" instances to match the number of "spam" instancesham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)# Combine ham "subset" with "spam"balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])return balanced_dfdef random_split(df, train_frac, validation_frac):# Shuffle the entire DataFramedf = df.sample(frac=1, random_state=123).reset_index(drop=True)# Calculate split indicestrain_end = int(len(df) * train_frac)validation_end = train_end + int(len(df) * validation_frac)# Split the DataFrametrain_df = df[:train_end]validation_df = df[train_end:validation_end]test_df = df[validation_end:]return train_df, validation_df, test_dfclass SpamDataset(Dataset):def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):self.data = pd.read_csv(csv_file)# Pre-tokenize textsself.encoded_texts = [tokenizer.encode(text) for text in self.data["Text"]]if max_length is None:self.max_length = self._longest_encoded_length()else:self.max_length = max_length# Truncate sequences if they are longer than max_lengthself.encoded_texts = [encoded_text[:self.max_length]for encoded_text in self.encoded_texts]# Pad sequences to the longest sequenceself.encoded_texts = [encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))for encoded_text in self.encoded_texts]def __getitem__(self, index):encoded = self.encoded_texts[index]label = self.data.iloc[index]["Label"]return (torch.tensor(encoded, dtype=torch.long),torch.tensor(label, dtype=torch.long))def __len__(self):return len(self.data)def _longest_encoded_length(self):max_length = 0for encoded_text in self.encoded_texts:encoded_length = len(encoded_text)if encoded_length > max_length:max_length = encoded_lengthreturn max_lengthdef calc_accuracy_loader(data_loader, model, device, num_batches=None):model.eval()correct_predictions, num_examples = 0, 0if num_batches is None:num_batches = len(data_loader)else:num_batches = min(num_batches, len(data_loader))for i, (input_batch, target_batch) in enumerate(data_loader):if i < num_batches:input_batch, target_batch = input_batch.to(device), target_batch.to(device)with torch.no_grad():logits = model(input_batch)[:, -1, :]  # Logits of last output tokenpredicted_labels = torch.argmax(logits, dim=-1)num_examples += predicted_labels.shape[0]correct_predictions += (predicted_labels == target_batch).sum().item()else:breakreturn correct_predictions / num_examplesdef calc_loss_batch(input_batch, target_batch, model, device):input_batch, target_batch = input_batch.to(device), target_batch.to(device)logits = model(input_batch)[:, -1, :]  # Logits of last output tokenloss = torch.nn.functional.cross_entropy(logits, target_batch)return lossdef calc_loss_loader(data_loader, model, device, num_batches=None):total_loss = 0.if len(data_loader) == 0:return float("nan")elif num_batches is None:num_batches = len(data_loader)else:num_batches = min(num_batches, len(data_loader))for i, (input_batch, target_batch) in enumerate(data_loader):if i < num_batches:loss = calc_loss_batch(input_batch, target_batch, model, device)total_loss += loss.item()else:breakreturn total_loss / num_batchesdef evaluate_model(model, train_loader, val_loader, device, eval_iter):model.eval()with torch.no_grad():train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)model.train()return train_loss, val_lossdef train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,eval_freq, eval_iter, tokenizer):# Initialize lists to track losses and tokens seentrain_losses, val_losses, train_accs, val_accs = [], [], [], []examples_seen, global_step = 0, -1# Main training loopfor epoch in range(num_epochs):model.train()  # Set model to training modefor input_batch, target_batch in train_loader:optimizer.zero_grad()  # Reset loss gradients from previous batch iterationloss = calc_loss_batch(input_batch, target_batch, model, device)loss.backward()  # Calculate loss gradientsoptimizer.step()  # Update model weights using loss gradientsexamples_seen += input_batch.shape[0]  # New: track examples instead of tokensglobal_step += 1# Optional evaluation stepif global_step % eval_freq == 0:train_loss, val_loss = evaluate_model(model, train_loader, val_loader, device, eval_iter)train_losses.append(train_loss)val_losses.append(val_loss)print(f"Ep {epoch+1} (Step {global_step:06d}): "f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")# Calculate accuracy after each epochtrain_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")print(f"Validation accuracy: {val_accuracy*100:.2f}%")train_accs.append(train_accuracy)val_accs.append(val_accuracy)return train_losses, val_losses, train_accs, val_accs, examples_seendef plot_values(epochs_seen, examples_seen, train_values, val_values, label="loss"):fig, ax1 = plt.subplots(figsize=(5, 3))# Plot training and validation loss against epochsax1.plot(epochs_seen, train_values, label=f"Training {label}")ax1.plot(epochs_seen, val_values, linestyle="-.", label=f"Validation {label}")ax1.set_xlabel("Epochs")ax1.set_ylabel(label.capitalize())ax1.legend()# Create a second x-axis for tokens seenax2 = ax1.twiny()  # Create a second x-axis that shares the same y-axisax2.plot(examples_seen, train_values, alpha=0)  # Invisible plot for aligning ticksax2.set_xlabel("Examples seen")fig.tight_layout()  # Adjust layout to make roomplt.savefig(f"{label}-plot.pdf")# plt.show()if __name__ == "__main__":import argparseparser = argparse.ArgumentParser(description="Finetune a GPT model for classification")parser.add_argument("--test_mode",default=False,action="store_true",help=("This flag runs the model in test mode for internal testing purposes. ""Otherwise, it runs the model as it is used in the chapter (recommended)."))args = parser.parse_args()######################################### Download and prepare dataset########################################url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"zip_path = "sms_spam_collection.zip"extracted_path = "sms_spam_collection"data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=args.test_mode)df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])balanced_df = create_balanced_dataset(df)balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)train_df.to_csv("train.csv", index=None)validation_df.to_csv("validation.csv", index=None)test_df.to_csv("test.csv", index=None)######################################### Create data loaders########################################tokenizer = tiktoken.get_encoding("gpt2")train_dataset = SpamDataset(csv_file="train.csv",max_length=None,tokenizer=tokenizer)val_dataset = SpamDataset(csv_file="validation.csv",max_length=train_dataset.max_length,tokenizer=tokenizer)test_dataset = SpamDataset(csv_file="test.csv",max_length=train_dataset.max_length,tokenizer=tokenizer)num_workers = 0batch_size = 8torch.manual_seed(123)train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers,drop_last=True,)val_loader = DataLoader(dataset=val_dataset,batch_size=batch_size,num_workers=num_workers,drop_last=False,)test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,num_workers=num_workers,drop_last=False,)######################################### Load pretrained model######################################### Small GPT model for testing purposesif args.test_mode:BASE_CONFIG = {"vocab_size": 50257,"context_length": 120,"drop_rate": 0.0,"qkv_bias": False,"emb_dim": 12,"n_layers": 1,"n_heads": 2}model = GPTModel(BASE_CONFIG)model.eval()device = "cpu"# Code as it is used in the main chapterelse:CHOOSE_MODEL = "gpt2-small (124M)"INPUT_PROMPT = "Every effort moves"BASE_CONFIG = {"vocab_size": 50257,     # Vocabulary size"context_length": 1024,  # Context length"drop_rate": 0.0,        # Dropout rate"qkv_bias": True         # Query-key-value bias}model_configs = {"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},}BASE_CONFIG.update(model_configs[CHOOSE_MODEL])assert train_dataset.max_length <= BASE_CONFIG["context_length"], (f"Dataset length {train_dataset.max_length} exceeds model's context "f"length {BASE_CONFIG['context_length']}. Reinitialize data sets with "f"`max_length={BASE_CONFIG['context_length']}`")model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")model = GPTModel(BASE_CONFIG)load_weights_into_gpt(model, params)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")######################################### Modify and pretrained model########################################for param in model.parameters():param.requires_grad = Falsetorch.manual_seed(123)num_classes = 2model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)model.to(device)for param in model.trf_blocks[-1].parameters():param.requires_grad = Truefor param in model.final_norm.parameters():param.requires_grad = True######################################### Finetune modified model########################################start_time = time.time()torch.manual_seed(123)optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)num_epochs = 5train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(model, train_loader, val_loader, optimizer, device,num_epochs=num_epochs, eval_freq=50, eval_iter=5,tokenizer=tokenizer)end_time = time.time()execution_time_minutes = (end_time - start_time) / 60print(f"Training completed in {execution_time_minutes:.2f} minutes.")######################################### Plot results######################################### loss plotepochs_tensor = torch.linspace(0, num_epochs, len(train_losses))examples_seen_tensor = torch.linspace(0, examples_seen, len(train_losses))plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses)# accuracy plotepochs_tensor = torch.linspace(0, num_epochs, len(train_accs))examples_seen_tensor = torch.linspace(0, examples_seen, len(train_accs))plot_values(epochs_tensor, examples_seen_tensor, train_accs, val_accs, label="accuracy")


http://www.ppmy.cn/ops/155732.html

相关文章

基于深度学习的视觉检测小项目(十七) 用户管理后台的编程

完成了用户管理功能的阶段。下一阶段进入AI功能相关。所有的资源见文章链接。 补充完后台代码的用户管理界面代码&#xff1a; import sqlite3from PySide6.QtCore import Slot from PySide6.QtWidgets import QDialog, QMessageBoxfrom . import user_manage # 导入使用ui…

Spring Boot面试题

为什么要用Spring Boot Spring Boot 优点非常多&#xff0c;如&#xff1a; 独立运行 Spring Boot 而且内嵌了各种 servlet 容器&#xff0c;Tomcat、Jetty 等&#xff0c;现在不再需要打成 war 包部署到 容器 中&#xff0c;Spring Boot 只要打成一个可执行的 jar 包就能独…

于动态规划的启幕之章,借 C++ 笔触绘就算法新篇

注意&#xff1a;代码由易到难 P1216 [IOI 1994] 数字三角形 Number Triangles 题目链接&#xff1a;[IOI 1994] 数字三角形 Number Triangles - 洛谷 题目描述 观察下面的数字金字塔。 写一个程序来查找从最高点到底部任意处结束的路径&#xff0c;使路径经过数字的和最大。每…

设计模式Python版 桥接模式

文章目录 前言一、桥接模式二、桥接模式示例三、桥接模式与适配器模式的联用 前言 GOF设计模式分三大类&#xff1a; 创建型模式&#xff1a;关注对象的创建过程&#xff0c;包括单例模式、简单工厂模式、工厂方法模式、抽象工厂模式、原型模式和建造者模式。结构型模式&…

【学习笔记之coze扣子】智能体创建

首先登录https://www.coze.cn/扣子平台注册一个账号 我们先创建一个简单的智能体 人设与回复逻辑怎么写 点击优化&#xff0c;将系统带出的信息替换到内容中即可&#xff0c;此时一个大概就出来了&#xff0c;我们再根据细节进行优化即可 在这个部分中&#xff0c;我们可以…

浅析JWT

浅析jwt 基本用法&#xff0c;基本原理什么的&#xff0c;问问ai都知道了&#xff0c;这里只简略带过。主要是在后面记录些自己的总结和思考 JWT 组成 json web token&#xff0c;由三部分组成&#xff0c;用 . 分隔&#xff1a; Header.Payload.Signature1. Header&#xf…

PyTorch中的movedim、transpose与permute

在PyTorch中&#xff0c;movedim、transpose 和 permute这三个操作都可以用来重新排列张量&#xff08;tensor&#xff09;的维度&#xff0c;它们功能相似却又有所不同。 movedim &#x1f517; torch.movedim 用途&#xff1a;将张量的一个或多个维度移动到新的位置。参数&…

C++:结构体和类

在之前的博客中已经讲过了C语言中的结构体概念了&#xff0c;重复的内容在这儿就不赘述了。C中的结构体在C语言的基础上还有些补充&#xff0c;在这里说明一下&#xff0c;顺便简单地讲一下类的概念。 一、成员函数 结构体类型声明的关键字是 struct &#xff0c;在C中结构体…