FCN-test/train.py

128 lines
5.7 KiB
Python

from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import visdom
from BagData import test_dataloader, train_dataloader
from FCN import FCN8s, FCN16s, FCN32s, FCNs, VGGNet
def train(model_name='FCN8s', epo_num=50, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=1e-2, momentum=0.7):
vis = visdom.Visdom()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 选择不同的VGGNet和FCN模型
vgg_model_options = {'VGG16': VGGNet(requires_grad=True, show_params=show_vgg_params),
'VGG19': VGGNet(model='vgg19', requires_grad=True, show_params=show_vgg_params)}
fcn_model_options = {'FCN8s': FCN8s(pretrained_net=vgg_model_options['VGG16'], n_class=2),
'FCN16s': FCN16s(pretrained_net=vgg_model_options['VGG16'], n_class=2),
'FCN32s': FCN32s(pretrained_net=vgg_model_options['VGG16'], n_class=2)}
# 根据输入选择模型
if model_name in fcn_model_options:
fcn_model = fcn_model_options[model_name]
else:
raise ValueError(f"Invalid model name: {model_name}")
fcn_model = fcn_model.to(device)
# 选择不同的损失函数
if loss_func == 'bce':
criterion = nn.BCELoss().to(device)
elif loss_func == 'ce':
criterion = nn.CrossEntropyLoss().to(device)
# 选择不同的优化器和学习率、动量
if optimizer == 'sgd':
opt = optim.SGD(fcn_model.parameters(), lr=lr, momentum=momentum)
elif optimizer == 'adam':
opt = optim.Adam(fcn_model.parameters(), lr=lr)
all_train_iter_loss = []
all_test_iter_loss = []
prev_time = datetime.now()
for epo in range(epo_num):
train_loss = 0
fcn_model.train()
for index, (bag, bag_msk) in enumerate(train_dataloader):
bag = bag.to(device)
bag_msk = bag_msk.to(device)
opt.zero_grad()
output = fcn_model(bag)
output = torch.sigmoid(output)
loss = criterion(output, bag_msk)
loss.backward()
iter_loss = loss.item()
all_train_iter_loss.append(iter_loss)
train_loss += iter_loss
opt.step()
output_np = output.cpu().detach().numpy().copy()
output_np = np.argmin(output_np, axis=1)
bag_msk_np = bag_msk.cpu().detach().numpy().copy()
bag_msk_np = np.argmin(bag_msk_np, axis=1)
if np.mod(index, 15) == 0:
# print('epoch {}, {}/{},train loss is {}'.format(epo, index, len(train_dataloader), iter_loss))
vis.images(output_np[:, None, :, :], win='train_pred', opts=dict(title='train prediction'))
vis.images(bag_msk_np[:, None, :, :], win='train_label', opts=dict(title='label'))
vis.line(all_train_iter_loss, win='train_iter_loss', opts=dict(title='train iter loss'))
test_loss = 0
fcn_model.eval()
with torch.no_grad():
for index, (bag, bag_msk) in enumerate(test_dataloader):
bag = bag.to(device)
bag_msk = bag_msk.to(device)
opt.zero_grad()
output = fcn_model(bag)
output = torch.sigmoid(output)
loss = criterion(output, bag_msk)
iter_loss = loss.item()
all_test_iter_loss.append(iter_loss)
test_loss += iter_loss
output_np = output.cpu().detach().numpy().copy()
output_np = np.argmin(output_np, axis=1)
bag_msk_np = bag_msk.cpu().detach().numpy().copy()
bag_msk_np = np.argmin(bag_msk_np, axis=1)
if np.mod(index, 15) == 0:
# print(r'Testing... Open http://localhost:8097/ to see test result.')
vis.images(output_np[:, None, :, :], win='test_pred', opts=dict(title='test prediction'))
vis.images(bag_msk_np[:, None, :, :], win='test_label', opts=dict(title='label'))
vis.line(all_test_iter_loss, win='test_iter_loss', opts=dict(title='test iter loss'))
cur_time = datetime.now()
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = "Time %02d:%02d:%02d" % (h, m, s)
prev_time = cur_time
print('epoch train loss = %f, epoch test loss = %f, %s'
% (train_loss / len(train_dataloader), test_loss / len(test_dataloader), time_str))
# 每10个训练轮次保存一次检查点
if np.mod(epo, 10) == 0:
checkpoint_name = f'checkpoints/fcn_model_{model_name}_ep{epo}_lr{lr}_mom{momentum}.pt'
torch.save(fcn_model, checkpoint_name)
print(f'Saving checkpoint: {checkpoint_name}')
if __name__ == "__main__":
# 调用训练函数,并修改不同的参数
## 1
#train(model_name='FCN8s', epo_num=11, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.02, momentum=0.7)
## 2
# train(model_name='FCN8s', epo_num=11, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.01, momentum=0.7)
# train(model_name='FCN8s', epo_num=11, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.02, momentum=0.7)
# train(model_name='FCN8s', epo_num=11, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.03, momentum=0.7)
train(model_name='FCN8s', epo_num=11, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.05, momentum=0.7)
#train(model_name='FCN16s', epo_num=11, show_vgg_params=False, loss_func='ce', optimizer='adam', lr=1e-3)
#train(model_name='FCN32s', epo_num=100, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.01, momentum=0.5)