Cuando generamos "karpathy" letra por letra, la K y la V de "k" es la misma en cada paso. El KV cache las guarda una vez y las reutiliza. Resultado: la generación es mucho más rápida.
Imagina que estás estudiando para un examen. Cada vez que necesitas recordar qué dijo el capítulo 3, tienes dos opciones:
El segundo método no cambia el resultado — te lleva a la misma respuesta — pero es mucho más rápido.
Para entender por qué el cache importa, veamos qué pasa durante la generación sin él. Queremos generar "karpathy" letra por letra.
Paso 1: Tenemos "[BOS]"
→ Calcular Q, K, V para token [BOS]
→ Softmax de atención (1×1)
→ Producir 'k'
Paso 2: Tenemos "[BOS] k"
→ Recalcular K, V para [BOS] ← ⚠️ mismo resultado que antes
→ Calcular K, V para 'k' ← nuevo
→ Softmax de atención (2×2)
→ Producir 'a'
Paso 3: Tenemos "[BOS] k a"
→ Recalcular K, V para [BOS] ← ⚠️ redundante
→ Recalcular K, V para 'k' ← ⚠️ redundante
→ Calcular K, V para 'a' ← nuevo
→ Softmax de atención (3×3)
→ Producir 'r'
...
Paso T: Tenemos "[BOS] k a r p a t h"
→ Recalcular K, V para todos los T-1 tokens anteriores ← ⚠️⚠️⚠️
→ Calcular K, V para el token nuevo
→ Softmax de atención (T×T)
→ Producir el token T
Coste sin cache: O(T²) ← crece cuadráticamente con la longitud
La observación clave es simple: K y V de un token dependen solo de ese token y su posición. Si el token no cambia (y los tokens anteriores nunca cambian — los generamos y ya), sus K y V son idénticos en todos los pasos futuros.
Los recuadros oscuros (sin cache) representan cálculos redundantes. Con cache, en cada paso solo se computa el nuevo token; los anteriores se recuperan del cache sin recalcular.
En microgpt, el cache se implementa acumulando listas de K y V. Cuando llega un nuevo token, se calcula su K y V, se añaden al cache, y la atención trabaja sobre el cache completo.
# Estructura del cache: lista de (K_acumulada, V_acumulada) por capa
kv_cache = None # vacío al inicio
# Generamos token por token
for paso in range(max_tokens):
# Solo enviamos el ÚLTIMO token (no toda la secuencia)
# Porque los anteriores ya están en el cache
x = wte[ultimo_token] + wpe[posicion_actual] # [1, 16]
# En la atención multi-cabeza:
Q_nuevo = x @ W_q # solo para el token nuevo
K_nuevo = x @ W_k # solo para el token nuevo
V_nuevo = x @ W_v # solo para el token nuevo
# Concatenar con el cache acumulado
if kv_cache is not None:
K_total = concatenar(kv_cache.K, K_nuevo) # [T, 4] por cabeza
V_total = concatenar(kv_cache.V, V_nuevo) # [T, 4] por cabeza
else:
K_total, V_total = K_nuevo, V_nuevo
# Atención: Q_nuevo · K_total^T / √d
scores = Q_nuevo @ K_total.T / sqrt(4) # [1, T]
weights = softmax(scores) # [1, T]
out = weights @ V_total # [1, 4]
# Actualizar cache
kv_cache.K = K_total
kv_cache.V = V_total
# Continuar forward pass normal...
proximo_token = muestrear(logits)
Clave: enviamos [1, 16] en vez de [T, 16] →
el trabajo de Q, K, V es constante por paso, no crece con T
T × n_head × d_head × 2 valores en memoria. Para microgpt con T=16, n_head=4, d_head=4: solo 512 números — trivial. Para GPT-4 con contextos de 128K tokens y 96 capas, el cache puede pesar gigabytes.
Observa cómo el cache acumula K y V con cada nuevo token generado: