Source code for commonpower.control.logging_utils.loggers

"""
Collection of loggers for controller performance.
"""
from typing import Callable

import wandb
from stable_baselines3.common.logger import Logger, make_output_format

from commonpower.control.logging_utils.callbacks import (
    BaseCallback,
    MARLBaseCallback,
    MARLWandBCallback,
    SafetyCallback,
    WandBSafetyCallback,
)


[docs] class BaseLogger: def __init__(self, log_dir: str): """ Base class for logging metrics during RL training. Args: log_dir (str): relative path to logging directory """ self.log_dir = log_dir def get_log_dir(self) -> str: return self.log_dir def log_function(self) -> Callable: raise NotImplementedError def finish_logging(self) -> None: raise NotImplementedError
[docs] class TensorboardLogger(BaseLogger): def __init__(self, log_dir: str, callback: BaseCallback = SafetyCallback): """ Class for using tensorboard logging in single-agent stable-baselines3 algorithms. Args: log_dir (str): relative path to logging directory callback (BaseCallback, optional): object that implements actual logging during training. By defining a \ customized callback, additional information can be logged (apart from standard metrics like mean_eps_reward) """ super().__init__(log_dir=log_dir) self.callback = callback
[docs] def log_function(self) -> BaseCallback: """ Hands over the callback so it can be used by the stable-baselines3 internal logging. Returns: BaseCallback: callback which is used during training to log additional information """ return self.callback()
def finish_logging(self) -> None: pass
[docs] class WandBLogger(BaseLogger): def __init__( self, log_dir: str, entity_name: str, run_name: str, project_name: str = None, callback: BaseCallback = WandBSafetyCallback, model_save_freq: int = 100, verbose: int = 2, alg_config: dict = None, ): """ Class for using Weights&Biases (wandb) logging in single-agent stable-baselines3 algorithms. Args: log_dir (str): relative path to logging directory entity_name (str): name of the wandb entity to which the runs will be logged run_name (str): name under which the run will be displayed in WandB project_name (str, optional): name of the wandb project to which the runs will be logged callback (BaseCallback, optional): object that implements actual logging during training. By defining a \ customized callback, additional information can be logged (apart from standard metrics like mean_eps_reward) model_save_freq (int, optional): after how many episodes the current model should be logged verbose (int, optional): output verbosity alg_config (dict, optional): dictionary of algorithm hyperparameters. Can be used to filter runs in wandb API """ super().__init__(log_dir=log_dir) self.entity_name = entity_name self.run_name = run_name self.project_name = project_name self.callback = callback self.alg_config = alg_config self.model_save_freq = model_save_freq self.verbose = verbose self.run = wandb.init( project=self.project_name, entity=self.entity_name, name=self.run_name, config=self.alg_config, sync_tensorboard=True, ) self.model_save_path = self.log_dir + f"models/{self.run.id}" self.log_dir = self.log_dir + f"runs/{self.run.id}"
[docs] def log_function(self) -> BaseCallback: """ Hands over the callback so it can be used by the stable-baselines3 internal logging. Returns: BaseCallback: callback which is used during training to log additional information """ return self.callback( model_save_path=self.model_save_path, model_save_freq=self.model_save_freq, verbose=self.verbose )
[docs] def finish_logging(self) -> None: """ Terminates the W&B run. Returns: None """ wandb.finish()
@property def run_id(self): return wandb.run.id
[docs] class MARLTensorboardLogger(BaseLogger): def __init__( self, log_dir: str, callback: MARLBaseCallback = MARLBaseCallback, format_strings: list = ["stdout", "tensorboard"], ): """ Class for using tensorboard logging in multi-agent IPPO/MAPPO algorithms from the on-policy repository (https://github.com/marlbenchmark/on-policy/blob/main/README.md). Args: log_dir (str): relative path to logging directory callback (MARLBaseCallback): object that implements actual logging during training. By defining a \ customized callback, additional information can be logged (apart from standard metrics like mean_eps_reward) format_strings (list): list of output formats for the SB3 logger """ super().__init__(log_dir=log_dir) self.callback = callback log_suffix = "" output_formats = [make_output_format(f, self.log_dir, log_suffix) for f in format_strings] self.log_function = Logger(folder=self.log_dir, output_formats=output_formats)
[docs] def get_callback(self) -> MARLBaseCallback: """ Hands over the callback. Returns: MARLBaseCallback: callback which is used during training to log additional information """ return self.callback()
[docs] def get_log_function(self) -> Callable: """ Hands over the logger we get from stable-baselines3 Returns: Callable: Logger """ return self.log_function
def finish_logging(self) -> None: pass
[docs] class MARLWandBLogger(BaseLogger): def __init__( self, log_dir: str, entity_name: str, project_name: str = None, callback: BaseCallback = MARLWandBCallback, format_strings: list = ["stdout", "tensorboard"], model_save_freq: int = 100, verbose: int = 2, alg_config: dict = None, ): """ Class for using Weights&Biases (wandb) logging in single-agent stable-baselines3 algorithms Args: log_dir (str): relative path to logging directory entity_name (str): name of the wandb entity to which the runs will be logged project_name (str, optional): name of the wandb project to which the runs will be logged callback (BaseCallback, optional): object that implements actual logging during training - by defining a customized callback, additional information can be logged (apart from standard metrics like mean_eps_reward) format_strings (list): list of output formats for the SB3 logger model_save_freq (int, optional): after how many episodes the current model should be logged verbose (int, optional): output verbosity alg_config (dict, optional): dictionary of algorithm hyperparameters. Can be used to filter runs in wandb API """ super().__init__(log_dir=log_dir) self.entity_name = entity_name self.project_name = project_name self.callback = callback self.alg_config = alg_config # init WandB self.model_save_freq = model_save_freq self.verbose = verbose self.run = wandb.init( project=self.project_name, entity=self.entity_name, config=self.alg_config, sync_tensorboard=True ) self.model_save_path = self.log_dir + f"models/{self.run.id}" self.log_dir = self.log_dir + f"runs/{self.run.id}" # init logger (importer from SB3) log_suffix = "" output_formats = [make_output_format(f, self.log_dir, log_suffix) for f in format_strings] self.log_function = Logger(folder=self.log_dir, output_formats=output_formats)
[docs] def get_callback(self) -> MARLBaseCallback: """ Hands over the callback. Returns: MARLBaseCallback: callback which is used during training to log additional information """ return self.callback( model_save_path=self.model_save_path, model_save_freq=self.model_save_freq, verbose=self.verbose )
[docs] def get_log_function(self) -> Callable: """ Hands over the logger we get from stable-baselines3 Returns: Callable: Logger """ return self.log_function
[docs] def finish_logging(self) -> None: """ Terminates the W&B run. Returns: None """ wandb.finish()