LECCIÓN 14 · BLOQUE 4

Knowledge distillation

Cuantización y pruning comprimen un modelo existente. Knowledge distillation hace algo diferente: un modelo grande actúa de maestro (teacher) y entrena a uno pequeño (student) transfiriéndole su conocimiento de forma más rica que con simples etiquetas. DistilBERT nació así.

Recuerda de las lecciones anteriores: cuantización (lección 12) redujo la memoria comprimiendo los pesos a menos bits. Pruning (lección 13) eliminó los pesos que contribuyen poco. Ambas técnicas trabajan sobre un modelo ya entrenado. Knowledge distillation funciona de otra manera: crea un modelo nuevo y más pequeño, pero lo entrena aprendiendo de un modelo grande ya entrenado — no solo con el dataset original, sino con las predicciones del modelo grande como señal de entrenamiento.

1. La idea: un experto que explica su razonamiento

Imagina que un médico con 30 años de experiencia quiere enseñar a un estudiante de medicina. Tiene dos opciones:

La segunda opción transmite mucho más información: el experto no solo dice qué ganó, sino cómo de parecidas son las diferentes opciones entre sí. Eso es knowledge distillation: el modelo grande (teacher) entrena al modelo pequeño (student) no con etiquetas duras (0 o 1), sino con su distribución completa de probabilidades.

Por qué la distribución de probabilidades es tan valiosa: cuando un modelo entrenado ve la imagen de un "gato", su probabilidad no es [0.0, 0.0, 0.99, 0.0, 0.0...] — es algo más parecido a [0.01, 0.03, 0.92, 0.02, 0.01...] donde "perro" tiene el 3% y "avión" tiene el 0.01%. Esas probabilidades pequeñas capturan que "gato" y "perro" se parecen más entre sí que "gato" y "avión". El student aprende esas relaciones de similitud, no solo la clase ganadora.

2. Hard labels vs soft labels — la diferencia que importa

✍️ Ejemplo con clasificación de sentimiento — 3 clases
Frase de ejemplo: "La película no era perfecta, pero me entretuvo bastante."

HARD LABELS (etiquetas del dataset):
  Positivo: 1    Neutro: 0    Negativo: 0
  → Solo dice que es "positivo". No dice nada más.

SOFT LABELS del teacher (modelo BERT-large):
  Positivo: 0.72    Neutro: 0.24    Negativo: 0.04
  → Dice: "es bastante positivo, pero con bastante componente neutro"
  → La frase tiene "no era perfecta" que apunta a neutro, pero "me entretuvo"
    que apunta a positivo. El teacher captura esa ambigüedad.

¿Qué aprende el student de cada señal?
  Con hard labels: aprende que esta frase → clase "positivo". Punto.
  Con soft labels: aprende que esta frase → espacio cercano a positivo,
                   con influencia de neutro, y casi nada de negativo.
                   Aprende la GEOMETRÍA del espacio semántico.

Otro ejemplo: "El actor fue increíble. La trama, un desastre."
  Hard labels: Neutro: 1 (label del dataset)
  Soft labels del teacher: Positivo: 0.38  Neutro: 0.41  Negativo: 0.21
  → El dataset dice "neutro", pero el teacher ve que hay elementos positivos
    (el actor) y negativos (la trama). Esa riqueza se transfiere al student.

3. Temperatura T — suavizar las distribuciones

Hay un problema con las soft labels del teacher: un modelo bien entrenado a menudo produce predicciones muy seguras. Por ejemplo: "positivo: 0.999, neutro: 0.001, negativo: 0.000". Esa distribución no es mucho más informativa que las hard labels — casi todo el peso está en una clase.

Para solucionar esto se introduce la temperatura T. La temperatura "suaviza" la distribución: hace que los valores grandes bajen un poco y los pequeños suban un poco, revelando las relaciones de similitud que de otra forma estarían ocultas.

Matemáticamente, la temperatura se aplica dividiendo los logits (los números antes del softmax) por T antes de calcular el softmax. Un logit es el número crudo que sale de la última capa del modelo, antes de convertirlo en probabilidad.

✍️ Efecto de la temperatura — cálculo paso a paso
Supón que el teacher produce estos logits para 3 clases:
  logits = [8.5,  3.2,  0.1]
          (pos) (neu) (neg)

Con T=1 (softmax normal):
  exp(8.5/1) = exp(8.5) = 4914.8
  exp(3.2/1) = exp(3.2) = 24.5
  exp(0.1/1) = exp(0.1) = 1.11
  suma = 4914.8 + 24.5 + 1.11 = 4940.4

  p_positivo = 4914.8 / 4940.4 = 0.9948 (casi certeza total)
  p_neutro   = 24.5   / 4940.4 = 0.0050
  p_negativo = 1.11   / 4940.4 = 0.0002

