FCN-test/BagData.py

58 lines
1.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
import cv2
from onehot import onehot
"""数据归一化"""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
class BagDataset(Dataset):
def __init__(self, transform=None):
self.transform = transform
def __len__(self):
return len(os.listdir('bag_data'))
def __getitem__(self, idx):
img_name = os.listdir('bag_data')[idx]
imgA = cv2.imread('bag_data/'+img_name) # 数据读取
imgA = cv2.resize(imgA, (160, 160)) # 将样本尺寸统一为160160
imgB = cv2.imread('bag_data_msk/'+img_name, 0)
imgB = cv2.resize(imgB, (160, 160))
imgB = imgB/255
imgB = imgB.astype('uint8')
imgB = onehot(imgB, 2)
imgB = imgB.transpose(2,0,1)
imgB = torch.FloatTensor(imgB)
#print(imgB.shape)
if self.transform:
imgA = self.transform(imgA)
return imgA, imgB
bag = BagDataset(transform)
train_size = int(0.9 * len(bag)) # 定义训练集的比例为90%
test_size = len(bag) - train_size
train_dataset, test_dataset = random_split(bag, [train_size, test_size]) # 划分训练集和测试集
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1) # 训练数据加载器定义每次加载batch_size个样本
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=1)
if __name__ =='__main__':
for train_batch in train_dataloader:
print(train_batch)
for test_batch in test_dataloader:
print(test_batch)