El truco que usa casi todo el mundo en la práctica: en vez de entrenar desde cero (caro y necesita muchísimos datos), partes de una red ya entrenada y la adaptas a tu problema.
Entrenar una red potente desde cero requiere millones de ejemplos y mucho cómputo. Pero ya existen redes entrenadas sobre datasets enormes (ImageNet, corpus de texto…) que ya saben extraer características útiles. El transfer learning reutiliza ese conocimiento:
"Congelar" = poner requires_grad = False para que el optimizador no los toque.
Simulemos un "backbone" preentrenado y congelémoslo:
backbone = nn.Sequential(nn.Linear(10,8), nn.ReLU(), nn.Linear(8,5))
for p in backbone.parameters():
p.requires_grad = False # congelar TODO
print(sum(p.requires_grad for p in backbone.parameters()))
Cero parámetros entrenables: la red está "congelada", conserva lo que ya sabía.
La última capa de la red original predice las clases de su problema. La cambias por una nueva, adaptada a tu tarea. La capa nueva nace descongelada (entrenable):
backbone[2] = nn.Linear(8, 2) # nueva cabeza: 2 clases (tu problema)
entrenables = sum(p.numel() for p in backbone.parameters() if p.requires_grad)
print(entrenables)
Ahora solo son entrenables los 18 pesos de la cabeza nueva (8×2 + 2 bias). El resto de la red aporta sus características ya aprendidas, sin cambiar. Entrenas muy poco y muy rápido.
En la práctica usarías un modelo preentrenado de verdad. El patrón es idéntico al de arriba:
from torchvision import models
modelo = models.resnet18(weights='DEFAULT') # red entrenada en ImageNet
for p in modelo.parameters():
p.requires_grad = False # congelar el backbone
modelo.fc = nn.Linear(modelo.fc.in_features, 2) # nueva cabeza: 2 clases
# entrenas normal: solo se ajustará modelo.fc
torchvision (pip install torchvision), que no
está en este entorno. Pero el mecanismo — congelar + reemplazar la cabeza — es
exactamente el que ejecutamos arriba con resultados reales.