To include batch size in PyTorch basic examples, the easiest and cleanest way is to use PyTorch torch.utils.data.DataLoader
and torch.utils.data.TensorDataset
.
Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.
DataLoader
will take care of creating batches for you.
Building on your question, there is a complete code snippet, where we iterate over a dataset of 10000 examples for 2 epochs with a batch size of 64:
import torch
from torch.utils.data import DataLoader, TensorDataset
# Create the dataset with N_SAMPLES samples
N_SAMPLES, D_in, H, D_out = 10000, 1000, 100, 10
x = torch.randn(N_SAMPLES, D_in)
y = torch.randn(N_SAMPLES, D_out)
# Define the batch size and the number of epochs
BATCH_SIZE = 64
N_EPOCHS = 2
# Use torch.utils.data to create a DataLoader
# that will take care of creating batches
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# Define model, loss and optimizer
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.ReLU(),
torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# Get the dataset size for printing (it is equal to N_SAMPLES)
dataset_size = len(dataloader.dataset)
# Loop over epochs
for epoch in range(N_EPOCHS):
print(f"Epoch {epoch + 1}\n-------------------------------")
# Loop over batches in an epoch using DataLoader
for id_batch, (x_batch, y_batch) in enumerate(dataloader):
y_batch_pred = model(x_batch)
loss = loss_fn(y_batch_pred, y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Every 100 batches, print the loss for this batch
# as well as the number of examples processed so far
if id_batch % 100 == 0:
loss, current = loss.item(), (id_batch + 1)* len(x_batch)
print(f"loss: {loss:>7f} [{current:>5d}/{dataset_size:>5d}]")
The output should be something like:
Epoch 1
-------------------------------
loss: 643.433716 [ 64/10000]
loss: 648.195435 [ 6464/10000]
Epoch 2
-------------------------------
loss: 613.619873 [ 64/10000]
loss: 625.018555 [ 6464/10000]