必威体育Betway必威体育官网
当前位置:首页 > IT技术

tf.gather用法

时间:2019-06-21 00:45:15来源:IT技术作者:seo实验室小编阅读:76次「手机版」
 

gather

tf.gather(params, indices, validate_indices=None, name=None, axis=0)

Gather slices from params axis axis according to indices.

从’params"中,按照axis坐标和indices标注的元素下标,把这些元素抽取出来组成新的tensor.

测试代码:

    import tensorflow as tf
    import numpy as np
    print("\n先测试一维张量\n")
    t=np.random.randint(1,10,5)
    g1=tf.gather(t,[2,1,4])
    sess=tf.Session()
    print(t)
    print(sess.run(g1))
    print("\n再测试二维张量\n")
    t=np.random.randint(1,10,[4,5])
    g2=tf.gather(t,[1,2,2],axis=0)
    g3=tf.gather(t,[1,2,2],axis=1)
    print(t)
    print(sess.run(g2))
    print(sess.run(g3))

结果如下:

先测试一维张量
[7 4 7 1 3]
[7 4 3]
再测试二维张量
[[5 5 7 4 3]
 [8 7 6 5 2]
 [6 9 4 4 8]
 [7 3 3 2 2]]
[[8 7 6 5 2]
 [6 9 4 4 8]
 [6 9 4 4 8]]
[[5 7 7]
 [7 6 6]
 [9 4 4]
 [3 3 3]]

相关阅读

分享到:

栏目导航

推荐阅读

热门阅读