Chapter 6 -Fine-tuning for classification
6.8-Using the LLM as a spam classifier
-
对模型进行微调和评估后,我们现在可以对垃圾短信进行分类了(下图)。让我们使用基于 GPT 的微调垃圾邮件分类模型。
followingclassify_review
函数遵循数据预处理步骤,类似于我们在SpamDataset
实现器中使用的步骤。然后,在将文本处理成token ID 后,函数使用该模型预测一个整数类标签,类似于我们在第 6.6 节中实现的标签,然后返回相应的类名python">def classify_review(text, model, tokenizer, device, max_length=None, pad_token_id=50256):model.eval()# Prepare inputs to the modelinput_ids = tokenizer.encode(text)supported_context_length = model.pos_emb.weight.shape[0]# Note: In the book, this was originally written as pos_emb.weight.shape[1] by mistake# It didn't break the code but would have caused unnecessary truncation (to 768 instead of 1024)# Truncate sequences if they too longinput_ids = input_ids[:min(max_length, supported_context_length)]# Pad sequences to the longest sequenceinput_ids += [pad_token_id] * (max_length - len(input_ids))input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) # add batch dimension# Model inferencewith torch.no_grad():logits = model(input_tensor)[:, -1, :] # Logits of the last output tokenpredicted_label = torch.argmax(logits, dim=-1).item()# Return the classified resultreturn "spam" if predicted_label == 1 else "not spam"
-
让我们在下面的几个例子中尝试一下
python">text_1 = ("You are a winner you have been specially"" selected to receive $1000 cash or a $2000 award." )print(classify_review(text_1, model, tokenizer, device, max_length=train_dataset.max_length ))"""输出""" spam
python">text_2 = ("Hey, just wanted to check if we're still on"" for dinner tonight? Let me know!" )print(classify_review(text_2, model, tokenizer, device, max_length=train_dataset.max_length ))"""输出""" not spam
-
最后,让我们保存模型,以防以后我们想重用模型而不必再次训练它
python">torch.save(model.state_dict(), "review_classifier.pth")
然后,在新会话中,我们可以按如下方式加载模型
python">model_state_dict = torch.load("review_classifier.pth", map_location=device, weights_only=True) model.load_state_dict(model_state_dict)
6.9-Summary and takeaways
-
summary
- There are different strategies for fine-tuning LLMs, including classification fine-tuning and instruction fine-tuning.
- Classification fine-tuning involves replacing the output layer of an LLM via a small classification layer.
- In the case of classifying text messages as “spam” or “not spam,” the new classification layer consists of only two output nodes. Previously, we used the number of output nodes equal to the number of unique tokens in the vocabulary (i.e., 50,256).
- Instead of predicting the next token in the text as in pretraining, classification fine-tuning trains the model to output a correct class label—for example, “spam” or “not spam.”
- The model input for fine-tuning is text converted into token IDs, similar to pretraining.
- Before fine-tuning an LLM, we load the pretrained model as a base model.
- Evaluating a classification model involves calculating the classification accuracy (the fraction or percentage of correct predictions).
- Fine-tuning a classification model uses the same cross entropy loss function as when pretraining the LLM
-
takeaways
- ./gpt_class_finetune.py为用于分类微调的脚本
- appendix E 中有lora参数高效训练的介绍
-
gpt_class_finetune.py
python"># Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). # Source for "Build a Large Language Model From Scratch" # - https://www.manning.com/books/build-a-large-language-model-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch# This is a summary file containing the main takeaways from chapter 6.import urllib.request import zipfile import os from pathlib import Path import timeimport matplotlib.pyplot as plt import pandas as pd import tiktoken import torch from torch.utils.data import Dataset, DataLoaderfrom gpt_download import download_and_load_gpt2 from previous_chapters import GPTModel, load_weights_into_gptdef download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=False):if data_file_path.exists():print(f"{data_file_path} already exists. Skipping download and extraction.")returnif test_mode: # Try multiple times since CI sometimes has connectivity issuesmax_retries = 5delay = 5 # delay between retries in secondsfor attempt in range(max_retries):try:# Downloading the filewith urllib.request.urlopen(url, timeout=10) as response:with open(zip_path, "wb") as out_file:out_file.write(response.read())break # if download is successful, break out of the loopexcept urllib.error.URLError as e:print(f"Attempt {attempt + 1} failed: {e}")if attempt < max_retries - 1:time.sleep(delay) # wait before retryingelse:print("Failed to download file after several attempts.")return # exit if all retries failelse: # Code as it appears in the chapter# Downloading the filewith urllib.request.urlopen(url) as response:with open(zip_path, "wb") as out_file:out_file.write(response.read())# Unzipping the filewith zipfile.ZipFile(zip_path, "r") as zip_ref:zip_ref.extractall(extracted_path)# Add .tsv file extensionoriginal_file_path = Path(extracted_path) / "SMSSpamCollection"os.rename(original_file_path, data_file_path)print(f"File downloaded and saved as {data_file_path}")def create_balanced_dataset(df):# Count the instances of "spam"num_spam = df[df["Label"] == "spam"].shape[0]# Randomly sample "ham" instances to match the number of "spam" instancesham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)# Combine ham "subset" with "spam"balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])return balanced_dfdef random_split(df, train_frac, validation_frac):# Shuffle the entire DataFramedf = df.sample(frac=1, random_state=123).reset_index(drop=True)# Calculate split indicestrain_end = int(len(df) * train_frac)validation_end = train_end + int(len(df) * validation_frac)# Split the DataFrametrain_df = df[:train_end]validation_df = df[train_end:validation_end]test_df = df[validation_end:]return train_df, validation_df, test_dfclass SpamDataset(Dataset):def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):self.data = pd.read_csv(csv_file)# Pre-tokenize textsself.encoded_texts = [tokenizer.encode(text) for text in self.data["Text"]]if max_length is None:self.max_length = self._longest_encoded_length()else:self.max_length = max_length# Truncate sequences if they are longer than max_lengthself.encoded_texts = [encoded_text[:self.max_length]for encoded_text in self.encoded_texts]# Pad sequences to the longest sequenceself.encoded_texts = [encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))for encoded_text in self.encoded_texts]def __getitem__(self, index):encoded = self.encoded_texts[index]label = self.data.iloc[index]["Label"]return (torch.tensor(encoded, dtype=torch.long),torch.tensor(label, dtype=torch.long))def __len__(self):return len(self.data)def _longest_encoded_length(self):max_length = 0for encoded_text in self.encoded_texts:encoded_length = len(encoded_text)if encoded_length > max_length:max_length = encoded_lengthreturn max_lengthdef calc_accuracy_loader(data_loader, model, device, num_batches=None):model.eval()correct_predictions, num_examples = 0, 0if num_batches is None:num_batches = len(data_loader)else:num_batches = min(num_batches, len(data_loader))for i, (input_batch, target_batch) in enumerate(data_loader):if i < num_batches:input_batch, target_batch = input_batch.to(device), target_batch.to(device)with torch.no_grad():logits = model(input_batch)[:, -1, :] # Logits of last output tokenpredicted_labels = torch.argmax(logits, dim=-1)num_examples += predicted_labels.shape[0]correct_predictions += (predicted_labels == target_batch).sum().item()else:breakreturn correct_predictions / num_examplesdef calc_loss_batch(input_batch, target_batch, model, device):input_batch, target_batch = input_batch.to(device), target_batch.to(device)logits = model(input_batch)[:, -1, :] # Logits of last output tokenloss = torch.nn.functional.cross_entropy(logits, target_batch)return lossdef calc_loss_loader(data_loader, model, device, num_batches=None):total_loss = 0.if len(data_loader) == 0:return float("nan")elif num_batches is None:num_batches = len(data_loader)else:num_batches = min(num_batches, len(data_loader))for i, (input_batch, target_batch) in enumerate(data_loader):if i < num_batches:loss = calc_loss_batch(input_batch, target_batch, model, device)total_loss += loss.item()else:breakreturn total_loss / num_batchesdef evaluate_model(model, train_loader, val_loader, device, eval_iter):model.eval()with torch.no_grad():train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)model.train()return train_loss, val_lossdef train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,eval_freq, eval_iter, tokenizer):# Initialize lists to track losses and tokens seentrain_losses, val_losses, train_accs, val_accs = [], [], [], []examples_seen, global_step = 0, -1# Main training loopfor epoch in range(num_epochs):model.train() # Set model to training modefor input_batch, target_batch in train_loader:optimizer.zero_grad() # Reset loss gradients from previous batch iterationloss = calc_loss_batch(input_batch, target_batch, model, device)loss.backward() # Calculate loss gradientsoptimizer.step() # Update model weights using loss gradientsexamples_seen += input_batch.shape[0] # New: track examples instead of tokensglobal_step += 1# Optional evaluation stepif global_step % eval_freq == 0:train_loss, val_loss = evaluate_model(model, train_loader, val_loader, device, eval_iter)train_losses.append(train_loss)val_losses.append(val_loss)print(f"Ep {epoch+1} (Step {global_step:06d}): "f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")# Calculate accuracy after each epochtrain_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")print(f"Validation accuracy: {val_accuracy*100:.2f}%")train_accs.append(train_accuracy)val_accs.append(val_accuracy)return train_losses, val_losses, train_accs, val_accs, examples_seendef plot_values(epochs_seen, examples_seen, train_values, val_values, label="loss"):fig, ax1 = plt.subplots(figsize=(5, 3))# Plot training and validation loss against epochsax1.plot(epochs_seen, train_values, label=f"Training {label}")ax1.plot(epochs_seen, val_values, linestyle="-.", label=f"Validation {label}")ax1.set_xlabel("Epochs")ax1.set_ylabel(label.capitalize())ax1.legend()# Create a second x-axis for tokens seenax2 = ax1.twiny() # Create a second x-axis that shares the same y-axisax2.plot(examples_seen, train_values, alpha=0) # Invisible plot for aligning ticksax2.set_xlabel("Examples seen")fig.tight_layout() # Adjust layout to make roomplt.savefig(f"{label}-plot.pdf")# plt.show()if __name__ == "__main__":import argparseparser = argparse.ArgumentParser(description="Finetune a GPT model for classification")parser.add_argument("--test_mode",default=False,action="store_true",help=("This flag runs the model in test mode for internal testing purposes. ""Otherwise, it runs the model as it is used in the chapter (recommended)."))args = parser.parse_args()######################################### Download and prepare dataset########################################url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"zip_path = "sms_spam_collection.zip"extracted_path = "sms_spam_collection"data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path, test_mode=args.test_mode)df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])balanced_df = create_balanced_dataset(df)balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)train_df.to_csv("train.csv", index=None)validation_df.to_csv("validation.csv", index=None)test_df.to_csv("test.csv", index=None)######################################### Create data loaders########################################tokenizer = tiktoken.get_encoding("gpt2")train_dataset = SpamDataset(csv_file="train.csv",max_length=None,tokenizer=tokenizer)val_dataset = SpamDataset(csv_file="validation.csv",max_length=train_dataset.max_length,tokenizer=tokenizer)test_dataset = SpamDataset(csv_file="test.csv",max_length=train_dataset.max_length,tokenizer=tokenizer)num_workers = 0batch_size = 8torch.manual_seed(123)train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers,drop_last=True,)val_loader = DataLoader(dataset=val_dataset,batch_size=batch_size,num_workers=num_workers,drop_last=False,)test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,num_workers=num_workers,drop_last=False,)######################################### Load pretrained model######################################### Small GPT model for testing purposesif args.test_mode:BASE_CONFIG = {"vocab_size": 50257,"context_length": 120,"drop_rate": 0.0,"qkv_bias": False,"emb_dim": 12,"n_layers": 1,"n_heads": 2}model = GPTModel(BASE_CONFIG)model.eval()device = "cpu"# Code as it is used in the main chapterelse:CHOOSE_MODEL = "gpt2-small (124M)"INPUT_PROMPT = "Every effort moves"BASE_CONFIG = {"vocab_size": 50257, # Vocabulary size"context_length": 1024, # Context length"drop_rate": 0.0, # Dropout rate"qkv_bias": True # Query-key-value bias}model_configs = {"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},}BASE_CONFIG.update(model_configs[CHOOSE_MODEL])assert train_dataset.max_length <= BASE_CONFIG["context_length"], (f"Dataset length {train_dataset.max_length} exceeds model's context "f"length {BASE_CONFIG['context_length']}. Reinitialize data sets with "f"`max_length={BASE_CONFIG['context_length']}`")model_size = CHOOSE_MODEL.split(" ")[-1].lstrip("(").rstrip(")")settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")model = GPTModel(BASE_CONFIG)load_weights_into_gpt(model, params)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")######################################### Modify and pretrained model########################################for param in model.parameters():param.requires_grad = Falsetorch.manual_seed(123)num_classes = 2model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes)model.to(device)for param in model.trf_blocks[-1].parameters():param.requires_grad = Truefor param in model.final_norm.parameters():param.requires_grad = True######################################### Finetune modified model########################################start_time = time.time()torch.manual_seed(123)optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)num_epochs = 5train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(model, train_loader, val_loader, optimizer, device,num_epochs=num_epochs, eval_freq=50, eval_iter=5,tokenizer=tokenizer)end_time = time.time()execution_time_minutes = (end_time - start_time) / 60print(f"Training completed in {execution_time_minutes:.2f} minutes.")######################################### Plot results######################################### loss plotepochs_tensor = torch.linspace(0, num_epochs, len(train_losses))examples_seen_tensor = torch.linspace(0, examples_seen, len(train_losses))plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses)# accuracy plotepochs_tensor = torch.linspace(0, num_epochs, len(train_accs))examples_seen_tensor = torch.linspace(0, examples_seen, len(train_accs))plot_values(epochs_tensor, examples_seen_tensor, train_accs, val_accs, label="accuracy")