# Import necessary libraries
# Adapted from: https://github.com/tamangmilan/llama3/blob/main/build_llama3_from_scratch.ipynb
# and https://github.com/Atulit23/ai_from_scratch/blob/main/MLA.ipynb https://medium.com/@atulit23/implementing-multi-head-latent-attention-from-scratch-in-python-1e14d03fbc91
import torch
from torch import nn
from torch.nn import functional as F

import math
import numpy as np
import time
from dataclasses import dataclass
from typing import Optional, Tuple, List
import pandas as pd
from matplotlib import pyplot as plt
import os

from torch.profiler import profile, ProfilerActivity, record_function

torch.set_default_device('cuda')
torch.set_default_dtype(torch.bfloat16)
from xformers.ops.swiglu_op import SwiGLU

# Input block

### Step 1: Input Block ###

# Using Tiny Shakespeare dataset for character-level tokenizer. Some part of the following character-level tokenizer is referenced from Andrej karpathy's GitHub (https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare_char/prepare.py) which I found is explained very well.
# Load tiny_shakespeare data file (https://github.com/tamangmilan/llama3/blob/main/tiny_shakespeare.txt)

device: str = 'cuda' if torch.cuda.is_available() else 'cpu'   # Assign device to cuda or cpu based on availability

# Load tiny_shakespeare data file.
with open('tiny_shakespeare.txt', 'r') as f:
  data = f.read()

# Prepare vocabulary by taking all the unique characters from the tiny_shakespeare data
vocab = sorted(list(set(data)))

# Training Llama 3 model requires addtional tokens such as <|begin_of_text|>, <|end_of_text|> and <|pad_id|>, we'll add them into vocabulary
vocab.extend(['<|begin_of_text|>','<|end_of_text|>','<|pad_id|>'])
vocab_size = len(vocab)

# Create a mapping between characters with corresponding integer indexes in vocabulary.
# This is important to build tokenizers encode and decode functions.
itos = {i:ch for i, ch in enumerate(vocab)}
stoi = {ch:i for i, ch in enumerate(vocab)}

# Tokenizers encode function: take a string, output a list of integers
def encode(s):
  return [stoi[ch] for ch in s]

# Tokenizers decode function: take a list of integers, output a string
def decode(l):
  return ''.join(itos[i] for i in l)

# Define tensor token variable to be used later during model training
token_bos = torch.tensor([stoi['<|begin_of_text|>']], dtype=torch.int, device=device)
token_eos = torch.tensor([stoi['<|end_of_text|>']], dtype=torch.int, device=device)
token_pad = torch.tensor([stoi['<|pad_id|>']], dtype=torch.int, device=device)



# Model stuff:

 #Step2: The Decoder Block
# Note: Since the Llama 3 model is developed by Meta, so to be in sync with their codebase and for future compatibility,
# I will use most of the code from Meta GitHub with some necessary changes required to achieve our goal.

# Define parameters dataclass: we'll use these parameters during model building, training and inference.
# Note: Since we want to see the results of training and inferencing faster rather than focusing on high accuracy, we're taking lower values for most of the parameters which are set higher in the Llama 3 model.

@dataclass
class ModelArgs:
    dim: int = 1024              # embedding dimension
    n_layers: int = 8           # number of model decoder blocks
    n_heads: int = 16            # number of heads for queries embedding
    n_kv_heads: int = 8         # number of heads for keys and values embedding
    vocab_size: int = len(vocab) # Length of vocabulary
    multiple_of: int = 256        # Require to calculate dim of feedfoward network
    ffn_dim_multiplier: Optional[float] = None  # Require to calculate dim of feedfoward network
    norm_eps: float = 1e-5                       # Default Epsilon value set for the RMSNorm calculation
    rope_theta: float = 10000.0   # Default theta value for the RePE calculation

    max_batch_size: int = 10      # Max batch size
    max_seq_len: int = 256         # Max sequence length

    epochs: int = 1000             # Total number of training iteration
    log_interval: int = 10        # Number of interval to print the logs and loss values
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'   # Assign device to cuda or cpu based on availability
    d_rope = 16
    d_kv_comp = 64

