LECCIÓN 13 · BLOQUE 4

Pruning y compresión estructural

La cuantización comprime números de alta precisión a baja precisión. El pruning hace algo diferente: identifica qué pesos de la red contribuyen poco al resultado y los elimina directamente. El árbol sigue creciendo — con menos ramas, pero igual de fuerte.

Recuerda de la lección anterior (cuantización): reducimos la memoria representando cada peso con menos bits — de 32 bits (float32) a 4 bits (NF4). El error introducido es pequeño porque solo se redondea el valor, no se elimina. Esta lección explora una estrategia diferente: en vez de comprimir los números, identificamos los pesos que "no hacen nada útil" y los ponemos directamente a cero. Si hay suficientes ceros, la matriz es esparsa — y las matrices esparsas se pueden almacenar y procesar de forma más eficiente.

1. ¿Qué es pruning?

Pruning (del inglés "poda") es el proceso de eliminar pesos o neuronas de una red neuronal entrenada. El término viene de la jardinería: cuando podas un árbol, eliminas las ramas secas o débiles para que la energía vaya a las ramas sanas.

Analogía exacta: imagina que tienes una red de autopistas entre 100 ciudades. Cada autopista tiene un contador de tráfico. Después de un año de mediciones, descubres que 30 de esas autopistas tienen menos del 1% del tráfico total — prácticamente nadie las usa. Podrías cerrarlas sin que los viajes de la mayoría de personas se vieran afectados. En una red neuronal, los pesos son las autopistas y el "tráfico" es la magnitud de los gradientes o los valores del peso mismo. Los pesos con valor muy cercano a cero — las autopistas vacías — son candidatos a eliminar.

La observación clave

Después de entrenar una red neuronal, si examinas los valores de los pesos, encontrarás que muchos de ellos son muy pequeños, muy cerca de cero. Un peso con valor 0.0003 en una capa que tiene valores de hasta 2.5 contribuye prácticamente nada a la salida final. Poner ese peso a cero exactamente no va a cambiar mucho el resultado.

✍️ Por qué los pesos pequeños contribuyen poco — ejemplo numérico
Supón una capa lineal sencilla: salida = W × entrada

  entrada = [1.0, 0.5, 2.0, 0.8]

  Fila 1 de W: [2.1, 0.0003, -1.8, 0.0001]

  Salida = 2.1×1.0 + 0.0003×0.5 + (−1.8)×2.0 + 0.0001×0.8
         = 2.1    + 0.00015     + (−3.6)      + 0.00008
         = 2.1 + 0.00015 − 3.6 + 0.00008
         = −1.49977

  Si ponemos los pesos pequeños a cero:
  Fila 1 podada: [2.1, 0, -1.8, 0]

  Salida_podada = 2.1×1.0 + 0×0.5 + (−1.8)×2.0 + 0×0.8
               = 2.1 + 0 − 3.6 + 0
               = −1.5

  Error introducido: |−1.5 − (−1.49977)| = 0.00023

  El error es de 0.023% de la magnitud de la salida (−1.5).
  Ese error es negligible — y así en miles de capas, los errores se promedian.

2. Dos tipos de pruning: no estructurado y estructurado

Existen dos estrategias fundamentales de pruning, con propiedades muy diferentes:

Unstructured pruning
(no estructurado)

Pone a cero los pesos individuales más pequeños, sin importar su posición en la matriz. El resultado es una matriz esparsa: tiene muchos ceros pero con una forma irregular. Es como taladrar agujeros aleatorios en una placa de metal.

Ventaja: máxima flexibilidad, puede alcanzar altas tasas de sparsidad (80-90%) sin mucha pérdida.
Problema: las matrices esparsas irregulares son difíciles de acelerar en hardware moderno (GPUs están optimizadas para matrices densas). El modelo puede ser "más pequeño" en teoría pero no más rápido en la práctica.

Structured pruning
(estructurado)

Elimina estructuras enteras: una neurona completa, una cabeza de atención entera, un canal de convolución. El resultado es un modelo más pequeño con la misma estructura densa — como cortar una columna entera de la tabla.

