Hiciste backprop a mano en la lección 7. Ahora deja que PyTorch lo haga por ti — y comprueba con tus propios ojos que calcula EXACTAMENTE los mismos números, automáticamente.
Hasta ahora calculamos cada derivada a mano. En la práctica nadie hace eso: lo hace
autograd, el motor de diferenciación automática de PyTorch. Tú escribes solo el
forward; PyTorch construye un grafo de las operaciones y, al llamar
loss.backward(), recorre ese grafo hacia atrás aplicando la regla de la cadena —
exactamente lo que hiciste en la lección 7, pero solo.
requires_grad=True le dice a PyTorch "vigila este valor,
quiero su gradiente". Tras loss.backward(), cada tensor tendrá su gradiente en
.grad. Cero cálculo manual.
Usamos los mismos pesos canónicos y la misma entrada x = (1, 0):
import torch # Pesos canonicos, marcados como entrenables W1 = torch.tensor([[0.20,-0.30],[0.40,0.10]], requires_grad=True) b1 = torch.tensor([-0.10, 0.20], requires_grad=True) W2 = torch.tensor([[0.50],[-0.40]], requires_grad=True) b2 = torch.tensor([0.10], requires_grad=True) x = torch.tensor([1., 0.]) t = torch.tensor([1.]) # Forward (idéntico a la lección 4) a1 = torch.sigmoid(x @ W1 + b1) # capa oculta a2 = torch.sigmoid(a1 @ W2 + b2) # salida loss = 0.5 * (a2 - t)**2 # MSE ½
loss.backward() # autograd calcula TODOS los gradientes print(W2.grad) # gradiente de los pesos de salida print(b2.grad) print(W1.grad) # gradiente de los pesos de entrada print(b1.grad)
PyTorch imprime:
Comparemos, lado a lado, lo que calculaste a mano (lección 7) con lo que dio autograd:
| Gradiente | A mano (lección 7) | autograd (PyTorch) |
|---|---|---|
| dL/dW₂ | [−0.059533, −0.053868] | [−0.059533, −0.053868] |
| dL/db₂ | −0.113401 | −0.113401 |
| dL/dW₁ (de x₁) | [−0.014140, 0.011312] | [−0.014140, 0.011312] |
| dL/db₁ | [−0.014140, 0.011312] | [−0.014140, 0.011312] |
# Verificación automática assert torch.allclose(W2.grad.flatten(), torch.tensor([-0.059533, -0.053868]), atol=1e-5) # ✅ pasa — autograd == tu backprop a mano
En la última lección daremos el paso final: construir la red con las
herramientas de alto nivel de PyTorch (nn.Module, nn.Linear, un
optimizador) y entrenar XOR hasta que lo resuelva de verdad — en pocas líneas.