### Deepseek Attn:

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, scale=40):
        super().__init__()
        assert dim % 2 == 0, "Dimension must be even for rotary embeddings"
        self.dim = dim
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim//2, 2).bfloat16() / (dim//2)))
        self.register_buffer("inv_freq", inv_freq)
        self.scale = 40

    def forward(self, seq_len):
        t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq) / self.scale
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        return torch.cat((freqs, freqs), dim=-1)

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary(x, cos, sin):
    """
    Apply rotary embeddings to the first half of x.
    """
    # Split x into two parts: one for rotary embeddings and the other untouched
    x_rot, x_base = x.split(cos.shape[-1], dim=-1)
    # Apply rotary embeddings to the rotary part
    x_rot = (x_rot * cos) + (rotate_half(x_rot) * sin)
    # Concatenate the rotary-applied and base parts
    return torch.cat([x_rot, x_base], dim=-1)


class MemoryOptimizedMLA(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        self.config = config
        self.d_head = config.dim // config.n_heads
        self.split_dim = self.d_head - config.d_rope

        # Projections
        self.W_dkv = nn.Linear(config.dim, config.d_kv_comp)
        self.W_dq = nn.Linear(config.dim, config.d_kv_comp)

        # Changed value projection to use d_head instead of split_dim
        self.W_uk = nn.Linear(config.d_kv_comp, config.n_heads * self.split_dim)
        self.W_uv = nn.Linear(config.d_kv_comp, config.n_heads * self.d_head)
        self.W_uq = nn.Linear(config.d_kv_comp, config.n_heads * self.split_dim)

        self.W_qr = nn.Linear(config.d_kv_comp, config.n_heads * config.d_rope)
        self.W_kr = nn.Linear(config.dim, config.n_heads * config.d_rope)

        self.rotary = RotaryEmbedding(config.d_rope)
        self.output = nn.Linear(config.n_heads * self.d_head, config.dim)

    def forward(self, h, past_kv=None):
        config = self.config
        batch_size, seq_len, _ = h.shape

        # KV Compression
        c_kv = self.W_dkv(h)
        k = self.W_uk(c_kv).view(batch_size, seq_len, config.n_heads, self.split_dim)
        v = self.W_uv(c_kv).view(batch_size, seq_len, config.n_heads, self.d_head)

        # Query Compression
        c_q = self.W_dq(h)
        q_base = self.W_uq(c_q).view(batch_size, seq_len, config.n_heads, self.split_dim)
        q_rot = self.W_qr(c_q).view(batch_size, seq_len, config.n_heads, config.d_rope)

        # Rotary embeddings with proper dimensions
        rotary_emb = self.rotary(seq_len)
        cos = torch.cos(rotary_emb).view(1, seq_len, 1, -1)  # [1, seq, 1, dim]
        sin = torch.sin(rotary_emb).view(1, seq_len, 1, -1)

        # Apply rotary embeddings
        q_rot = apply_rotary(q_rot, cos, sin)
        k_rot = apply_rotary(
            self.W_kr(h).view(batch_size, seq_len, config.n_heads, config.d_rope),
            cos, sin
        )

        q = torch.cat([q_base, q_rot], dim=-1)
        k = torch.cat([k, k_rot], dim=-1)

        # Attention computation
        #scores = torch.einsum("bqhd,bkhd->bhqk", q, k) / math.sqrt(self.d_head)
        #attn = F.softmax(scores, dim=-1)
        #out = torch.einsum("bhqk,bkhd->bqhd", attn, v)
        mask = torch.full((seq_len, seq_len),float("-inf"),device=self.config.device)
        mask = torch.triu(mask, diagonal=1).to(self.config.device)
        q = q.transpose(1,2)
        k = k.transpose(1,2)
        v = v.transpose(1,2)
        out = F.scaled_dot_product_attention(q,k,v, mask)

        return self.output(out.contiguous().view(batch_size, seq_len, -1)), (c_kv, k_rot)


### Transformer block:

## Step2f: The Decoder Block. The class name is assigned as TransformerBlock to match the name of Meta llama 3 code base.

class TransformerBlock(nn.Module):
  def __init__(self, args: ModelArgs):
    super().__init__()
    self.args = args
    # Initilizate RMSNorm for attention
    self.attention_norm = nn.RMSNorm(args.dim, eps = args.norm_eps)
    # Initilizate Attention class
    # self.attention = Attention(args)
    self.attention = MemoryOptimizedMLA(args)
    # Initilizate RMSNorm for feedfoward class
    self.ff_norm = nn.RMSNorm(args.dim, eps = args.norm_eps)
    # Initilizate feedfoward class
    #self.feedforward = FeedForward(args.dim, 4 * args.dim, args.multiple_of, args.ffn_dim_multiplier)
    #self.feedforward = SwiGLU(args.dim, 4 * args.dim, bias = False)
    self.feedforward = SwiGLU(1024, 2816, bias = False) # This results in the same number of parameters as the FeedForward implementation

  def forward(self, x, past_kv=None):
    # start_pos = token position for inference mode, inference = True for inference and False for training mode
    # i) pass input embedding to attention_norm and then pass to attention block.
    # ii) the output of attention is then added to embedding(before norm)
    h, new_kv = self.attention(self.attention_norm(x), past_kv)
    h = x + h

    # i) pass attention output to ff_norm and then pass to the feedforward network.
    # ii) the output of feedforward network is then added to the attention output(before ff_norm)
    out = h + self.feedforward(self.ff_norm(h))
    # Shape: [bsz,seq_len,dim]
    return out, new_kv

### Model definition:

## Step3: The Output Block
# This is the Llama 3 model. Again, the class name is maintained as Transformer to match with Meta Llama 3 model.

class Transformer(nn.Module):
  def __init__(self, params: ModelArgs):
    super().__init__()
    # set all the ModelArgs in params variable
    self.params = params
    # Initilizate embedding class from the input block
    self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)

    # Initialize the decoder block and store it inside the ModuleList.
    # This is because we've 4 decoder blocks in our Llama 3 model. (Official Llama 3 has 32 blocks)
    self.layers = nn.ModuleList()
    for layer_id in range(params.n_layers):
      self.layers.append(TransformerBlock(args=params))

    # Initilizate RMSNorm for the output block
    self.norm = nn.RMSNorm(params.dim, eps = params.norm_eps)

    # Initilizate linear layer at the output block.
    self.output = nn.Linear(params.dim, params.vocab_size, bias=False)

  def forward(self, x, targets):

    # start_pos = token position for inference mode, inference = True for inference and False for training mode
    # x is the batch of token_ids generated from the texts or prompts using tokenizers.
    # x[bsz, seq_len] -> h[bsz, seq_len, dim]
    h = self.tok_embeddings(x)

    # If the target is none, Inference mode is activated and set to "True" and "False" if Training mode is activated.
    if targets is None:
      inference = True
    else:
      inference = False

    # The embeddings (h) will then pass though all the decoder blocks.
    new_kv = None
    for layer in self.layers:
      h, new_kv = layer(h, new_kv)

    # The output from the final decoder block will feed into the RMSNorm
    h = self.norm(h)

    # After normalized, the embedding h will then feed into the Linear layer.
    # The main task of the Linear layer is to generate logits that maps the embeddings with the vocabulary size.
    # h[bsz, seq_len, dim] -> logits[bsz, seq_len, vocab_size]
    logits = self.output(h).float()
    loss = None

    # Inference mode is activated if the targets is not available
    if targets is None:
      loss = None
    # Training mode is activated if the targets are available. And Loss will be calculated for further model training.
    else:
      loss = F.cross_entropy(logits.view(-1, self.params.vocab_size), targets.view(-1))

    return logits, loss