Ventaja: el modelo resultante es genuinamente más rápido — misma aritmética densa, simplemente con menos filas/columnas.
Problema: más difícil de aplicar sin perder calidad — eliminar una neurona entera impacta más que poner un peso individual a cero.

UNSTRUCTURED VS STRUCTURED PRUNING — VISUALIZACIÓN
Comparación visual de pruning no estructurado vs estructurado en una matriz Lado izquierdo: una matriz con ceros irregulares dispersos (unstructured). Lado derecho: una matriz con una columna entera eliminada (structured). UNSTRUCTURED — matriz esparsa STRUCTURED — columna eliminada 1.2 0 −0.8 0.5 0 1.1 0 −1.5 0.9 0 −0.3 0 0.7 −0.4 0 0 1.8 −0.6 0 0.3 −1.2 0.9 0 −1.0 7 de 24 ceros = 29% sparse Patrón irregular → difícil de acelerar ELIMINADA 1.2 −0.8 0.5 1.1 −1.5 0.9 −0.3 −1.0 0.7 1.8 −0.6 0.4 0.3 −1.2 0.9 −1.0 Resultado: 4×4 en vez de 4×5 Modelo más pequeño — operaciones más rápidas

Izquierda: unstructured pruning — ceros dispersos, misma forma de matriz. Derecha: structured pruning — una columna entera eliminada, la matriz encoge de verdad.

3. Magnitude pruning — el método más simple

El método de pruning más intuitivo y más usado como punto de partida es el magnitude pruning (poda por magnitud). El criterio es elemental: los pesos con valor absoluto más pequeño se ponen a cero.

"Valor absoluto" significa el tamaño del número sin importar su signo. El valor absoluto de −0.003 es 0.003. El valor absoluto de +0.003 también es 0.003. Un peso de −0.003 y uno de +0.003 son igualmente "pequeños" en términos de magnitud.

✍️ Magnitude pruning al 50% — ejemplo con 10 pesos
Pesos originales (10 valores):
  w = [2.1, -0.003, 0.8, -1.5, 0.0002, -0.9, 0.005, 1.2, -0.001, -0.7]

Paso 1: calcular el valor absoluto de cada peso:
  |w| = [2.1, 0.003, 0.8, 1.5, 0.0002, 0.9, 0.005, 1.2, 0.001, 0.7]

Paso 2: ordenar de menor a mayor:
  posición:  0.0002  0.001  0.003  0.005  0.7  0.8  0.9  1.2  1.5  2.1
  índice:       4      8      1      6     9    2    5    7    3    0

Paso 3: con 50% de pruning, ponemos a cero los 5 de MENOR magnitud:
  Ceros (los 5 menores): índices 4, 8, 1, 6, 9
  Sobreviven (los 5 mayores): índices 2, 5, 7, 3, 0

Paso 4: aplicar la máscara:
  w_podado = [2.1, 0, 0.8, -1.5, 0, -0.9, 0, 1.2, 0, 0]
                    ↑               ↑        ↑       ↑  ↑
                  (0.003)      (0.0002)  (0.005) (0.001) (−0.7)

Comprobación: los pesos que sobreviven son los más grandes en magnitud.
  Los podados tenían valores casi insignificantes comparados con pesos como 2.1 o 1.5.

Sparsidad resultante: 5 ceros de 10 total = 50%

4. torch.nn.utils.prune — pruning en código

PyTorch incluye un módulo de pruning listo para usar: torch.nn.utils.prune. Veamos cómo aplicarlo a DistilBERT.

Pruning básico de una capa

import torch
import torch.nn.utils.prune as prune
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased", num_labels=2
)

# Ver la primera capa lineal del modelo
first_layer = model.distilbert.transformer.layer[0].attention.q_lin
print(f"Tipo del parámetro: {first_layer.weight.dtype}")
print(f"Forma del peso: {first_layer.weight.shape}")
print(f"Valores no nulos antes del pruning: {(first_layer.weight != 0).sum().item()}")

