KV Caching

A guide to how KV caching is implemented for LLMs, including a practical example of implementing it for LLaMa 3.2 1B.
caching
LLMs
inferencing
Author

Abbie Petulante

Published

January 9, 2025

Understanding KV Caching

This notebook provides a hands-on exploration of KV caching in language model text generation, specifically using LLaMa 3.2 1B. We’ll examine how caching works, its benefits for inference speed, and its implications for model state management.

Open In Colab

Caching: What Is It?

When a language model processes text, it doesn’t just look at one word at a time - it builds up a complex internal state that represents its “understanding” of the entire context. This state consists of key-value pairs at each layer of the transformer architecture, which encode the relationships and patterns in the input text.

The Traditional Generation Process

Without caching, here’s what happens every time you ask for a completion:

  1. The model takes your prompt (e.g., “The story begins with a”)
  2. Converts it to tokens
  3. Processes these tokens through all its layers, building up its internal state
  4. Uses this state to predict the next token
  5. Adds the new token to the input
  6. Repeats steps 3-5 until done.

What this means is thats after concatenating each newly generated token to the running sequence, it recomputes everything from scratch with the new, longer input, for every single token generated.

And, that means if you want five different endings to the same prompt, the model has to process “The story begins with a” through this process five separate times from scratch!

Enter Caching

Caching is like giving the model a short-term memory. Here’s how it works:

First time: - Process the prompt normally through the steps above - But save the internal state (key-value pairs for KV caching) after processing the prompt

Subsequent times: - Instead of reprocessing the prompt, load the saved state which is already prepped to generate the next token - Start generating from there

In this notebook, we’ll demonstrate both of these ways of generating output from a model, and look at some implications for how we can use KV caching and a saved internal state of the model to get better, faster, responses!

Installations, Imports, and Setup

First, let’s install the necessary packages.

# Install required packages
!pip install -q --upgrade transformers datasets
# Standard libraries
import re
import os
import time

# AI/ML Libraries
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
# Check PyTorch version and CUDA availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
PyTorch version: 2.5.1+cu121
CUDA available: True
CUDA device: NVIDIA A100-SXM4-40GB
Available GPU memory: 39.56 GB

Loading and Testing our Model

Let’s start by loading a small LLM to demonstrate these concepts, and looking at its output before we do anything. We’ll use LLaMa 3.2 1B. This model is excellent for this example because: 1. It is small enough to run on smaller GPUs 2. It uses a relatively simple transformer architecture, making it easier to understand the core concepts 3. Despite its small size, it produces coherent enough outputs to demonstrate the effects of caching on generation

Hugging Face Authentication

LLaMa 3.2 requires authentication with Hugging Face to access the model. You’ll need to: 1. Have a Hugging Face account 2. Accept the LLaMa 3.2 model terms of use on the Hugging Face model page 3. Create an access token on Hugging Face (https://huggingface.co/settings/tokens)

After you have your access token and have accepted the terms, the code below will help you log in:

from huggingface_hub import login
import getpass

token = getpass.getpass("Enter your Hugging Face token: ")
login(token=token)

# Verify login
print("Login status: Authenticated with Hugging Face")
Enter your Hugging Face token: ··········
Login status: Authenticated with Hugging Face
model_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: 
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  warnings.warn(
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=2048, out_features=128256, bias=False)
)

Before we dive into caching, let’s look at how the model processes text. We’ll create a simple function to tokenize and process text, showing the internal states at each step.

def inspect_tokens(text):
    """Display token information for a given text."""
    tokens = tokenizer.encode(text, return_tensors="pt")
    print(f"Text: {text}")
    print(f"Number of tokens: {len(tokens[0])}")
    print("\nToken IDs:")
    print(tokens[0].tolist())
    print("\nDecoded tokens:")
    print([tokenizer.decode([t]) for t in tokens[0]])
    return tokens

# Example usage
sample_text = "The quick brown fox"
tokens = inspect_tokens(sample_text)
Text: The quick brown fox
Number of tokens: 5

Token IDs:
[128000, 791, 4062, 14198, 39935]

Decoded tokens:
['<|begin_of_text|>', 'The', ' quick', ' brown', ' fox']

Generation without Caching

First, let’s start by looking at how we’d generate an output without implementing any caching.

While huggingface makes this easy for us to do all at once, we’ll still write a function here to do the step-by-step generation so that we can observe exactly how this process goes.