## Step 4: Train Llama 3 Model:

# Create a dataset by encoding the entire tiny_shakespeare data token_ids list using the tokenizer's encode function that we've built at the input block section
dataset = torch.tensor(encode(data), dtype=torch.int).to(ModelArgs.device)
print(f"dataset-shape: {dataset.shape}")

# Define function to generate batches from the given dataset
def get_dataset_batch(data, split, args:ModelArgs):
  seq_len = args.max_seq_len
  batch_size = args.max_batch_size
  device = args.device

  train = data[:int(0.8 * len(data))]
  val = data[int(0.8 * len(data)): int(0.9 * len(data))]
  test = data[int(0.9 * len(data)):]

  batch_data = train
  if split == "val":
    batch_data = val

  if split == "test":
    batch_data = test

  # Picking random starting points from the dataset to give random samples for training, validation and testing.

  ix = torch.randint(0, len(batch_data) - seq_len - 3, (batch_size,)).to(device)
  x = torch.stack([torch.cat([token_bos, batch_data[i:i+seq_len-1]]) for i in ix]).long().to(device)
  y = torch.stack([torch.cat([batch_data[i+1:i+seq_len], token_eos]) for i in ix]).long().to(device)

  return x,y

### Test: get_dataset function ###
"""
xs, ys = get_dataset_batch(dataset, split="train", args=ModelArgs)
print([(decode(xs[i].tolist()), decode(ys[i].tolist())) for i in range(len(xs))])
"""

