参考链接:

tf.gather,tf.range()的详解

tf.gather( )的用法

tf.gather()函数详解

tf.gather_nd和tf.gather的区别与联系

tf.gather()函数

tf.gather
该接口的作用:就是抽取出params的第axis维度上在indices里面所有的index(看后面的例子,就会懂)

tf.gather(
    params,
    indices,
    validate_indices=None,
    name=None,
    axis=0
)

'''
Args:
    params: A Tensor. The tensor from which to gather values. Must be at least rank axis + 1.
    indices: A Tensor. Must be one of the following types: int32, int64. Index tensor. Must be in range [0, params.shape[axis]).
    axis: A Tensor. Must be one of the following types: int32, int64. The axis in params to gather indices from. Defaults to the first dimension. Supports negative indexes.
    name: A name for the operation (optional).
Returns:
    A Tensor. Has the same type as params.
'''

说明
参数

  • params: A Tensor.
  • indices: A Tensor. types必须是: int32, int64. 里面的每一个元素大小必须在 [0, params.shape[axis])范围内.
  • axis: 维度。沿着params的哪一个维度进行抽取indices

返回的是一个tensor
 

下面看一些例子

1、当indices=[0,2],axis=0

input =[ [[[1, 1, 1], [2, 2, 2]],
         [[3, 3, 3], [4, 4, 4]],
         [[5, 5, 5], [6, 6, 6]]],

         [[[7, 7, 7], [8, 8, 8]],
         [[9, 9, 9], [10, 10, 10]],
         [[11, 11, 11], [12, 12, 12]]],

        [[[13, 13, 13], [14, 14, 14]],
         [[15, 15, 15], [16, 16, 16]],
         [[17, 17, 17], [18, 18, 18]]]
         ]

print(tf.shape(input))
with tf.Session() as sess:
    output=tf.gather(input, [0,2],axis=0)#其实默认axis=0
    print(sess.run(output))

输出结果
[[[[ 1 1 1]
[ 2 2 2]]

[[ 3 3 3]
[ 4 4 4]]

[[ 5 5 5]
[ 6 6 6]]]

[[[13 13 13]
[14 14 14]]

[[15 15 15]
[16 16 16]]

[[17 17 17]
[18 18 18]]]]

解释:

