前言
终于,Pytorch也支持MAC的硬件加速,两个字评价一下感受:真香~
周末笔者在自己机器上完成环境安装,笔者机器环境如下:
接着,笔者在该文用卷积、BERT模型对比了有无MAC硬件加速的模型运行时间
软件安装
按照官网给出的命令,即可完成安装MAC硬件加速版pytorch。
https://pytorch.org/get-started/locally/
conda install pytorch torchvision torchaudio -c pytorch
简单测试
利用卷积操作,测试有无硬件加速的效果。
import torch
import time
dev = 'mps:0'
conv = torch.nn.Conv2d(10, 10, 3).to(dev)
img = torch.randn(64, 10, 64, 64).to(dev)
t0 = time.time()
for i in range(1000):
conv(img)
t1 = time.time()
print('Use mps, time:{}'.format(t1-t0))
dev = 'cpu'
conv = torch.nn.Conv2d(10, 10, 3).to(dev)
img = torch.randn(64, 10, 64, 64).to(dev)
t0 = time.time()
for i in range(1000):
conv(img)
t1 = time.time()
print('Use cpu, time:{}'.format(t1-t0))
运行结果
BERT测试
使用huggingface的glue代码作示例。
数据准备
运行下述代码完成数据下载工作。
''' Script for downloading all GLUE data.
Note: for legal reasons, we are unable to host MRPC.
You can either use the version hosted by the SentEval team, which is already tokenized,
or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually.
For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example).
You should then rename and place specific files in a folder (see below for an example).
mkdir MRPC
cabextract MSRParaphraseCorpus.msi -d MRPC
cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt
cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt
rm MRPC/_*
rm MSRParaphraseCorpus.msi
1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now.
2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray!
'''
import os
import sys
import shutil
import argparse
import tempfile
import urllib.request
import zipfile
TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
TASK2PATH = {"CoLA": 'https://dl.fbaipublicfiles.com/glue/data/CoLA.zip',