init
This commit is contained in:
commit
f42c31ef6c
|
|
@ -0,0 +1,2 @@
|
|||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
__pycache__
|
||||
.vscode
|
||||
.ipynb_checkpoints
|
||||
.spyproject
|
||||
.idea
|
||||
.tmp
|
||||
assets
|
||||
bag_data
|
||||
bag_data_msk
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, Dataset, random_split
|
||||
from torchvision import transforms
|
||||
|
||||
import cv2
|
||||
from onehot import onehot
|
||||
|
||||
"""数据归一化"""
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
||||
|
||||
class BagDataset(Dataset):
|
||||
|
||||
def __init__(self, transform=None):
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(os.listdir('bag_data'))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_name = os.listdir('bag_data')[idx]
|
||||
imgA = cv2.imread('bag_data/'+img_name) # 数据读取
|
||||
imgA = cv2.resize(imgA, (160, 160)) # 将样本尺寸统一为(160,160)
|
||||
imgB = cv2.imread('bag_data_msk/'+img_name, 0)
|
||||
imgB = cv2.resize(imgB, (160, 160))
|
||||
imgB = imgB/255
|
||||
imgB = imgB.astype('uint8')
|
||||
imgB = onehot(imgB, 2)
|
||||
imgB = imgB.transpose(2,0,1)
|
||||
imgB = torch.FloatTensor(imgB)
|
||||
#print(imgB.shape)
|
||||
if self.transform:
|
||||
imgA = self.transform(imgA)
|
||||
|
||||
return imgA, imgB
|
||||
|
||||
bag = BagDataset(transform)
|
||||
|
||||
train_size = int(0.9 * len(bag)) # 定义训练集的比例为90%
|
||||
test_size = len(bag) - train_size
|
||||
train_dataset, test_dataset = random_split(bag, [train_size, test_size]) # 划分训练集和测试集
|
||||
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1) # 训练数据加载器定义,每次加载batch_size个样本
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=1)
|
||||
|
||||
|
||||
if __name__ =='__main__':
|
||||
|
||||
for train_batch in train_dataloader:
|
||||
print(train_batch)
|
||||
|
||||
for test_batch in test_dataloader:
|
||||
print(test_batch)
|
||||
|
|
@ -0,0 +1,261 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
from torchvision.models.vgg import VGG
|
||||
|
||||
"""Architecture选择:FCN32s; FCN16s; FCN8s; FCNs"""
|
||||
class FCN32s(nn.Module):
|
||||
|
||||
def __init__(self, pretrained_net, n_class):
|
||||
super().__init__()
|
||||
self.n_class = n_class
|
||||
self.pretrained_net = pretrained_net
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn1 = nn.BatchNorm2d(512)
|
||||
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(256)
|
||||
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn3 = nn.BatchNorm2d(128)
|
||||
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn4 = nn.BatchNorm2d(64)
|
||||
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn5 = nn.BatchNorm2d(32)
|
||||
self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.pretrained_net(x)
|
||||
x5 = output['x5']
|
||||
|
||||
score = self.bn1(self.relu(self.deconv1(x5)))
|
||||
score = self.bn2(self.relu(self.deconv2(score)))
|
||||
score = self.bn3(self.relu(self.deconv3(score)))
|
||||
score = self.bn4(self.relu(self.deconv4(score)))
|
||||
score = self.bn5(self.relu(self.deconv5(score)))
|
||||
score = self.classifier(score)
|
||||
|
||||
return score
|
||||
|
||||
|
||||
class FCN16s(nn.Module):
|
||||
|
||||
def __init__(self, pretrained_net, n_class):
|
||||
super().__init__()
|
||||
self.n_class = n_class
|
||||
self.pretrained_net = pretrained_net
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn1 = nn.BatchNorm2d(512)
|
||||
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(256)
|
||||
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn3 = nn.BatchNorm2d(128)
|
||||
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn4 = nn.BatchNorm2d(64)
|
||||
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn5 = nn.BatchNorm2d(32)
|
||||
self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.pretrained_net(x)
|
||||
x5 = output['x5']
|
||||
x4 = output['x4']
|
||||
|
||||
score = self.relu(self.deconv1(x5))
|
||||
score = self.bn1(score + x4)
|
||||
score = self.bn2(self.relu(self.deconv2(score)))
|
||||
score = self.bn3(self.relu(self.deconv3(score)))
|
||||
score = self.bn4(self.relu(self.deconv4(score)))
|
||||
score = self.bn5(self.relu(self.deconv5(score)))
|
||||
score = self.classifier(score)
|
||||
|
||||
return score
|
||||
|
||||
|
||||
class FCN8s(nn.Module):
|
||||
|
||||
def __init__(self, pretrained_net, n_class):
|
||||
super().__init__()
|
||||
self.n_class = n_class
|
||||
self.pretrained_net = pretrained_net
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn1 = nn.BatchNorm2d(512)
|
||||
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(256)
|
||||
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn3 = nn.BatchNorm2d(128)
|
||||
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn4 = nn.BatchNorm2d(64)
|
||||
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn5 = nn.BatchNorm2d(32)
|
||||
self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
output = self.pretrained_net(x)
|
||||
x5 = output['x5']
|
||||
x4 = output['x4']
|
||||
x3 = output['x3']
|
||||
|
||||
score = self.relu(self.deconv1(x5))
|
||||
score = self.bn1(score + x4)
|
||||
score = self.relu(self.deconv2(score))
|
||||
score = self.bn2(score + x3)
|
||||
score = self.bn3(self.relu(self.deconv3(score)))
|
||||
score = self.bn4(self.relu(self.deconv4(score)))
|
||||
score = self.bn5(self.relu(self.deconv5(score)))
|
||||
score = self.classifier(score)
|
||||
|
||||
return score
|
||||
|
||||
|
||||
class FCNs(nn.Module):
|
||||
|
||||
def __init__(self, pretrained_net, n_class):
|
||||
super().__init__()
|
||||
self.n_class = n_class
|
||||
self.pretrained_net = pretrained_net
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn1 = nn.BatchNorm2d(512)
|
||||
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn2 = nn.BatchNorm2d(256)
|
||||
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn3 = nn.BatchNorm2d(128)
|
||||
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn4 = nn.BatchNorm2d(64)
|
||||
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
|
||||
self.bn5 = nn.BatchNorm2d(32)
|
||||
self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
|
||||
# classifier is 1x1 conv, to reduce channels from 32 to n_class
|
||||
|
||||
def forward(self, x):
|
||||
output = self.pretrained_net(x)
|
||||
x5 = output['x5']
|
||||
x4 = output['x4']
|
||||
x3 = output['x3']
|
||||
x2 = output['x2']
|
||||
x1 = output['x1']
|
||||
|
||||
score = self.bn1(self.relu(self.deconv1(x5)))
|
||||
score = score + x4
|
||||
score = self.bn2(self.relu(self.deconv2(score)))
|
||||
score = score + x3
|
||||
score = self.bn3(self.relu(self.deconv3(score)))
|
||||
score = score + x2
|
||||
score = self.bn4(self.relu(self.deconv4(score)))
|
||||
score = score + x1
|
||||
score = self.bn5(self.relu(self.deconv5(score)))
|
||||
score = self.classifier(score)
|
||||
|
||||
return score
|
||||
|
||||
"""pretrained:是否使用预训练模型;model:backbone类型(vgg11;vgg13;vgg16;vgg19);其他参数使用默认设置"""
|
||||
class VGGNet(VGG):
|
||||
def __init__(self, pretrained=True, model='vgg16', requires_grad=True, remove_fc=True, show_params=False):
|
||||
super().__init__(make_layers(cfg[model]))
|
||||
self.ranges = ranges[model]
|
||||
|
||||
if pretrained:
|
||||
exec("self.load_state_dict(models.%s(pretrained=True).state_dict())" % model)
|
||||
|
||||
if not requires_grad:
|
||||
for param in super().parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# delete redundant fully-connected layer params, can save memory
|
||||
# 去掉vgg最后的全连接层(classifier)
|
||||
if remove_fc:
|
||||
del self.classifier
|
||||
|
||||
if show_params:
|
||||
for name, param in self.named_parameters():
|
||||
print(name, param.size())
|
||||
|
||||
def forward(self, x):
|
||||
output = {}
|
||||
# get the output of each maxpooling layer (5 maxpool in VGG net)
|
||||
for idx, (begin, end) in enumerate(self.ranges):
|
||||
#self.ranges = ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)) (vgg16 examples)
|
||||
for layer in range(begin, end):
|
||||
x = self.features[layer](x)
|
||||
output["x%d"%(idx+1)] = x
|
||||
|
||||
return output
|
||||
|
||||
|
||||
ranges = {
|
||||
'vgg11': ((0, 3), (3, 6), (6, 11), (11, 16), (16, 21)),
|
||||
'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)),
|
||||
'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)),
|
||||
'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37))
|
||||
}
|
||||
|
||||
# Vgg-Net config
|
||||
# Vgg网络结构配置
|
||||
cfg = {
|
||||
'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
||||
'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
|
||||
}
|
||||
|
||||
# make layers using Vgg-Net config(cfg)
|
||||
# 由cfg构建vgg-Net
|
||||
def make_layers(cfg, batch_norm=False):
|
||||
layers = []
|
||||
in_channels = 3
|
||||
for v in cfg:
|
||||
if v == 'M':
|
||||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
else:
|
||||
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
|
||||
if batch_norm:
|
||||
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
|
||||
else:
|
||||
layers += [conv2d, nn.ReLU(inplace=True)]
|
||||
in_channels = v
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
'''
|
||||
VGG-16网络参数
|
||||
|
||||
Sequential(
|
||||
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
(1): ReLU(inplace)
|
||||
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
(3): ReLU(inplace)
|
||||
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
(6): ReLU(inplace)
|
||||
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
(8): ReLU(inplace)
|
||||
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
(11): ReLU(inplace)
|
||||
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
(13): ReLU(inplace)
|
||||
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
(15): ReLU(inplace)
|
||||
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
(18): ReLU(inplace)
|
||||
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
(20): ReLU(inplace)
|
||||
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
(22): ReLU(inplace)
|
||||
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
(25): ReLU(inplace)
|
||||
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
(27): ReLU(inplace)
|
||||
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
(29): ReLU(inplace)
|
||||
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
)
|
||||
|
||||
'''
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
# pytorch FCN easiest demo
|
||||
|
||||
> 不断更新中~
|
||||
|
||||
这个repo是在读论文[Fully Convolutional Networks for Semantic Segmentation](http://arxiv.org/abs/1411.4038)时的一个pytorch简单复现,数据集很小,是一些随机背景上的一些包的图片(所有数据集大小一共不到80M),如下图
|
||||
|
||||

|
||||
|
||||
> 关于此数据集详细信息,见[数据集](#数据集)
|
||||
|
||||
根据论文实现了FCN32s、FCN16s、FCN8s和FCNs
|
||||
|
||||
>部分代码参考了[这个repo](https://github.com/wkentaro/pytorch-fcn)
|
||||
|
||||
使用visdom可视化,运行了20个epoch后的可视化如下图:
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
## 1.如何运行
|
||||
|
||||
### 1.1 我的运行环境
|
||||
|
||||
* Windows 10
|
||||
* CUDA 9.x (可选)
|
||||
* Anaconda 3 (numpy、os、datetime、matplotlib)
|
||||
* pytorch == 0.4.1 or 1.0
|
||||
* torchvision == 0.2.1
|
||||
* visdom == 0.1.8.5
|
||||
* OpenCV-Python == 3.4.1
|
||||
|
||||
### 1.2 具体操作
|
||||
|
||||
* 打开终端,输入
|
||||
```sh
|
||||
python -m visdom.server
|
||||
```
|
||||
* 打开另一终端,输入
|
||||
```sh
|
||||
python train.py
|
||||
```
|
||||
* 若没有问题可以打开浏览器输入`http://localhost:8097/`来使用`visdom`可视化
|
||||
|
||||
### 1.3 训练细节
|
||||
|
||||

|
||||
|
||||
## 2. 数据集
|
||||
|
||||
* training data来自[这里](https://github.com/yunlongdong/FCN-pytorch-easiest/tree/master/last),ground-truth来自[这里](https://github.com/yunlongdong/FCN-pytorch-easiest/tree/master/last_msk)。
|
||||
* 链接中提供的图片中,部分ground-truth的有误,而且部分有ground-truth的图片没有对应training data的图片,将这些有错误的图片分别剔除,重新编号排序之后剩余533张图片。
|
||||
* 之后我随机选取了67张图片**旋转180度**,一共在training data和ground-truth分别凑够600张图片(0.jpg ~ 599.jpg)。
|
||||
|
||||
## 3. 可视化
|
||||
|
||||
* train prediction:训练时模型的输出
|
||||
* label:ground-truth
|
||||
* test prediction:预测时模型的输出(每次训练都会预测,但预测数据不参与训练与backprop)
|
||||
* train iter loss:训练时每一批(batch)的loss情况
|
||||
* test iter loss:测试时每一批(batch)的loss情况
|
||||
|
||||
## 4. 包含文件
|
||||
|
||||
### 4.1 [train.py](train.py)
|
||||
|
||||
* 训练网络与可视化
|
||||
* 主函数
|
||||
|
||||
### 4.2 [FCN.py](FCN.py)
|
||||
|
||||
* FCN32s、FCN16s、FCN8s、FCNs网络定义
|
||||
* VGGNet网络定义、VGG不同种类网络参数、构建VGG网络的函数
|
||||
|
||||
### 4.3 [BagData.py](BagData.py)
|
||||
|
||||
* 定义方便PyTorch读取数据的Dataset和DataLoader
|
||||
* 定义数据的变换transform
|
||||
|
||||
### 4.4 [onehot.py](onehot.py)
|
||||
|
||||
* 图片的onehot编码
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# checkpoints
|
||||
|
||||
The model`s checkpoints create in this folder.
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
# 版权信息 / License
|
||||
|
||||
除非额外说明,本仓库的所有公开文档均遵循[署名-非商业性使用-相同方式共享 4.0 国际 (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.zh)许可协议。
|
||||
|
||||
您可以自由地:
|
||||
|
||||
* 共享 — 在任何媒介以任何形式复制、发行本作品
|
||||
* 演绎 — 修改、转换或以本作品为基础进行创作
|
||||
|
||||
|
||||
|
||||
惟须遵守下列条件:
|
||||
|
||||
* 署名 — 您必须给出适当的署名,提供指向本许可协议的链接,同时标明是否(对原始作品)作了修改。您可以用任何合理的方式来署名,但是不得以任何方式暗示许可人为您或您的使用背书。
|
||||
* 非商业性使用 — 您不得将本作品用于商业目的。
|
||||
* 相同方式共享 — 如果您再混合、转换或者基于本作品进行创作,您必须基于与原先许可协议相同的许可协议 分发您贡献的作品。
|
||||
|
||||
---
|
||||
|
||||
Unless otherwise noted, all public documents in this repository are subject to [Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en) license.
|
||||
|
||||
|
||||
You are free to:
|
||||
|
||||
* Share — copy and redistribute the material in any medium or format
|
||||
* Adapt — remix, transform, and build upon the material
|
||||
|
||||
|
||||
Under the following terms:
|
||||
|
||||
* Attribution — You must give appropriate credit, provide a link to the license, and indicate if changes were made. You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use.
|
||||
* NonCommercial — You may not use the material for commercial purposes.
|
||||
* ShareAlike — If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original.
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
import numpy as np
|
||||
|
||||
def onehot(data, n):
|
||||
buf = np.zeros(data.shape + (n, ))
|
||||
nmsk = np.arange(data.size)*n + data.ravel()
|
||||
buf.ravel()[nmsk-1] = 1
|
||||
return buf
|
||||
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
from datetime import datetime
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import visdom
|
||||
|
||||
from BagData import test_dataloader, train_dataloader
|
||||
from FCN import FCN8s, FCN16s, FCN32s, FCNs, VGGNet
|
||||
|
||||
|
||||
def train(epo_num=50, show_vgg_params=False):
|
||||
|
||||
vis = visdom.Visdom() # 可视化工具
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 使用gpu或cpu,若安装cuda默认使用gpu
|
||||
"""调用FCN.py中不同的vgg网络(自选一个),进行定量实验分析"""
|
||||
vgg_model = VGGNet(requires_grad=True, show_params=show_vgg_params) # backbone网络,用于特征提取
|
||||
"""调用FCN.py中不同的FCN网络(自选一个),进行定量实验分析"""
|
||||
fcn_model = FCNs(pretrained_net=vgg_model, n_class=2) # Architecture结构,用于分割
|
||||
fcn_model = fcn_model.to(device) # 将模型加载至指定gpu
|
||||
"""更改损失函数为交叉熵nn.CELoss(), 进行定量实验分析"""
|
||||
criterion = nn.BCELoss().to(device)
|
||||
"""更改模型优化器,如:adam等(自选一个),并调整学习率lr和momentum参数,进行定量实验分析"""
|
||||
optimizer = optim.SGD(fcn_model.parameters(), lr=1e-2, momentum=0.7)
|
||||
|
||||
all_train_iter_loss = [] # 训练loss存储list
|
||||
all_test_iter_loss = [] # 测试loss存储list
|
||||
|
||||
# start timing
|
||||
prev_time = datetime.now() # 训练起始时间记录
|
||||
for epo in range(epo_num): # epo_num数据集训练次数
|
||||
|
||||
train_loss = 0 #初始化loss
|
||||
fcn_model.train() # 打开训练模式
|
||||
for index, (bag, bag_msk) in enumerate(train_dataloader): # 加载数据,bag为数据样本,bag_msk为对应的二进制标签
|
||||
# bag.shape is torch.Size([4, 3, 160, 160])
|
||||
# bag_msk.shape is torch.Size([4, 2, 160, 160])
|
||||
|
||||
bag = bag.to(device) #数据加载至device上
|
||||
bag_msk = bag_msk.to(device) #标签加载至device上
|
||||
|
||||
optimizer.zero_grad() #优化器梯度清空
|
||||
output = fcn_model(bag) # 模型forward
|
||||
output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])
|
||||
loss = criterion(output, bag_msk) # 计算模型预测与标签之间的损失
|
||||
loss.backward() # 反向传播
|
||||
iter_loss = loss.item()
|
||||
all_train_iter_loss.append(iter_loss) #loss添加至存储的list里
|
||||
train_loss += iter_loss #训练集loss累加
|
||||
optimizer.step() # 模型参数优化更新
|
||||
|
||||
output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160)
|
||||
output_np = np.argmin(output_np, axis=1)
|
||||
bag_msk_np = bag_msk.cpu().detach().numpy().copy() # bag_msk_np.shape = (4, 2, 160, 160)
|
||||
bag_msk_np = np.argmin(bag_msk_np, axis=1)
|
||||
|
||||
if np.mod(index, 15) == 0:
|
||||
print('epoch {}, {}/{},train loss is {}'.format(epo, index, len(train_dataloader), iter_loss))
|
||||
# vis.close()
|
||||
"""结果可视化"""
|
||||
vis.images(output_np[:, None, :, :], win='train_pred', opts=dict(title='train prediction'))
|
||||
vis.images(bag_msk_np[:, None, :, :], win='train_label', opts=dict(title='label'))
|
||||
vis.line(all_train_iter_loss, win='train_iter_loss',opts=dict(title='train iter loss'))
|
||||
|
||||
test_loss = 0
|
||||
fcn_model.eval() # 模型打开测试模式
|
||||
with torch.no_grad(): # 测试阶段取消梯度
|
||||
for index, (bag, bag_msk) in enumerate(test_dataloader):
|
||||
|
||||
bag = bag.to(device)
|
||||
bag_msk = bag_msk.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
output = fcn_model(bag)
|
||||
output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])
|
||||
loss = criterion(output, bag_msk)
|
||||
iter_loss = loss.item()
|
||||
all_test_iter_loss.append(iter_loss)
|
||||
test_loss += iter_loss
|
||||
|
||||
output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160)
|
||||
output_np = np.argmin(output_np, axis=1)
|
||||
bag_msk_np = bag_msk.cpu().detach().numpy().copy() # bag_msk_np.shape = (4, 2, 160, 160)
|
||||
bag_msk_np = np.argmin(bag_msk_np, axis=1)
|
||||
|
||||
if np.mod(index, 15) == 0:
|
||||
print(r'Testing... Open http://localhost:8097/ to see test result.')
|
||||
# vis.close()
|
||||
vis.images(output_np[:, None, :, :], win='test_pred', opts=dict(title='test prediction'))
|
||||
vis.images(bag_msk_np[:, None, :, :], win='test_label', opts=dict(title='label'))
|
||||
vis.line(all_test_iter_loss, win='test_iter_loss', opts=dict(title='test iter loss'))
|
||||
|
||||
"""显示模型运行时间"""
|
||||
cur_time = datetime.now()
|
||||
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
|
||||
m, s = divmod(remainder, 60)
|
||||
time_str = "Time %02d:%02d:%02d" % (h, m, s)
|
||||
prev_time = cur_time
|
||||
|
||||
print('epoch train loss = %f, epoch test loss = %f, %s'
|
||||
%(train_loss/len(train_dataloader), test_loss/len(test_dataloader), time_str))
|
||||
|
||||
"""每5个epoch保存一次模型到checkpoints路径"""
|
||||
if np.mod(epo, 5) == 0:
|
||||
torch.save(fcn_model, 'checkpoints/fcn_model_{}.pt'.format(epo))
|
||||
print('saveing checkpoints/fcn_model_{}.pt'.format(epo))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
train(epo_num=100, show_vgg_params=False) #调用训练函数开始训练,训练epoch数为100
|
||||
Loading…
Reference in New Issue