0

I am iterating over training samples in batches, however last batch always returns fewer samples.

Is it possible to specify step size in torch according to the current batch length?

For example most batch are of size 64, last batch only 6 samples.

If I do the usual routine:

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

It seems that the last 6 samples carry the same weight when updating the gradients as the 64 sized batches, but in fact they should only carry about 1/10 weight due to fewer samples.

In Mxnet I could specify the step size accordingly but I don't know how to do it in torch.

1
  • Other options are to temporarily reduce the optimizers learning rate, specify sum reduction on the loss if it supports it (equivalent to lejlot's answer) or specify drop_last=True when initializing the DataLoader.
    – jodag
    Commented Aug 21, 2022 at 15:21

2 Answers 2

2

You can define a custom loss function and then e.g. reweight it based on batch size

def reweighted_cross_entropy(my_outputs, my_labels):
    # compute batch size
    my_batch_size = my_outputs.size()[0] 

    original_loss = nn.CrossEntropyLoss()
    loss = original_loss (my_outputs, my_labels)

    # reweight accordingly
    return my_batch_size * loss

if you are using something like gradient descent then it is easy to see that

[1/10 * lr] grad [loss] = lr * grad [ 1/10 loss]

so reweighting the loss will be equivalent to reweighting your learning rate. This won't be exactly true for more comlpex optimisers though but can be good enough in practise.

1
  • This might to the trick though it doesn't quite look pretty. I may want to check Mxnet source files to see how it is implemented there, especially when combined with something like AdamW that uses momentums. Nevertheless, when using batch norm we will probably have to discard the short batches anyway.
    – Anonymous
    Commented Aug 21, 2022 at 13:38
0

I suggest just ignore the last batch. Pytorch Dataloader has parameter to implement that behavior:

drop_last = True #(False by default)

Not the answer you're looking for? Browse other questions tagged or ask your own question.