云计算百科
云计算领域专业知识百科平台

python打卡day38

Dataset和Dataloader类

  • Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
  • Dataloader类
  • minist手写数据集的了解
  • 作业:了解下cifar数据集,尝试获取其中一张图片

    import torch
    import torchvision
    from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
    import torchvision.transforms as transforms
    import matplotlib.pyplot as plt
    import numpy as np

    # 设置随机种子以确保结果可重复
    torch.manual_seed(42)

    transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) # CIFAR的标准化参数
    ])
    train_dataset = torchvision.datasets.CIFAR10(
    root='./dataCIFAR', # 数据存放的路径
    train=True, # 使用训练集
    download=True, # 如果没有数据,就下载
    transform=transform
    )

    # 定义类别
    classes = ('plane', 'car', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck')

    # 随机选择一张图片
    idx = torch.randint(0, len(train_dataset), size=(1,))
    img, label = train_dataset[idx]

    # 反标准化函数
    def denormalize(x):
    mean = torch.tensor([0.4914, 0.4822, 0.4465])
    std = torch.tensor([0.2470, 0.2435, 0.2616])
    # CIFAR-10是彩色图像,需要对所有通道进行反标准化
    return x * std[:, None, None] + mean[:, None, None]

    # 显示图片
    plt.figure()
    plt.imshow(denormalize(img).permute(1, 2, 0)) # 调整通道顺序以正确显示彩色图像
    plt.title(f'Label: {classes[label]}')
    plt.axis('off')
    plt.show()

    # 3. 创建数据加载器
    train_loader = DataLoader(
    train_dataset,
    batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关
    shuffle=True # 随机打乱数据
    )

    @浙大疏锦行

    赞(0)
    未经允许不得转载:网硕互联帮助中心 » python打卡day38
    分享到: 更多 (0)

    评论 抢沙发

    评论前必须登录!