blob: a82be1cf29016718561759a85fbc57924e5c38c4 [file] [log] [blame]
import operator
import threading
import time
from functools import reduce
import torch
import torch.distributed.rpc as rpc
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
OBSERVER_NAME = "observer{}"
class Policy(nn.Module):
def __init__(self, in_features, nlayers, out_features):
r"""
Inits policy class
Args:
in_features (int): Number of input features the model takes
nlayers (int): Number of layers in the model
out_features (int): Number of features the model outputs
"""
super().__init__()
self.model = nn.Sequential(
nn.Flatten(1, -1),
nn.Linear(in_features, out_features),
*[nn.Linear(out_features, out_features) for _ in range(nlayers)],
)
self.dim = 0
def forward(self, x):
action_scores = self.model(x)
return F.softmax(action_scores, dim=self.dim)
class AgentBase:
def __init__(self):
r"""
Inits agent class
"""
self.id = rpc.get_worker_info().id
self.running_reward = 0
self.eps = 1e-7
self.rewards = {}
self.future_actions = torch.futures.Future()
self.lock = threading.Lock()
self.agent_latency_start = None
self.agent_latency_end = None
self.agent_latency = []
self.agent_throughput = []
def reset_metrics(self):
r"""
Sets all benchmark metrics to their empty values
"""
self.agent_latency_start = None
self.agent_latency_end = None
self.agent_latency = []
self.agent_throughput = []
def set_world(self, batch_size, state_size, nlayers, out_features, batch=True):
r"""
Further initializes agent to be aware of rpc environment
Args:
batch_size (int): size of batches of observer requests to process
state_size (list): List of ints dictating the dimensions of the state
nlayers (int): Number of layers in the model
out_features (int): Number of out features in the model
batch (bool): Whether to process and respond to observer requests as a batch or 1 at a time
"""
self.batch = batch
self.policy = Policy(reduce(operator.mul, state_size), nlayers, out_features)
self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
self.batch_size = batch_size
for rank in range(batch_size):
ob_info = rpc.get_worker_info(OBSERVER_NAME.format(rank + 2))
self.rewards[ob_info.id] = []
self.saved_log_probs = (
[] if self.batch else {k: [] for k in range(self.batch_size)}
)
self.pending_states = self.batch_size
self.state_size = state_size
self.states = torch.zeros(self.batch_size, *state_size)
@staticmethod
@rpc.functions.async_execution
def select_action_batch(agent_rref, observer_id, state):
r"""
Receives state from an observer to select action for. Queues the observers's request
for an action until queue size equals batch size named during Agent initiation, at which point
actions are selected for all pending observer requests and communicated back to observers
Args:
agent_rref (RRef): RRFef of this agent
observer_id (int): Observer id of observer calling this function
state (Tensor): Tensor representing current state held by observer
"""
self = agent_rref.local_value()
observer_id -= 2
self.states[observer_id].copy_(state)
future_action = self.future_actions.then(
lambda future_actions: future_actions.wait()[observer_id].item()
)
with self.lock:
if self.pending_states == self.batch_size:
self.agent_latency_start = time.time()
self.pending_states -= 1
if self.pending_states == 0:
self.pending_states = self.batch_size
probs = self.policy(self.states)
m = Categorical(probs)
actions = m.sample()
self.saved_log_probs.append(m.log_prob(actions).t())
future_actions = self.future_actions
self.future_actions = torch.futures.Future()
future_actions.set_result(actions)
self.agent_latency_end = time.time()
batch_latency = self.agent_latency_end - self.agent_latency_start
self.agent_latency.append(batch_latency)
self.agent_throughput.append(self.batch_size / batch_latency)
return future_action
@staticmethod
def select_action_non_batch(agent_rref, observer_id, state):
r"""
Select actions based on observer state and communicates back to observer
Args:
agent_rref (RRef): RRef of this agent
observer_id (int): Observer id of observer calling this function
state (Tensor): Tensor representing current state held by observer
"""
self = agent_rref.local_value()
observer_id -= 2
agent_latency_start = time.time()
state = state.float().unsqueeze(0)
probs = self.policy(state)
m = Categorical(probs)
action = m.sample()
self.saved_log_probs[observer_id].append(m.log_prob(action))
agent_latency_end = time.time()
non_batch_latency = agent_latency_end - agent_latency_start
self.agent_latency.append(non_batch_latency)
self.agent_throughput.append(1 / non_batch_latency)
return action.item()
def finish_episode(self, rets):
r"""
Finishes the episode
Args:
rets (list): List containing rewards generated by selct action calls during
episode run
"""
return self.agent_latency, self.agent_throughput