One important note: Caching is actually implemented by default in huggingface: any normal call with model() will implement a cache automatically unless use_cache is specifically set to False.

For now, to illustrate our point, we’ll do this. Later, when we implement caching, we’ll explore some ways to have finer control over this caching, so it’s not just all happening under the hood.

def generate_completion(prompt, max_length=100):
    """Generate a completion without caching."""

    #First, convert the prompt into tokens
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

    generated_tokens = []

    start_time = time.time()

    with torch.no_grad():
        for _ in range(max_length):
            # generate the input for the original prompt
            logits = model(input_ids, use_cache=False).logits # caching is actually done by default, so we need to explicitly turn it off!

            predictions = torch.softmax(logits[:, -1, :], dim=-1)
            next_token_id = torch.argmax(predictions).unsqueeze(0)
            generated_tokens.append(next_token_id.item())

            # Append the new token to the input sequence for the next iteration
            input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)

            if next_token_id.item() == tokenizer.eos_token_id:
                break
    generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)

    end_time = time.time()

    return generated_text, end_time - start_time
# Let's try it with a simple prompt. We'll make it long enough that we can see meaningful speedup.
prompt = '''Continue a short story about the fall of an ancient civilization. This civilization was once the greatest that the world had ever seen. It now only exists as ruins, and tourists
who go there in the modern day do not know much of anything about it. At it's height, the civilization was massive, had a thriving economy, beautiful gardens, and a
great culture. The civilization was so advanced that it was able to create a new language, which was spoken by all of the people in the civilization and crossed boundaries to even
be spoken by neighboring lands. Today, this language has been lost to time. The civilization lived in the closest thing to a utopia that the modern world had ever seen.
The rulers of the city were wise and benevolent. They were able to create a society that was peaceful and prosperous. The people of the city'''
start_time = time.time()
completion = generate_completion(prompt)
time_taken = time.time() - start_time
print(f"Completion: {completion}")
print(f"Time taken: {time_taken:.2f} seconds")
Completion: (' were happy and content. They were \nable to live in peace and harmony with each other. The people of the city were able to live in a way that was not only peaceful, but also prosperous. They were able to \nlive in a way that was not only peaceful, but also prosperous. The people of the city were able to live in a way that was not only peaceful, but also prosperous. They were \nable to live in a way that was not only peaceful, but also prosperous', 4.772660255432129)
Time taken: 4.77 seconds

One Step Closer

In the above function, you can see that we start with tokenizing the prompt, then passing through that prompt to prep the model to generate a response.

Then, we one by one generate tokens that follow that input to arrive at our final generation.

One thing that we can think to easily do, which gets us a step “closer” to caching, might be to simply save that first pass through the prompt elsewhere, outside of the function, to make sure that our function now only picks up generating the additional tokens. This would save us the first pass through the initial prompt, at least.

def generate_skipped_input_completion(starting_logits, starting_inputs, max_length=100):
    """Generate a completion, giving the logits to begin with."""
    input_ids = starting_inputs.to(device)
    generated_tokens = []

    with torch.no_grad():
        for i in range(max_length):
            if i == 0:
                logits = starting_logits # don't run the model anymore on the first iteration
            else:
                logits = model(input_ids, # in future passes, still run full set of input_ids
                               use_cache=False).logits # caching is actually done by default, so we need to explicitly turn it off!

            predictions = torch.softmax(logits[:, -1, :], dim=-1)
            next_token_id = torch.argmax(predictions).unsqueeze(0)
            generated_tokens.append(next_token_id.item())

            # Append the new token to the input sequence for the next iteration
            input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)

            if next_token_id.item() == tokenizer.eos_token_id:
                break
    generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)

    return generated_text
#Use the same prompt as before

start_time = time.time()
starting_inputs = tokenizer.encode(prompt, return_tensors='pt').to(device)
starting_logits = model(starting_inputs, use_cache=False).logits
end_time = time.time()
time_to_input = end_time - start_time

start_time = time.time()
completion = generate_skipped_input_completion(starting_logits, starting_inputs)
end_time = time.time()
time_to_generate = end_time - start_time

