tf.gather用法和理解
参考链接:https://www.jianshu.com/p/cfaa828e389c:https://blog.csdn.net/qq_32806793/article/details/85324531tf.gather()函数tf.gather该接口的作用:就是抽取出params的第axis维度上在indices里面所有的index(看后面的例子,就...
参考链接:
tf.gather()函数
tf.gather
该接口的作用:就是抽取出params的第axis维度上在indices里面所有的index(看后面的例子,就会懂)
tf.gather(
params,
indices,
validate_indices=None,
name=None,
axis=0
)
'''
Args:
params: A Tensor. The tensor from which to gather values. Must be at least rank axis + 1.
indices: A Tensor. Must be one of the following types: int32, int64. Index tensor. Must be in range [0, params.shape[axis]).
axis: A Tensor. Must be one of the following types: int32, int64. The axis in params to gather indices from. Defaults to the first dimension. Supports negative indexes.
name: A name for the operation (optional).
Returns:
A Tensor. Has the same type as params.
'''
说明
参数
- params: A Tensor.
- indices: A Tensor. types必须是: int32, int64. 里面的每一个元素大小必须在 [0, params.shape[axis])范围内.
- axis: 维度。沿着params的哪一个维度进行抽取indices
返回的是一个tensor
下面看一些例子
1、当indices=[0,2],axis=0
input =[ [[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]],
[[[7, 7, 7], [8, 8, 8]],
[[9, 9, 9], [10, 10, 10]],
[[11, 11, 11], [12, 12, 12]]],
[[[13, 13, 13], [14, 14, 14]],
[[15, 15, 15], [16, 16, 16]],
[[17, 17, 17], [18, 18, 18]]]
]
print(tf.shape(input))
with tf.Session() as sess:
output=tf.gather(input, [0,2],axis=0)#其实默认axis=0
print(sess.run(output))
输出结果
[[[[ 1 1 1]
[ 2 2 2]]
[[ 3 3 3]
[ 4 4 4]]
[[ 5 5 5]
[ 6 6 6]]]
[[[13 13 13]
[14 14 14]]
[[15 15 15]
[16 16 16]]
[[17 17 17]
[18 18 18]]]]
解释:
右中括号就暂时不理会他先了。
第一个[ 是列表语法需要的括号,剩下的最里面的三个[[[是axis=0需要搜寻的中括号。这里一共有3个[[[。
indices的[0,2]即取第0个[[[和第2个[[[,也就是第0个和第2个三维立体。
2、当indices=[0,2],axis=1
input =[ [[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]],
[[[7, 7, 7], [8, 8, 8]],
[[9, 9, 9], [10, 10, 10]],
[[11, 11, 11], [12, 12, 12]]],
[[[13, 13, 13], [14, 14, 14]],
[[15, 15, 15], [16, 16, 16]],
[[17, 17, 17], [18, 18, 18]]]
]
print(tf.shape(input))
with tf.Session() as sess:
output=tf.gather(input, [0,2],axis=1)#默认axis=0
print(sess.run(output))
输出结果
[[[[ 1 1 1]
[ 2 2 2]]
[[ 5 5 5]
[ 6 6 6]]]
[[[ 7 7 7]
[ 8 8 8]]
[[11 11 11]
[12 12 12]]]
[[[13 13 13]
[14 14 14]]
[[17 17 17]
[18 18 18]]]]
解释:
第一个[ 是列表语法需要的括号,先把这个干扰去掉,剩下的所有内侧的 [[ 是axis=1搜索的中括号。
然后[0,2]即再取每个[[[体内的第0个[[和第2个[[,也就是去每个三维体的第0个面和第2个面
3、当indices是多维时,输出形状为“rank(params) + rank(indices) - 1”,即params的维数+indices的维数-1。本人主要做点云,经常出现二维张量中提取张量的操作,这里探究一下。
import tensorflow as tf
input1 =[ [[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]],
[[[7, 7, 7], [8, 8, 8]],
[[9, 9, 9], [10, 10, 10]],
[[11, 11, 11], [12, 12, 12]]],
[[[13, 13, 13], [14, 14, 14]],
[[15, 15, 15], [16, 16, 16]],
[[17, 17, 17], [18, 18, 18]]]
]
print(tf.shape(input1))
input2 = tf.reshape(input1,[-1,3])
print(tf.shape(input2))
index = tf.ones(shape=(2,3,2),dtype=tf.int32)
gather = tf.gather(input2,index)
with tf.Session() as sess:
output1=tf.gather(input1, [0,2],axis=0)#其实默认axis=0
print('output1:')
print(sess.run(output1))
print('\n')
print('index:\n')
print(sess.run(index))
print('gather:\n')
print(sess.run(gather))
输出:
Tensor("Shape_27:0", shape=(4,), dtype=int32)
Tensor("Shape_28:0", shape=(2,), dtype=int32)
output1:
[[[[ 1 1 1]
[ 2 2 2]]
[[ 3 3 3]
[ 4 4 4]]
[[ 5 5 5]
[ 6 6 6]]]
[[[13 13 13]
[14 14 14]]
[[15 15 15]
[16 16 16]]
[[17 17 17]
[18 18 18]]]]
index:
[[[1 1]
[1 1]
[1 1]]
[[1 1]
[1 1]
[1 1]]]
gather:
[[[[2 2 2]
[2 2 2]]
[[2 2 2]
[2 2 2]]
[[2 2 2]
[2 2 2]]]
[[[2 2 2]
[2 2 2]]
[[2 2 2]
[2 2 2]]
[[2 2 2]
[2 2 2]]]]
因为 index 的形状是 (2,3,2)的,所以提取出来的张量为:最外层为扩展的维度,去掉最外层的[ ],对应的为index的第一个纬度值。从例子看,为2。现在来看去掉[ ]后的第一个
[[[2 2 2]
[2 2 2]]
[[2 2 2]
[2 2 2]]
[[2 2 2]
[2 2 2]]]
其包括3个
[[2 2 2]
[2 2 2]]
也就是说对应index的第二个维度,这一页张量包含3个二维矩阵。将页这层再去掉[],来看
[[2 2 2]
[2 2 2]]
是个矩阵,对应index的第三个维度。那么两行2 2 2 代表什么呢?接着看index的最里面的值是一行两列 1 1,所以提取了两次张量最里面的第二行.
接着看
index2= [[1, 1],
[1, 3],
[5, 1]]
index3 = [1,3]
gather2 = tf.gather(input2,index2)
gather3 = tf.gather(input2,index3)
with tf.Session() as sess:
print('gather2:\n')
print(sess.run(gather2))
print('\n')
print('gather3:\n')
print(sess.run(gather3))
print('\n')
print('gather4:\n')
print(sess.run(gather4))
输出:
gather2:
[[[2 2 2]
[2 2 2]]
[[2 2 2]
[4 4 4]]
[[6 6 6]
[2 2 2]]]
gather3:
[[2 2 2]
[4 4 4]]
gather4:
[[2 2 2]
[4 4 4]
[7 7 7]]
总结来说,index中的数值控制取哪些行,index 的维度控制生成的形状。最终要取多少行,要依据index最内部的行向量是多少维的。换句话说,index最里层有多少个元素,提取出来的张量最里面的二维矩阵就有多少行。更直白地说,index最内部的行索引变为列索引,然后索引号替换为对应的input2中的行就好了(对二维张量的操作,其他的请自行探究)。
更多推荐
所有评论(0)