PASO 5 / 10

Multi-Head Attention

Si self-attention te dio UNA forma de mirar la frase, multi-head te da VARIAS en paralelo. Cada "cabeza" es una self-attention completa pero con sus propias matrices, capturando un tipo de relación distinto.

🎯 De un vistazo

Multi-head es hacer self-attention varias veces en paralelo, cada vez con un "enfoque" distinto, y combinar los resultados.

PROPÓSITO

¿Para qué sirve?

Ejecutar varias self-attentions en paralelo ("cabezas"), cada una con sus propias matrices, para capturar distintos tipos de relación a la vez.

APORTE

¿Qué aporta al modelo?

Múltiples perspectivas simultáneas (sintáctica, semántica, posicional...). Combinadas, dan una comprensión mucho más rica de la frase que una sola atención.

NECESIDAD

¿Por qué es indispensable?

Una sola atención solo puede capturar UN tipo de relación por capa. Pero las frases tienen muchos tipos de relación al mismo tiempo — necesitamos varias "miradas" en paralelo.

↓ A continuación, el detalle con ejemplos y visualizaciones ↓

1. La idea en una sola frase

Multi-head attention = self-attention hecho varias veces en paralelo, con matrices distintas, y después combinar los resultados.

Eso es prácticamente todo lo que tenés que entender. Lo demás son detalles de cómo se hace la combinación. Si ya entendiste el paso 4 (self-attention), ya entendiste la mayor parte del paso 5.

2. ¿Por qué hacerlo varias veces?

Self-attention con UN solo set de matrices (WQ, WK, WV) puede aprender un tipo de relación entre tokens. Pero las palabras de una frase tienen muchos tipos de relaciones simultáneamente.

Mirá esta frase y todas las relaciones que tiene:

"El perro grande come rápido porque tiene hambre"
  • Relación sujeto–verbo: "perro" ↔ "come", "perro" ↔ "tiene".
  • Relación adjetivo–sustantivo: "grande" → "perro".
  • Relación adverbio–verbo: "rápido" → "come".
  • Relación causal: "come" ↔ "porque" ↔ "tiene".
  • Relación de proximidad: cada palabra con sus vecinas inmediatas.

Una sola self-attention NO puede capturar todas estas relaciones al mismo tiempo. Tendría que elegir cuál priorizar — por ejemplo, sujeto-verbo a costa de adjetivo-sustantivo.

Solución: en lugar de UNA self-attention, hacemos varias en paralelo (típicamente entre 8 y 96 en modelos reales). Cada una se llama una "cabeza" (head) y aprende un tipo distinto de relación. Después combinamos las salidas.

3. Analogía: varias cámaras filmando la misma escena

Imagina una escena de fútbol filmada por una sola cámara fija. Vas a ver solo lo que esa cámara captura desde su ángulo. Pierdes información: lo que pasa fuera del cuadro, los detalles del rostro de los jugadores, etc.

Ahora imagina la misma escena filmada por varias cámaras al mismo tiempo: una panorámica, una de primer plano, una desde el arco, una aérea con dron. Cada cámara da una perspectiva distinta. Al juntarlas (en el editing), obtenés una comprensión mucho más rica de la jugada.

📷 Single-head (1 cámara)

  • UN ángulo, UNA perspectiva.
  • Captura UN tipo de patrón.
  • Simple, pero limitado.

🎥 Multi-head (varias cámaras)

  • VARIOS ángulos, VARIAS perspectivas en paralelo.
  • Cada cabeza captura un tipo de patrón distinto.
  • Combinadas, dan una comprensión rica de la frase.

4. La arquitectura completa (vista de pájaro)

El flujo es así:

