添加链接
link之家
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

PyTorch 1.8.0 torch.sparse类

稀疏矩阵适合用于处理大数据计算问题,在节省存储空间方面常常有较好的表现,因此在很多工程项目中都有稀疏矩阵的参与

1. 构建稀疏矩阵

  • 官方直接torch构建(coo类型的稀疏矩阵)
i = torch.LongTensor([[0, 1, 1],[2, 0, 2]])   #row, col
v = torch.FloatTensor([3, 4, 5])    #data
torch.sparse.FloatTensor(i, v, torch.Size([2,3])).to_dense()   #torch.Size()
  • 调用sp.coo_matrix(),构建方式同上,传入数据为np.array,通常用于数据预处理
row = np.array(row_list, dtype=np.int32)
col = np.array(col_list, dtype=np.int32)
data = np.oneslike(data_list, dtype=np.float32)
n_nodes = max + 1
tmp_adj = sp.coo_matrix((data, (row, col)), shape=(n_nodes, n_nodes))

转换成tensor,具体流程同上,把稀疏矩阵中每个元素都转换成tensor即可。

2. 稀疏矩阵计算

  • 该版本稀疏矩阵类支持稀疏矩阵和稀疏矩阵的乘积torch.sparse.mm(sparse, sparse/dense);(1.8.0支持,之前版本不支持)
  • 矩阵元素乘torch.mul(sparse,sparse),此处两个sparse的row,col,size需要一致。
  • 稀疏矩阵支持转置。Sparse.matrix.t()
  • 稀疏矩阵支持整行索引,支持Sparse.matrix[row_index];稀疏矩阵不支持具体位置位置索引Sparse.matrix[row_index,col_index]。
a = torch.sparse.FloatTensor(torch.tensor([[0,1,2],[2,3,4]]), torch.tensor([1,1,1]), torch.Size([5,5]))
a[0] ->tensor(indices=tensor([[2]]),values=tensor([1]),size=(5,), nnz=1, layout=torch.sparse_coo)
  • 稀疏矩阵相加 (bug1)

稀疏矩阵在gpu上调用torch.add(),结果是两个稀疏矩阵的拼接,size会变成原来的二倍,相同索引会被复制两遍,但在cpu上调用torch.add(),结果虽然是拼接,但是相同索引处的值会相加。

torch.add(a,a)   
->tensor(indices=tensor([[0, 1, 2], [2, 3, 4]]),values=tensor([2, 2, 2]),
size=(5, 5), nnz=3, layout=torch.sparse_coo)
b=a.cuda()
torch.add(b,b)   
-> tensor(indices=tensor([[0, 1, 2, 0, 1, 2],[2, 3, 4, 2, 3, 4]]),values=tensor([1, 1, 1, 1, 1, 1]),
device='cuda:0', size=(5, 5), nnz=6, layout=torch.sparse_coo)
a1=torch.sparse.FloatTensor(torch.tensor([[0,3,2],[2,3,2]]), torch.tensor([1,1,1]), torch.Size([5,5]))
torch.add(a,a1)    
->tensor(indices=tensor([[0, 1, 2, 3, 2],[2, 3, 4, 3, 2]]),