Con T=2 (temperatura 2):
  exp(8.5/2) = exp(4.25) = 70.1
  exp(3.2/2) = exp(1.60) = 4.95
  exp(0.1/2) = exp(0.05) = 1.05
  suma = 70.1 + 4.95 + 1.05 = 76.1

  p_positivo = 70.1 / 76.1 = 0.921
  p_neutro   = 4.95 / 76.1 = 0.065
  p_negativo = 1.05 / 76.1 = 0.014

Con T=4 (temperatura 4):
  exp(8.5/4) = exp(2.125) = 8.38
  exp(3.2/4) = exp(0.800) = 2.23
  exp(0.1/4) = exp(0.025) = 1.025
  suma = 8.38 + 2.23 + 1.025 = 11.635

  p_positivo = 8.38  / 11.635 = 0.720
  p_neutro   = 2.23  / 11.635 = 0.192
  p_negativo = 1.025 / 11.635 = 0.088

Comparación:
  T=1:  [0.9948, 0.0050, 0.0002] ← casi como hard labels
  T=2:  [0.921,  0.065,  0.014]  ← algo de información en neutro y negativo
  T=4:  [0.720,  0.192,  0.088]  ← mucho más informativa: captura las relaciones

→ Con T=4, el student aprende que "negativo" y "neutro" son más similares
  entre sí que cualquiera de los dos respecto a "positivo" para esta frase.
Rango práctico de T: T=1 es el softmax normal (sin modificar). T=2 a T=4 son los valores más usados en papers y práctica. T demasiado alto (T>10) aplana tanto la distribución que se pierde la señal del teacher. El paper original de Hinton et al. usó T=4 para MNIST. Para LLMs se suele usar T=2 o T=3.

4. La función de pérdida de destilación

La función de pérdida combina dos señales de aprendizaje para el student:

L_total = α × L_CE(student, hard_labels) + (1 − α) × T² × L_KL(student_T, teacher_T)

Desglosemos cada parte:

L_CE — Cross-Entropy con las etiquetas reales

La cross-entropy es la pérdida de clasificación habitual que ya conoces del curso de redes neuronales. Mide cuánto difiere la predicción del student de la etiqueta correcta del dataset (hard label). Esto asegura que el student aprenda la tarea correctamente.

L_KL — KL Divergence con el teacher

KL Divergence (divergencia de Kullback-Leibler) mide cuán diferente es una distribución de probabilidades de otra. Si el teacher da [0.72, 0.19, 0.09] y el student da [0.60, 0.25, 0.15], la KL divergence mide qué tan lejos está el student del teacher. Durante el entrenamiento, el student intenta minimizar esa distancia — aprender a parecerse al teacher. Tanto el student como el teacher usan sus versiones con temperatura T.

α (alpha) — el balance entre las dos señales

α controla cuánto peso le damos a cada pérdida. α=0.5 es el valor más común: mitad y mitad. α más cercano a 1 significa "aprende más de las etiquetas reales". α más cercano a 0 significa "aprende más del teacher". Para destilación pura del conocimiento, valores como α=0.3 o α=0.5.

El factor T²

Cuando usas temperatura T, los gradientes de la KL divergence se vuelven T² veces más pequeños matemáticamente. El factor T² compensa eso para que la escala de la pérdida KL sea comparable a la de la cross-entropy.

✍️ Cálculo de KL Divergence — ejemplo numérico con 3 clases
Continuando el ejemplo anterior (temperatura T=2):

  Teacher con T=2: p_teacher = [0.921, 0.065, 0.014]
  Student hace su predicción, también con T=2: p_student = [0.750, 0.160, 0.090]

La KL Divergence de p_student respecto a p_teacher es:
  KL(p_teacher || p_student) = Σ p_teacher(i) × log(p_teacher(i) / p_student(i))

Para cada clase:
  Clase 0 (positivo):  0.921 × log(0.921 / 0.750) = 0.921 × log(1.228)
                     = 0.921 × 0.2053 = 0.1891

  Clase 1 (neutro):    0.065 × log(0.065 / 0.160) = 0.065 × log(0.4063)
                     = 0.065 × (−0.9004) = −0.0585  ← negativo porque p_student > p_teacher

  Clase 2 (negativo):  0.014 × log(0.014 / 0.090) = 0.014 × log(0.1556)
                     = 0.014 × (−1.8605) = −0.0261  ← negativo por la misma razón

  KL total = 0.1891 + (−0.0585) + (−0.0261) = 0.1045

Nota: KL ≥ 0 siempre. Si el resultado parece negativo es porque estamos usando
la variante "no simétrica" — en PyTorch F.kl_div espera log-probabilidades del student.
PyTorch usa: KL = Σ p_teacher × (log(p_teacher) − log(p_student))

