| import operator |
| import threading |
| import time |
| from functools import reduce |
| |
| import torch |
| import torch.distributed.rpc as rpc |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| from torch.distributions import Categorical |
| |
| |
| OBSERVER_NAME = "observer{}" |
| |
| |
| class Policy(nn.Module): |
| def __init__(self, in_features, nlayers, out_features): |
| r""" |
| Inits policy class |
| Args: |
| in_features (int): Number of input features the model takes |
| nlayers (int): Number of layers in the model |
| out_features (int): Number of features the model outputs |
| """ |
| super().__init__() |
| |
| self.model = nn.Sequential( |
| nn.Flatten(1, -1), |
| nn.Linear(in_features, out_features), |
| *[nn.Linear(out_features, out_features) for _ in range(nlayers)], |
| ) |
| self.dim = 0 |
| |
| def forward(self, x): |
| action_scores = self.model(x) |
| return F.softmax(action_scores, dim=self.dim) |
| |
| |
| class AgentBase: |
| def __init__(self): |
| r""" |
| Inits agent class |
| """ |
| self.id = rpc.get_worker_info().id |
| self.running_reward = 0 |
| self.eps = 1e-7 |
| |
| self.rewards = {} |
| |
| self.future_actions = torch.futures.Future() |
| self.lock = threading.Lock() |
| |
| self.agent_latency_start = None |
| self.agent_latency_end = None |
| self.agent_latency = [] |
| self.agent_throughput = [] |
| |
| def reset_metrics(self): |
| r""" |
| Sets all benchmark metrics to their empty values |
| """ |
| self.agent_latency_start = None |
| self.agent_latency_end = None |
| self.agent_latency = [] |
| self.agent_throughput = [] |
| |
| def set_world(self, batch_size, state_size, nlayers, out_features, batch=True): |
| r""" |
| Further initializes agent to be aware of rpc environment |
| Args: |
| batch_size (int): size of batches of observer requests to process |
| state_size (list): List of ints dictating the dimensions of the state |
| nlayers (int): Number of layers in the model |
| out_features (int): Number of out features in the model |
| batch (bool): Whether to process and respond to observer requests as a batch or 1 at a time |
| """ |
| self.batch = batch |
| self.policy = Policy(reduce(operator.mul, state_size), nlayers, out_features) |
| self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2) |
| |
| self.batch_size = batch_size |
| for rank in range(batch_size): |
| ob_info = rpc.get_worker_info(OBSERVER_NAME.format(rank + 2)) |
| |
| self.rewards[ob_info.id] = [] |
| |
| self.saved_log_probs = ( |
| [] if self.batch else {k: [] for k in range(self.batch_size)} |
| ) |
| |
| self.pending_states = self.batch_size |
| self.state_size = state_size |
| self.states = torch.zeros(self.batch_size, *state_size) |
| |
| @staticmethod |
| @rpc.functions.async_execution |
| def select_action_batch(agent_rref, observer_id, state): |
| r""" |
| Receives state from an observer to select action for. Queues the observers's request |
| for an action until queue size equals batch size named during Agent initiation, at which point |
| actions are selected for all pending observer requests and communicated back to observers |
| Args: |
| agent_rref (RRef): RRFef of this agent |
| observer_id (int): Observer id of observer calling this function |
| state (Tensor): Tensor representing current state held by observer |
| """ |
| self = agent_rref.local_value() |
| observer_id -= 2 |
| |
| self.states[observer_id].copy_(state) |
| future_action = self.future_actions.then( |
| lambda future_actions: future_actions.wait()[observer_id].item() |
| ) |
| |
| with self.lock: |
| if self.pending_states == self.batch_size: |
| self.agent_latency_start = time.time() |
| self.pending_states -= 1 |
| if self.pending_states == 0: |
| self.pending_states = self.batch_size |
| probs = self.policy(self.states) |
| m = Categorical(probs) |
| actions = m.sample() |
| self.saved_log_probs.append(m.log_prob(actions).t()) |
| future_actions = self.future_actions |
| self.future_actions = torch.futures.Future() |
| future_actions.set_result(actions) |
| |
| self.agent_latency_end = time.time() |
| |
| batch_latency = self.agent_latency_end - self.agent_latency_start |
| self.agent_latency.append(batch_latency) |
| self.agent_throughput.append(self.batch_size / batch_latency) |
| |
| return future_action |
| |
| @staticmethod |
| def select_action_non_batch(agent_rref, observer_id, state): |
| r""" |
| Select actions based on observer state and communicates back to observer |
| Args: |
| agent_rref (RRef): RRef of this agent |
| observer_id (int): Observer id of observer calling this function |
| state (Tensor): Tensor representing current state held by observer |
| """ |
| self = agent_rref.local_value() |
| observer_id -= 2 |
| agent_latency_start = time.time() |
| |
| state = state.float().unsqueeze(0) |
| probs = self.policy(state) |
| m = Categorical(probs) |
| action = m.sample() |
| self.saved_log_probs[observer_id].append(m.log_prob(action)) |
| |
| agent_latency_end = time.time() |
| non_batch_latency = agent_latency_end - agent_latency_start |
| self.agent_latency.append(non_batch_latency) |
| self.agent_throughput.append(1 / non_batch_latency) |
| |
| return action.item() |
| |
| def finish_episode(self, rets): |
| r""" |
| Finishes the episode |
| Args: |
| rets (list): List containing rewards generated by selct action calls during |
| episode run |
| """ |
| return self.agent_latency, self.agent_throughput |