import os import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset, random_split from torchvision import transforms import cv2 from onehot import onehot """数据归一化""" transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) class BagDataset(Dataset): def __init__(self, transform=None): self.transform = transform def __len__(self): return len(os.listdir('bag_data')) def __getitem__(self, idx): img_name = os.listdir('bag_data')[idx] imgA = cv2.imread('bag_data/'+img_name) # 数据读取 imgA = cv2.resize(imgA, (160, 160)) # 将样本尺寸统一为(160,160) imgB = cv2.imread('bag_data_msk/'+img_name, 0) imgB = cv2.resize(imgB, (160, 160)) imgB = imgB/255 imgB = imgB.astype('uint8') imgB = onehot(imgB, 2) imgB = imgB.transpose(2,0,1) imgB = torch.FloatTensor(imgB) #print(imgB.shape) if self.transform: imgA = self.transform(imgA) return imgA, imgB bag = BagDataset(transform) train_size = int(0.9 * len(bag)) # 定义训练集的比例为90% test_size = len(bag) - train_size train_dataset, test_dataset = random_split(bag, [train_size, test_size]) # 划分训练集和测试集 train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1) # 训练数据加载器定义,每次加载batch_size个样本 test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=1) if __name__ =='__main__': for train_batch in train_dataloader: print(train_batch) for test_batch in test_dataloader: print(test_batch)