gather
tf.gather(params, indices, validate_indices=None, name=None, axis=0)
Gather slices from params
axis axis
according to indices
.
从’params"中,按照axis坐标和indices标注的元素下标,把这些元素抽取出来组成新的tensor.
测试代码:
import tensorflow as tf
import numpy as np
print("\n先测试一维张量\n")
t=np.random.randint(1,10,5)
g1=tf.gather(t,[2,1,4])
sess=tf.Session()
print(t)
print(sess.run(g1))
print("\n再测试二维张量\n")
t=np.random.randint(1,10,[4,5])
g2=tf.gather(t,[1,2,2],axis=0)
g3=tf.gather(t,[1,2,2],axis=1)
print(t)
print(sess.run(g2))
print(sess.run(g3))
结果如下:
先测试一维张量
[7 4 7 1 3]
[7 4 3]
再测试二维张量
[[5 5 7 4 3]
[8 7 6 5 2]
[6 9 4 4 8]
[7 3 3 2 2]]
[[8 7 6 5 2]
[6 9 4 4 8]
[6 9 4 4 8]]
[[5 7 7]
[7 6 6]
[9 4 4]
[3 3 3]]