58 lines
1.7 KiB
Python
58 lines
1.7 KiB
Python
import os
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
from torch.utils.data import DataLoader, Dataset, random_split
|
||
from torchvision import transforms
|
||
|
||
import cv2
|
||
from onehot import onehot
|
||
|
||
"""数据归一化"""
|
||
transform = transforms.Compose([
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
||
|
||
class BagDataset(Dataset):
|
||
|
||
def __init__(self, transform=None):
|
||
self.transform = transform
|
||
|
||
def __len__(self):
|
||
return len(os.listdir('bag_data'))
|
||
|
||
def __getitem__(self, idx):
|
||
img_name = os.listdir('bag_data')[idx]
|
||
imgA = cv2.imread('bag_data/'+img_name) # 数据读取
|
||
imgA = cv2.resize(imgA, (160, 160)) # 将样本尺寸统一为(160,160)
|
||
imgB = cv2.imread('bag_data_msk/'+img_name, 0)
|
||
imgB = cv2.resize(imgB, (160, 160))
|
||
imgB = imgB/255
|
||
imgB = imgB.astype('uint8')
|
||
imgB = onehot(imgB, 2)
|
||
imgB = imgB.transpose(2,0,1)
|
||
imgB = torch.FloatTensor(imgB)
|
||
#print(imgB.shape)
|
||
if self.transform:
|
||
imgA = self.transform(imgA)
|
||
|
||
return imgA, imgB
|
||
|
||
bag = BagDataset(transform)
|
||
|
||
train_size = int(0.9 * len(bag)) # 定义训练集的比例为90%
|
||
test_size = len(bag) - train_size
|
||
train_dataset, test_dataset = random_split(bag, [train_size, test_size]) # 划分训练集和测试集
|
||
|
||
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1) # 训练数据加载器定义,每次加载batch_size个样本
|
||
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=1)
|
||
|
||
|
||
if __name__ =='__main__':
|
||
|
||
for train_batch in train_dataloader:
|
||
print(train_batch)
|
||
|
||
for test_batch in test_dataloader:
|
||
print(test_batch)
|