This commit is contained in:
LYC 2024-04-01 17:01:58 +08:00
commit f42c31ef6c
9 changed files with 569 additions and 0 deletions

2
.gitattributes vendored Normal file
View File

@ -0,0 +1,2 @@
# Auto detect text files and perform LF normalization
* text=auto

9
.gitignore vendored Normal file
View File

@ -0,0 +1,9 @@
__pycache__
.vscode
.ipynb_checkpoints
.spyproject
.idea
.tmp
assets
bag_data
bag_data_msk

57
BagData.py Normal file
View File

@ -0,0 +1,57 @@
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
import cv2
from onehot import onehot
"""数据归一化"""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
class BagDataset(Dataset):
def __init__(self, transform=None):
self.transform = transform
def __len__(self):
return len(os.listdir('bag_data'))
def __getitem__(self, idx):
img_name = os.listdir('bag_data')[idx]
imgA = cv2.imread('bag_data/'+img_name) # 数据读取
imgA = cv2.resize(imgA, (160, 160)) # 将样本尺寸统一为160160
imgB = cv2.imread('bag_data_msk/'+img_name, 0)
imgB = cv2.resize(imgB, (160, 160))
imgB = imgB/255
imgB = imgB.astype('uint8')
imgB = onehot(imgB, 2)
imgB = imgB.transpose(2,0,1)
imgB = torch.FloatTensor(imgB)
#print(imgB.shape)
if self.transform:
imgA = self.transform(imgA)
return imgA, imgB
bag = BagDataset(transform)
train_size = int(0.9 * len(bag)) # 定义训练集的比例为90%
test_size = len(bag) - train_size
train_dataset, test_dataset = random_split(bag, [train_size, test_size]) # 划分训练集和测试集
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1) # 训练数据加载器定义每次加载batch_size个样本
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=1)
if __name__ =='__main__':
for train_batch in train_dataloader:
print(train_batch)
for test_batch in test_dataloader:
print(test_batch)

261
FCN.py Normal file
View File

