Advanced Integer Indexing in PyTorch
We often need to index into tensors at specific indices in a dimension, like the following numpy example:
In [2]: import numpy as np
In [3]: a = np.random.rand(5, 4)
In [4]: a
Out[4]:
array([[0.73716428, 0.75346028, 0.83733549, 0.81293311],
[0.41163123, 0.12830572, 0.12188361, 0.29422126],
[0.69099127, 0.27298692, 0.14673925, 0.46489198],
[0.25129696, 0.08969553, 0.35402478, 0.61180101],
[0.19963404, 0.90344766, 0.12198926, 0.14525652]])
In [5]: idx = np.argmax(a, axis=-1)
In [6]: idx
Out[6]: array([2, 0, 0, 3, 1])
Now, we want element 2 from row 0, 0 from row 1, 0 from row 2, and so on. But simple indexing gives us entire columns:
In [7]: a[:, idx]
Out[7]:
array([[0.83733549, 0.73716428, 0.73716428, 0.81293311, 0.75346028],
[0.12188361, 0.41163123, 0.41163123, 0.29422126, 0.12830572],
[0.14673925, 0.69099127, 0.69099127, 0.46489198, 0.27298692],
[0.35402478, 0.25129696, 0.25129696, 0.61180101, 0.08969553],
[0.12198926, 0.19963404, 0.19963404, 0.14525652, 0.90344766]])
As mentioned in numpy’s advanced indexing documentation,
what we need here is integer indexing (not :
) in all dimensions, instead of just the dimension of interest. “the
shape of the resultant array will be the concatenation of the shape of the index array (or the shape that all the index
arrays were broadcast to) with the shape of any unused dimensions (those not indexed) in the array being indexed”.
In [8]: a[range(len(a)), idx]
Out[8]: array([0.83733549, 0.41163123, 0.69099127, 0.61180101, 0.90344766])
This gets tedious for tensors of more dimensions. Can we make a convenience function for this operation that works
generally in all cases? The crux is that the indices of all the other dimensions are just sweeping through the
dimension sizes, and can be generated by a “meshgrid” operation. Numpy has a special function
ix_()
for this purpose, but PyTorch does not.
So let us make this PyTorch function.
import torch
from typing import List
def index_at_dim(dim_indices: torch.Tensor, dim: int) -> List[torch.Tensor]:
"""Returns a list of indices which can used to index a tensor, given indices DIM_INDICES of a particular dimension DIM
of the tensor.
Args:
DIM_INDICES: the indices of a particular dimension DIM of a tensor
DIM: the dimension at which the original tensor is indexed to produce DIM_INDICES
Returns:
ALL_DIM_INDICES: a list of indices, which can directly be used to index the original tensor.
Notes:
We do not need the original tensor or its full shape for this operation.
Suppose that original_tensor.shape = (B, C, H, W) and dim_indices = original_tensor.argmax(dim=1).
So dim_indices.shape = (B, H, W). The returned ALL_DIM_INDICES will be a list of 4 elements, each of shape (B, H, W).
The element at position DIM=1 in ALL_DIM_INDICES will be the input DIM_INDICES. The rest of the (B, H, W) shaped
elements will be produced by meshgrid.
Finally, original_tensor[ALL_DIM_INDICES] will give the indended output of shape (B, H, W).
In contrast, original_tensor[:, DIM_INDICES, :, :] will give an output of shape (B, B, H, W, H, W) following the
standard broadcasted indexing rule mentioned in numpy advanced indexing documentation at
https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing : "the shape of the resultant array will
be the concatenation of the shape of the index array (or the shape that all the index arrays were broadcast to) with
the shape of any unused dimensions (those not indexed) in the array being indexed".
"""
# perform meshgrid for all other dimensions
all_dim_indices = list(torch.meshgrid([torch.arange(i, device=dim_indices.device) for i in dim_indices.shape]))
# decide where to insert dim_indices in all_dim_indices, handling negative values of dim
insertion_dim = dim if dim >= 0 else len(all_dim_indices) + dim + 1
# insert
all_dim_indices.insert(insertion_dim, dim_indices)
return all_dim_indices
Let us test this function.
In [9]: original_tensor = torch.rand((2, 3, 2, 3))
In [10]: original_tensor
Out[10]:
tensor([[[[0.6992, 0.5410, 0.9593],
[0.5337, 0.1280, 0.1848]],
[[0.9777, 0.3163, 0.6971],
[0.5111, 0.6317, 0.7111]],
[[0.3729, 0.9333, 0.4457],
[0.8135, 0.6239, 0.3684]]],
[[[0.7674, 0.0989, 0.7065],
[0.5977, 0.8354, 0.9586]],
[[0.8313, 0.7643, 0.0018],
[0.1987, 0.0140, 0.9092]],
[[0.3039, 0.8545, 0.1460],
[0.0750, 0.1559, 0.9609]]]])
In [11]: dim_indices = original_tensor.argmax(dim=1)
In [12]: dim_indices
Out[12]:
tensor([[[1, 2, 0],
[2, 1, 1]],
[[1, 2, 0],
[0, 0, 2]]])
In [13]: dim_indices.shape
Out[13]: torch.Size([2, 2, 3])
In [14]: original_tensor[:, dim_indices, :, :].shape
Out[14]: torch.Size([2, 2, 2, 3, 2, 3])
In [15]: all_dim_indices = index_at_dim(dim_indices, dim=1)
/Users/samarth/miniconda3/envs/ovseg/lib/python3.10/site-packages/torch/functional.py:507: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp:3550.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
In [16]: len(all_dim_indices)
Out[16]: 4
In [17]: all_dim_indices[0].shape
Out[17]: torch.Size([2, 2, 3])
In [18]: a[all_dim_indices]
Out[18]:
tensor([[[0.1943, 0.7574, 0.9271],
[0.8292, 0.6917, 0.3716]],
[[0.5395, 0.9838, 0.2469],
[0.1352, 0.1789, 0.4731]]])
In [19]: a[all_dim_indices].shape
Out[19]: torch.Size([2, 2, 3])