View source on GitHub
|
Gather slices from params into a Tensor with shape specified by indices.
tf.gather_nd(
params, indices, batch_dims=0, name=None
)
Used in the notebooks
| Used in the guide | Used in the tutorials |
|---|---|
indices is a Tensor of indices into params. The index vectors are
arranged along the last axis of indices.
This is similar to tf.gather, in which indices defines slices into the
first dimension of params. In tf.gather_nd, indices defines slices into
the first N dimensions of params, where N = indices.shape[-1].
Gathering scalars
In the simplest case the vectors in indices index the full rank of params:
tf.gather_nd(indices=[[0, 0],[1, 1]],params = [['a', 'b'],['c', 'd']]).numpy()array([b'a', b'd'], dtype=object)
In this case the result has 1-axis fewer than indices, and each index vector
is replaced by the scalar indexed from params.
In this case the shape relationship is:
index_depth = indices.shape[-1]
assert index_depth == params.shape.rank
result_shape = indices.shape[:-1]
If indices has a rank of K, it is helpful to think indices as a
(K-1)-dimensional tensor of indices into params.
Gathering slices
If the index vectors do not index the full rank of params then each location
in the result contains a slice of params. This example collects rows from a
matrix:
tf.gather_nd(indices = [[1],[0]],params = [['a', 'b', 'c'],['d', 'e', 'f']]).numpy()array([[b'd', b'e', b'f'],[b'a', b'b', b'c']], dtype=object)
Here indices contains [2] index vectors, each with a length of 1.
The index vectors each refer to rows of the params matrix. Each
row has a shape of [3] so the output shape is [2, 3].
In this case, the relationship between the shapes is:
index_depth = indices.shape[-1]
outer_shape = indices.shape[:-1]
assert index_depth <= params.shape.rank
inner_shape = params.shape[index_depth:]
output_shape = outer_shape + inner_shape
It is helpful to think of the results in this case as tensors-of-tensors.
The shape of the outer tensor is set by the leading dimensions of indices.
While the shape of the inner tensors is the shape of a single slice.
Batches
Additionally, both params and indices can have M leading batch
dimensions that exactly match. In this case batch_dims must be set to M.
For example, to collect one row from each of a batch of matrices you could set the leading elements of the index vectors to be their location in the batch:
tf.gather_nd(indices = [[0, 1],[1, 0],[2, 4],[3, 2],[4, 1]],params=tf.zeros([5, 7, 3])).shape.as_list()[5, 3]
The batch_dims argument lets you omit those leading location dimensions
from the index:
tf.gather_nd(batch_dims=1,indices = [[1],[0],[4],[2],[1]],params=tf.zeros([5, 7, 3])).shape.as_list()[5, 3]
This is equivalent to caling a separate gather_nd for each location in the
batch dimensions.
params=tf.zeros([5, 7, 3])indices=tf.zeros([5, 1])batch_dims = 1index_depth = indices.shape[-1]batch_shape = indices.shape[:batch_dims]assert params.shape[:batch_dims] == batch_shapeouter_shape = indices.shape[batch_dims:-1]assert index_depth <= params.shape.rankinner_shape = params.shape[batch_dims + index_depth:]output_shape = batch_shape + outer_shape + inner_shapeoutput_shape.as_list()[5, 3]
More examples
Indexing into a 3-tensor:
tf.gather_nd(indices = [[1]],params = [[['a0', 'b0'], ['c0', 'd0']],[['a1', 'b1'], ['c1', 'd1']]]).numpy()array([[[b'a1', b'b1'],[b'c1', b'd1']]], dtype=object)
tf.gather_nd(indices = [[0, 1], [1, 0]],params = [[['a0', 'b0'], ['c0', 'd0']],[['a1', 'b1'], ['c1', 'd1']]]).numpy()array([[b'c0', b'd0'],[b'a1', b'b1']], dtype=object)
tf.gather_nd(indices = [[0, 0, 1], [1, 0, 1]],params = [[['a0', 'b0'], ['c0', 'd0']],[['a1', 'b1'], ['c1', 'd1']]]).numpy()array([b'b0', b'b1'], dtype=object)
The examples below are for the case when only indices have leading extra dimensions. If both 'params' and 'indices' have leading batch dimensions, use the 'batch_dims' parameter to run gather_nd in batch mode.
Batched indexing into a matrix:
tf.gather_nd(indices = [[[0, 0]], [[0, 1]]],params = [['a', 'b'], ['c', 'd']]).numpy()array([[b'a'],[b'b']], dtype=object)
Batched slice indexing into a matrix:
tf.gather_nd(indices = [[[1]], [[0]]],params = [['a', 'b'], ['c', 'd']]).numpy()array([[[b'c', b'd']],[[b'a', b'b']]], dtype=object)
Batched indexing into a 3-tensor:
tf.gather_nd(indices = [[[1]], [[0]]],params = [[['a0', 'b0'], ['c0', 'd0']],[['a1', 'b1'], ['c1', 'd1']]]).numpy()array([[[[b'a1', b'b1'],[b'c1', b'd1']]],[[[b'a0', b'b0'],[b'c0', b'd0']]]], dtype=object)
tf.gather_nd(indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]],params = [[['a0', 'b0'], ['c0', 'd0']],[['a1', 'b1'], ['c1', 'd1']]]).numpy()array([[[b'c0', b'd0'],[b'a1', b'b1']],[[b'a0', b'b0'],[b'c1', b'd1']]], dtype=object)
tf.gather_nd(indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]],params = [[['a0', 'b0'], ['c0', 'd0']],[['a1', 'b1'], ['c1', 'd1']]]).numpy()array([[b'b0', b'b1'],[b'd0', b'c1']], dtype=object)
Examples with batched 'params' and 'indices':
tf.gather_nd(batch_dims = 1,indices = [[1],[0]],params = [[['a0', 'b0'],['c0', 'd0']],[['a1', 'b1'],['c1', 'd1']]]).numpy()array([[b'c0', b'd0'],[b'a1', b'b1']], dtype=object)
tf.gather_nd(batch_dims = 1,indices = [[[1]], [[0]]],params = [[['a0', 'b0'], ['c0', 'd0']],[['a1', 'b1'], ['c1', 'd1']]]).numpy()array([[[b'c0', b'd0']],[[b'a1', b'b1']]], dtype=object)
tf.gather_nd(batch_dims = 1,indices = [[[1, 0]], [[0, 1]]],params = [[['a0', 'b0'], ['c0', 'd0']],[['a1', 'b1'], ['c1', 'd1']]]).numpy()array([[b'c0'],[b'b1']], dtype=object)
See also tf.gather.
Returns | |
|---|---|
A Tensor. Has the same type as params.
|
View source on GitHub