1

In pytorch, DataLoader will split a dataset into batches of set size with additional options of shuffling etc, which one can then loop over.

But if I need the batch size to increment, such as first 10 batch of size 50, next 5 batch of size 100 and so on, what's the best way of doing so?

I tried splitting the tensor then concat them:

#10x50 + 5*100
originalTensor = torch.randn(1000, 80)
split1=torch.split(originalTensor, 500, dim=0)
split2=torch.split(list(split1)[0], 100, dim=0)

Thereafter is there a way to pass the concatenated tensor into dataLoader or any other way to directly turn the concat tensor into a generator (which might lose shuffling and other functionalities)?

1 Answer 1

1

I think you can do that by simply providing a non-default batch_sampler to your DataLoader.
For instance:

class VaryingSizeBatchSampler(Sampler):
    r"""Wraps another sampler to yield a varying-size mini-batch of indices.

    Args:
        sampler (Sampler): Base sampler.
        batch_size_fn (function): Size of current mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``
    """

    def __init__(self, sampler, batch_size_fn, drop_last):
        if not isinstance(sampler, Sampler):
            raise ValueError("sampler should be an instance of "
                             "torch.utils.data.Sampler, but got sampler={}"
                             .format(sampler))
        self.sampler = sampler
        self.batch_size_fn = batch_size_fn
        self.drop_last = drop_last
        self.batch_counter = 0

    def __iter__(self):
        batch = []
        cur_batch_size = self.batch_size_fn(self.batch_counter)  # get current batch size
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == cur_batch_size:
                yield batch
                self.batch_counter += 1
                cur_batch_size = self.batch_size_fn(self.batch_counter)  # get current batch size                
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch            

    def __len__(self):
        raise NotImplementedError('You need to implement it yourself!')
1
  • 1
    that's really helpful! thanks for providing substantial details
    – santoku
    Commented Sep 22, 2019 at 16:53

Not the answer you're looking for? Browse other questions tagged or ask your own question.