本篇介绍Pytorch的基础数据类型,判断方式以及常用向量

1|0基础数据类型


  • torch.Tensor是一种包含单一数据类型元素的多维矩阵。
  • 目前在1.2版本中有9种类型。

pytorch数据类型

  • 同python相比,pytorch没有string类型;
  • 由于pytorch是面向计算的,对于字符这种通常通过编码下手;
  • 怎样用数字的形式去表示语言(字符串) : NLP -> one-hot 或 Embedding(Word2vec,glove)

2|0判断数据类型


  1. 打印数据类型:a.type()
  2. 打印的是基本的数据类型,没有提供额外的信息:type(a)
  3. 合法性检验:isinstance(a, torch.FloatTensor)
1
2
3
4
5
6
7
8
In[2]: import torch
In[3]: a = torch.randn(2,3) //两维 , 每个数字是由随机的正态分布来初始化的,均值是0 方差是1
In[4]: a.type() // 方法一:打印数据类型
Out[4]: 'torch.FloatTensor'
In[5]: type(a) // 方法二:较少
Out[5]: torch.Tensor
In[6]: isinstance(a, torch.FloatTensor) // 方法三:合法性检验
Out[6]: True

同一个tensor部署在cpu和gpu时的数据类型是不一样的