@ -0,0 +1,261 @@
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models.vgg import VGG
"""Architecture选择FCN32s; FCN16s; FCN8s; FCNs"""
class FCN32s(nn.Module):
def __init__(self, pretrained_net, n_class):
super().__init__()
self.n_class = n_class
self.pretrained_net = pretrained_net
self.relu = nn.ReLU(inplace=True)
self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn1 = nn.BatchNorm2d(512)
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn2 = nn.BatchNorm2d(256)
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn5 = nn.BatchNorm2d(32)
self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
def forward(self, x):
output = self.pretrained_net(x)
x5 = output['x5']
score = self.bn1(self.relu(self.deconv1(x5)))
score = self.bn2(self.relu(self.deconv2(score)))
score = self.bn3(self.relu(self.deconv3(score)))
score = self.bn4(self.relu(self.deconv4(score)))
score = self.bn5(self.relu(self.deconv5(score)))
score = self.classifier(score)
return score
class FCN16s(nn.Module):
def __init__(self, pretrained_net, n_class):
super().__init__()
self.n_class = n_class
self.pretrained_net = pretrained_net
self.relu = nn.ReLU(inplace=True)
self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn1 = nn.BatchNorm2d(512)
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn2 = nn.BatchNorm2d(256)
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn5 = nn.BatchNorm2d(32)
self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
def forward(self, x):
output = self.pretrained_net(x)
x5 = output['x5']
x4 = output['x4']
score = self.relu(self.deconv1(x5))
score = self.bn1(score + x4)
score = self.bn2(self.relu(self.deconv2(score)))
score = self.bn3(self.relu(self.deconv3(score)))
score = self.bn4(self.relu(self.deconv4(score)))
score = self.bn5(self.relu(self.deconv5(score)))
score = self.classifier(score)
return score
class FCN8s(nn.Module):
def __init__(self, pretrained_net, n_class):
super().__init__()
self.n_class = n_class
self.pretrained_net = pretrained_net
self.relu = nn.ReLU(inplace=True)
self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn1 = nn.BatchNorm2d(512)
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn2 = nn.BatchNorm2d(256)
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn5 = nn.BatchNorm2d(32)
self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
def forward(self, x):
output = self.pretrained_net(x)
x5 = output['x5']
x4 = output['x4']
x3 = output['x3']
score = self.relu(self.deconv1(x5))
score = self.bn1(score + x4)
score = self.relu(self.deconv2(score))
score = self.bn2(score + x3)
score = self.bn3(self.relu(self.deconv3(score)))
score = self.bn4(self.relu(self.deconv4(score)))
score = self.bn5(self.relu(self.deconv5(score)))
score = self.classifier(score)
return score
class FCNs(nn.Module):
def __init__(self, pretrained_net, n_class):
super().__init__()
self.n_class = n_class
self.pretrained_net = pretrained_net
self.relu = nn.ReLU(inplace=True)
self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn1 = nn.BatchNorm2d(512)
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn2 = nn.BatchNorm2d(256)
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn5 = nn.BatchNorm2d(32)
self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
# classifier is 1x1 conv, to reduce channels from 32 to n_class
def forward(self, x):
output = self.pretrained_net(x)
x5 = output['x5']
x4 = output['x4']
x3 = output['x3']
x2 = output['x2']
x1 = output['x1']
score = self.bn1(self.relu(self.deconv1(x5)))
score = score + x4
score = self.bn2(self.relu(self.deconv2(score)))
score = score + x3
score = self.bn3(self.relu(self.deconv3(score)))
score = score + x2
score = self.bn4(self.relu(self.deconv4(score)))
score = score + x1
score = self.bn5(self.relu(self.deconv5(score)))
score = self.classifier(score)
return score
"""pretrained:是否使用预训练模型model:backbone类型vgg11;vgg13;vgg16;vgg19其他参数使用默认设置"""
class VGGNet(VGG):
def __init__(self, pretrained=True, model='vgg16', requires_grad=True, remove_fc=True, show_params=False):
super().__init__(make_layers(cfg[model]))
self.ranges = ranges[model]
if pretrained:
exec("self.load_state_dict(models.%s(pretrained=True).state_dict())" % model)
if not requires_grad:
for param in super().parameters():
param.requires_grad = False
# delete redundant fully-connected layer params, can save memory
# 去掉vgg最后的全连接层(classifier)
if remove_fc:
del self.classifier
if show_params:
for name, param in self.named_parameters():
print(name, param.size())
def forward(self, x):
output = {}
# get the output of each maxpooling layer (5 maxpool in VGG net)
for idx, (begin, end) in enumerate(self.ranges):
#self.ranges = ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)) (vgg16 examples)
for layer in range(begin, end):
x = self.features[layer](x)
output["x%d"%(idx+1)] = x
return output
ranges = {
'vgg11': ((0, 3), (3, 6), (6, 11), (11, 16), (16, 21)),
'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)),
'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)),
'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37))
}
# Vgg-Net config
# Vgg网络结构配置
cfg = {
'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
# make layers using Vgg-Net config(cfg)
# 由cfg构建vgg-Net
def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
'''
VGG-16网络参数
Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
'''
if __name__ == "__main__":
pass

82
README.md Normal file
View File

@ -0,0 +1,82 @@
# pytorch FCN easiest demo
> 不断更新中~
这个repo是在读论文[Fully Convolutional Networks for Semantic Segmentation](http://arxiv.org/abs/1411.4038)时的一个pytorch简单复现数据集很小是一些随机背景上的一些包的图片所有数据集大小一共不到80M如下图
![数据集示意图](assets/task.png)
> 关于此数据集详细信息,见[数据集](#数据集)
根据论文实现了FCN32s、FCN16s、FCN8s和FCNs
>部分代码参考了[这个repo](https://github.com/wkentaro/pytorch-fcn)
使用visdom可视化运行了20个epoch后的可视化如下图
![可视化1](assets/vis1.jpg)
![可视化2](assets/vis2.jpg)
## 1.如何运行
### 1.1 我的运行环境
* Windows 10
* CUDA 9.x (可选)
* Anaconda 3 numpy、os、datetime、matplotlib
* pytorch == 0.4.1 or 1.0
* torchvision == 0.2.1
* visdom == 0.1.8.5
* OpenCV-Python == 3.4.1
### 1.2 具体操作
* 打开终端,输入
```sh
python -m visdom.server
```
* 打开另一终端,输入
```sh
python train.py
```
* 若没有问题可以打开浏览器输入`http://localhost:8097/`来使用`visdom`可视化
### 1.3 训练细节
![训练细节](assets/train.jpg)
## 2. 数据集
* training data来自[这里](https://github.com/yunlongdong/FCN-pytorch-easiest/tree/master/last)ground-truth来自[这里](https://github.com/yunlongdong/FCN-pytorch-easiest/tree/master/last_msk)。
* 链接中提供的图片中部分ground-truth的有误而且部分有ground-truth的图片没有对应training data的图片将这些有错误的图片分别剔除重新编号排序之后剩余533张图片。
* 之后我随机选取了67张图片**旋转180度**一共在training data和ground-truth分别凑够600张图片0.jpg ~ 599.jpg
## 3. 可视化
* train prediction训练时模型的输出
* labelground-truth
* test prediction预测时模型的输出每次训练都会预测但预测数据不参与训练与backprop
* train iter loss训练时每一批batch的loss情况
* test iter loss测试时每一批batch的loss情况
## 4. 包含文件
### 4.1 [train.py](train.py)
* 训练网络与可视化
* 主函数
### 4.2 [FCN.py](FCN.py)
* FCN32s、FCN16s、FCN8s、FCNs网络定义
* VGGNet网络定义、VGG不同种类网络参数、构建VGG网络的函数
### 4.3 [BagData.py](BagData.py)
* 定义方便PyTorch读取数据的Dataset和DataLoader
* 定义数据的变换transform
### 4.4 [onehot.py](onehot.py)
* 图片的onehot编码

3
checkpoints/README.md Normal file
View File

@ -0,0 +1,3 @@
# checkpoints
The model`s checkpoints create in this folder.

33
license.md Normal file
View File

@ -0,0 +1,33 @@
# 版权信息 / License
除非额外说明,本仓库的所有公开文档均遵循[署名-非商业性使用-相同方式共享 4.0 国际 (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.zh)许可协议。
您可以自由地:
* 共享 — 在任何媒介以任何形式复制、发行本作品
* 演绎 — 修改、转换或以本作品为基础进行创作
惟须遵守下列条件:
* 署名 — 您必须给出适当的署名,提供指向本许可协议的链接,同时标明是否(对原始作品)作了修改。您可以用任何合理的方式来署名,但是不得以任何方式暗示许可人为您或您的使用背书。
* 非商业性使用 — 您不得将本作品用于商业目的。
* 相同方式共享 — 如果您再混合、转换或者基于本作品进行创作,您必须基于与原先许可协议相同的许可协议 分发您贡献的作品。
---
Unless otherwise noted, all public documents in this repository are subject to [Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en) license.
You are free to:
* Share — copy and redistribute the material in any medium or format
* Adapt — remix, transform, and build upon the material
Under the following terms:
* Attribution — You must give appropriate credit, provide a link to the license, and indicate if changes were made. You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use.
* NonCommercial — You may not use the material for commercial purposes.
* ShareAlike — If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original.

8
onehot.py Normal file
View File

@ -0,0 +1,8 @@
import numpy as np
def onehot(data, n):
buf = np.zeros(data.shape + (n, ))
nmsk = np.arange(data.size)*n + data.ravel()
buf.ravel()[nmsk-1] = 1
return buf

114
train.py Normal file
View File

@ -0,0 +1,114 @@
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import visdom
from BagData import test_dataloader, train_dataloader
from FCN import FCN8s, FCN16s, FCN32s, FCNs, VGGNet
def train(epo_num=50, show_vgg_params=False):
vis = visdom.Visdom() # 可视化工具
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 使用gpu或cpu若安装cuda默认使用gpu
"""调用FCN.py中不同的vgg网络自选一个,进行定量实验分析"""
vgg_model = VGGNet(requires_grad=True, show_params=show_vgg_params) # backbone网络用于特征提取
"""调用FCN.py中不同的FCN网络自选一个,进行定量实验分析"""
fcn_model = FCNs(pretrained_net=vgg_model, n_class=2) # Architecture结构用于分割
fcn_model = fcn_model.to(device) # 将模型加载至指定gpu
"""更改损失函数为交叉熵nn.CELoss(), 进行定量实验分析"""
criterion = nn.BCELoss().to(device)
"""更改模型优化器adam等自选一个并调整学习率lr和momentum参数,进行定量实验分析"""
optimizer = optim.SGD(fcn_model.parameters(), lr=1e-2, momentum=0.7)
all_train_iter_loss = [] # 训练loss存储list
all_test_iter_loss = [] # 测试loss存储list
# start timing
prev_time = datetime.now() # 训练起始时间记录
for epo in range(epo_num): # epo_num数据集训练次数
train_loss = 0 #初始化loss
fcn_model.train() # 打开训练模式
for index, (bag, bag_msk) in enumerate(train_dataloader): # 加载数据bag为数据样本bag_msk为对应的二进制标签
# bag.shape is torch.Size([4, 3, 160, 160])
# bag_msk.shape is torch.Size([4, 2, 160, 160])
bag = bag.to(device) #数据加载至device上
bag_msk = bag_msk.to(device) #标签加载至device上
optimizer.zero_grad() #优化器梯度清空
output = fcn_model(bag) # 模型forward
output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])
loss = criterion(output, bag_msk) # 计算模型预测与标签之间的损失
loss.backward() # 反向传播
iter_loss = loss.item()
all_train_iter_loss.append(iter_loss) #loss添加至存储的list里
train_loss += iter_loss #训练集loss累加
optimizer.step() # 模型参数优化更新
output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160)
output_np = np.argmin(output_np, axis=1)
bag_msk_np = bag_msk.cpu().detach().numpy().copy() # bag_msk_np.shape = (4, 2, 160, 160)
bag_msk_np = np.argmin(bag_msk_np, axis=1)
if np.mod(index, 15) == 0:
print('epoch {}, {}/{},train loss is {}'.format(epo, index, len(train_dataloader), iter_loss))
# vis.close()
"""结果可视化"""
vis.images(output_np[:, None, :, :], win='train_pred', opts=dict(title='train prediction'))
vis.images(bag_msk_np[:, None, :, :], win='train_label', opts=dict(title='label'))
vis.line(all_train_iter_loss, win='train_iter_loss',opts=dict(title='train iter loss'))
test_loss = 0
fcn_model.eval() # 模型打开测试模式
with torch.no_grad(): # 测试阶段取消梯度
for index, (bag, bag_msk) in enumerate(test_dataloader):
bag = bag.to(device)
bag_msk = bag_msk.to(device)
optimizer.zero_grad()
output = fcn_model(bag)
output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])
loss = criterion(output, bag_msk)
iter_loss = loss.item()
all_test_iter_loss.append(iter_loss)
test_loss += iter_loss
output_np = output.cpu().detach().numpy().copy() # output_np.shape = (4, 2, 160, 160)
output_np = np.argmin(output_np, axis=1)
bag_msk_np = bag_msk.cpu().detach().numpy().copy() # bag_msk_np.shape = (4, 2, 160, 160)
bag_msk_np = np.argmin(bag_msk_np, axis=1)
if np.mod(index, 15) == 0:
print(r'Testing... Open http://localhost:8097/ to see test result.')
# vis.close()
vis.images(output_np[:, None, :, :], win='test_pred', opts=dict(title='test prediction'))
vis.images(bag_msk_np[:, None, :, :], win='test_label', opts=dict(title='label'))
vis.line(all_test_iter_loss, win='test_iter_loss', opts=dict(title='test iter loss'))
"""显示模型运行时间"""
cur_time = datetime.now()
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = "Time %02d:%02d:%02d" % (h, m, s)
prev_time = cur_time
print('epoch train loss = %f, epoch test loss = %f, %s'
%(train_loss/len(train_dataloader), test_loss/len(test_dataloader), time_str))
"""每5个epoch保存一次模型到checkpoints路径"""
if np.mod(epo, 5) == 0:
torch.save(fcn_model, 'checkpoints/fcn_model_{}.pt'.format(epo))
print('saveing checkpoints/fcn_model_{}.pt'.format(epo))
if __name__ == "__main__":
train(epo_num=100, show_vgg_params=False) #调用训练函数开始训练训练epoch数为100