if __name__ == '__main__':
conf = LinXiaoNetConfig() conf.set_cuda(True) conf.set_input_shape(8, 8) conf.set_train_info(5, 16, 1e-2) conf.set_checkpoint_config(5, 'checkpoints/v2train') conf.set_num_worker(0) conf.set_log('log/v2train.log') # conf.set_pretrained_path('checkpoints/v2m4000/epoch_15')
init_logger(conf.log_file) logger()(conf)
device = 'cuda' if conf.use_cuda else 'cpu'
# 创建策略网络 model = LinXiaoNet(3) model.to(device)
loss_func = AlphaLoss() loss_func.to(device)
optimizer = torch.optim.SGD(model.parameters(), conf.init_lr, 0.9, weight_decay=5e-4) lr_schedule = torch.optim.lr_scheduler.StepLR(optimizer, 1, 0.95)
# initial config tree tree = MonteTree(model, device, chess_size=conf.input_shape[0], simulate_count=500) data_cache = TrainDataCache(num_worker=conf.num_worker)
ep_num = 0 chess_num = 0 # config train interval train_every_chess = 18
# 加载检查点 if conf.pretrain_path is not None: model_data, optimizer_data, lr_schedule_data, data_cache, ep_num, chess_num = load_checkpoint(conf.pretrain_path) model.load_state_dict(model_data) optimizer.load_state_dict(optimizer_data) lr_schedule.load_state_dict(lr_schedule_data) logger()('successfully load pretrained : {}'.format(conf.pretrain_path))
while True: logger()(f'self chess game no.{chess_num+1} start.') # 进行一次自我对弈,获取对弈记录 chess_record = tree.self_game() logger()(f'self chess game no.{chess_num+1} end.') # 根据对弈记录生成训练数据 train_data = generate_train_data(tree.chess_size, chess_record) # 将训练数据存入缓存 for i in range(len(train_data)): data_cache.push(train_data[i]) if chess_num % train_every_chess == 0: logger()(f'train start.') loader = data_cache.get_loader(conf.batch_size) model.train() for _ in range(conf.epoch_num): loss_record = [] for bat_state, bat_dist, bat_winner in loader: bat_state, bat_dist, bat_winner = bat_state.to(device), bat_dist.to(device), bat_winner.to(device) optimizer.zero_grad() prob, value = model(bat_state) loss = loss_func(prob, value, bat_dist, bat_winner) loss.backward() optimizer.step() loss_record.append(loss.item()) logger()(f'train epoch {ep_num} loss: {sum(loss_record) / float(len(loss_record))}') ep_num += 1 if ep_num % conf.checkpoint_save_every_num == 0: save_checkpoint( os.path.join(conf.checkpoint_save_dir, f'epoch_{ep_num}'), ep_num, chess_num, model.state_dict(), optimizer.state_dict(), lr_schedule.state_dict(), data_cache ) lr_schedule.step() logger()(f'train end.') chess_num += 1 save_chess_record( os.path.join(conf.checkpoint_save_dir, f'chess_record_{chess_num}.pkl'), chess_record ) # break
pass
评论