print(f"Completion: {completion}")
print(f"Time to make starting point: {time_to_input:.2f} seconds")
print(f"Time to generate: {time_to_generate:.2f} seconds")
print(f"Total time: {time_to_generate + time_to_input:.2f} seconds")
Completion:  were happy and content. They were 
able to live in peace and harmony with each other. The people of the city were able to live in a way that was not only peaceful, but also prosperous. They were able to 
live in a way that was not only peaceful, but also prosperous. The people of the city were able to live in a way that was not only peaceful, but also prosperous. They were 
able to live in a way that was not only peaceful, but also prosperous
Time to make starting point: 0.02 seconds
Time to generate: 4.10 seconds
Total time: 4.12 seconds

But, as we can see, saving just that initial pass through the prompt doesn’t save much time!

This is because, in this “version”, although we’ve saved the initial pass through the prompt, we still end up passing through the full prompt + new tokens for each new token generated, so we still end up processing that initial sequence over and over!

This is why caching doesn’t just save what the model was doing - it saves the internal state of the model. In caching, we avoid running that initial prompt through ever again, because we’ve saved what the actual model itself was like at that time, not just what it was prepped to output.

How Caching Works In-Depth

So, what does saving the internal state of the model really mean, and how to we do it in practice?

A Quick Introduction to Transformer Architecture

Before we understand caching, we need to understand how transformers process sequences. In a transformer like Llama, text flows through the model in several stages:

  1. Tokenization: Text → Token IDs
  2. Token Embeddings: Token IDs → Vectors
  3. Multiple Transformer Layers: Each containing:
    • Self-attention mechanism
    • Feed-forward neural networks

The Self-Attention Mechanism: Where Caching Happens

The self-attention portion is where the caching can happen. Let’s look more specifically at what happens in one of these layers.

Step 1: Query, Key, Value Creation

For each token in the sequence, the model creates three vectors: - Query (\(\widehat{Q}\)): What the current token is looking for - Key (\(\widehat{K}\)): What the token offers to others - Value (\(\widehat{V}\)): The actual information content

For example, for a simple sequence like “The cat sat”, you would need to calculate:

Token 1 (“The”):

\(Q_1\) = \(W_Q\) × \(x_1\)

\(K_1\) = \(W_K\) × \(x_1\)

\(V_1\) = \(W_V\) × \(x_1\)

Token 2 (“cat”):

\(Q_2\) = \(W_Q\) × \(x_2\)

\(K_2\) = \(W_K\) × \(x_2\)

\(V_2\) = \(W_V\) × \(x_2\)

Token 3 (“sat”):

\(Q_3\) = \(W_Q\) × \(x_3\)

\(K_3\) = \(W_K\) × \(x_3\)

\(V_3\) = \(W_V\) × \(x_3\)

Calculating a \(Q,K,V\) vector for each word in the sequence, from weight matricies \(W_Q, W_K, W_V\), on the tokenized vector \(x\) for each word.

Step 2: Attention Score Computation

Then, these vectors come together to form \(Q, K, V\) matrices.

for instance:

\[ Q = \begin{bmatrix} Q_1 \\ Q_2 \\ Q_3 \end{bmatrix} \]

where this \(Q\) is a matrix of size seq_length x hidden_dim: One \(\widehat{Q}\) vector per token in the sequence, which has its length determined by the size of the matrix \(W_Q\), a hard-coded dimension of the model.