Input: 3 vectores de dim 4
↓ (se le pasa el MISMO input a las dos cabezas)
Cabeza 1
sus propias W_Q¹, W_K¹, W_V¹
→ self-attention
→ salida de dim 2
Cabeza 2
sus propias W_Q², W_K², W_V²
→ self-attention
→ salida de dim 2
↓ concatenar (pegar las salidas)
Concat: 3 vectores de dim 4
(2 cabezas × dim 2 = dim 4)
↓ proyectar con W_O
Output final: 3 vectores de dim 4
Las 4 piezas nuevas que tenés que entender:
  1. Cada cabeza tiene sus propias matrices W (distintas a las de las otras cabezas).
  2. Las cabezas trabajan en dimensiones más chicas (la dim original dividida por el número de cabezas).
  3. Al final, las salidas se concatenan (se pegan una al lado de la otra).
  4. Hay una matriz extra WO que "mezcla" la información de todas las cabezas.

5. ¿Qué tamaño tiene cada cabeza?

Esta parte es muy importante. La idea: no hacemos H cabezas de la dimensión completa (sería muy costoso). Hacemos H cabezas más chicas, cada una usando d_model / H dimensiones.

Las dimensiones en nuestro ejemplo de juguete

d_model = 4    (dimensión del input — viene del paso 3)
H = 2          (número de cabezas que vamos a usar)

Dimensión de cada cabeza:
  d_k = d_model / H = 4 / 2 = 2

Cada cabeza tendrá:
  W_Q^h shape = [d_model, d_k] = [4, 2]
  W_K^h shape = [d_model, d_k] = [4, 2]
  W_V^h shape = [d_model, d_k] = [4, 2]

Salida de cada cabeza: 3 vectores de dim 2.

Después de concatenar 2 cabezas:
  3 vectores de dim (2 + 2) = 3 vectores de dim 4

Matriz W_O para mezclar las cabezas:
  W_O shape = [d_model, d_model] = [4, 4]

Output final:
  3 vectores de dim 4 (misma forma que el input — importante!)

¿Por qué dividir la dimensión en lugar de mantenerla?

Imagina que tenemos d_model = 768 (típico en BERT). Sin dividir, cada cabeza tendría matrices de 768×768 = ~590k parámetros. Con 12 cabezas, serían ~7 millones de parámetros solo para una capa de atención. Demasiado.

Dividiendo: cada cabeza tiene matrices 768×64 = ~49k parámetros. Con 12 cabezas, totalizamos los mismos ~590k que tendría una sola cabeza grande, pero ahora con 12 perspectivas distintas en lugar de una. Misma cantidad de parámetros, mucho más expresividad.

Dimensiones en modelos reales

Modelo d_model num_heads d_k por cabeza
GPT-2 small7681264
GPT-2 medium10241664
GPT-3 / GPT-41228896128
LLaMA-2 7B409632128
Nuestro ejemplo422

6. Ejemplo trabajado: 2 cabezas sobre "el perro come"

Vamos a aplicar multi-head attention con 2 cabezas a nuestros 3 vectores. Las matrices son inventadas (en un modelo real se aprenderían).

Las matrices de las 2 cabezas (4×2 cada una)

Continuidad con el paso 4: estas matrices NO son nuevas. Son las dos mitades de las matrices 4×4 que usamos en el paso 4 (Self-Attention). La Cabeza 1 toma las columnas 0-1 de cada W, y la Cabeza 2 toma las columnas 2-3. Así se ve el concepto real: una matriz grande partida por columnas en H=2 cabezas más chicas.

H1Cabeza 1

Columnas 0-1 de las W del paso 4

H2Cabeza 2

Columnas 2-3 de las W del paso 4

Sub-paso A: cada cabeza calcula sus propios Q, K, V

Es exactamente como en el paso 4, pero usando las matrices de cada cabeza. Las salidas tienen dimensión 2 (no 4) porque las matrices son 4×2.

Cabeza 1: Q¹, K¹, V¹

Cabeza 2: Q², K², V²

Sub-paso B: cada cabeza hace su propia atención

Cada cabeza ejecuta los 5 sub-pasos del paso 4 con sus propios Q, K, V: producto punto, escalar (por √d_k = √2 ≈ 1.414), softmax, suma ponderada de V. Las matrices de atención salen distintas porque las matrices W son distintas.

