La cuantización comprime números de alta precisión a baja precisión. El pruning hace algo diferente: identifica qué pesos de la red contribuyen poco al resultado y los elimina directamente. El árbol sigue creciendo — con menos ramas, pero igual de fuerte.
Pruning (del inglés "poda") es el proceso de eliminar pesos o neuronas de una red neuronal entrenada. El término viene de la jardinería: cuando podas un árbol, eliminas las ramas secas o débiles para que la energía vaya a las ramas sanas.
Después de entrenar una red neuronal, si examinas los valores de los pesos, encontrarás que muchos de ellos son muy pequeños, muy cerca de cero. Un peso con valor 0.0003 en una capa que tiene valores de hasta 2.5 contribuye prácticamente nada a la salida final. Poner ese peso a cero exactamente no va a cambiar mucho el resultado.
Supón una capa lineal sencilla: salida = W × entrada
entrada = [1.0, 0.5, 2.0, 0.8]
Fila 1 de W: [2.1, 0.0003, -1.8, 0.0001]
Salida = 2.1×1.0 + 0.0003×0.5 + (−1.8)×2.0 + 0.0001×0.8
= 2.1 + 0.00015 + (−3.6) + 0.00008
= 2.1 + 0.00015 − 3.6 + 0.00008
= −1.49977
Si ponemos los pesos pequeños a cero:
Fila 1 podada: [2.1, 0, -1.8, 0]
Salida_podada = 2.1×1.0 + 0×0.5 + (−1.8)×2.0 + 0×0.8
= 2.1 + 0 − 3.6 + 0
= −1.5
Error introducido: |−1.5 − (−1.49977)| = 0.00023
El error es de 0.023% de la magnitud de la salida (−1.5).
Ese error es negligible — y así en miles de capas, los errores se promedian.
Existen dos estrategias fundamentales de pruning, con propiedades muy diferentes:
Pone a cero los pesos individuales más pequeños, sin importar su posición en la matriz. El resultado es una matriz esparsa: tiene muchos ceros pero con una forma irregular. Es como taladrar agujeros aleatorios en una placa de metal.
Ventaja: máxima flexibilidad, puede alcanzar altas tasas de sparsidad (80-90%) sin mucha pérdida.
Problema: las matrices esparsas irregulares son difíciles de acelerar en hardware moderno (GPUs están optimizadas para matrices densas). El modelo puede ser "más pequeño" en teoría pero no más rápido en la práctica.
Elimina estructuras enteras: una neurona completa, una cabeza de atención entera, un canal de convolución. El resultado es un modelo más pequeño con la misma estructura densa — como cortar una columna entera de la tabla.
Ventaja: el modelo resultante es genuinamente más rápido — misma aritmética densa, simplemente con menos filas/columnas.
Problema: más difícil de aplicar sin perder calidad — eliminar una neurona entera impacta más que poner un peso individual a cero.
Izquierda: unstructured pruning — ceros dispersos, misma forma de matriz. Derecha: structured pruning — una columna entera eliminada, la matriz encoge de verdad.
El método de pruning más intuitivo y más usado como punto de partida es el magnitude pruning (poda por magnitud). El criterio es elemental: los pesos con valor absoluto más pequeño se ponen a cero.
"Valor absoluto" significa el tamaño del número sin importar su signo. El valor absoluto de −0.003 es 0.003. El valor absoluto de +0.003 también es 0.003. Un peso de −0.003 y uno de +0.003 son igualmente "pequeños" en términos de magnitud.
Pesos originales (10 valores):
w = [2.1, -0.003, 0.8, -1.5, 0.0002, -0.9, 0.005, 1.2, -0.001, -0.7]
Paso 1: calcular el valor absoluto de cada peso:
|w| = [2.1, 0.003, 0.8, 1.5, 0.0002, 0.9, 0.005, 1.2, 0.001, 0.7]
Paso 2: ordenar de menor a mayor:
posición: 0.0002 0.001 0.003 0.005 0.7 0.8 0.9 1.2 1.5 2.1
índice: 4 8 1 6 9 2 5 7 3 0
Paso 3: con 50% de pruning, ponemos a cero los 5 de MENOR magnitud:
Ceros (los 5 menores): índices 4, 8, 1, 6, 9
Sobreviven (los 5 mayores): índices 2, 5, 7, 3, 0
Paso 4: aplicar la máscara:
w_podado = [2.1, 0, 0.8, -1.5, 0, -0.9, 0, 1.2, 0, 0]
↑ ↑ ↑ ↑ ↑
(0.003) (0.0002) (0.005) (0.001) (−0.7)
Comprobación: los pesos que sobreviven son los más grandes en magnitud.
Los podados tenían valores casi insignificantes comparados con pesos como 2.1 o 1.5.
Sparsidad resultante: 5 ceros de 10 total = 50%
PyTorch incluye un módulo de pruning listo para usar: torch.nn.utils.prune.
Veamos cómo aplicarlo a DistilBERT.
import torch
import torch.nn.utils.prune as prune
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2
)
# Ver la primera capa lineal del modelo
first_layer = model.distilbert.transformer.layer[0].attention.q_lin
print(f"Tipo del parámetro: {first_layer.weight.dtype}")
print(f"Forma del peso: {first_layer.weight.shape}")
print(f"Valores no nulos antes del pruning: {(first_layer.weight != 0).sum().item()}")
# Aplicar magnitude pruning al 30% — poner a cero el 30% de los pesos más pequeños
prune.l1_unstructured(
first_layer, # módulo al que aplicar el pruning
name="weight", # qué parámetro podar (weight o bias)
amount=0.30, # fracción de pesos a poner a cero (0.30 = 30%)
)
print(f"Valores no nulos después del pruning: {(first_layer.weight != 0).sum().item()}")
# Calcular sparsidad real
total = first_layer.weight.numel()
zeros = (first_layer.weight == 0).sum().item()
print(f"Sparsidad: {zeros}/{total} = {zeros/total*100:.1f}%")
# Pruning global: poda el X% de pesos más pequeños de TODAS las capas juntas
# (más inteligente que podar cada capa por separado)
parameters_to_prune = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
parameters_to_prune.append((module, "weight"))
print(f"Capas lineales encontradas: {len(parameters_to_prune)}")
# Aplicar pruning global al 30%
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.30,
)
# Calcular sparsidad total del modelo
total_zeros = 0
total_params = 0
for module, name in parameters_to_prune:
zeros = float(torch.sum(module.weight == 0))
total = float(module.weight.nelement())
total_zeros += zeros
total_params += total
print(f"Sparsidad total: {total_zeros/total_params*100:.1f}%")
print(f"Pesos activos: {int(total_params - total_zeros):,} de {int(total_params):,}")
# Después del pruning, los ceros son una "máscara" temporal
# Para hacerlos permanentes (fusionar la máscara con el peso):
for module, name in parameters_to_prune:
prune.remove(module, name) # fusiona weight_mask con weight
print("Pruning permanente aplicado. Los ceros son ahora parte de los pesos.")
Una forma de structured pruning muy específica para Transformers es la poda de cabezas de atención. Recuerda del curso de Transformer que la atención multi-cabeza tiene varias cabezas que operan en paralelo — en BERT-base hay 12 cabezas por capa, en total 144 cabezas en el modelo.
Los investigadores han encontrado que no todas las cabezas son igual de importantes. Algunas cabezas aprenden patrones que otras cabezas ya están capturando — son redundantes. Eliminar esas cabezas redundantes da un modelo más pequeño y más rápido.
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2
)
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
# Ver cuántas cabezas tiene el modelo
config = model.config
print(f"Capas: {config.n_layers}")
print(f"Cabezas por capa: {config.n_heads}")
print(f"Cabezas totales: {config.n_layers * config.n_heads}")
# Calcular la varianza de los pesos de atención para cada cabeza
# (en la práctica se usa con datos reales; aquí usamos los pesos directamente)
head_importance = []
n_heads = config.n_heads
head_dim = config.dim // config.n_heads # dimensión por cabeza: 768/12=64
for layer_idx in range(config.n_layers):
q_weight = model.distilbert.transformer.layer[layer_idx].attention.q_lin.weight
# Dividir el peso en n_heads bloques y calcular varianza de cada uno
q_reshaped = q_weight.view(n_heads, head_dim, -1)
for h in range(n_heads):
variance = q_reshaped[h].var().item()
head_importance.append({
'layer': layer_idx,
'head': h,
'variance': variance
})
# Ordenar por varianza (las de menor varianza son candidatas a eliminar)
head_importance.sort(key=lambda x: x['variance'])
print("\nLas 5 cabezas con MENOR varianza (candidatas a poda):")
for h in head_importance[:5]:
print(f" Capa {h['layer']}, Cabeza {h['head']}: varianza = {h['variance']:.6f}")
print("\nLas 5 cabezas con MAYOR varianza (más importantes):")
for h in head_importance[-5:]:
print(f" Capa {h['layer']}, Cabeza {h['head']}: varianza = {h['variance']:.6f}")
Esto depende del modelo y la tarea, pero hay reglas empíricas bien establecidas:
| Nivel de poda | Pesos eliminados | Impacto en accuracy | Recomendación |
|---|---|---|---|
| 10–30% | 1 de cada 3-10 pesos | Sin pérdida notable | Siempre seguro. Punto de partida. |
| 30–50% | 1 de cada 2-3 pesos | Pérdida muy pequeña (<1%) | Recomendado para producción. |
| 50–70% | La mitad o más | Pérdida pequeña (1-3%) | Válido si la velocidad importa más. |
| 70–90% | La gran mayoría | Pérdida notable (3-10%) | Solo con fine-tuning post-poda. |
| >90% | Casi todo | Pérdida significativa | Requiere lottery ticket + reentrenamiento. |
Genera una capa con pesos aleatorios y aplica magnitude pruning. Observa qué pesos sobreviven y cuántos quedan a cero.
En la próxima lección: Knowledge distillation — un modelo grande (teacher) le enseña a uno pequeño (student) no solo las respuestas correctas, sino también cómo de similares son las respuestas incorrectas.