[原创] PyTorch做inference/prediction的时候如何使用GPU

话不多说,直接进入主题。

判断能不能使用GPU
可能有多种原因会导致不能使用GPU,比如PyTorch安装的是CPU版的,显卡驱动没有正确安装等。下面的 if 语句在正常的情况下会返回 True:

if torch.cuda.is_available():
    print('PyTorch can use GPU on current machine!')

文章来源:https://www.codelast.com/
 设置模型使用GPU

model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(your_model_file_path))
model.eval()  # 设置成evaluation模式
if torch.cuda.is_available():
    print('PyTorch can use GPU on current machine!')
    device = torch.device("cuda")
    model.to(device)

your_model_file_path 是模型文件的路径。

阅读更多

[原创] PyTorch模型 .pt,.pth,.pkl 的区别

我们经常会看到后缀名为 .pt,.pth,.pkl 的PyTorch模型文件,这几种模型文件在格式上有什么区别吗?
其实它们并不是在格式上有区别,而只是后缀上不同而已(仅此而已)。在用 torch.save() 函数保存模型文件的时候,各人有不同的喜好,有些人喜欢用 .pt 后缀,有些人喜欢用 .pth 或 .pkl。用相同的 torch.save() 语句保存出来的模型文件没有什么不同。
在PyTorch官方的文档/代码里,有用 .pt 的也有用 .pth 的
据某些文章的说法,一般惯例是使用 .pth,但是官方文档里貌似 .pt 更多,而且官方也不是很在意固定用一种,大家就自便吧。

阅读更多