GitHub链接
张量的结构操作.
创建张量
张量的创建方法类似 numpy 中创建 array 的方法.
a = torch.tensor([1,2,3],dtype = torch.float) b = torch.arange(1,10,step = 2) c = torch.linspace(0.0,2*3.14,10) d = torch.zeros((3,3))
a = torch.ones((3,3),dtype = torch.int) b = torch.zeros_like(a,dtype = torch.float)
torch.fill_(b,5)
torch.manual_seed(42) minval,maxval = 0, 10 a = minval + (maxval-minval)*torch.rand([5])
b = torch.normal(mean = torch.zeros(3,3), std = torch.ones(3,3))
mean,std = 2, 5 c = std*torch.randn((3,3)) + mean
d = torch.randperm(20)
I = torch.eye(3,3) print(I) t = torch.diag(torch.tensor([1,2,3]))
|
索引切片
几乎和 numpy 相同, 切片时支持缺省参数和省略号, 可以通过索引和切片对部分元素进行修改.
对于不规则的切片提取, 可以使用torch.index_select, torch.masked_select, torch.take.
如果要通过修改张量的某些元素得到新的张量, 可以使用torch.where, torch.masked_fill, torch.index_fill.
torch.manual_seed(42) minval,maxval = 0, 10 t = torch.floor(minval + (maxval-minval)*torch.rand([5,5])).int()
print(t[0]) print(t[-1])
print(t[1,3]) print(t[1][3])
print(t[1:4,:])
x = torch.Tensor([[1,2],[3,4]]) x.data[1,:] = torch.tensor([0.0,0.0])
a = torch.arange(27).view(3,3,3)
print(a[...,1])
|
对于不规则的切片提取, 可以使用torch.index_select, torch.masked_select, torch.take.
有 4 个班级, 每个班级 5 个学生, 每个学生 7 门科目成绩, 可以用一个 4×5×7 的张量来表示.
minval = 0 maxval = 100 scores = torch.floor(minval + (maxval-minval)*torch.rand([4,5,7])).int()
torch.index_select(scores,dim = 1,index = torch.tensor([0,2,4]))
q = torch.index_select(torch.index_select(scores,dim = 1,index = torch.tensor([0,2,4])),dim=2,index = torch.tensor([1,3,6]))
s = torch.take(scores,torch.tensor([0*5*7+0,2*5*7+3*7+1,3*5*7+4*7+6]))
g = torch.masked_select(scores,scores>=80)
ifpass = torch.where(scores>60,torch.tensor(1),torch.tensor(0))
torch.index_fill(scores,dim = 1,index = torch.tensor([0,2,4]),value = 100)
b = torch.masked_fill(scores,scores<60,60)
|
维度变换
相关函数有torch.reshape(改变形状, 或者调用张量的 view 方法), torch.squeeze(减少维度), torch.unsqueeze(增加维度), torch.transpose/torch.permute(交换维度).
torch.manual_seed(0) minval,maxval = 0,255 a = (minval + (maxval-minval)*torch.rand([1,3,3,2])).int() b = a.view([3,6]) c = torch.reshape(b,[1,3,3,2])
|
如果张量在某个维度上只有一个元素, 利用 torch.squeeze 可以消除这个维度. torch.unsqueeze 作用相反.
a = torch.tensor([[1.0,2.0]]) s = torch.squeeze(a) d = torch.unsqueeze(s,axis=0)
|
torch.transpose 可以交换张量的维度, 常用于图片存储格式变换.
如果是二维矩阵, 通常会调用矩阵的转置方法matrix.t(), 等价于torch.transpose(matrix,0,1).
minval=0 maxval=255
data = torch.floor(minval + (maxval-minval)*torch.rand([100,256,256,4])).int()
data_t = torch.transpose(torch.transpose(data,1,2),1,3) print(data_t.shape)
data_p = torch.permute(data,[0,3,1,2])
|
合并分割
可以用 torch.cat 方法和 torch.stack 方法将多个张量合并, 可以用 torch.split 方法把一个张量分割成多个张量.
torch.cat 是连接, 不会增加维度, torch.stack 是堆叠, 会增加维度.
a = torch.tensor([[1.0,2.0],[3.0,4.0]]) b = torch.tensor([[5.0,6.0],[7.0,8.0]]) c = torch.tensor([[9.0,10.0],[11.0,12.0]])
abc_cat = torch.cat([a,b,c],dim = 0) abc_stack = torch.stack([a,b,c],axis = 0)
torch.cat([a,b,c],axis = 1) torch.stack([a,b,c],axis = 1)
|
torch.split 是 torch.cat 的逆运算, 可以指定分割份数平均分割, 也可以通过指定每份的记录数量进行分割.
a,b,c = torch.split(abc_cat,split_size_or_sections = 2,dim = 0) p,q,r = torch.split(abc_cat,split_size_or_sections =[4,1,1],dim = 0)
|