Mapa de atención de la Cabeza 1

Cada fila = quien pregunta. Cada columna = a quién atiende. Suma de cada fila = 100%.

Mapa de atención de la Cabeza 2

Mira cómo este patrón es distinto al de la Cabeza 1.

Observación importante: los mapas de atención de cada cabeza son distintos. Eso es exactamente lo que queríamos: cada cabeza ve la frase desde su propio ángulo. En modelos entrenados de verdad, esos patrones serían interpretables (sujeto-verbo, adjetivo-sustantivo, etc.).

Sub-paso C: salidas de cada cabeza (3 vectores de dim 2 cada una)

Salida Cabeza 1

Salida Cabeza 2

Sub-paso D: concatenar las salidas

"Concatenar" significa pegar los vectores uno al lado del otro. Si la salida de la cabeza 1 para "perro" es [a, b] y la de la cabeza 2 es [c, d], el concat es [a, b, c, d].

Mini-repaso: ¿qué es concatenar?

Tres ejemplos simples:
  concat([1, 2], [3, 4])         = [1, 2, 3, 4]
  concat([a, b, c], [d])         = [a, b, c, d]
  concat([1, 2], [3, 4], [5, 6]) = [1, 2, 3, 4, 5, 6]

La dimensión final = suma de las dimensiones de cada parte.
En nuestro caso: dim 2 (cabeza 1) + dim 2 (cabeza 2) = dim 4.

El resultado del concat para nuestras 3 palabras

Sub-paso E: la proyección final con WO

El concat ya tiene la dim correcta (4), pero las 4 dimensiones todavía vienen "separadas": las primeras 2 son de la cabeza 1, las otras 2 de la cabeza 2. WO es una matriz que mezcla esa información para que las dimensiones finales combinen información de ambas cabezas.

output_final = Concat · WO

WO es una matriz [d_model × d_model] = [4 × 4] que también se aprende durante el entrenamiento.

El output final (después de WO)

Esto es lo que sale del bloque de multi-head attention. Mismo número de tokens (3) y misma dimensión (4) que el input — listo para el siguiente paso del Transformer.

7. ¿Qué aprenden las cabezas en modelos reales?

Los investigadores han estudiado los modelos entrenados (BERT, GPT) abriendo el capó y mirando qué patrón aprende cada cabeza. Lo que descubrieron es fascinante:

El modelo nunca fue programado para hacer esto. Estos patrones emergen automáticamente durante el entrenamiento, simplemente porque ayudan a predecir mejor la siguiente palabra. Una arquitectura simple + muchos datos = patrones lingüísticos complejos aprendidos solos.

🎮 Demo interactivo: el algoritmo completo paso a paso

Click "Siguiente paso" para ver cada fase de multi-head attention con números reales calculados en JS:

0. Input
1. Q, K, V por cabeza
2. Atención por cabeza
3. Salida de cada cabeza
4. Concatenar
5. Proyección W_O
Listo para comenzar
Apretá "Siguiente paso" para arrancar el recorrido.

✅ Resumen: lo que entra al siguiente paso

  1. Empezamos con los inputs del paso 3 (3 vectores de dim 4).
  2. Pasamos los inputs por 2 cabezas en paralelo, cada una con sus propias WQ, WK, WV.
  3. Cada cabeza ejecuta el algoritmo del paso 4 (Q·K, escalar, softmax, ·V) y produce su propia salida (3 vectores de dim 2).
  4. Concatenamos las salidas de las cabezas → 3 vectores de dim 4.
  5. Multiplicamos por WO para mezclar la info de las cabezas → 3 vectores de dim 4.

El output tiene la misma forma que el input (3 vectores de dim 4) — esto es crucial para los conexiones residuales que veremos en el paso 7.

En el Paso 6: Feed-Forward Network, cada uno de esos vectores va a pasar por una pequeña red neuronal que procesa cada token de forma independiente, agregando más capacidad de cómputo al modelo.