PyTorch入门实践:COVID-19 病例预测 (回归)
文章目录
更多Pytorch内容欢迎查看 快速入门Pytorch-CSDN博客
任务描述
根据美国特定州过去5天的调查结果,预测第5天新检测阳性病例的百分比。
数据简介:
- 在这种情况下,数据包含在
.csv
文件中 - 每行代表一个数据样本,包含118个特征(id + 37个州+ 16个特征 * 5天)
- 一行的最后一个元素是它的标签
功能函数
导入需要的Python包
# 数值、矩阵操作
import math
import numpy as np
# 数据读取与写入
import pandas as pd
import os
import csv
# 进度条
# from tqdm import tqdm
# 如果是使用notebook 推荐使用以下(颜值更高 : ) )
from tqdm.notebook import tqdm
# Pytorch 深度学习张量操作框架
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
# 绘制pytorch的网络
from torchviz import make_dot
# 学习曲线绘制
from torch.utils.tensorboard import SummaryWriter
一些重要的方法(随机种子设置、数据拆分、模型预测)
# 定义一个函数来设置随机种子,以确保实验的可复现性
def same_seed(seed):
"""
设置随机种子(便于复现)
"""
# 设置CUDA的确定性,确保每次运行的结果是确定的
torch.backends.cudnn.deterministic = True
# 关闭CUDA的benchmark模式,因为这与确定性运行模式冲突
torch.backends.cudnn.benchmark = False
# 设置NumPy的随机种子
np.random.seed(seed)
# 设置PyTorch的随机种子
torch.manual_seed(seed)
# 如果CUDA可用,则为GPU设置随机种子
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# 打印设置的种子值
print(f'Set Seed = {
seed}')
# 定义一个函数来将数据集随机拆分为训练集和验证集
def train_valid_split(data_set, valid_ratio, seed):
"""
数据集拆分成训练集(training set)和 验证集(validation set)
"""
# 计算验证集的大小
valid_set_size = int(valid_ratio * len(data_set))
# 训练集的大小是数据集总大小减去验证集大小
train_set_size = len(data_set) - valid_set_size
# 使用PyTorch的random_split函数来拆分数据集,传入随机种子以确保可复现性
train_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(seed))
# 将拆分得到的数据集转换为NumPy数组格式并返回
return np.array(train_set), np.array(valid_set)
# 定义一个函数来进行模型的预测
def predict(test_loader, model, device):
# 将模型设置为评估模式
model.eval()
# 初始化一个列表来存储预测结果
preds = []
# 遍历测试数据集
for x in tqdm(test_loader):
# 将数据移动到指定的设备上(CPU或GPU)
x = x.to(device)
# 使用with torch.no_grad()来禁止计算梯度,因为在预测模式下不需要计算梯度
with torch.no_grad():
# 进行前向传播以获得预测结果
pred = model(x)
# 将预测结果从GPU移回CPU,并将其从计算图中分离出来
preds.append(pred.detach().cpu())
# 将所有批次的预测结果拼接成一个NumPy数组,并返回
preds = torch.cat(preds, dim=0).numpy()
return preds
数据加载
自定义数据集加载类
# 定义一个COVID19数据集类,继承自PyTorch的Dataset类
class COVID19Dataset(Dataset):
"""
x: np.ndarray 特征矩阵.
y: np.ndarray 目标标签, 如果为None,则是预测的数据集
"""
def __init__(self, x, y=None):
# 如果y不是None,则将y转换为PyTorch的FloatTensor类型,