Because in reality, there are multiple “heads” in each attention layer (multiple \(W_Q, W_K, W_V\)’s, the dimensions are:

  • \(Q\): [num_heads, seq_length, head_dim]
  • \(K\): [num_heads, seq_length, head_dim]
  • \(V\): [num_heads, seq_length, head_dim]

Where: - head_dim = d_h / num_heads - seq_length grows as we generate

So, in the above example, the K matrix, for instance,

When the model processes a sequence, it first computes these \(Q, K, V\) matricies for the input sequence.

Then, an attention score is calculated from these matricies as:

\[ Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V \]

Which is what we really are wanting from an attention layer to progress through. This attention score says how much every token in the sequence should “pay attention” to every other in the sequence, and is used to contextualize the input in order to generate the next token.

Why Caching Matters: The Computational Challenge

So, attention wants to be able to know about all of the tokens in a sequence. And it’s computation will depend on what’s in the sequence, so it makes sense that we need to recalculate it.

But what about what comes before? When we process “the cat”, is computing all of \(Q_1, K_1, V_1, Q_2, K_2, V_2\) and for “the cat sat”, all of \(Q_1, K_1, V_1, Q_2, K_2, V_2, Q_3, K_3, V_3\). You can see how, for long prompts, this quickly becomes a lot.

Without caching, when parsing a sequence, the model must: 1. Compute Q, K, V for the current token 1. Recompute Q, K, V for ALL previous tokens 2. Compute attention scores for ALL combinations 3. Process through ALL layers again

For a sequence of length \(L\), this means \(O(L²)\) computations for EACH new token!

But why regenerate all of the \(Q, K, V\) vectors of previous parts of the sequence? \(W_Q, W_K,\) and \(W_V\) are fixed weight matrices. \(Q, K, V\) matrices are changing as more is added to the sequence, but they’re just getting added to, a calculation for \(Q_1, K_1, V_1\) is the same every time.

Attention at Inference Time

Before we go on, we need to clear up a nuance about how generating the next token (doing inference) happens, which changes how this attention is calculated slighlty at inference time vs when batch-processing a whole (determined) sequence like we just laid out above, which you would do during training.

When it comes to generating the next new word, we need to get the attention score, which contextualizes the current word to all others that came before it. But consider how this calculation works.

\(Q, K,\) and \(V\) are matrices. So, if:

\[ Q = \begin{bmatrix} q_{the,1} & q_{the,2} & q_{the,3} \\ q_{cat,1} & q_{cat,2} & q_{cat,3} \end{bmatrix} \]

\[ K^T = \begin{bmatrix} k_{the,1} & k_{cat,1} \\ k_{the,2} & k_{cat,2} \\ k_{the,3} & k_{cat,3} \end{bmatrix} \]

Then, \[ Q \times K^T = \begin{bmatrix} (q_{the,1} \times k_{the,1} + q_{the,2} \times k_{the,2} + q_{the,3} \times k_{the,3}) & (q_{the,1} \times k_{cat,1} + q_{the,2} \times k_{cat,2} + q_{the,3} \times k_{cat,3}) \\ (q_{cat,1} \times k_{the,1} + q_{cat,2} \times k_{the,2} + q_{cat,3} \times k_{the,3}) & (q_{cat,1} \times k_{cat,1} + q_{cat,2} \times k_{cat,2} + q_{cat,3} \times k_{cat,3}) \end{bmatrix} \]

Now, recall what \(Q,K,\) and \(V\) are meant to represent. \(Q\) is the “query” this “asks” about the token in question. \(K\) the “key” says what information a token has to offer, and \(V\) is the “value” that stores the actual information to give.

We don’t really need any to ask any questions (i.e store any “Q” element) for a word we’ve already generated - there’s nothing more to “ask” or “understand” about a token in the past. During training, computing Q vectors for all tokens is important because the model needs to learn how each token influences and is influenced by every other token in the sequence. But during inference, we only care about how our new token should relate to what came before. We just need the current token’s Q vector to ask ‘how should I pay attention to all previous tokens?’ by using it with the cached K and V values.

And this is evident in the matrix - each row contains all combinations for that given Q. In the above example, the first row tells us about “the” and the second row tells us about “cat”. Querying the current token “(”cat”) doesn’t depend on the query values of the previous word “The”. When multiplying this matrix result by V, a similar observation can be made.

In practice, at generation time, what this means is that we only need K’s and V’s for every token that came before to properly contextualize the current Q. Our Q matrix will actually only be made up of the Q vector for the current token.

So, when prompt caching, you’ll see that we will store the K and V values to avoid re-computing them, but we don’t need to store Q, since the Q of the current token only is actually all that’s being used for the the next token.

KV-Caching

KV caching is method which stores our already computed \(K\) and \(V\) vectors to be re-used in future generations. We don’t recompute everything, we just store what we know we’ll use again, and pull it up when we need to.

When using KV-Caching, we store the Key and Value matrices for each layer. For each new token:

  1. Only compute \(Q, K, V\) for the new token
  2. Concatenate \(K, V\) with cached versions (for instance, although we’ll see an alternate pproach below):
K_new = torch.cat([K_cached, K_token], dim=2)  # dim=1 is sequence dimension
V_new = torch.cat([V_cached, V_token], dim=2)

Before Caching (Processing “The cat sat”)

  1. Process “The”

→ Compute \(Q_1, K_1, V_1\)

  1. Process “The cat”

→ Compute \(K_1, V_1, Q_2, K_2, V_2\)

  1. Process “The cat sat”

→ Compute \(K_1, V_1, K_2, V_2, Q_3, K_3, V_3\)

With Caching

  1. Process “The”

→ Compute \(Q_1, K_1, V_1\)

→ Cache \(K_1, V_1\)

  1. Process “cat”

→ Retrieve \(K_1, V_1\) from cache

→ Only compute \(Q_2, K_2, V_2\)

→ Cache \(K_1, K_2, V_1, V_2\)

  1. Process “sat”

→ Retrieve \(K_1, K_2, V_1, V_2\) from cache

→ Only compute \(Q_3, K_3, V_3\)

→ Cache \(K_1, K_2, K_3, V_1, V_2, V_3\)

and so on.

Memory Requirements

One potential issue with KV-Caching is that it trades memory for speed, since we now need to store all of those K and V values.

For a model with: - \(\mathit{h}\) heads - \(d\) model dimension - \(s\) sequence length

The cache size per layer is: \[ size_{layer} = 2 \times \mathit{h} \times s \times \frac{d}{\mathit{h}} \times sizeof(float16) \]

Then, for \(L\) layers:

\[ size_{total} = L \times 2 \times \mathit{h} \times s \times \frac{d}{\mathit{h}} \times \textit{sizeof(float16)} \]

The factor of 2 comes from storing both K and V.

For Llama-3.2-1B with a 1000-token sequence:

  • ~32 layers
  • 32 heads
  • 64 head dimension
  • float16 (2 bytes)

→ Total cache size ≈ 8MB per 1000 tokens

Performance Impact

For a sequence of length S and generation length G:

Without Caching:

  • For each new token, we recompute \(K,V\) vectors for all previous tokens
  • Need to process entire sequence each time
  • Total Computations ≈ \(S \times G \times (S + G)\)

With Caching:

  • Initial processing of prompt: L² computations
  • For each new token: just one new set of computations
  • Total Computations ≈ \(S^2 + G\)

The speedup becomes more dramatic as the prompt length (\(S\)) increases:

Prompt Length Generation Length Speedup Factor Example
10 tokens 20 tokens ~2x A short sentence
100 tokens 20 tokens ~8x A paragraph
1000 tokens 20 tokens ~40x A long document

This dramatic improvement occurs because: 1. Without caching, each new token requires reprocessing the entire history 2. With caching, each new token only requires computing its own \(K,V\) vectors 3. The longer the prompt, the more redundant computation we avoid

For real-world applications like chatbots or document processing where prompts can be thousands of tokens long, KV caching becomes essential for reasonable inference speed.

KV Caching in Code

Now, let’s edit our generation function to include this KV caching.

Adding Explicit Cache Management in Transformers Library

As we already stated above, caching mechanisms are already built-in and implemented by default in huggingface’s transformers library. However, there are also ways to have much more control over the caching, which we’ll explore in this implementation. Using explicit Cache classes like DynamicCache provides several advantages:

1. Cache Reusability

  • You can save a cache state and reuse it for multiple different generations
  • Useful for generating different endings from the same prompt
  • Helps avoid recomputing prompt processing multiple times

2. Cache Control

  • Choose different cache implementations (Dynamic, Static, Sliding Window)
  • Control memory usage with different cache strategies
  • Explicitly manage when caches are cleared or updated

3. Advanced Use Cases

  • Sliding Window Attention: Limit memory usage for long sequences
  • Quantized Caching: Reduce memory footprint with quantization
  • Cross-Attention Caching: Useful for encoder-decoder models

4. Debugging and Inspection

  • Examine cache contents directly
  • Monitor memory usage
  • Debug attention patterns

You can read more about different ways to implement caching in the huggingface Cache documentation.

def generate_kv_cached_completion(prompt, max_length=100):
    """
    Generate completion using explicit KV caching with DynamicCache.
    This gives us more control over cache management compared to the model's default caching.
    """
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

    # Initialize DynamicCache - this allows us to:
    # 1. Explicitly manage what's cached
    # 2. Reuse the cache across multiple generations
    # 3. Inspect cache contents if needed
    past_key_values = DynamicCache()

    with torch.no_grad():
        # Initial forward pass - process the prompt
        # past_key_values here will store K,V pairs for the prompt
        outputs = model(
            input_ids,
            use_cache=True,
            past_key_values=past_key_values,  # Pass our managed cache
            return_dict=True
        )

        generated_sequence = input_ids
        generated_text = []

        # Generate tokens one at a time
        for _ in range(max_length):
            # Get logits for next token prediction
            next_token_logits = outputs.logits[:, -1, :]
            next_token = torch.argmax(torch.softmax(next_token_logits, dim=-1)).unsqueeze(0).unsqueeze(0)

            # Keep track of the sequence and generated tokens
            generated_sequence = torch.cat([generated_sequence, next_token], dim=-1)
            generated_text.append(next_token.item())

            # Forward pass for next token, using our managed cache
            outputs = model(
                next_token,
                use_cache=True,
                past_key_values=outputs.past_key_values,
                return_dict=True
            )

            if next_token.item() == tokenizer.eos_token_id:
                break

    return tokenizer.decode(generated_text, skip_special_tokens=True)

Now, let’s see how much this speeds up the generation of our story.

# Using a long prompt:
prompt = """The last library on Earth wasn't a building - it was a person. Her name was Sarah Chen, and she was the final recipient of the Memory Archive Protocol,
a desperate procedure developed in the last days before the global web collapsed.
The process had encoded the contents of humanity's greatest digital archives directly into her neural pathways.
Now, ten years after the collapse, she wandered the dusty remains of Silicon Valley, her mind a vast repository of everything from ancient
philosophical texts to modern scientific papers, from classic literature to social media's last posts. Each night, she transcribed a small portion of her knowledge onto carefully preserved paper,
racing against time and her own mortality to preserve what remained of human knowledge.
But on this particular morning, as she wrote in her small, fortified sanctuary, Sarah realized something had changed.
Some of the memories were starting to move on their own, rearranging themselves, evolving into something new. She was simultaneously transported into the memories
and experiencing them in third person. She saw the words dance on the page in time with seeing what the words meant happen in front of her.
It was all out of order. Confusing. She tried to get a handle on what was happening. She steadied herself and focused, tried to put her attention to the here and now. But it was hard to fight it.
She thought about"""

# Test non-cached version
completion, non_cached_time = generate_completion(prompt)
print(f"Completion: {completion}")
print(f"Time taken: {non_cached_time:.2f} seconds")

# Test cached version
start_time = time.time()
completion = generate_kv_cached_completion(prompt)
end_time = time.time()
cached_time = end_time - start_time

print(f"Completion: {completion}")
print(f"Time taken: {cached_time:.2f} seconds")

print(f"\nSpeedup: {non_cached_time/cached_time:.2f}x")
Completion:  the last time she had been here, ten years ago. She had been in the library, and she had been in the library, and she had been in the library. 
She had been in the library, and she had been in the library, and she had been in the library. 
She had been in the library, and she had been in the library, and she had been in the library. 
She had been in the library, and she had been in the library, and
Time taken: 5.70 seconds
Completion:  the last time she had been here, ten years ago. She had been in the library, and she had been in the library, and she had been in the library. 
She had been in the library, and she had been in the library, and she had been in the library. 
She had been in the library, and she had been in the library, and she had been in the library. 
She had been in the library, and she had been in the library, and
Time taken: 1.87 seconds

Speedup: 3.05x

So, we’ve reduced our time considerably!

One final note: A question you might be asking is “Why am I getting the same response every time, and does that have to do with storing the internal state?”

But no! Even though \(K,V\) caching is storing those values, those aren’t where the randomness is happening. It just happens that in our next token generation, we did:

next_token = torch.argmax(torch.softmax(next_token_logits, dim=-1)).unsqueeze(0).unsqueeze(0)

So, we forced the generation to pick what the model thinks is the “best” next token every time, making the calculation deterministic. This is useful to get the most accurate speed comparisons, but not necessary. We could have changed that line to:

next_token = torch.multinomial(torch.softmax(next_token_logits / temperature, dim=-1), num_samples=1).unsqueeze(0)

Where the temperature controls the amount of randomness, and torch.multinomial() will sample the responses instead of always choosing the maximum.

Prompt Caching

You should now understand KV caching in-depth. And it’s a great way to speed up generation when processing a single prompt.

However, prompts, and even more so segments of prompts, are extremely commonly re-used. System prompts for instance, are given as a part of a prompt every single time that a chat bot is called.

And sometimes, these segments of prompts are mixed and matched. “You are friendly and kind” might appear right at the beggining of a prompt, or after “You are a helpful assistant”.

This paper introduced the idea of a “prompt cache” - a way to store how the model would process segments of prompts for easy, fast retrieval. Since this technique builds on KV caching, it’s a natural next step! But it’s a lot to cover in this notebook, which is already dense.

So… please see my notebook on prompt caching to explore the concept in depth!