In NumPy, I would do
a = np.zeros((4, 5, 6))
a = a[:, :, np.newaxis, :]
assert a.shape == (4, 5, 1, 6)
How to do the same in PyTorch?
a = torch.zeros(4, 5, 6)
a = a[:, :, None, :]
assert a.shape == (4, 5, 1, 6)
np.newaxis
is just None
, anyway.
Commented
Dec 27, 2020 at 21:53
a.unsqueeze(2)
is much more effective and to the point.
effective
?
You can add a new axis with torch.unsqueeze()
(first argument being the index of the new axis):
>>> a = torch.zeros(4, 5, 6)
>>> a = a.unsqueeze(2)
>>> a.shape
torch.Size([4, 5, 1, 6])
Or using the in-place version: torch.unsqueeze_()
:
>>> a = torch.zeros(4, 5, 6)
>>> a.unsqueeze_(2)
>>> a.shape
torch.Size([4, 5, 1, 6])
x = torch.tensor([1, 2, 3, 4])
y = torch.unsqueeze(x, 0)
y will be -> tensor([[ 1, 2, 3, 4]])
EDIT: see more details here: https://pytorch.org/docs/stable/generated/torch.unsqueeze.html