split
torch.split(input, split_size_or_sections, dim = 0)
函数会将输入张量(input)沿着指定维度(dim)分割成特定数量的张量块,并返回元素为张量块的元素。
简单来说,可以将 torch.split 函数看成是 torch.chunk 函数的进阶版,因为 torch.split 不仅能够指定块数均匀分割(torch.chunk 只能指定块数均匀分割),而且能够指定分割每一块的长度。
torch.split 函数有三个参数:
-
tensor(Tensor)- 待分割的输入张量,此处的 tensor 参数和 torch.chunk 函数中的 input 参数类似,只需要注意使用关键字参数时候的参数名
-
split_size_or_sections(int)or(list(int))参数
-
指定为 int 时,和
torch.chunk(input, chunks, dim = 0)
函数中的 chunks 参数功能一样;
-
指定为 list(int) 时,list 中的每一个整数元素代表分割的块数,而每个块的长度由对应的整型元素决定;
-
dim(int)- 进行分割的维度
torch.split 函数一共有两种分割形式,而这两种分割形式是由传入 split_size_or_sections 参数的类型所决定的。
指定为 int 时
当传入 torch.split 函数中的 split_size_or_sections 参数为整型时(int),torch.split 函数和 torch.chunk 函数所实现的功能一样,torch.split 函数中的 split_size_or_sections 参数和 torch.chunk 函数中的 chunks 参数等价。
“简单回顾上一小节介绍的 torch.chunk:
使用 torch.chunk 函数沿着 dim 维度将张量均匀的分割成 chunks 块,若式子
结果为:
-
整数(整除),表示能够将其均匀的分割成 chunks 块,直接进行分割即可;
-
浮点数(不能够整除),先按每块
(
为向上取整)进行分割,余下的作为最后一块;
”
比如,将形状为
的张量
,现在沿着第 1 个维度均匀的分割成 2 块。
B.size(1) = 3
、chunks = 2,即:
1.5 不是整数,则将其向上取整
,先将 3 按每块 2 个进行分割,余下的作为最后一块。
import torch
B = torch.arange(6).reshape(2, 3)
# 使用torch.chunk函数
result_chunk = torch.chunk(input = B,
chunks = 2,
dim = 1)
# 使用torch.split函数
result_split = torch.split(tensor = B,
split_size_or_sections = 2,
dim = 1)
print(B)
# tensor([[0, 1, 2],
# [3, 4, 5]])
print(result_chunk)
# (tensor([[0, 1],
# [3, 4]]),
# tensor([[2],
# [5]]))
print(result_split)
# (tensor([[0, 1],
# [3, 4]]),
# tensor([[2],
# [5]]))
实验结果显示,当传入 torch.split 函数中的 split_size_or_sections 参数为整型时(int),torch.split 和 torch.chunk 两个函数完全一样。
指定为 list 时
当传入 torch.split 函数中的参数 split_size_or_sections 为列表类型时(具体来说应该是元素为 int 整型的 list 列表),list 中的每一个整数元素代表分割的块数,而每个块的长度由对应的整型元素决定。
比如,将形状为
的张量
,现在沿着第 1 个维度分割成 2 块,第一块长度为 1,而第二块长度为 2。使用 torch.split 函数,只需要为 split_size_or_sections 参数传入 [1, 2] 列表即可。
import torch
B = torch.arange(6).reshape(2, 3)
result = torch.split(tensor = B,
split_size_or_sections = [1, 2],
dim = 1)
print(B)
# tensor([[0, 1, 2],
# [3, 4, 5]])
print(result)
# (tensor([[0],
# [3]]),
# tensor([[1, 2],
# [4, 5]]))
传入 split_size_or_sections 参数的 list 中的每一个整数元素代表分割的块数,而每个块的长度由对应的整型元素决定,因此待分割维度的长度应该等于 list 中所有整型元素之和,否则程序会报错。
import torch
B = torch.arange(6).reshape(2, 3)
result = torch.split(tensor = B,
split_size_or_sections = [1, 4],
dim = 1)