128 lines
5.7 KiB
Python
128 lines
5.7 KiB
Python
from datetime import datetime
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import visdom
|
|
from BagData import test_dataloader, train_dataloader
|
|
from FCN import FCN8s, FCN16s, FCN32s, FCNs, VGGNet
|
|
|
|
def train(model_name='FCN8s', epo_num=50, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=1e-2, momentum=0.7):
|
|
vis = visdom.Visdom()
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# 选择不同的VGGNet和FCN模型
|
|
vgg_model_options = {'VGG16': VGGNet(requires_grad=True, show_params=show_vgg_params),
|
|
'VGG19': VGGNet(model='vgg19', requires_grad=True, show_params=show_vgg_params)}
|
|
fcn_model_options = {'FCN8s': FCN8s(pretrained_net=vgg_model_options['VGG16'], n_class=2),
|
|
'FCN16s': FCN16s(pretrained_net=vgg_model_options['VGG16'], n_class=2),
|
|
'FCN32s': FCN32s(pretrained_net=vgg_model_options['VGG16'], n_class=2)}
|
|
|
|
# 根据输入选择模型
|
|
if model_name in fcn_model_options:
|
|
fcn_model = fcn_model_options[model_name]
|
|
else:
|
|
raise ValueError(f"Invalid model name: {model_name}")
|
|
fcn_model = fcn_model.to(device)
|
|
|
|
# 选择不同的损失函数
|
|
if loss_func == 'bce':
|
|
criterion = nn.BCELoss().to(device)
|
|
elif loss_func == 'ce':
|
|
criterion = nn.CrossEntropyLoss().to(device)
|
|
|
|
# 选择不同的优化器和学习率、动量
|
|
if optimizer == 'sgd':
|
|
opt = optim.SGD(fcn_model.parameters(), lr=lr, momentum=momentum)
|
|
elif optimizer == 'adam':
|
|
opt = optim.Adam(fcn_model.parameters(), lr=lr)
|
|
|
|
all_train_iter_loss = []
|
|
all_test_iter_loss = []
|
|
|
|
prev_time = datetime.now()
|
|
for epo in range(epo_num):
|
|
train_loss = 0
|
|
fcn_model.train()
|
|
for index, (bag, bag_msk) in enumerate(train_dataloader):
|
|
bag = bag.to(device)
|
|
bag_msk = bag_msk.to(device)
|
|
|
|
opt.zero_grad()
|
|
output = fcn_model(bag)
|
|
output = torch.sigmoid(output)
|
|
loss = criterion(output, bag_msk)
|
|
loss.backward()
|
|
iter_loss = loss.item()
|
|
all_train_iter_loss.append(iter_loss)
|
|
train_loss += iter_loss
|
|
opt.step()
|
|
|
|
output_np = output.cpu().detach().numpy().copy()
|
|
output_np = np.argmin(output_np, axis=1)
|
|
bag_msk_np = bag_msk.cpu().detach().numpy().copy()
|
|
bag_msk_np = np.argmin(bag_msk_np, axis=1)
|
|
|
|
if np.mod(index, 15) == 0:
|
|
# print('epoch {}, {}/{},train loss is {}'.format(epo, index, len(train_dataloader), iter_loss))
|
|
vis.images(output_np[:, None, :, :], win='train_pred', opts=dict(title='train prediction'))
|
|
vis.images(bag_msk_np[:, None, :, :], win='train_label', opts=dict(title='label'))
|
|
vis.line(all_train_iter_loss, win='train_iter_loss', opts=dict(title='train iter loss'))
|
|
|
|
test_loss = 0
|
|
fcn_model.eval()
|
|
with torch.no_grad():
|
|
for index, (bag, bag_msk) in enumerate(test_dataloader):
|
|
bag = bag.to(device)
|
|
bag_msk = bag_msk.to(device)
|
|
|
|
opt.zero_grad()
|
|
output = fcn_model(bag)
|
|
output = torch.sigmoid(output)
|
|
loss = criterion(output, bag_msk)
|
|
iter_loss = loss.item()
|
|
all_test_iter_loss.append(iter_loss)
|
|
test_loss += iter_loss
|
|
|
|
output_np = output.cpu().detach().numpy().copy()
|
|
output_np = np.argmin(output_np, axis=1)
|
|
bag_msk_np = bag_msk.cpu().detach().numpy().copy()
|
|
bag_msk_np = np.argmin(bag_msk_np, axis=1)
|
|
|
|
if np.mod(index, 15) == 0:
|
|
# print(r'Testing... Open http://localhost:8097/ to see test result.')
|
|
vis.images(output_np[:, None, :, :], win='test_pred', opts=dict(title='test prediction'))
|
|
vis.images(bag_msk_np[:, None, :, :], win='test_label', opts=dict(title='label'))
|
|
vis.line(all_test_iter_loss, win='test_iter_loss', opts=dict(title='test iter loss'))
|
|
|
|
cur_time = datetime.now()
|
|
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
|
|
m, s = divmod(remainder, 60)
|
|
time_str = "Time %02d:%02d:%02d" % (h, m, s)
|
|
prev_time = cur_time
|
|
|
|
print('epoch train loss = %f, epoch test loss = %f, %s'
|
|
% (train_loss / len(train_dataloader), test_loss / len(test_dataloader), time_str))
|
|
|
|
# 每10个训练轮次保存一次检查点
|
|
if np.mod(epo, 10) == 0:
|
|
checkpoint_name = f'checkpoints/fcn_model_{model_name}_ep{epo}_lr{lr}_mom{momentum}.pt'
|
|
torch.save(fcn_model, checkpoint_name)
|
|
print(f'Saving checkpoint: {checkpoint_name}')
|
|
|
|
if __name__ == "__main__":
|
|
# 调用训练函数,并修改不同的参数
|
|
|
|
## 1
|
|
#train(model_name='FCN8s', epo_num=11, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.02, momentum=0.7)
|
|
|
|
|
|
## 2
|
|
# train(model_name='FCN8s', epo_num=11, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.01, momentum=0.7)
|
|
# train(model_name='FCN8s', epo_num=11, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.02, momentum=0.7)
|
|
# train(model_name='FCN8s', epo_num=11, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.03, momentum=0.7)
|
|
train(model_name='FCN8s', epo_num=11, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.05, momentum=0.7)
|
|
|
|
#train(model_name='FCN16s', epo_num=11, show_vgg_params=False, loss_func='ce', optimizer='adam', lr=1e-3)
|
|
#train(model_name='FCN32s', epo_num=100, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.01, momentum=0.5) |