2

I'm new to AI and python and I'm trying to run only one batch to aim to overfit.I found the code: iter(train_loader).next()

but I'm not sure where to implement it in my code. even if I did, how can I check after each iteration to make sure that I'm training the same batch?

train_loader = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=48,
    shuffle=True,
    num_workers=2
)

net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(128*128*3,10)
)


nepochs = 3
statsrec = np.zeros((3,nepochs))

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)


for epoch in range(nepochs):  # loop over the dataset multiple times

    running_loss = 0.0
    n = 0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        
         # Zero the parameter gradients
        optimizer.zero_grad() 

        # Forward, backward, and update parameters
        outputs = net(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
    
        # accumulate loss
        running_loss += loss.item()
        n += 1
    
    ltrn = running_loss/n
    ltst, atst = stats(train_loader, net)
    statsrec[:,epoch] = (ltrn, ltst, atst)
    print(f"epoch: {epoch} training loss: {ltrn: .3f}  test loss: {ltst: .3f} test accuracy: {atst: .1%}")

please give me a hint

1 Answer 1

3

If you are looking to train on a single batch, then remove your loop over your dataloader:

for i, data in enumerate(train_loader, 0):
    inputs, labels = data

And simply get the first element of the train_loader iterator before looping over the epochs, otherwise next will be called at every iteration and you will run on a different batch every epoch:

inputs, labels = next(iter(train_loader))
i = 0
for epoch in range(nepochs):
    optimizer.zero_grad() 
    outputs = net(inputs)
    loss = loss_fn(outputs, labels)
    loss.backward()
    optimizer.step()
    # ...
3
  • thanks!!. how can I check that I'm running the same batch? is there a print function I can use?
    – ShB
    Commented Feb 22, 2021 at 21:50
  • 1
    inputs might be too cumbersome to print, but you could look at labels, with print(labels).
    – Ivan
    Commented Feb 22, 2021 at 22:02
  • Tanks for your help!
    – ShB
    Commented Feb 22, 2021 at 22:35

Not the answer you're looking for? Browse other questions tagged or ask your own question.