Parameter-Efficient Fine-Tuning (PEFT): A Hands-On Guide with LoRA
towardsai.net
Author(s): BeastBoyJay Originally published on Towards AI. Imagine building a powerful AI model without needing massive computational resources PEFT makes that possible, and Ill show you how with LoRA from scratch.IntroductionTraditional fine-tuning challenges :Fine-tuning large models sounds cool until reality hits. Imagine trying to sculpt a masterpiece but needing a giant crane just to lift your tools. Thats what traditional fine-tuning feels like. Youre working with millions (sometimes billions) of parameters, and the computational cost can skyrocket faster than your coffee bill during finals week.Hardware Struggles:Got a spare supercomputer lying around? Probably not.GPUs heat up like your phone during a marathon PUBG session.RAM gets maxed out faster than your Netflix binge in 4K.Data Dilemma:You need a ton of data, or your model behaves like a forgetful student on exam day.Gathering and cleaning that much data? A nightmare in itself.Snail-Speed Training:Hit run and wait and wait and maybe even take a nap while your model chugs along.Maintenance Mayhem:Tiny tweaks mean re-training the whole colossal beast.Waste of time, energy, and your already-thin patience.Need a solution :PEFT, solution for this traditional bulky fine-tuning method. Think of PEFT (Parameter-Efficient Fine-Tuning) as upgrading a car by just changing the tires instead of rebuilding the whole engine. Instead of retraining every parameter in a massive model, PEFT tweaks just the essential parts saving time, resources, and sanity.Why it rocks:Resource-Smart: No supercomputer required.Time-Saving: Faster results with minimal effort.Scalable: Handles large models like a pro.What is PEFT ?PEFT (Parameter-Efficient Fine-Tuning) is like giving your AI model a performance boost by only adjusting the most important parameters, rather than retraining the entire thing. Think of it as overclocking your model without needing to upgrade the whole motherboard.Why Is PEFT Necessary?Reduced Training Costs:Instead of burning through a fortune in GPU time to retrain the whole model, PEFT lets you fine-tune with minimal resources, saving both cash and computing power.Faster Adaptation to Tasks:PEFT allows you to quickly adapt large models to new tasks by only tuning the necessary components speeding up the training process without sacrificing accuracy.Minimal Memory Requirements:Rather than loading the entire model into memory, PEFT uses fewer resources, letting you work on large-scale models without draining your system.How PEFT works ?The core idea of the PEFT is toTypes of PEFT techniques :LoRA (Low-Rank Adaptation) :Lets talk about one of the coolest tricks in PEFT (Parameter-Efficient Fine-Tuning) LoRA. Imagine youve got this massive pre-trained model, like a Transformer, thats already packed with all sorts of knowledge. Now, instead of modifying everything in the model, LoRA lets you tweak just the essentials specifically, a few sneaky little low-rank matrices that help the model adapt to new tasks. The rest of the model stays frozen in time, like an immovable fortress, while LoRA does its magic.So, how does LoRA work its sorcery?Heres the gist of it: Lets say theres a weight matrix W in the model (maybe in the attention mechanism, where the model decides whats important in the input). LoRA comes in and says, Why not approximate W as the product of two much smaller matrices, A and B? Mathematically, its like:WABThese matrices, A and B, are low-rank which, in nerd terms, means they have way fewer parameters to deal with compared to the original weight matrix. The magic? Because A and B are so much smaller, weve got fewer parameters to tune during fine-tuning.But thats not all heres the real kicker:When it comes to fine-tuning, LoRA focuses only on training the parameters of A and B. The rest of the massive model stays locked, untouched. Its like having the keys to just one door in a huge mansion youre making minimal changes, but theyre all targeted and impactful.By doing this, you reduce the number of parameters you need to update during fine-tuning, which makes the whole process way more efficient. Youre getting the same task-specific performance without the heavy lifting of retraining everything. Its like finding the shortcut in a maze you still reach the goal, but with way less effort!Adapters :Lets talk about Adapters not the kind you plug into your phone charger, but these nifty little modules that slot into the transformer architecture like a perfect puzzle piece!Imagine youve got a powerful pre-trained model, and you need to adapt it to a new task. Instead of retraining the entire thing, you introduce an adapter a lightweight, task-specific module that fits neatly after each transformer block. The best part? You dont have to touch the core model at all. Its like adding a few extra gears to a well-oiled machine without dismantling the whole thing.Heres the lowdown on how adapters work:Insertion into Layers: Think of an adapter as a mini-module that slides in after key layers in the transformer, like right after the attention or feed-forward layers. It usually consists of a couple of fully connected layers, where the input size is the same as the original layer (because, lets face it, we dont want to mess with the models flow), but the output dimension is smaller. Its like a sleek, efficient middleman.Task-Specific Tuning: Heres where the fun happens: When you fine-tune the model, only the adapter parameters are updated. That means the core model packed with all its pre-trained knowledge stays frozen, like a wise professor whos teaching you everything they know, but youre just adding some extra knowledge with the adapter. The adapter absorbs the task-specific tweaks without messing up the original wisdom of the model.The Big Win?The core model retains its massive, generalized knowledge while the adapter learns just enough to tackle the new task. Its like teaching a world-class musician a new song without changing their entire repertoire. Efficient, fast, and keeps things clean.Prefix Tuning :Lets get into the groove of Prefix Tuning a clever, minimalist trick that adds just the right amount of guidance to steer a model without overhauling its entire structure. Its like giving your car a gentle nudge to take a different route without touching the engine. Cool, right?Heres how Prefix Tuning works its magic:Learnable Prefix: Picture this: before the model gets to process the input text, you prep a small, task-specific set of tokens this is your prefix. Its like a little note that says, Hey, focus on this when youre working! These tokens are learnable, meaning you can train them to carry the relevant task information. Importantly, the rest of the models weights stay locked down, untouched.Controlling Attention: The prefix isnt just a random add-on. These tokens guide the models attention mechanisms, telling it which parts of the input to focus on. Its like placing a signpost at the start of the road, directing the model on where to head next. So, when the model generates an output, its subtly influenced by the prefix tokens, helping it stay on track for the specific task at hand.The Beauty of Prefix Tuning?The brilliance of prefix tuning lies in its simplicity. Youre not retraining the entire model or altering its inner workings. Instead, youre enhancing its attention just enough to guide it in the right direction for the task you need it to perform.BitFit :Lets dive into BitFit, a deceptively simple yet highly effective PEFT technique thats like tweaking just the small dials on a well-tuned machine to get the perfect result. Instead of overhauling the entire system, BitFit focuses on the tiniest components to make a big impact.How BitFit Works:Bias Tuning: Imagine your model is a giant network of gears and levers (aka weights) that are already trained and doing their thing. Now, instead of retraining every gear, BitFit zooms in on the bias terms the extra parameters that get added to the final output of each layer. These bias terms are like small adjustments that help shift the models output in the right direction, but they dont have the complexity or weight of the entire models weights.Minimalist Fine-Tuning: The trick is that only the bias terms are tuned, while the rest of the models weights remain frozen. Bias terms are much smaller in number compared to the full set of weights, so youre making very targeted changes. Its like fine-tuning the volume on a speaker without touching the entire sound system. Youre still getting the desired sound (or task performance), but without the hassle of fiddling with everything.Why BitFit Rocks:The real charm of BitFit is its efficiency. By focusing on just a few parameters, youre able to fine-tune a model for a specific task while keeping the computational load light. Its a great way to make tweaks without the heavy lifting of full model fine-tuning, making it fast and resource-friendly.Implementing LORA from scratch in Pytorch:Now i will explain you how you can Implement the LORA from scratch so that you have more deep understanding about it.importing necessary libraries :import torchimport torchvision.datasets as datasetsimport torchvision.transforms as transformsimport torch.nn as nnfrom tqdm import tqdmMaking torch model deterministic :_ = torch.manual_seed(0)Training a small model :Lets have some fun with LoRA! Well start by building a small, simple model to classify those classic MNIST digits you know, the ones everyone loves to work with when learning machine learning. But heres the twist: instead of stopping at basic digit classification, were going to take it up a notch.Well identify one digit our network struggles with (maybe it just doesnt vibe with the number 7?), and fine-tune the whole thing using LoRA to make it smarter and better at recognizing that tricky number. Its going to be a cool mix of training, tweaking, and improving perfect for seeing LoRA in action!Loading the Dataset:transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])# Load the MNIST datasetmnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)# Create a dataloader for the trainingtrain_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)# Load the MNIST test setmnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)# Define the devicedevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")Model Architecture:class SimpleNN(nn.Module): def __init__(self, hidden_size_1=1000, hidden_size_2=2000): super(SimpleNN,self).__init__() self.linear1 = nn.Linear(28*28, hidden_size_1) self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) self.linear3 = nn.Linear(hidden_size_2, 10) self.relu = nn.ReLU() def forward(self, img): x = img.view(-1, 28*28) x = self.relu(self.linear1(x)) x = self.relu(self.linear2(x)) x = self.linear3(x) return xmodel = SimpleNN().to(device)Training Loop:def train(train_loader, model, epochs=5, total_iterations_limit=None): cross_el = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) total_iterations = 0 for epoch in range(epochs): model.train() loss_sum = 0 num_iterations = 0 data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}') if total_iterations_limit is not None: data_iterator.total = total_iterations_limit for data in data_iterator: num_iterations += 1 total_iterations += 1 x, y = data x = x.to(device) y = y.to(device) optimizer.zero_grad() output = model(x.view(-1, 28*28)) loss = cross_el(output, y) loss_sum += loss.item() avg_loss = loss_sum / num_iterations data_iterator.set_postfix(loss=avg_loss) loss.backward() optimizer.step() if total_iterations_limit is not None and total_iterations >= total_iterations_limit: returntrain(train_loader, model, epochs=1)After executing the above code your small model will get trained and ready to inference,but before that let me keep a copy of the original weights (cloning them) so later we can prove that a fine-tuning with LoRA doesnt alter the original weights.original_weights = {}for name, param in model.named_parameters(): original_weights[name] = param.clone().detach()Now, Testing the performance of the Trained Mode :def test(): correct = 0 total = 0 wrong_counts = [0 for i in range(10)] with torch.no_grad(): for data in tqdm(test_loader, desc='Testing'): x, y = data x = x.to(device) y = y.to(device) output = model(x.view(-1, 784)) for idx, i in enumerate(output): if torch.argmax(i) == y[idx]: correct +=1 else: wrong_counts[y[idx]] +=1 total +=1 print(f'Accuracy: {round(correct/total, 3)}') for i in range(len(wrong_counts)): print(f'wrong counts for the digit {i}: {wrong_counts[i]}')test()Output:Accuracy: 0.954wrong counts for the digit 0: 31wrong counts for the digit 1: 17wrong counts for the digit 2: 46wrong counts for the digit 3: 74wrong counts for the digit 4: 29wrong counts for the digit 5: 7wrong counts for the digit 6: 36wrong counts for the digit 7: 80wrong counts for the digit 8: 25wrong counts for the digit 9: 116As you can see the worst performing digit is 9.LoRA Implementation :Define the LoRA parameterization as described in the paper. The full detail on how PyTorch parameterizations work is here: clickclass LoRAParametrization(nn.Module): def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'): super().__init__() # Section 4.1 of the paper: # We use a random Gaussian initialization for A and zero for B, so W = BA is zero at the beginning of training self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device)) self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device)) nn.init.normal_(self.lora_A, mean=0, std=1) # Section 4.1 of the paper: # We then scale Wx by /r , where is a constant in r. # When optimizing with Adam, tuning is roughly the same as tuning the learning rate if we scale the initialization appropriately. # As a result, we simply set to the first r we try and do not tune it. # This scaling helps to reduce the need to retune hyperparameters when we vary r. self.scale = alpha / rank self.enabled = True def forward(self, original_weights): if self.enabled: # Return W + (B*A)*scale return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale else: return original_weightsimport torch.nn.utils.parametrize as parametrizedef linear_layer_parameterization(layer, device, rank=1, lora_alpha=1): # Only add the parameterization to the weight matrix, ignore the Bias # From section 4.2 of the paper: # We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules (so they are not trained in downstream tasks) both for simplicity and parameter-efficiency. # [...] # We leave the empirical investigation of [...], and biases to a future work. features_in, features_out = layer.weight.shape return LoRAParametrization( features_in, features_out, rank=rank, alpha=lora_alpha, device=device )parametrize.register_parametrization( model.linear1, "weight", linear_layer_parameterization(model.linear1, device))parametrize.register_parametrization( model.linear2, "weight", linear_layer_parameterization(model.linear2, device))parametrize.register_parametrization( model.linear3, "weight", linear_layer_parameterization(model.linear3, device))def enable_disable_lora(enabled=True): for layer in [model.linear1, model.linear2, model.linear3]: layer.parametrizations["weight"][0].enabled = enabledDisplay the number of parameters added by LoRA.total_parameters_lora = 0total_parameters_non_lora = 0for index, layer in enumerate([model.linear1, model.linear2, model.linear3]): total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement() total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement() print( f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}' )# The non-LoRA parameters count must match the original networkassert total_parameters_non_lora == total_parameters_originalprint(f'Total number of parameters (original): {total_parameters_non_lora:,}')print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100print(f'Parameters incremment: {parameters_incremment:.3f}%')Output:Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])Total number of parameters (original): 2,807,010Total number of parameters (original + LoRA): 2,813,804Parameters introduced by LoRA: 6,794Parameters incremment: 0.242%Freezing all the parameters of the original network and only fine tuning the ones introduced by LoRA. Then fine-tune the model on the digit 9 and only for 100 batches.# Freeze the non-Lora parametersfor name, param in model.named_parameters(): if 'lora' not in name: print(f'Freezing non-LoRA parameter {name}') param.requires_grad = False# Load the MNIST dataset again, by keeping only the digit 9mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)exclude_indices = mnist_trainset.targets == 9mnist_trainset.data = mnist_trainset.data[exclude_indices]mnist_trainset.targets = mnist_trainset.targets[exclude_indices]# Create a dataloader for the trainingtrain_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)train(train_loader, model, epochs=1, total_iterations_limit=100)After Training the above new LoRA introduced weights modelVerifying that the fine-tuning didnt alter the original weights, but only the ones introduced by LoRA.# Check that the frozen parameters are still unchanged by the finetuningassert torch.all(model.linear1.parametrizations.weight.original == original_weights['linear1.weight'])assert torch.all(model.linear2.parametrizations.weight.original == original_weights['linear2.weight'])assert torch.all(model.linear3.parametrizations.weight.original == original_weights['linear3.weight'])enable_disable_lora(enabled=True)# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization# The original weights have been moved to net.linear1.parametrizations.weight.original# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-moduleassert torch.equal(model.linear1.weight, model.linear1.parametrizations.weight.original + (model.linear1.parametrizations.weight[0].lora_B @ model.linear1.parametrizations.weight[0].lora_A) * model.linear1.parametrizations.weight[0].scale)enable_disable_lora(enabled=False)# If we disable LoRA, the linear1.weight is the original oneassert torch.equal(model.linear1.weight, original_weights['linear1.weight'])Testing the network with LoRA enabled (the digit 9 should be classified better)# Test with LoRA enabledenable_disable_lora(enabled=True)test()Output:Accuracy: 0.924wrong counts for the digit 0: 47wrong counts for the digit 1: 27wrong counts for the digit 2: 65wrong counts for the digit 3: 240wrong counts for the digit 4: 89wrong counts for the digit 5: 32wrong counts for the digit 6: 54wrong counts for the digit 7: 137wrong counts for the digit 8: 61wrong counts for the digit 9: 9Testing the network with LoRA disabled (the accuracy and errors counts must be the same as the original network)enable_disable_lora(enabled=False)test()Output:wrong counts for the digit 0: 31wrong counts for the digit 1: 17wrong counts for the digit 2: 46wrong counts for the digit 3: 74wrong counts for the digit 4: 29wrong counts for the digit 5: 7wrong counts for the digit 6: 36wrong counts for the digit 7: 80wrong counts for the digit 8: 25wrong counts for the digit 9: 116Conclusion :The implementation weve walked through demonstrates the power and efficiency of LoRA in practice. Through our MNIST example, weve seen how LoRA can significantly improve model performance on specific tasks (like digit 9 recognition) while adding only 0.242% more parameters to the original model. This perfectly illustrates why PEFT techniques, particularly LoRA, are becoming increasingly important in the AI landscape.Key takeaways from our exploration:PEFT techniques like LoRA make fine-tuning accessible even with limited computational resourcesBy focusing on crucial parameters, we can achieve significant improvements in task-specific performanceThe original model weights remain unchanged, allowing for multiple task-specific adaptationsThe implementation requires minimal code changes to existing architecturesThe future of AI model adaptation lies in such efficient techniques that balance performance with resource utilization. As models continue to grow in size and complexity, PEFT approaches will become even more crucial for practical applications.GitHub Repository :I have created an project in which you can fine tune resnet on your custom dataset by using the technique that we have just learned.For the complete code and implementation details, visit: github.com/yourusername/peft-lora-guideJoin thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming asponsor. Published via Towards AI
0 Comments
·0 Shares
·32 Views