添加链接
link之家
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
备案 控制台
学习
实践
活动
专区
工具
TVP
写文章
专栏首页 AI机器学习与深度学习算法 PyTorch入门笔记-分割split函数
1 0

海报分享

PyTorch入门笔记-分割split函数

split

torch.split(input, split_size_or_sections, dim = 0) 函数会将输入张量(input)沿着指定维度(dim)分割成特定数量的张量块,并返回元素为张量块的元素。 简单来说,可以将 torch.split 函数看成是 torch.chunk 函数的进阶版,因为 torch.split 不仅能够指定块数均匀分割(torch.chunk 只能指定块数均匀分割),而且能够指定分割每一块的长度。 torch.split 函数有三个参数:

  • tensor(Tensor)- 待分割的输入张量,此处的 tensor 参数和 torch.chunk 函数中的 input 参数类似,只需要注意使用关键字参数时候的参数名
  • split_size_or_sections(int)or(list(int))参数
    • 指定为 int 时,和 torch.chunk(input, chunks, dim = 0) 函数中的 chunks 参数功能一样;
    • 指定为 list(int) 时,list 中的每一个整数元素代表分割的块数,而每个块的长度由对应的整型元素决定;
  • dim(int)- 进行分割的维度

torch.split 函数一共有两种分割形式,而这两种分割形式是由传入 split_size_or_sections 参数的类型所决定的。

指定为 int 时

当传入 torch.split 函数中的 split_size_or_sections 参数为整型时(int),torch.split 函数和 torch.chunk 函数所实现的功能一样,torch.split 函数中的 split_size_or_sections 参数和 torch.chunk 函数中的 chunks 参数等价。

“简单回顾上一小节介绍的 torch.chunk: 使用 torch.chunk 函数沿着 dim 维度将张量均匀的分割成 chunks 块,若式子

\frac{input.size(dim)}{chunks}

结果为:

  • 整数(整除),表示能够将其均匀的分割成 chunks 块,直接进行分割即可;
  • 浮点数(不能够整除),先按每块
\lceil \frac{input.size(dim)}{chunks} \rceil

\lceil \ \rceil

为向上取整)进行分割,余下的作为最后一块;

比如,将形状为

[2, 3]

的张量

B

,现在沿着第 1 个维度均匀的分割成 2 块。 B.size(1) = 3 、chunks = 2,即:

\frac{input.size(dim)}{chunks} = \frac{B.size(3)}{chunks}
=\frac{3}{2} = 1.5

1.5 不是整数,则将其向上取整

\lceil 1.5 \rceil = 2

,先将 3 按每块 2 个进行分割,余下的作为最后一块。

import torch
B = torch.arange(6).reshape(2, 3)
# 使用torch.chunk函数
result_chunk = torch.chunk(input = B,
                     chunks = 2,
                     dim = 1)
# 使用torch.split函数
result_split = torch.split(tensor = B,
                       split_size_or_sections = 2,
                       dim = 1)
print(B)
# tensor([[0, 1, 2],
#         [3, 4, 5]])
print(result_chunk)
# (tensor([[0, 1],
#          [3, 4]]), 
#  tensor([[2],
#          [5]]))
print(result_split)
# (tensor([[0, 1],
#          [3, 4]]), 
#  tensor([[2],
#          [5]]))

实验结果显示,当传入 torch.split 函数中的 split_size_or_sections 参数为整型时(int),torch.split 和 torch.chunk 两个函数完全一样。

指定为 list 时

当传入 torch.split 函数中的参数 split_size_or_sections 为列表类型时(具体来说应该是元素为 int 整型的 list 列表),list 中的每一个整数元素代表分割的块数,而每个块的长度由对应的整型元素决定。

比如,将形状为

[2, 3]

的张量

B

,现在沿着第 1 个维度分割成 2 块,第一块长度为 1,而第二块长度为 2。使用 torch.split 函数,只需要为 split_size_or_sections 参数传入 [1, 2] 列表即可。

import torch
B = torch.arange(6).reshape(2, 3)
result = torch.split(tensor = B,
                 split_size_or_sections = [1, 2],
                 dim = 1)
print(B)
# tensor([[0, 1, 2],
#       [3, 4, 5]])
print(result)
# (tensor([[0],
#          [3]]), 
#  tensor([[1, 2],
#          [4, 5]]))

传入 split_size_or_sections 参数的 list 中的每一个整数元素代表分割的块数,而每个块的长度由对应的整型元素决定,因此待分割维度的长度应该等于 list 中所有整型元素之和,否则程序会报错。

import torch
B = torch.arange(6).reshape(2, 3)
result = torch.split(tensor = B,
                 split_size_or_sections = [1, 4],
                 dim = 1)