# Define a evaluate loss function to calculate and store training and validation loss for logging and plotting
@torch.no_grad()
def evaluate_loss(model, args:ModelArgs):
  out = {}
  model.eval()

  for split in ["train", "val"]:
    losses = []
    for _ in range(10):
      xb, yb = get_dataset_batch(dataset, split, args)
      _, loss = model(x=xb, targets=yb)
      losses.append(loss.item())
    out[split] = np.mean(losses)

  model.train()
  return out

# Define a training function to perform model training
def train(model, optimizer, args:ModelArgs):
    scheduler = scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
    epochs = args.epochs
    log_interval = args.log_interval
    device = args.device
    losses = []
    start_time = time.time()

    # get the log file
    i = 0
    file_path = 'my_training_run'
    while os.path.exists('my_training_run' + str(i)):
        i = i + 1
    with open('my_training_run' + str(i), 'w') as logfile:
        for epoch in range(epochs):
            optimizer.zero_grad()

            xs, ys = get_dataset_batch(dataset, 'train', args)
            xs = xs.to(device)
            ys = ys.to(device)
            logits, loss = model(x=xs, targets=ys)
            loss.backward()
            optimizer.step()
            scheduler.step()

            if epoch % log_interval == 0:
                batch_time = time.time() - start_time
                x = evaluate_loss(model, args)
                losses += [x]
                outstr = f"Epoch {epoch} | val loss {x['val']:.3f} | Time {batch_time:.3f}"
                print(outstr)
                logfile.write(outstr + '\n')
                start_time = time.time()

        # Print the final validation loss
        valstr = f"validation loss: {losses[-1]['val']}"
        print(valstr)
        logfile.write(valstr + '\n')
    # Display the interval losses in plot
    return pd.DataFrame(losses)

if __name__ == '__main__':
    ## Start training our Llama 3 model
    model = Transformer(ModelArgs).to(ModelArgs.device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0003)
    #with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
      #with record_function("model_inference"):
    train_results = train(model, optimizer, ModelArgs)
    train_results.plot()
    plt.show()
    #prof_key_average = prof.key_averages()
    # GENERATE:
    #prompts = "Consider you what services he has done"
    #output_tokens, output_texts = generate(model, prompts, ModelArgs)
    #output_texts = output_texts[0].replace("<|begin_of_text|>", "")
    #print(output_texts)