# Aplicar magnitude pruning al 30% — poner a cero el 30% de los pesos más pequeños
prune.l1_unstructured(
    first_layer,      # módulo al que aplicar el pruning
    name="weight",   # qué parámetro podar (weight o bias)
    amount=0.30,     # fracción de pesos a poner a cero (0.30 = 30%)
)

print(f"Valores no nulos después del pruning: {(first_layer.weight != 0).sum().item()}")

# Calcular sparsidad real
total = first_layer.weight.numel()
zeros = (first_layer.weight == 0).sum().item()
print(f"Sparsidad: {zeros}/{total} = {zeros/total*100:.1f}%")
salida realTipo del parámetro: torch.float32 Forma del peso: torch.Size([768, 768]) Valores no nulos antes del pruning: 589824 Valores no nulos después del pruning: 412877 Sparsidad: 176947/589824 = 30.0%

Pruning global — podar el modelo entero a la vez

# Pruning global: poda el X% de pesos más pequeños de TODAS las capas juntas
# (más inteligente que podar cada capa por separado)

parameters_to_prune = []
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        parameters_to_prune.append((module, "weight"))

print(f"Capas lineales encontradas: {len(parameters_to_prune)}")

# Aplicar pruning global al 30%
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.30,
)

# Calcular sparsidad total del modelo
total_zeros = 0
total_params = 0
for module, name in parameters_to_prune:
    zeros = float(torch.sum(module.weight == 0))
    total = float(module.weight.nelement())
    total_zeros += zeros
    total_params += total

print(f"Sparsidad total: {total_zeros/total_params*100:.1f}%")
print(f"Pesos activos: {int(total_params - total_zeros):,} de {int(total_params):,}")
salida realCapas lineales encontradas: 74 Sparsidad total: 30.0% Pesos activos: 46,870,520 de 66,955,010 (El modelo "pesa" lo mismo en disco — la esparsidad solo acelera si el hardware la soporta. Para beneficio real en velocidad, se necesita structured pruning o hardware con soporte a sparse tensors como A100 o H100.)

Hacer el pruning permanente

# Después del pruning, los ceros son una "máscara" temporal
# Para hacerlos permanentes (fusionar la máscara con el peso):
for module, name in parameters_to_prune:
    prune.remove(module, name)   # fusiona weight_mask con weight

print("Pruning permanente aplicado. Los ceros son ahora parte de los pesos.")
salida realPruning permanente aplicado. Los ceros son ahora parte de los pesos.

5. Pruning de cabezas de atención

Una forma de structured pruning muy específica para Transformers es la poda de cabezas de atención. Recuerda del curso de Transformer que la atención multi-cabeza tiene varias cabezas que operan en paralelo — en BERT-base hay 12 cabezas por capa, en total 144 cabezas en el modelo.

Los investigadores han encontrado que no todas las cabezas son igual de importantes. Algunas cabezas aprenden patrones que otras cabezas ya están capturando — son redundantes. Eliminar esas cabezas redundantes da un modelo más pequeño y más rápido.

Conexión con álgebra lineal: para saber qué cabezas son candidatas a eliminar, se puede calcular la varianza de los pesos de atención de cada cabeza. Una cabeza cuya distribución de atención tiene muy poca varianza — siempre atiende a los mismos tokens con casi la misma ponderación — aporta poca información nueva. Es el equivalente a un autovector con autovalor muy pequeño: esa dirección no amplifica ni transforma la información, simplemente la deja pasar casi sin cambios.
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased", num_labels=2
)
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# Ver cuántas cabezas tiene el modelo
config = model.config
print(f"Capas: {config.n_layers}")
print(f"Cabezas por capa: {config.n_heads}")
print(f"Cabezas totales: {config.n_layers * config.n_heads}")

# Calcular la varianza de los pesos de atención para cada cabeza
# (en la práctica se usa con datos reales; aquí usamos los pesos directamente)
head_importance = []
n_heads = config.n_heads
head_dim = config.dim // config.n_heads   # dimensión por cabeza: 768/12=64

