A Gentle Introduction to Language Model Fine-tuning


import dataclasses

 

import tokenizers

import torch

import torch.nn as nn

import torch.nn.functional as F

from torch import Tensor

 

 

# Model architecture same as training script

@dataclasses.dataclass

class LlamaConfig:

    “”“Define Llama model hyperparameters.”“”

    vocab_size: int = 50000

    max_position_embeddings: int = 2048

    hidden_size: int = 768

    intermediate_size: int = 4*768

    num_hidden_layers: int = 12

    num_attention_heads: int = 12

    num_key_value_heads: int = 3

 

class RotaryPositionEncoding(nn.Module):

    “”“Rotary position encoding.”“”

 

    def __init__(self, dim: int, max_position_embeddings: int) -> None:

        super().__init__()

        self.dim = dim

        self.max_position_embeddings = max_position_embeddings

        N = 10_000.0

        inv_freq = 1.0 / (N ** (torch.arange(0, dim, 2) / dim))

        inv_freq = torch.cat((inv_freq, inv_freq), dim=1)

        position = torch.arange(max_position_embeddings)

        sinusoid_inp = torch.outer(position, inv_freq)

        self.register_buffer(“cos”, sinusoid_inp.cos())

        self.register_buffer(“sin”, sinusoid_inp.sin())

 

    def forward(self, x: Tensor) -> Tensor:

        batch_size, seq_len, num_heads, head_dim = x.shape

        device = x.device

        dtype = x.dtype

        cos = self.cos.to(device, dtype)[:seq_len].view(1, seq_len, 1, 1)

        sin = self.sin.to(device, dtype)[:seq_len].view(1, seq_len, 1, 1)

        x1, x2 = x.chunk(2, dim=1)

        rotated = torch.cat((x2, x1), dim=1)

        return (x * cos) + (rotated * sin)

 

class LlamaAttention(nn.Module):

    “”“Grouped-query attention with rotary embeddings.”“”

 

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.hidden_size = config.hidden_size

        self.num_heads = config.num_attention_heads

        self.head_dim = self.hidden_size // self.num_heads

        self.num_kv_heads = config.num_key_value_heads

        assert (self.head_dim * self.num_heads) == self.hidden_size

 

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)

        self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)

        self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)

        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

 

    def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding) -> Tensor:

        bs, seq_len, dim = hidden_states.size()

 

        query_states = self.q_proj(hidden_states).view(bs, seq_len, self.num_heads, self.head_dim)

        key_states = self.k_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)

        value_states = self.v_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)

 

        attn_output = F.scaled_dot_product_attention(

            rope(query_states).transpose(1, 2),

            rope(key_states).transpose(1, 2),

            value_states.transpose(1, 2),

            is_causal=True,

            dropout_p=0.0,

            enable_gqa=True,

        )

 

        attn_output = attn_output.transpose(1, 2).reshape(bs, seq_len, self.hidden_size)

        return self.o_proj(attn_output)

 

class LlamaMLP(nn.Module):

    “”“Feed-forward network with SwiGLU activation.”“”

 

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)

        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)

        self.act_fn = F.silu

        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)

 

    def forward(self, x: Tensor) -> Tensor:

        gate = self.act_fn(self.gate_proj(x))

        up = self.up_proj(x)

        return self.down_proj(gate * up)

 

class LlamaDecoderLayer(nn.Module):

    “”“Single transformer layer for a Llama model.”“”

 

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e5)

        self.self_attn = LlamaAttention(config)

        self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e5)

        self.mlp = LlamaMLP(config)

 

    def forward(self, hidden_states: Tensor, rope: RotaryPositionEncoding) -> Tensor:

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        attn_outputs = self.self_attn(hidden_states, rope=rope)

        hidden_states = attn_outputs + residual

 

        residual = hidden_states

        hidden_states = self.post_attention_layernorm(hidden_states)

        return self.mlp(hidden_states) + residual

 

class LlamaModel(nn.Module):

    “”“The full Llama model without any pretraining heads.”“”

 

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.rotary_emb = RotaryPositionEncoding(

            config.hidden_size // config.num_attention_heads,

            config.max_position_embeddings,

        )

 

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)

        self.layers = nn.ModuleList([

            LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)

        ])

        self.norm = nn.RMSNorm(config.hidden_size, eps=1e5)

 

    def forward(self, input_ids: Tensor) -> Tensor:

        hidden_states = self.embed_tokens(input_ids)

        for layer in self.layers:

            hidden_states = layer(hidden_states, rope=self.rotary_emb)

        return self.norm(hidden_states)

 

