Con datos reales no entrenas con todo a la vez (no cabría en memoria, y entrena peor). Los partes en lotes (batches) y los recorres mezclados. PyTorch tiene dos herramientas justo para esto.
Entrenar con todos los datos en cada paso tiene dos problemas:
Un Dataset es un objeto que sabe cuántos ejemplos hay y cómo entregar el ejemplo
número i. Para datos que ya están en tensores, TensorDataset los une:
from torch.utils.data import TensorDataset, DataLoader
X = torch.arange(20).float().reshape(10, 2) # 10 ejemplos, 2 features
Y = torch.arange(10).float().reshape(10, 1) # 10 etiquetas
ds = TensorDataset(X, Y)
print(len(ds)) # cuántos ejemplos
print(ds[0]) # el primer (entrada, etiqueta)
ds[0] devuelve una tupla (x, y): la entrada y su etiqueta. Para datos
propios (CSV, imágenes…) se crea un Dataset personalizado, pero la idea es la misma.
El DataLoader envuelve un Dataset y lo sirve en lotes, opcionalmente
barajado. Es un iterable que recorres con un for:
dl = DataLoader(ds, batch_size=4, shuffle=False)
for i, (xb, yb) in enumerate(dl):
print(i, xb.shape, yb.shape)
Con 10 ejemplos y batch_size=4 salen 3 lotes: 4 + 4 + 2 (el último
lleva lo que sobra). Cada xb es un mini-tensor de 4 ejemplos listo para
net(xb).
shuffle=True en el DataLoader de entrenamiento (baraja cada
época, mejora el aprendizaje). En el de validación/prueba usa
shuffle=False (no hace falta y así los resultados son reproducibles).
El bucle de la lección 7 ahora tiene un for interno que recorre los lotes:
for epoca in range(num_epocas):
for xb, yb in dl: # ← un paso por LOTE
opt.zero_grad()
pred = net(xb)
loss = loss_fn(pred, yb)
loss.backward()
opt.step()
Dataset empaqueta (X, Y) y sabe entregar el ejemplo i;
DataLoader los sirve en lotes barajados. El entrenamiento pasa a ser
un bucle de épocas, y dentro un bucle de lotes. Con esto ya tienes todas las piezas: en la próxima
lección montamos un entrenamiento completo de verdad, con validación y guardado
del modelo.