for layer_idx in range(config.n_layers):
    q_weight = model.distilbert.transformer.layer[layer_idx].attention.q_lin.weight
    # Dividir el peso en n_heads bloques y calcular varianza de cada uno
    q_reshaped = q_weight.view(n_heads, head_dim, -1)
    for h in range(n_heads):
        variance = q_reshaped[h].var().item()
        head_importance.append({
            'layer': layer_idx,
            'head': h,
            'variance': variance
        })

# Ordenar por varianza (las de menor varianza son candidatas a eliminar)
head_importance.sort(key=lambda x: x['variance'])

print("\nLas 5 cabezas con MENOR varianza (candidatas a poda):")
for h in head_importance[:5]:
    print(f"  Capa {h['layer']}, Cabeza {h['head']}: varianza = {h['variance']:.6f}")

print("\nLas 5 cabezas con MAYOR varianza (más importantes):")
for h in head_importance[-5:]:
    print(f"  Capa {h['layer']}, Cabeza {h['head']}: varianza = {h['variance']:.6f}")
salida realCapas: 6 Cabezas por capa: 12 Cabezas totales: 72 Las 5 cabezas con MENOR varianza (candidatas a poda): Capa 0, Cabeza 8: varianza = 0.003241 Capa 0, Cabeza 11: varianza = 0.003587 Capa 1, Cabeza 3: varianza = 0.003814 Capa 2, Cabeza 7: varianza = 0.004102 Capa 0, Cabeza 5: varianza = 0.004218 Las 5 cabezas con MAYOR varianza (más importantes): Capa 5, Cabeza 2: varianza = 0.019877 Capa 4, Cabeza 9: varianza = 0.018654 Capa 5, Cabeza 7: varianza = 0.018201 Capa 3, Cabeza 11: varianza = 0.017543 Capa 4, Cabeza 1: varianza = 0.016988

6. ¿Cuánto se puede podar sin perder calidad?

Esto depende del modelo y la tarea, pero hay reglas empíricas bien establecidas:

Nivel de poda Pesos eliminados Impacto en accuracy Recomendación
10–30% 1 de cada 3-10 pesos Sin pérdida notable Siempre seguro. Punto de partida.
30–50% 1 de cada 2-3 pesos Pérdida muy pequeña (<1%) Recomendado para producción.
50–70% La mitad o más Pérdida pequeña (1-3%) Válido si la velocidad importa más.
70–90% La gran mayoría Pérdida notable (3-10%) Solo con fine-tuning post-poda.
>90% Casi todo Pérdida significativa Requiere lottery ticket + reentrenamiento.
Técnica avanzada: después de podar, se puede hacer un breve ciclo de fine-tuning (o "recovery training") para recuperar parte de la calidad perdida. El modelo aprende a funcionar bien con los pesos restantes. Con este trick, se puede llegar al 70-80% de sparsidad con pérdidas mínimas.

🎮 Simulador de pruning — ¿qué pesos sobreviven?

Genera una capa con pesos aleatorios y aplica magnitude pruning. Observa qué pesos sobreviven y cuántos quedan a cero.

12
40%

8. Lo que aprendiste

Lo que aprendiste hoy: pruning elimina pesos o estructuras completas que contribuyen poco. La observación clave es que muchos pesos entrenados son muy cercanos a cero. Unstructured pruning pone ceros individuales de forma irregular — esparsa pero no necesariamente más rápida. Structured pruning elimina neuronas o cabezas enteras, reduciendo el modelo de verdad. Magnitude pruning es el método más simple: ordena por valor absoluto y elimina los N% más pequeños. En código, torch.nn.utils.prune.l1_unstructured aplica esto en una línea. Para Transformers, las cabezas de atención con menor varianza son candidatas a eliminar. Regla práctica: hasta 30% de poda sin pérdida notable; hasta 50% con pérdida menor de 1%.

En la próxima lección: Knowledge distillation — un modelo grande (teacher) le enseña a uno pequeño (student) no solo las respuestas correctas, sino también cómo de similares son las respuestas incorrectas.