【强化学习+组合优化】SAC + PointerNetwork 解决TSP问题
TSP强化学习环境见之前的博客:https://blog.csdn.net/weixin_41369892/article/details/131519384
先上效果,跑了20个点,感觉不是很好(RL解决组合优化问题真的不好调参)
平均总距离随训练的变化,可以看出的确是逐步优化的。
代码
"Environment.py"
import warnings
import numpy as np
import matplotlib.pyplot as pltclass TSPEnvironment:"""__init__() parm: num city, coordinate_dimension, box sizestep() and reset() return: (coordinates, path, valid) -> state, reward, done"""def __init__(self, num_cities, coordinate_dimension=2, box_size=1.0):assert coordinate_dimension >= 2, "coordinate_dimension must >= 2 !"self.num_cities = num_citiesself.coordinate_dimension = coordinate_dimensionself.box_size = box_sizeself.coordinates, self.cities_coordinates, self.path, self.now_location = None, None, None, Noneself.done = Falseself.total_distance = 0.0self.__init_environment = self.Resetself.__init_environment()def reset(self, start_city=None):if start_city is not None:assert start_city < self.num_cities, "Start city must < num of city !!!"self.now_location = start_city if start_city is not None else np.random.choice(list(self.cities_coordinates.keys()))self.path = [self.now_location]self.done = Falseself.total_distance = 0.0valid = self.get_valid_cities(self.path, self.coordinates)coordinates = np.array([i for i in self.coordinates])path = [i for i in self.path]return (coordinates, path, valid), 0.0, self.donedef Reset(self, start_city=None):if start_city is not None:assert start_city < self.num_cities, "Start city must < num of city !!!"self.coordinates = np.random.rand(self.num_cities, self.coordinate_dimension) * self.box_sizeself.cities_coordinates = dict(enumerate(self.coordinates))self.now_location = start_city if start_city is not None else np.random.choice(list(self.cities_coordinates.keys()))self.path = [self.now_location]self.done = Falseself.total_distance = 0.0valid = self.get_valid_cities(self.path, self.coordinates)coordinates = np.array([i for i in self.coordinates])path = [i for i in self.path]return (coordinates, path, valid), 0.0, self.donedef step(self, action: int):if self.done:warn_msg = "The environment {} is done, please call Reset()/reset() or create new environment!".format(self)warnings.warn(warn_msg)return self.now_location, self.path, self.coordinates, self.cities_coordinates, None, self.doneelse:assert self.coordinates is not None, "No coordinates, please call 'Reset()' first!"assert self.cities_coordinates is not None, "No cities_coordinates, please call 'Reset()' first!"assert self.path is not None, "No path, please call 'Reset()/reset()' first!"assert self.now_location is not None, "No now_location, please call 'Reset()/reset()' first!"next_city = actionassert next_city < self.num_cities and next_city >= 0, "There is no city: {} !\n\t\tValid cities: {}".format(next_city, set(self.cities_coordinates.keys()))assert next_city not in self.path, "Wrong next city: {}, Can not be repeated access: {} !\n\t\tValid cities: {}.".format(next_city, set(self.path), set(self.cities_coordinates.keys()) - set(self.path))next_city_coordinate = self.cities_coordinates[next_city]now_city_coordinate = self.cities_coordinates[self.now_location]distance = self.euclidian_distance(next_city_coordinate, now_city_coordinate)reward = - distance[0]self.total_distance += distance[0]self.path.append(next_city)self.now_location = next_cityif set(self.path) == set(self.cities_coordinates.keys()): self.done = Trueif self.done:start_end_distance = self.euclidian_distance(self.cities_coordinates[self.path[0]], self.cities_coordinates[self.path[-1]])reward += - start_end_distance[0]self.total_distance += start_end_distance[0]valid = self.get_valid_cities(self.path, self.coordinates)coordinates = np.array([i for i in self.coordinates])path = [i for i in self.path]return (coordinates, path, valid), reward, self.done@staticmethoddef euclidian_distance(x, y):return np.sqrt(np.sum((x - y) ** 2, axis=-1, keepdims=True))def render(self):assert self.coordinates is not None, "No coordinates, please call reset() first!"if self.coordinate_dimension != 2:warnings.warn("Only show the first two dimension!")fig = plt.figure(figsize=(7, 7))ax = fig.add_subplot(111)ax.set_title("TSP environment")ax.scatter(self.coordinates[:, 0], self.coordinates[:, 1], c="red", s=50, marker="*")# plot start city as color bluestart_city = self.cities_coordinates[self.path[0]]text = start_city[0] + 0.1, start_city[1]ax.annotate("start city", xy=start_city[[0, 1]], xytext=text, weight="bold")ax.scatter(start_city[0], start_city[1], c="blue", marker="*", s=50)# plot path as color orange, access cities as color greenax.plot(self.coordinates[self.path, 0], self.coordinates[self.path, 1], c="orange", linewidth=1, linestyle="--")ax.scatter(self.coordinates[self.path[1:], 0], self.coordinates[self.path[1:], 1], c="green", s=50, marker="*")if self.done:end_city = self.cities_coordinates[self.path[-1]]text = end_city[0] + 0.1, end_city[1]ax.annotate("end city", xy=end_city[[0, 1]], xytext=text, weight="bold")ax.scatter(end_city[0], end_city[1], c="black", s=50, marker="*")ax.plot([start_city[0], end_city[0]], [start_city[1], end_city[1]], c="orange", linewidth=1, linestyle="--")plt.xticks([])plt.yticks([])plt.show()def get_total_distance(self):return self.total_distance@staticmethoddef get_valid_cities(path, coordinates):return (np.delete(coordinates, path, axis=0),[i for i in range(coordinates.shape[0]) if i not in path])if __name__ == '__main__':from PointerNetwork import PointerNetworkimport torchenv = TSPEnvironment(num_cities=20)nn = PointerNetwork(2, 512, n_layers=1)(coordinates, path, valid), _, _ = env.reset()env.render()coordinates = torch.from_numpy(np.array([coordinates, coordinates])).float()path = torch.from_numpy(np.array([path, path]))print(coordinates.shape)print(path.shape)opt = torch.optim.Adam(nn.parameters(), lr=0.0001)for _ in range(100):y = nn(coordinates, path)t = torch.tensor([[0] * 20] * 2).float()t[:, 0] = 0.5t[:, 1] = 0.5loss = torch.nn.functional.cross_entropy(y[0], t)opt.zero_grad()loss.backward()opt.step()print(loss.item())print(y[0])
"PointerNetwork.py"
import torch
import torch.nn as nnclass Encoder(nn.Module):def __init__(self, input_size, embedding_size, n_layers=2):super().__init__()self.embedding_layer = nn.Linear(input_size, embedding_size, bias=False)self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=embedding_size,nhead=8,dropout=0.0,batch_first=True),num_layers=n_layers)def forward(self, x):x = self.embedding_layer(x)return self.transformer(x)class Decoder(nn.Module):def __init__(self, hidden_size, n_layers=2):super().__init__()self.lstm = nn.LSTM(input_size=hidden_size,hidden_size=hidden_size,num_layers=n_layers,batch_first=True)self.mlp = nn.Sequential(nn.Linear(hidden_size, hidden_size * 4),nn.ReLU(),nn.Linear(hidden_size * 4, hidden_size))def forward(self, last_embedding, hx):x, hx = self.lstm(last_embedding, hx)return self.mlp(x), hxclass PointerNetwork(nn.Module):ratio = int(2)def __init__(self, input_size, hidden_size, n_layers=2):super().__init__()self.encoder = Encoder(input_size, hidden_size, n_layers)self.decoder = Decoder(hidden_size, n_layers)self.cell_state = nn.Linear(hidden_size, hidden_size * n_layers)self.hidden_state = nn.Linear(hidden_size, hidden_size * n_layers)self.n_layers = n_layersdef get_lstm_state(self, memory):batch_size = memory.size(0)cs = self.cell_state(memory).view(batch_size, self.n_layers, -1).permute(1, 0, 2)hs = self.hidden_state(memory).view(batch_size, self.n_layers, -1).permute(1, 0, 2)return (hs, cs)def forward(self, coordinates, path):enc_output = self.encoder(coordinates)hx = self.get_lstm_state(enc_output.sum(-2))path = path.unsqueeze(-1).expand(-1, -1, enc_output.size(-1))path_embed = torch.gather(enc_output, dim=1, index=path)dec_output, hx = self.decoder(path_embed, hx)return enc_output, dec_output
"SoftActorCritic.py"
from PointerNetwork import PointerNetworkfrom collections import dequeimport torch
import random
import torch.nn as nn
import torch.nn.functional as Fclass ReplayBuffer:def __init__(self, capacity):self.memory = deque(maxlen=capacity)def __len__(self):return len(self.memory)def save_memory(self, state, action, reward, next_state, done):self.memory.append([state, action, reward, next_state, done])def sample(self, batch_size):sample_size = min(len(self), batch_size)experiences = random.sample(self.memory, sample_size)return experiencesclass QNetwork(nn.Module):def __init__(self, hidden_size, type="0"):super().__init__()self.type = typeif type == "0":self.nn = nn.Sequential(nn.Linear(2 * hidden_size, 4 * hidden_size),nn.ReLU(),nn.Linear(4 * hidden_size, 1))elif type == "1":self.nn = nn.Sequential(nn.Linear(hidden_size, 2 * hidden_size),nn.ReLU(),nn.Linear(2 * hidden_size, 1))else: raise Exception("Invalid QNetwork type")def forward(self, x): return self.nn(x)class ActorNetwork(nn.Module):def __init__(self, hidden_size):super().__init__()self.nn = nn.Sequential(nn.Linear(2 * hidden_size, 4 * hidden_size),nn.ReLU(),nn.Linear(4 * hidden_size, 1))def forward(self, x): return self.nn(x)class SoftActorCritic:def __init__(self,input_size,hidden_size,n_layers,gamma,buffer_capacity,learning_rate,target_entropy_scaling,log_alpha,tau):self.gamma = gammaself.tau = tauself.pointer_network = PointerNetwork(input_size, hidden_size, n_layers)self.actor = ActorNetwork(hidden_size)self.q_network1 = QNetwork(hidden_size, type="0")self.target_q_network1 = QNetwork(hidden_size, type="0")self.target_q_network1.load_state_dict(self.q_network1.state_dict())self.q_network2 = QNetwork(hidden_size, type="1")self.target_q_network2 = QNetwork(hidden_size, type="1")self.target_q_network2.load_state_dict(self.q_network2.state_dict())self.optimizer = torch.optim.Adam(list(self.pointer_network.parameters())+ list(self.actor.parameters())+ list(self.q_network1.parameters())+ list(self.q_network2.parameters()),lr=learning_rate)self.log_alpha = torch.tensor(log_alpha, requires_grad=True)self.target_entropy_scaling = target_entropy_scalingself.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=learning_rate)self.replay_buffer = ReplayBuffer(buffer_capacity)def soft_update(self, target, source):for param, target_param in zip(source.parameters(), target.parameters()):target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)def update_weights(self, batch_size):batch_data = self.replay_buffer.sample(batch_size)target_values = []with torch.no_grad():alpha = self.log_alpha.exp()for i in range(len(batch_data)):state, action, reward, next_state, done = batch_data[i]coordinates, path, valid = next_stateenc_output, dec_output = self.pointer_network_forward(coordinates, path)if not done:next_pi = self.actor_network_forward(enc_output, dec_output, valid)target_q1 = self.q_network_forward(self.target_q_network1, enc_output, dec_output)target_q2 = self.q_network_forward(self.target_q_network2, enc_output, dec_output)target_q1 = target_q1.gather(1, torch.tensor(valid[-1]).unsqueeze(0))target_q2 = target_q2.gather(1, torch.tensor(valid[-1]).unsqueeze(0))if not done:min_q = torch.min((next_pi * (target_q1 - alpha * (next_pi + 1e-6).log())).sum(-1, keepdim=True),(next_pi * (target_q2 - alpha * (next_pi + 1e-6).log())).sum(-1, keepdim=True))else:min_q = 0.0td_target = torch.tensor(reward).view(-1, 1) + (1 - torch.tensor(done).float().view(-1, 1)) * min_q * self.gammatarget_values.append(td_target)td_target = torch.cat(target_values, dim=0).float()actor_losses, q1_losses, q2_losses = [], [], []for i in range(len(batch_data)):state, action, reward, next_state, done = batch_data[i]coordinates, path, valid = stateenc_output, dec_output = self.pointer_network_forward(coordinates, path)pi = self.actor_network_forward(enc_output, dec_output, valid)q1 = self.q_network_forward(self.q_network1, enc_output, dec_output)q2 = self.q_network_forward(self.q_network2, enc_output, dec_output)q1_losses.append(F.mse_loss(q1[:, action], td_target[i]))q2_losses.append(F.mse_loss(q2[:, action], td_target[i]))q_min = torch.min(q1.gather(1, torch.tensor(valid[-1]).unsqueeze(0)),q2.gather(1, torch.tensor(valid[-1]).unsqueeze(0))).detach()actor_losses.append((pi * ((alpha * (pi + 1e-6).log()) - q_min)).sum())q1_losses = sum(q1_losses) / len(q1_losses)q2_losses = sum(q2_losses) / len(q2_losses)actor_losses = sum(actor_losses) / len(actor_losses)losses = q1_losses + q2_losses + actor_lossesself.optimizer.zero_grad()losses.backward()self.optimizer.step()self.soft_update(self.target_q_network1, self.q_network1)self.soft_update(self.target_q_network2, self.q_network2)alpha_losses = []for i in range(len(batch_data)):state, action, reward, next_state, done = batch_data[i]with torch.no_grad():coordinates, path, valid = stateenc_output, dec_output = self.pointer_network_forward(coordinates, path)pi = self.actor_network_forward(enc_output, dec_output, valid)tgt_ent = self.target_entropy_scaling * (- torch.log(1 / torch.tensor(pi.shape[-1])))alpha_loss = pi * (- self.log_alpha.exp() * ((pi).log() + tgt_ent))alpha_loss = alpha_loss.sum()alpha_losses.append(alpha_loss)alpha_losses = sum(alpha_losses) / len(alpha_losses)self.alpha_optimizer.zero_grad()alpha_losses.backward()self.alpha_optimizer.step()return losses.item(), alpha_losses.item()def q_network_forward(self, q_network, enc_output, dec_output):x = dec_output[:, -1:, :]if q_network.type == "0":x = q_network(torch.cat([x.expand_as(enc_output), enc_output], dim=-1))elif q_network.type == "1":x = q_network(x + enc_output)else: raise Exception("Invalid QNetwork type")return x.squeeze(-1)def pointer_network_forward(self, coordinates, path):coordinates = torch.tensor(coordinates)path = torch.tensor(path)return self.pointer_network(coordinates.float().unsqueeze(0),path.unsqueeze(0),)def actor_network_forward(self, enc_output, dec_output, valid):x = self.actor(torch.cat([enc_output, dec_output[:, -1:].expand_as(enc_output)], dim=-1)).squeeze(-1)x = x.gather(1, torch.tensor(valid[-1]).unsqueeze(0))return x.softmax(-1)def choice_action(self, state):coordinates, path, valid = stateenc_output, dec_output = self.pointer_network_forward(coordinates, path)pi = self.actor_network_forward(enc_output, dec_output, valid)dist = torch.distributions.categorical.Categorical(pi)index = dist.sample()action = valid[1][index.item()]return action
"training.py"
from SoftActorCritic import SoftActorCritic
from Environment import TSPEnvironmentfrom collections import deque
from tqdm import trangeimport matplotlib.pyplot as plt
import numpy as np
import copydef avg_move(l, window_size=10):window = deque(maxlen=window_size)new_list = []for i in l:window.append(i)new_list.append(np.array(window).mean())return new_listif __name__ == '__main__':episodes = 1000batch_size = 128train_timesteps = 20render_ep = 100window_size = 50agent = SoftActorCritic(input_size=2,hidden_size=256,n_layers=1,gamma=0.9,buffer_capacity=int(1e6),learning_rate=1e-3,target_entropy_scaling=0.1,log_alpha=0.0,tau=0.05)env = TSPEnvironment(num_cities=20, box_size=10.)pbar = trange(1, episodes + 1)timestep = 0distances = []for episode in pbar:state, _, _ = env.reset()while True:timestep += 1if timestep % train_timesteps == 0:agent.update_weights(batch_size)action = agent.choice_action(state)next_state, reward, done = env.step(action)agent.replay_buffer.save_memory(state, action, reward, next_state, done)state = copy.deepcopy(next_state)if done: breakif episode % render_ep == 0:env.render()distances.append(env.get_total_distance())pbar.set_description("Episode {}/{}: Distances: {:.2f}, Timesteps: {}, Alpha: {:.2f}".format(episode, episodes, distances[-1], timestep, agent.log_alpha.exp().item()))plt.plot(distances)plt.plot(avg_move(distances, window_size))plt.show()