En el código real:
  loss_kl = F.kl_div(
      input=F.log_softmax(student_logits / T, dim=-1),  ← log-probs del student
      target=F.softmax(teacher_logits / T, dim=-1),     ← probs del teacher
      reduction='batchmean'
  ) × (T ** 2)

5. El flujo de destilación — diagrama

KNOWLEDGE DISTILLATION — TEACHER → KNOWLEDGE → STUDENT
Diagrama de knowledge distillation: teacher y student procesan el mismo input, las dos losses se combinan El mismo batch de datos entra al teacher (congelado) y al student (en entrenamiento). La pérdida del student tiene dos componentes: KL con el teacher y cross-entropy con las etiquetas reales. Input Batch Teacher BERT-large, LLaMA-13B... 🔒 pesos congelados Student DistilBERT, modelo pequeño... ✏️ se está entrenando Soft labels (con T) [0.72, 0.19, 0.09] Student logits (con T) [0.60, 0.25, 0.15] Hard labels [1, 0, 0] (del dataset) (1−α) × T² × L_KL α × L_CE L_total → backward → actualizar student

El teacher procesa el batch y produce soft labels (con temperatura T). El student procesa el mismo batch y produce sus predicciones. La pérdida total combina KL divergence (con el teacher) y cross-entropy (con las etiquetas reales).

6. DistilBERT — un caso real de destilación

DistilBERT es el ejemplo más famoso de knowledge distillation en NLP. Fue creado por HuggingFace en 2019 destilando BERT-base (110M parámetros, 12 capas) a un modelo más pequeño (66M parámetros, 6 capas).

✍️ Cómo se creó DistilBERT — los detalles
BERT-base (teacher):     12 capas Transformer, 768 dims, 110M params
DistilBERT (student):     6 capas Transformer, 768 dims,  66M params

¿Cómo se inicializó el student?
  No desde cero — tomaron las capas pares de BERT: 2, 4, 6, 8, 10, 12.
  Esto da al student un punto de partida inteligente, no pesos aleatorios.

¿Qué funciones de pérdida usaron?
  1. L_MLM: Masked Language Modeling (la tarea previa de BERT)
     → el student predice tokens enmascarados, como en el preentrenamiento
  2. L_cos: cosine embedding loss entre las representaciones internas
     → el student aprende a producir representaciones similares a las del teacher
  3. L_CE: cross-entropy sobre los logits del teacher (soft labels con T=1)

Resultados del paper:
  Benchmark GLUE (promedio de tareas NLP):
    BERT-base:   79.5 puntos
    DistilBERT:  77.0 puntos   ← 97% del rendimiento de BERT

  Velocidad de inferencia:
    BERT-base:   100% (referencia)
    DistilBERT:  60% del tamaño → 40% más rápido en CPU

  Memoria:
    BERT-base:   ~440 MB (float32)
    DistilBERT:  ~264 MB (float32) → 40% menos
    DistilBERT:  ~66 MB  (int8)   → 85% menos que BERT original

Conclusión del paper (textual):
  "DistilBERT retiene 97% of BERT performance while being 40% smaller
   and 60% faster." — Sanh et al., 2019

7. Código de destilación — el bucle de entrenamiento

A diferencia de la cuantización y el pruning, la destilación requiere implementar un bucle de entrenamiento custom porque necesitamos pasar el batch por DOS modelos (teacher y student) y calcular la pérdida combinada.

Preparar teacher y student

import torch
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

# Teacher: BERT-base (ya fine-tuned en SST-2 — clasificación de sentimiento)
teacher = AutoModelForSequenceClassification.from_pretrained(
    "textattack/bert-base-uncased-SST-2"
).to(device)
teacher.eval()     # el teacher nunca se entrena — modo evaluación
for param in teacher.parameters():
    param.requires_grad = False   # congelar todos los pesos del teacher

# Student: DistilBERT desde cero (o con pesos preentrenados, pero SIN fine-tune)
student = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=2
).to(device)
# student.train() por defecto — se va a entrenar

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

print(f"Teacher params: {sum(p.numel() for p in teacher.parameters()):,}")
print(f"Student params: {sum(p.numel() for p in student.parameters()):,}")
salida realTeacher params: 109,483,778 Student params: 66,955,010

La función de pérdida de destilación

