Cuantización y pruning comprimen un modelo existente. Knowledge distillation hace algo diferente: un modelo grande actúa de maestro (teacher) y entrena a uno pequeño (student) transfiriéndole su conocimiento de forma más rica que con simples etiquetas. DistilBERT nació así.
Imagina que un médico con 30 años de experiencia quiere enseñar a un estudiante de medicina. Tiene dos opciones:
La segunda opción transmite mucho más información: el experto no solo dice qué ganó, sino cómo de parecidas son las diferentes opciones entre sí. Eso es knowledge distillation: el modelo grande (teacher) entrena al modelo pequeño (student) no con etiquetas duras (0 o 1), sino con su distribución completa de probabilidades.
Frase de ejemplo: "La película no era perfecta, pero me entretuvo bastante."
HARD LABELS (etiquetas del dataset):
Positivo: 1 Neutro: 0 Negativo: 0
→ Solo dice que es "positivo". No dice nada más.
SOFT LABELS del teacher (modelo BERT-large):
Positivo: 0.72 Neutro: 0.24 Negativo: 0.04
→ Dice: "es bastante positivo, pero con bastante componente neutro"
→ La frase tiene "no era perfecta" que apunta a neutro, pero "me entretuvo"
que apunta a positivo. El teacher captura esa ambigüedad.
¿Qué aprende el student de cada señal?
Con hard labels: aprende que esta frase → clase "positivo". Punto.
Con soft labels: aprende que esta frase → espacio cercano a positivo,
con influencia de neutro, y casi nada de negativo.
Aprende la GEOMETRÍA del espacio semántico.
Otro ejemplo: "El actor fue increíble. La trama, un desastre."
Hard labels: Neutro: 1 (label del dataset)
Soft labels del teacher: Positivo: 0.38 Neutro: 0.41 Negativo: 0.21
→ El dataset dice "neutro", pero el teacher ve que hay elementos positivos
(el actor) y negativos (la trama). Esa riqueza se transfiere al student.
Hay un problema con las soft labels del teacher: un modelo bien entrenado a menudo produce predicciones muy seguras. Por ejemplo: "positivo: 0.999, neutro: 0.001, negativo: 0.000". Esa distribución no es mucho más informativa que las hard labels — casi todo el peso está en una clase.
Para solucionar esto se introduce la temperatura T. La temperatura "suaviza" la distribución: hace que los valores grandes bajen un poco y los pequeños suban un poco, revelando las relaciones de similitud que de otra forma estarían ocultas.
Matemáticamente, la temperatura se aplica dividiendo los logits (los números antes del softmax) por T antes de calcular el softmax. Un logit es el número crudo que sale de la última capa del modelo, antes de convertirlo en probabilidad.
Supón que el teacher produce estos logits para 3 clases:
logits = [8.5, 3.2, 0.1]
(pos) (neu) (neg)
Con T=1 (softmax normal):
exp(8.5/1) = exp(8.5) = 4914.8
exp(3.2/1) = exp(3.2) = 24.5
exp(0.1/1) = exp(0.1) = 1.11
suma = 4914.8 + 24.5 + 1.11 = 4940.4
p_positivo = 4914.8 / 4940.4 = 0.9948 (casi certeza total)
p_neutro = 24.5 / 4940.4 = 0.0050
p_negativo = 1.11 / 4940.4 = 0.0002
Con T=2 (temperatura 2):
exp(8.5/2) = exp(4.25) = 70.1
exp(3.2/2) = exp(1.60) = 4.95
exp(0.1/2) = exp(0.05) = 1.05
suma = 70.1 + 4.95 + 1.05 = 76.1
p_positivo = 70.1 / 76.1 = 0.921
p_neutro = 4.95 / 76.1 = 0.065
p_negativo = 1.05 / 76.1 = 0.014
Con T=4 (temperatura 4):
exp(8.5/4) = exp(2.125) = 8.38
exp(3.2/4) = exp(0.800) = 2.23
exp(0.1/4) = exp(0.025) = 1.025
suma = 8.38 + 2.23 + 1.025 = 11.635
p_positivo = 8.38 / 11.635 = 0.720
p_neutro = 2.23 / 11.635 = 0.192
p_negativo = 1.025 / 11.635 = 0.088
Comparación:
T=1: [0.9948, 0.0050, 0.0002] ← casi como hard labels
T=2: [0.921, 0.065, 0.014] ← algo de información en neutro y negativo
T=4: [0.720, 0.192, 0.088] ← mucho más informativa: captura las relaciones
→ Con T=4, el student aprende que "negativo" y "neutro" son más similares
entre sí que cualquiera de los dos respecto a "positivo" para esta frase.
La función de pérdida combina dos señales de aprendizaje para el student:
Desglosemos cada parte:
La cross-entropy es la pérdida de clasificación habitual que ya conoces del curso de redes neuronales. Mide cuánto difiere la predicción del student de la etiqueta correcta del dataset (hard label). Esto asegura que el student aprenda la tarea correctamente.
L_KL — KL Divergence con el teacherKL Divergence (divergencia de Kullback-Leibler) mide cuán diferente es una distribución de probabilidades de otra. Si el teacher da [0.72, 0.19, 0.09] y el student da [0.60, 0.25, 0.15], la KL divergence mide qué tan lejos está el student del teacher. Durante el entrenamiento, el student intenta minimizar esa distancia — aprender a parecerse al teacher. Tanto el student como el teacher usan sus versiones con temperatura T.
α (alpha) — el balance entre las dos señalesα controla cuánto peso le damos a cada pérdida. α=0.5 es el valor más común: mitad y mitad. α más cercano a 1 significa "aprende más de las etiquetas reales". α más cercano a 0 significa "aprende más del teacher". Para destilación pura del conocimiento, valores como α=0.3 o α=0.5.
El factor T²Cuando usas temperatura T, los gradientes de la KL divergence se vuelven T² veces más pequeños matemáticamente. El factor T² compensa eso para que la escala de la pérdida KL sea comparable a la de la cross-entropy.
Continuando el ejemplo anterior (temperatura T=2):
Teacher con T=2: p_teacher = [0.921, 0.065, 0.014]
Student hace su predicción, también con T=2: p_student = [0.750, 0.160, 0.090]
La KL Divergence de p_student respecto a p_teacher es:
KL(p_teacher || p_student) = Σ p_teacher(i) × log(p_teacher(i) / p_student(i))
Para cada clase:
Clase 0 (positivo): 0.921 × log(0.921 / 0.750) = 0.921 × log(1.228)
= 0.921 × 0.2053 = 0.1891
Clase 1 (neutro): 0.065 × log(0.065 / 0.160) = 0.065 × log(0.4063)
= 0.065 × (−0.9004) = −0.0585 ← negativo porque p_student > p_teacher
Clase 2 (negativo): 0.014 × log(0.014 / 0.090) = 0.014 × log(0.1556)
= 0.014 × (−1.8605) = −0.0261 ← negativo por la misma razón
KL total = 0.1891 + (−0.0585) + (−0.0261) = 0.1045
Nota: KL ≥ 0 siempre. Si el resultado parece negativo es porque estamos usando
la variante "no simétrica" — en PyTorch F.kl_div espera log-probabilidades del student.
PyTorch usa: KL = Σ p_teacher × (log(p_teacher) − log(p_student))
En el código real:
loss_kl = F.kl_div(
input=F.log_softmax(student_logits / T, dim=-1), ← log-probs del student
target=F.softmax(teacher_logits / T, dim=-1), ← probs del teacher
reduction='batchmean'
) × (T ** 2)
El teacher procesa el batch y produce soft labels (con temperatura T). El student procesa el mismo batch y produce sus predicciones. La pérdida total combina KL divergence (con el teacher) y cross-entropy (con las etiquetas reales).
DistilBERT es el ejemplo más famoso de knowledge distillation en NLP. Fue creado por HuggingFace en 2019 destilando BERT-base (110M parámetros, 12 capas) a un modelo más pequeño (66M parámetros, 6 capas).
BERT-base (teacher): 12 capas Transformer, 768 dims, 110M params
DistilBERT (student): 6 capas Transformer, 768 dims, 66M params
¿Cómo se inicializó el student?
No desde cero — tomaron las capas pares de BERT: 2, 4, 6, 8, 10, 12.
Esto da al student un punto de partida inteligente, no pesos aleatorios.
¿Qué funciones de pérdida usaron?
1. L_MLM: Masked Language Modeling (la tarea previa de BERT)
→ el student predice tokens enmascarados, como en el preentrenamiento
2. L_cos: cosine embedding loss entre las representaciones internas
→ el student aprende a producir representaciones similares a las del teacher
3. L_CE: cross-entropy sobre los logits del teacher (soft labels con T=1)
Resultados del paper:
Benchmark GLUE (promedio de tareas NLP):
BERT-base: 79.5 puntos
DistilBERT: 77.0 puntos ← 97% del rendimiento de BERT
Velocidad de inferencia:
BERT-base: 100% (referencia)
DistilBERT: 60% del tamaño → 40% más rápido en CPU
Memoria:
BERT-base: ~440 MB (float32)
DistilBERT: ~264 MB (float32) → 40% menos
DistilBERT: ~66 MB (int8) → 85% menos que BERT original
Conclusión del paper (textual):
"DistilBERT retiene 97% of BERT performance while being 40% smaller
and 60% faster." — Sanh et al., 2019
A diferencia de la cuantización y el pruning, la destilación requiere implementar un bucle de entrenamiento custom porque necesitamos pasar el batch por DOS modelos (teacher y student) y calcular la pérdida combinada.
import torch
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
import numpy as np
device = "cuda" if torch.cuda.is_available() else "cpu"
# Teacher: BERT-base (ya fine-tuned en SST-2 — clasificación de sentimiento)
teacher = AutoModelForSequenceClassification.from_pretrained(
"textattack/bert-base-uncased-SST-2"
).to(device)
teacher.eval() # el teacher nunca se entrena — modo evaluación
for param in teacher.parameters():
param.requires_grad = False # congelar todos los pesos del teacher
# Student: DistilBERT desde cero (o con pesos preentrenados, pero SIN fine-tune)
student = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased",
num_labels=2
).to(device)
# student.train() por defecto — se va a entrenar
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
print(f"Teacher params: {sum(p.numel() for p in teacher.parameters()):,}")
print(f"Student params: {sum(p.numel() for p in student.parameters()):,}")
def distillation_loss(
student_logits, # salida cruda del student (antes de softmax)
teacher_logits, # salida cruda del teacher (antes de softmax)
hard_labels, # etiquetas reales del dataset (0 o 1)
T=2.0, # temperatura — suaviza las distribuciones
alpha=0.5 # balance entre CE (alpha) y KL (1-alpha)
):
# Pérdida 1: Cross-entropy con etiquetas reales (conocimiento supervisado)
loss_ce = F.cross_entropy(student_logits, hard_labels)
# Pérdida 2: KL divergence con el teacher (conocimiento del experto)
# Aplicar temperatura T a ambos modelos
student_probs_T = F.log_softmax(student_logits / T, dim=-1) # log-probs del student
teacher_probs_T = F.softmax(teacher_logits / T, dim=-1) # probs del teacher
loss_kl = F.kl_div(
input=student_probs_T,
target=teacher_probs_T,
reduction="batchmean" # dividir por el tamaño del batch
) * (T ** 2) # factor T² para compensar la escala de gradientes
# Combinar las dos pérdidas
total_loss = alpha * loss_ce + (1 - alpha) * loss_kl
return total_loss, loss_ce.item(), loss_kl.item()
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
## Dataset
dataset = load_dataset("sst2")
def tok_fn(b): return tokenizer(b["sentence"], truncation=True, max_length=128)
ds = dataset.map(tok_fn, batched=True)
ds = ds.rename_column("label", "labels")
ds.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
train_loader = DataLoader(
ds["train"].select(range(5000)),
batch_size=32,
collate_fn=DataCollatorWithPadding(tokenizer),
shuffle=True
)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-5)
## Bucle de entrenamiento con destilación
T, alpha = 2.0, 0.5
student.train()
for epoch in range(4):
total_loss = total_ce = total_kl = 0
n_batches = 0
for batch in train_loader:
batch = {k: v.to(device) for k, v in batch.items()}
hard_labels = batch.pop("labels")
# Forward pass del teacher (sin gradientes — está congelado)
with torch.no_grad():
teacher_out = teacher(**batch)
teacher_logits = teacher_out.logits
# Forward pass del student (con gradientes)
student_out = student(**batch)
student_logits = student_out.logits
# Calcular pérdida de destilación
loss, ce, kl = distillation_loss(
student_logits, teacher_logits, hard_labels, T=T, alpha=alpha
)
# Backward y actualización
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
total_ce += ce
total_kl += kl
n_batches += 1
print(f"Epoch {epoch+1}: loss={total_loss/n_batches:.4f} | CE={total_ce/n_batches:.4f} | KL={total_kl/n_batches:.4f}")
| Situación | Técnica recomendada | Por qué |
|---|---|---|
| Fine-tuning de modelo 7B+ con GPU limitada | QLoRA | 4-bit + LoRA: máximo ahorro de VRAM con mínima pérdida. |
| Acelerar inferencia de un modelo ya entrenado | Cuantización PTQ (int8) | Simple, rápido, pérdida mínima. bitsandbytes en 3 líneas. |
| Reducir el tamaño del modelo para despliegue | Cuantización int4 o int8 | 4× a 8× menos memoria. Para móvil o edge: GGUF. |
| Hardware con soporte sparse (A100, H100) | Structured pruning + cuantización | Combinar ambas para máxima aceleración real. |
| Crear un modelo pequeño desde cero | Knowledge distillation | El student aprende más que con el dataset solo. |
| Redistribuir un modelo fine-tuned compartiendo poco | LoRA | Solo 2-5 MB de adaptadores vs 260 MB del modelo completo. |
Observa cómo cambian las probabilidades del teacher con distintas temperaturas y cómo el student intenta aproximarse a ellas.
Logits crudos del teacher (antes de softmax):
En la próxima y última lección: Proyecto final — un pipeline completo de principio a fin combinando todas las técnicas del curso.