深度学习框架_PyTorch_repeat()函数
PyTorch中的repeat()函数可以对张量进行复制。当参数只有两个时,第一个参数表示的是复制后的列数,第二个参数表示复制后的行数。当参数有三个时,第一个参数表示的是复制后的通道数,第二个参数表示的是复制后的列数,第三个参数表示复制后的行数。接下来我们举一个例子来直观理解一下:>>> x = torch.tensor([6,7,8])>>> x.r...
·
PyTorch中的repeat()函数可以对张量进行复制。
当参数只有两个时,第一个参数表示的是复制后的列数,第二个参数表示复制后的行数。
当参数有三个时,第一个参数表示的是复制后的通道数,第二个参数表示的是复制后的列数,第三个参数表示复制后的行数。
接下来我们举一个例子来直观理解一下:
>>> x = torch.tensor([6,7,8])
>>> x.repeat(4,2)
tensor([[6, 7, 8, 6, 7, 8],
[6, 7, 8, 6, 7, 8],
[6, 7, 8, 6, 7, 8],
[6, 7, 8, 6, 7, 8]])
>>> x.repeat(4,2,1)
tensor([[[6, 7, 8],
[6, 7, 8]],
[[6, 7, 8],
[6, 7, 8]],
[[6, 7, 8],
[6, 7, 8]],
[[6, 7, 8],
[6, 7, 8]]])
>>> x.repeat(4,2,1).size()
torch.Size([4, 2, 3])
更多推荐
已为社区贡献18条内容
所有评论(0)