from datetime import datetime import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.optim as optim import visdom from BagData import test_dataloader, train_dataloader from FCN import FCN8s, FCN16s, FCN32s, FCNs, VGGNet def train(model_name='FCN8s', epo_num=50, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=1e-2, momentum=0.7): vis = visdom.Visdom() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 选择不同的VGGNet和FCN模型 vgg_model_options = {'VGG16': VGGNet(requires_grad=True, show_params=show_vgg_params), 'VGG19': VGGNet(model='vgg19', requires_grad=True, show_params=show_vgg_params)} fcn_model_options = {'FCN8s': FCN8s(pretrained_net=vgg_model_options['VGG16'], n_class=2), 'FCN16s': FCN16s(pretrained_net=vgg_model_options['VGG16'], n_class=2), 'FCN32s': FCN32s(pretrained_net=vgg_model_options['VGG16'], n_class=2)} # 根据输入选择模型 if model_name in fcn_model_options: fcn_model = fcn_model_options[model_name] else: raise ValueError(f"Invalid model name: {model_name}") fcn_model = fcn_model.to(device) # 选择不同的损失函数 if loss_func == 'bce': criterion = nn.BCELoss().to(device) elif loss_func == 'ce': criterion = nn.CrossEntropyLoss().to(device) # 选择不同的优化器和学习率、动量 if optimizer == 'sgd': opt = optim.SGD(fcn_model.parameters(), lr=lr, momentum=momentum) elif optimizer == 'adam': opt = optim.Adam(fcn_model.parameters(), lr=lr) all_train_iter_loss = [] all_test_iter_loss = [] prev_time = datetime.now() for epo in range(epo_num): train_loss = 0 fcn_model.train() for index, (bag, bag_msk) in enumerate(train_dataloader): bag = bag.to(device) bag_msk = bag_msk.to(device) opt.zero_grad() output = fcn_model(bag) output = torch.sigmoid(output) loss = criterion(output, bag_msk) loss.backward() iter_loss = loss.item() all_train_iter_loss.append(iter_loss) train_loss += iter_loss opt.step() output_np = output.cpu().detach().numpy().copy() output_np = np.argmin(output_np, axis=1) bag_msk_np = bag_msk.cpu().detach().numpy().copy() bag_msk_np = np.argmin(bag_msk_np, axis=1) if np.mod(index, 15) == 0: # print('epoch {}, {}/{},train loss is {}'.format(epo, index, len(train_dataloader), iter_loss)) vis.images(output_np[:, None, :, :], win='train_pred', opts=dict(title='train prediction')) vis.images(bag_msk_np[:, None, :, :], win='train_label', opts=dict(title='label')) vis.line(all_train_iter_loss, win='train_iter_loss', opts=dict(title='train iter loss')) test_loss = 0 fcn_model.eval() with torch.no_grad(): for index, (bag, bag_msk) in enumerate(test_dataloader): bag = bag.to(device) bag_msk = bag_msk.to(device) opt.zero_grad() output = fcn_model(bag) output = torch.sigmoid(output) loss = criterion(output, bag_msk) iter_loss = loss.item() all_test_iter_loss.append(iter_loss) test_loss += iter_loss output_np = output.cpu().detach().numpy().copy() output_np = np.argmin(output_np, axis=1) bag_msk_np = bag_msk.cpu().detach().numpy().copy() bag_msk_np = np.argmin(bag_msk_np, axis=1) if np.mod(index, 15) == 0: # print(r'Testing... Open http://localhost:8097/ to see test result.') vis.images(output_np[:, None, :, :], win='test_pred', opts=dict(title='test prediction')) vis.images(bag_msk_np[:, None, :, :], win='test_label', opts=dict(title='label')) vis.line(all_test_iter_loss, win='test_iter_loss', opts=dict(title='test iter loss')) cur_time = datetime.now() h, remainder = divmod((cur_time - prev_time).seconds, 3600) m, s = divmod(remainder, 60) time_str = "Time %02d:%02d:%02d" % (h, m, s) prev_time = cur_time print('epoch train loss = %f, epoch test loss = %f, %s' % (train_loss / len(train_dataloader), test_loss / len(test_dataloader), time_str)) # 每10个训练轮次保存一次检查点 if np.mod(epo, 10) == 0: checkpoint_name = f'checkpoints/fcn_model_{model_name}_ep{epo}_lr{lr}_mom{momentum}.pt' torch.save(fcn_model, checkpoint_name) print(f'Saving checkpoint: {checkpoint_name}') if __name__ == "__main__": # 调用训练函数,并修改不同的参数 ## 1 #train(model_name='FCN8s', epo_num=11, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.02, momentum=0.7) ## 2 # train(model_name='FCN8s', epo_num=11, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.01, momentum=0.7) # train(model_name='FCN8s', epo_num=11, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.02, momentum=0.7) # train(model_name='FCN8s', epo_num=11, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.03, momentum=0.7) train(model_name='FCN8s', epo_num=11, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.05, momentum=0.7) #train(model_name='FCN16s', epo_num=11, show_vgg_params=False, loss_func='ce', optimizer='adam', lr=1e-3) #train(model_name='FCN32s', epo_num=100, show_vgg_params=False, loss_func='bce', optimizer='sgd', lr=0.01, momentum=0.5)