pytorch 读取 MNIST
pytorch
本文字数:292 字 | 阅读时长 ≈ 1 min

pytorch 读取 MNIST

pytorch
本文字数:292 字 | 阅读时长 ≈ 1 min

MNIST 数据集一般有两种使用方法,其中一种在 torchvision 中已经包装好了,这里讲解手动加载 MNIST 数据集的方法

1. 下载

首先在官网下载 MNIST 数据集,地址,一共有四个压缩包,下载后解压即可

2. 读取数据

复制下面代码到 readdata.py 中,然后给定数据集路径读取即可

import os
import gzip
import numpy as np
from torch.utils.data import Dataset


'''
  load data
    - data_folder: MNIST folder name
    - data_name: MNIST data name
    - label_name: MNIST lable name
'''
def load_data(data_folder, data_name, label_name):
    with gzip.open(os.path.join(data_folder, label_name), 'rb') as lbpath: # rb表示的是读取二进制数据
        y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

    with gzip.open(os.path.join(data_folder, data_name), 'rb') as imgpath:
        x_train = np.frombuffer(
            imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
    return (x_train, y_train)


class CustomDataset(Dataset):
    """
        读取数据、初始化数据
    """
    def __init__(self, folder, data_name, label_name,transform=None):
        (train_set, train_labels) = load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
        self.train_set = train_set
        self.train_labels = train_labels
        self.transform = transform

    def __getitem__(self, index):

        img, target = np.array(self.train_set[index]), int(self.train_labels[index])
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.train_set)
9月 09, 2024