def distillation_loss(
    student_logits,    # salida cruda del student (antes de softmax)
    teacher_logits,    # salida cruda del teacher (antes de softmax)
    hard_labels,       # etiquetas reales del dataset (0 o 1)
    T=2.0,            # temperatura — suaviza las distribuciones
    alpha=0.5         # balance entre CE (alpha) y KL (1-alpha)
):
    # Pérdida 1: Cross-entropy con etiquetas reales (conocimiento supervisado)
    loss_ce = F.cross_entropy(student_logits, hard_labels)

    # Pérdida 2: KL divergence con el teacher (conocimiento del experto)
    # Aplicar temperatura T a ambos modelos
    student_probs_T = F.log_softmax(student_logits / T, dim=-1)  # log-probs del student
    teacher_probs_T = F.softmax(teacher_logits / T, dim=-1)       # probs del teacher

    loss_kl = F.kl_div(
        input=student_probs_T,
        target=teacher_probs_T,
        reduction="batchmean"   # dividir por el tamaño del batch
    ) * (T ** 2)   # factor T² para compensar la escala de gradientes

    # Combinar las dos pérdidas
    total_loss = alpha * loss_ce + (1 - alpha) * loss_kl

    return total_loss, loss_ce.item(), loss_kl.item()

El bucle de entrenamiento

from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

## Dataset
dataset = load_dataset("sst2")
def tok_fn(b): return tokenizer(b["sentence"], truncation=True, max_length=128)
ds = dataset.map(tok_fn, batched=True)
ds = ds.rename_column("label", "labels")
ds.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

train_loader = DataLoader(
    ds["train"].select(range(5000)),
    batch_size=32,
    collate_fn=DataCollatorWithPadding(tokenizer),
    shuffle=True
)

optimizer = torch.optim.AdamW(student.parameters(), lr=3e-5)

## Bucle de entrenamiento con destilación
T, alpha = 2.0, 0.5
student.train()

for epoch in range(4):
    total_loss = total_ce = total_kl = 0
    n_batches = 0

    for batch in train_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        hard_labels = batch.pop("labels")

        # Forward pass del teacher (sin gradientes — está congelado)
        with torch.no_grad():
            teacher_out = teacher(**batch)
            teacher_logits = teacher_out.logits

        # Forward pass del student (con gradientes)
        student_out = student(**batch)
        student_logits = student_out.logits

        # Calcular pérdida de destilación
        loss, ce, kl = distillation_loss(
            student_logits, teacher_logits, hard_labels, T=T, alpha=alpha
        )

        # Backward y actualización
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_ce += ce
        total_kl += kl
        n_batches += 1

    print(f"Epoch {epoch+1}: loss={total_loss/n_batches:.4f} | CE={total_ce/n_batches:.4f} | KL={total_kl/n_batches:.4f}")
salida realEpoch 1: loss=0.4821 | CE=0.5214 | KL=0.4428 Epoch 2: loss=0.3012 | CE=0.3201 | KL=0.2823 Epoch 3: loss=0.2188 | CE=0.2394 | KL=0.1982 Epoch 4: loss=0.1841 | CE=0.2019 | KL=0.1663 (Las dos pérdidas convergen juntas. KL decae más rápido — el student aprende rápidamente a imitar la distribución del teacher.)

8. Cuándo usar cada técnica del bloque

Situación Técnica recomendada Por qué
Fine-tuning de modelo 7B+ con GPU limitada QLoRA 4-bit + LoRA: máximo ahorro de VRAM con mínima pérdida.
Acelerar inferencia de un modelo ya entrenado Cuantización PTQ (int8) Simple, rápido, pérdida mínima. bitsandbytes en 3 líneas.
Reducir el tamaño del modelo para despliegue Cuantización int4 o int8 4× a 8× menos memoria. Para móvil o edge: GGUF.
Hardware con soporte sparse (A100, H100) Structured pruning + cuantización Combinar ambas para máxima aceleración real.
Crear un modelo pequeño desde cero Knowledge distillation El student aprende más que con el dataset solo.
Redistribuir un modelo fine-tuned compartiendo poco LoRA Solo 2-5 MB de adaptadores vs 260 MB del modelo completo.

🎮 Visualizador de soft labels — temperatura y teacher

Observa cómo cambian las probabilidades del teacher con distintas temperaturas y cómo el student intenta aproximarse a ellas.

Logits crudos del teacher (antes de softmax):

7.2
2.5
-1.5
T=1

10. Lo que aprendiste

Lo que aprendiste hoy: knowledge distillation entrena un modelo pequeño (student) usando las predicciones de un modelo grande (teacher) como señal adicional. Las hard labels solo dicen qué clase ganó; las soft labels del teacher revelan la similitud entre clases — información mucho más rica. La temperatura T suaviza las distribuciones del teacher para extraer esa similitud: T=2 a T=4 son los valores más comunes. La función de pérdida combina cross-entropy con las etiquetas reales (controlado por α) y KL divergence con el teacher (controlado por 1−α). El factor T² compensa la escala de los gradientes. DistilBERT se creó así: 40% más pequeño, 60% más rápido, 97% del rendimiento de BERT. Para crear un modelo nuevo y pequeño desde cero, la destilación supera al entrenamiento estándar con el dataset solo.

En la próxima y última lección: Proyecto final — un pipeline completo de principio a fin combinando todas las técnicas del curso.