class LlamaForPretraining(nn.Module):

    def __init__(self, config: LlamaConfig) -> None:

        super().__init__()

        self.base_model = LlamaModel(config)

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

 

    def forward(self, input_ids: Tensor) -> Tensor:

        hidden_states = self.base_model(input_ids)

        return self.lm_head(hidden_states)

 

 

def apply_repetition_penalty(logits: Tensor, tokens: list[int], penalty: float) -> Tensor:

    “”“Apply repetition penalty to the logits.”“”

    for tok in tokens:

        if logits[tok] > 0:

            logits[tok] /= penalty

        else:

            logits[tok] *= penalty

    return logits

 

 

@torch.no_grad()

def generate(model, tokenizer, prompt, max_tokens=100, temperature=1.0, repetition_penalty=1.0,

             repetition_penalty_range=10, top_k=50, device=None) -> str:

    “”“Generate text autoregressively from a prompt.

 

    Args:

        model: The trained LlamaForPretraining model

        tokenizer: The tokenizer

        prompt: Input text prompt

        max_tokens: Maximum number of tokens to generate

        temperature: Sampling temperature (higher = more random)

        repetition_penalty: Penalty for repeating tokens

        repetition_penalty_range: Number of previous tokens to consider for repetition penalty

        top_k: Only sample from top k most likely tokens

        device: Device the model is loaded on

 

    Returns:

        Generated text

    ““”

    # Turn model to evaluation mode: Norm layer will work differently

    model.eval()

 

    # Get special token IDs

    bot_id = tokenizer.token_to_id(“[BOT]”)

    eot_id = tokenizer.token_to_id(“[EOT]”)

 

    # Tokenize the prompt into integer tensor

    prompt_tokens = [bot_id] + tokenizer.encode(” “ + prompt).ids

    input_ids = torch.tensor([prompt_tokens], dtype=torch.int64, device=device)

 

    # Recursively generate tokens

    generated_tokens = []

    for _step in range(max_tokens):

        # Forward pass through model

        logits = model(input_ids)

 

        # Get logits for the last token

        next_token_logits = logits[0, 1, :] / temperature

 

        # Apply repetition penalty

        if repetition_penalty != 1.0 and len(generated_tokens) > 0:

            next_token_logits = apply_repetition_penalty(

                next_token_logits,

                generated_tokens[repetition_penalty_range:],

                repetition_penalty,

            )

 

        # Apply top-k filtering

        if top_k > 0:

            top_k_logits = torch.topk(next_token_logits, top_k)[0]

            indices_to_remove = next_token_logits < top_k_logits[1]

            next_token_logits[indices_to_remove] = float(“-inf”)

 

        # Sample from the filtered distribution

        probs = F.softmax(next_token_logits, dim=1)

        next_token = torch.multinomial(probs, num_samples=1)

 

        # Early stop if EOT token is generated

        if next_token.item() == eot_id:

            break

 

        # Append the new token to input_ids for next iteration

        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)

        generated_tokens.append(next_token.item())

 

    # Decode all generated tokens

    return tokenizer.decode(generated_tokens)

 

 

checkpoint = “llama_model_final.pth”   # saved model checkpoint

tokenizer = “bpe_50K.json”   # saved tokenizer

max_tokens = 100

temperature = 0.9

top_k = 50

penalty = 1.1

penalty_range = 10

 

# Load tokenizer and model

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

tokenizer = tokenizers.Tokenizer.from_file(tokenizer)

config = LlamaConfig()

model = LlamaForPretraining(config).to(device)

model.load_state_dict(torch.load(checkpoint, map_location=device))

 

prompt = “Once upon a time, there was”

response = generate(

    model=model,

    tokenizer=tokenizer,

    prompt=prompt,

    max_tokens=max_tokens,

    temperature=temperature,

    top_k=top_k,

    repetition_penalty=penalty,

    repetition_penalty_range=penalty_range,

    device=device,

)

print(prompt)

print(“-“ * 20)

print(response)



Source link