Lección 13

KV Cache: no recalcular lo que ya sabemos

Cuando generamos "karpathy" letra por letra, la K y la V de "k" es la misma en cada paso. El KV cache las guarda una vez y las reutiliza. Resultado: la generación es mucho más rápida.

1. El libro subrayado

Imagina que estás estudiando para un examen. Cada vez que necesitas recordar qué dijo el capítulo 3, tienes dos opciones:

El segundo método no cambia el resultado — te lleva a la misma respuesta — pero es mucho más rápido.

📚 K y V son el "post-it" de cada token. En la atención, la Key (K) y Value (V) de un token no cambian entre pasos de generación — solo dependen del token y su posición, que no se mueven. El KV cache los guarda la primera vez que se calculan y los reutiliza en cada paso siguiente.

2. El problema sin cache

Para entender por qué el cache importa, veamos qué pasa durante la generación sin él. Queremos generar "karpathy" letra por letra.

✍️ Generación sin KV cache — la ineficiencia
Paso 1: Tenemos "[BOS]"
  → Calcular Q, K, V para token [BOS]
  → Softmax de atención (1×1)
  → Producir 'k'

Paso 2: Tenemos "[BOS] k"
  → Recalcular K, V para [BOS]  ← ⚠️ mismo resultado que antes
  → Calcular K, V para 'k'      ← nuevo
  → Softmax de atención (2×2)
  → Producir 'a'

Paso 3: Tenemos "[BOS] k a"
  → Recalcular K, V para [BOS]  ← ⚠️ redundante
  → Recalcular K, V para 'k'   ← ⚠️ redundante
  → Calcular K, V para 'a'      ← nuevo
  → Softmax de atención (3×3)
  → Producir 'r'

...

Paso T: Tenemos "[BOS] k a r p a t h"
  → Recalcular K, V para todos los T-1 tokens anteriores ← ⚠️⚠️⚠️
  → Calcular K, V para el token nuevo
  → Softmax de atención (T×T)
  → Producir el token T

Coste sin cache: O(T²)  ← crece cuadráticamente con la longitud
      

3. La solución: guardar K y V

La observación clave es simple: K y V de un token dependen solo de ese token y su posición. Si el token no cambia (y los tokens anteriores nunca cambian — los generamos y ya), sus K y V son idénticos en todos los pasos futuros.

Comparación · Sin cache vs. con cache
Comparación de generación con y sin KV cache Sin cache, cada paso recalcula K y V para todos los tokens anteriores. Con cache, solo se calcula el token nuevo y se recuperan los anteriores del cache. Sin cache coste O(T²) BOS paso 1 BOS k paso 2 BOS k a paso 3 BOS k a r paso 4 ⚠️ Cada paso recalcula todo lo que ya se calculó antes Con KV cache coste O(T) KV cache (acumula K,V de tokens anteriores) BOS → K₀, V₀ guardados k → K₁, V₁ guardados a → K₂, V₂ guardados r ← solo calcular K₃, V₃ ✓ Atención con K₀,K₁,K₂,K₃ V₀,V₁,V₂ del cache V₃ recién calculado → mismo resultado, menos trabajo

Los recuadros oscuros (sin cache) representan cálculos redundantes. Con cache, en cada paso solo se computa el nuevo token; los anteriores se recuperan del cache sin recalcular.

4. Cómo funciona en el código

En microgpt, el cache se implementa acumulando listas de K y V. Cuando llega un nuevo token, se calcula su K y V, se añaden al cache, y la atención trabaja sobre el cache completo.

✍️ Generación con KV cache — pseudocódigo comentado
# Estructura del cache: lista de (K_acumulada, V_acumulada) por capa
kv_cache = None  # vacío al inicio

# Generamos token por token
for paso in range(max_tokens):
    # Solo enviamos el ÚLTIMO token (no toda la secuencia)
    # Porque los anteriores ya están en el cache
    x = wte[ultimo_token] + wpe[posicion_actual]  # [1, 16]

    # En la atención multi-cabeza:
    Q_nuevo = x @ W_q           # solo para el token nuevo
    K_nuevo = x @ W_k           # solo para el token nuevo
    V_nuevo = x @ W_v           # solo para el token nuevo

    # Concatenar con el cache acumulado
    if kv_cache is not None:
        K_total = concatenar(kv_cache.K, K_nuevo)  # [T, 4] por cabeza
        V_total = concatenar(kv_cache.V, V_nuevo)  # [T, 4] por cabeza
    else:
        K_total, V_total = K_nuevo, V_nuevo

    # Atención: Q_nuevo · K_total^T / √d
    scores = Q_nuevo @ K_total.T / sqrt(4)  # [1, T]
    weights = softmax(scores)               # [1, T]
    out = weights @ V_total                 # [1, 4]

    # Actualizar cache
    kv_cache.K = K_total
    kv_cache.V = V_total

    # Continuar forward pass normal...
    proximo_token = muestrear(logits)

Clave: enviamos [1, 16] en vez de [T, 16] →
  el trabajo de Q, K, V es constante por paso, no crece con T
      
El trade-off del cache: El KV cache cambia tiempo por memoria. Guardar K y V para una secuencia de longitud T requiere T × n_head × d_head × 2 valores en memoria. Para microgpt con T=16, n_head=4, d_head=4: solo 512 números — trivial. Para GPT-4 con contextos de 128K tokens y 96 capas, el cache puede pesar gigabytes.

🎮 Pruébalo: generación paso a paso

Observa cómo el cache acumula K y V con cada nuevo token generado:

5. Lo que aprendiste

En esta lección:
  • Sin cache, la generación cuesta O(T²): en cada paso se recalculan las K y V de todos los tokens anteriores.
  • Con KV cache, solo se calculan la K y V del nuevo token. Los anteriores se recuperan del cache. El coste es O(T).
  • Los resultados son matemáticamente idénticos — el cache no aproxima nada, es un atajo exacto.
  • El trade-off es memoria: el cache crece con la longitud de la secuencia. En modelos grandes con contextos largos, el cache puede ser más grande que los propios pesos del modelo.