右中括号就暂时不理会他先了。
第一个[ 是列表语法需要的括号,剩下的最里面的三个[[[是axis=0需要搜寻的中括号。这里一共有3个[[[。
indices的[0,2]即取第0个[[[和第2个[[[,也就是第0个和第2个三维立体。

2、当indices=[0,2],axis=1

input =[ [[[1, 1, 1], [2, 2, 2]],
         [[3, 3, 3], [4, 4, 4]],
         [[5, 5, 5], [6, 6, 6]]],

         [[[7, 7, 7], [8, 8, 8]],
         [[9, 9, 9], [10, 10, 10]],
         [[11, 11, 11], [12, 12, 12]]],

        [[[13, 13, 13], [14, 14, 14]],
         [[15, 15, 15], [16, 16, 16]],
         [[17, 17, 17], [18, 18, 18]]]
         ]
print(tf.shape(input))
with tf.Session() as sess:
    output=tf.gather(input, [0,2],axis=1)#默认axis=0
    print(sess.run(output))

输出结果
[[[[ 1 1 1]
[ 2 2 2]]

[[ 5 5 5]
[ 6 6 6]]]

[[[ 7 7 7]
[ 8 8 8]]

[[11 11 11]
[12 12 12]]]

[[[13 13 13]
[14 14 14]]

[[17 17 17]
[18 18 18]]]]

解释:

第一个[ 是列表语法需要的括号,先把这个干扰去掉,剩下的所有内侧的 [[ 是axis=1搜索的中括号。
然后[0,2]即再取每个[[[体内的第0个[[和第2个[[,也就是去每个三维体的第0个面和第2个面

3、当indices是多维时,输出形状为“rank(params) + rank(indices) - 1”,即params的维数+indices的维数-1。本人主要做点云,经常出现二维张量中提取张量的操作,这里探究一下。

import tensorflow as tf
input1 =[ [[[1, 1, 1], [2, 2, 2]],
         [[3, 3, 3], [4, 4, 4]],
         [[5, 5, 5], [6, 6, 6]]],

         [[[7, 7, 7], [8, 8, 8]],
         [[9, 9, 9], [10, 10, 10]],
         [[11, 11, 11], [12, 12, 12]]],

        [[[13, 13, 13], [14, 14, 14]],
         [[15, 15, 15], [16, 16, 16]],
         [[17, 17, 17], [18, 18, 18]]]
         ]

print(tf.shape(input1))
input2 = tf.reshape(input1,[-1,3])
print(tf.shape(input2))
index = tf.ones(shape=(2,3,2),dtype=tf.int32)
gather = tf.gather(input2,index)
with tf.Session() as sess:
    output1=tf.gather(input1, [0,2],axis=0)#其实默认axis=0
    print('output1:')
    print(sess.run(output1))
    print('\n')
    print('index:\n')
    print(sess.run(index))
    print('gather:\n')
    print(sess.run(gather))

输出:

Tensor("Shape_27:0", shape=(4,), dtype=int32)
Tensor("Shape_28:0", shape=(2,), dtype=int32)
output1:
[[[[ 1  1  1]
   [ 2  2  2]]

  [[ 3  3  3]
   [ 4  4  4]]

  [[ 5  5  5]
   [ 6  6  6]]]


 [[[13 13 13]
   [14 14 14]]

  [[15 15 15]
   [16 16 16]]

  [[17 17 17]
   [18 18 18]]]]


index:

[[[1 1]
  [1 1]
  [1 1]]

 [[1 1]
  [1 1]
  [1 1]]]
gather:

[[[[2 2 2]
   [2 2 2]]

  [[2 2 2]
   [2 2 2]]

  [[2 2 2]
   [2 2 2]]]


 [[[2 2 2]
   [2 2 2]]

  [[2 2 2]
   [2 2 2]]

  [[2 2 2]
   [2 2 2]]]]

因为 index 的形状是 (2,3,2)的,所以提取出来的张量为:最外层为扩展的维度,去掉最外层的[ ],对应的为index的第一个纬度值。从例子看,为2。现在来看去掉[ ]后的第一个

[[[2 2 2]
[2 2 2]]

[[2 2 2]
[2 2 2]]

[[2 2 2]
[2 2 2]]]

其包括3个

[[2 2 2]
[2 2 2]]

也就是说对应index的第二个维度,这一页张量包含3个二维矩阵。将页这层再去掉[],来看

[[2 2 2]
[2 2 2]]

是个矩阵,对应index的第三个维度。那么两行2 2 2 代表什么呢?接着看index的最里面的值是一行两列 1 1,所以提取了两次张量最里面的第二行.

接着看

index2= [[1, 1],
         [1, 3],
         [5, 1]]
index3 = [1,3]
gather2 = tf.gather(input2,index2)
gather3 = tf.gather(input2,index3)
with tf.Session() as sess:
    print('gather2:\n')
    print(sess.run(gather2))
    print('\n')
    print('gather3:\n')
    print(sess.run(gather3))
    print('\n')
    print('gather4:\n')
    print(sess.run(gather4))

输出:
gather2:

[[[2 2 2]
  [2 2 2]]

 [[2 2 2]
  [4 4 4]]

 [[6 6 6]
  [2 2 2]]]


gather3:

[[2 2 2]
 [4 4 4]]

gather4:

[[2 2 2]
 [4 4 4]
 [7 7 7]]

总结来说,index中的数值控制取哪些行,index 的维度控制生成的形状。最终要取多少行,要依据index最内部的行向量是多少维的。换句话说,index最里层有多少个元素,提取出来的张量最里面的二维矩阵就有多少行。更直白地说,index最内部的行索引变为列索引,然后索引号替换为对应的input2中的行就好了(对二维张量的操作,其他的请自行探究)。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