添加链接
link之家
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
pytorch框架下模型默认的输入数据类型

pytorch框架下模型默认的输入数据类型

刚开始动手写数据预处理程序,模型和训练函数的时候,由于我的数据类型是 torch.float64 ,在运行的过程中总是报错 RuntimeError: expected scalar type Double but found Float, 以为是自己编写的过程中数据类型之间不匹配导致的,因此不断地进行数据类型转换,导致改的很乱。后来不断深入,感觉这不是问题所在,查找资料发现, 原来pytorch默认使用单精度float32训练模型,其主要原因为 :使用float16训练模型,模型效果会有损失,而使用double(float64)会有2倍的内存压力,且不会带来太多的精度提升, 因此默认使用单精度float32训练模型。

pytorch如何更改默认单精度float32训练模型,而改为 torch .float64对模型进行训练呢?

解决办法: 把模型的权重参数数据类型和输入数据类型全部设置为torch.float64。

使用 torch.set_default_dtype(torch.float64) 把模型参数转化为float64。

输入类型使用 tensor.type(torch.float64) 将输入数据类型转换为torch.float64。

文章被以下专栏收录