Source code for commonpower.control.parsing

from collections import OrderedDict
from typing import Union

import gymnasium as gym
import numpy as np

from commonpower.control.environments import ControlEnv


[docs] class ParserFactory: def __init__(self, env: Union[gym.Wrapper, ControlEnv]): """ Class for handling conversions of observations and actions during deployment of RL controllers. Args: env (Union[gym.Wrapper, ControlEnv]): environment (potentially wrapped) """ self.env = env self.unwrapped_env = getattr(env, "unwrapped") self.env_obs_space = getattr(env, "observation_space") self.env_action_space = getattr(env, "action_space") if not self.env_obs_space.__class__ == self.env_action_space.__class__: raise TypeError("The type of the observation space does not match the type of the action space.")
[docs] def get_parser(self): """ Implements switching logic for deciding which parser to use. """ if isinstance(self.env_obs_space, list): sample_obs = [env_obs_space.sample() for env_obs_space in self.env_obs_space] sample_act = [env_act_space.sample() for env_act_space in self.env_action_space] else: sample_obs = self.env_obs_space.sample() sample_act = self.env_action_space.sample() if isinstance(sample_obs, np.ndarray): # single-agent parser = ArrayParser(env=self.env, unwrapped_env=self.unwrapped_env) elif isinstance(sample_obs, list): # multi-agent (MAPPO) parser = ListParser(env=self.env, unwrapped_env=self.unwrapped_env, sample_action=sample_act) elif isinstance(sample_obs, (dict, OrderedDict)): parser = DictParser(env=self.env, unwrapped_env=self.unwrapped_env) else: raise NotImplementedError("The current space is not supported!") return parser
[docs] class BaseParser: def __init__(self, env: Union[gym.Wrapper, ControlEnv], unwrapped_env: ControlEnv): """ Base class for parsers. Args: env (Union[gym.Wrapper, ControlEnv]): environment (potentially wrapped) unwrapped_env (ControlEnv): unwrapped environment """ self.env = env self.unwrapped_env = unwrapped_env
[docs] def parse_obs(self, original_obs: Union[np.ndarray, list, dict, OrderedDict]) -> OrderedDict: """ Transforms the observation returned by the environment to the form {ctrl_id: np.ndarray} used in the DeploymentRunner _run() function. Args: original_obs (Union[np.ndarray, list, dict, OrderedDict]): observation from environment Returns: (OrderedDict): transformed observation as {ctrl_id: np.ndarray} """ raise NotImplementedError
[docs] def parse_action(self, original_action: OrderedDict) -> Union[np.ndarray, list, dict, OrderedDict]: """ Transforms the action provided by the DeploymentRunner _run() function to the format required by the underlying environment Args: original_action (OrderedDict): action provided by DeploymentRunner Returns: (Union[np.ndarray, list, dict, OrderedDict]): transformed action """ raise NotImplementedError
[docs] class ArrayParser(BaseParser):
[docs] def parse_obs(self, original_obs: Union[np.ndarray, list, dict, OrderedDict]) -> OrderedDict: """ Transforms the observation returned by the environment to the form {ctrl_id: np.ndarray} used in the DeploymentRunner _run() function. Args: original_obs (np.ndarray): observation from environment Returns: (OrderedDict): transformed observation as {ctrl_id: np.ndarray} """ transformed_obs = OrderedDict() transformed_obs[getattr(self.env, "ctrl_id")] = original_obs return transformed_obs
[docs] def parse_action(self, original_action: OrderedDict) -> Union[np.ndarray, list, dict, OrderedDict]: """ Transforms the action provided by the DeploymentRunner _run() function to the format required by the underlying environment Args: original_action (OrderedDict): action provided by DeploymentRunner Returns: (np.ndarray): transformed action """ ctrl_action = [] ctrl_action_dict = original_action[next(iter(original_action))] for node_action in ctrl_action_dict.values(): for el_action in node_action.values(): ctrl_action.append(el_action) return np.array(ctrl_action).reshape((-1,))
[docs] class ListParser(BaseParser): def __init__(self, env, unwrapped_env, sample_action): super().__init__(env=env, unwrapped_env=unwrapped_env) self.sample_action = sample_action
[docs] def parse_obs(self, original_obs: Union[np.ndarray, list, dict, OrderedDict]) -> OrderedDict: """ Transforms the observation returned by the environment to the form {ctrl_id: np.ndarray} used in the DeploymentRunner _run() function. Args: original_obs (list): observation from environment Returns: (OrderedDict): transformed observation as {ctrl_id: np.ndarray} """ transformed_obs = OrderedDict() controllers = getattr(self.unwrapped_env, "controllers") # all RL controllers for ctrl_id in controllers.keys(): ctrl_idx = list(controllers.keys()).index(ctrl_id) transformed_obs[ctrl_id] = original_obs[ctrl_idx] return transformed_obs
[docs] def parse_action(self, original_action: OrderedDict) -> Union[np.ndarray, list, dict, OrderedDict]: """ Transforms the action provided by the DeploymentRunner _run() function to the format required by the underlying environment Args: original_action (OrderedDict): action provided by DeploymentRunner Returns: (list): transformed action """ all_agents_actions = np.zeros(np.array(self.sample_action).shape) for agent_idx, agent_action in enumerate(original_action.values()): action_idx = 0 for node_action in agent_action.values(): for element_action in node_action.values(): all_agents_actions[agent_idx][action_idx] = element_action[0] all_agents_actions_list = list(all_agents_actions) all_agents_actions_list = [agent_action.reshape((1, -1)) for agent_action in all_agents_actions_list] return all_agents_actions_list
[docs] class DictParser(BaseParser):
[docs] def parse_obs(self, original_obs: Union[np.ndarray, list, dict, OrderedDict]) -> OrderedDict: """ Transforms the observation returned by the environment to the form {ctrl_id: np.ndarray} used in the DeploymentRunner _run() function. Args: original_obs (OrderedDict): observation from environment Returns: (OrderedDict): transformed observation as {ctrl_id: np.ndarray} """ return original_obs
[docs] def parse_action(self, original_action: OrderedDict) -> Union[np.ndarray, list, dict, OrderedDict]: """ Transforms the action provided by the DeploymentRunner _run() function to the format required by the underlying environment Args: original_action (OrderedDict): action provided by DeploymentRunner Returns: (Union[dict, OrderedDict]): transformed action """ return original_action