Federico Belotti /

An introduction to RL with SheepRL: the A2C algorithm

What is Reinforcement Learning?

Reinforcement Learning (RL) is a branch of Machine Learning that deals with the problem of learning from experience, i.e. “RL is learning what to do - how to map situations to actions - so to maximize a numerical reward signal. The learner is not told which actions to take, but instead must discover which actions yield the most reward by trying them.” (Richard S. Sutton and Andrew G. Barto, Reinforcement Learning: An Introduction, 2018). Those actions may not only affect the immediate reward but also the next ones and, through that, all subsequent rewards. These two concepts, trial-and-error and delayed reward, are the two most important distinguishing features of RL, setting it apart from other form of learning, in fact:

  • RL is different from supervised learning because the training data is not given in the form of input-output pairs, but instead, the agent must discover which actions yield the most reward by trying them
  • RL is different from unsupervised learning, which is tipically used to uncover hidden structures in unlabeled data, even though one may think of the former as an instance of the latter since RL does not rely on labelled examples of correct behaviour

This informal idea of the agent’s goal being to maximize the total amount of reward it receives is called reward hypothesis:

That all of what we mean by goals and purposes can be well thought of as the maximization of the expected value of the cumulative sum of a received scalar signal (called reward).

RL at a high-level

At a high-level RL can be summarized by the following figure:

where the main components are:

  • Agent: the one that behaves in the environment and learns from experiences gathered through the repeated interactions with it
  • Environment: the world (physical or virtual) where the agent lives and acts
  • State: a representation of the environment at a particular point in time, used by the agent to base its decisions
  • Reward: a numerical value that represents whether the action $A_t$ taken at time $t$ in state $S_t$, resulting in the state $S_{t+1}$, was a good or bad choice

The agent and the environment interact with each other in a sequence of discrete time steps, $t=0,1,2,3,…$. At each time step $t$, the agent receives a representation of the environment’s state, $S_t \in \mathcal{S}$, where $\mathcal{S}$ is the set of all possible states. Based on that, the agent selects an action, $A_t \in \mathcal{A}(S_t)$, where $\mathcal{A}(S_t)$ is the set of all possible actions in state $S_t$. One time step later, the agent receives a numerical reward, $R_{t+1} \in \mathcal{R} \subset \mathbb{R}$, and finds the environment in a new state, $S_{t+1}$, changed given the previous agent’s action.

The main objective of the agent is to maximize the expected reward over the long run, i.e. to maximize the expected sum of the rewards it receives from the beginning to the end of the (possibly never-ending) episode, $T$ ($\infty$). The expected sum of rewards is called return and it is defined as:

\[G = R_{1} + R_{2} + R_{3} + ... = \sum_{k=0}^{\infty} R_{k+1}\]

In practice, the reward $R_t$ received at time $t$ is often discounted by a factor $\gamma \in (0,1)$, where $\gamma$ is called discount rate (or discount factor), i.e.:

\[G = R_{1} + \gamma R_{2} + \gamma^2 R_{3} + ... = \sum_{k=0}^{\infty} \gamma^{k} R_{k+1}\]

This is useful for at least two reasons:

  • It allows to make the sum of rewards finite even if the episode is infinite. In fact, if we suppose to have an infinite episode with a constant reward $R_t = r \ll \infty$, then the return would be infinite as well. By discounting the reward, the return becomes: \(G = r + \gamma r + \gamma^2 r + ... = \sum_{k=0}^{\infty} \gamma^{k} r = \frac{r}{1-\gamma}\)
  • It allows to choose how much importance is given to immediate rewards vs future rewards: the higher the discount rate, the more the agent will care about future rewards; the lower the discount rate, the more the agent will care about immediate rewards instead

RL at a low-level

At a low-level (i.e. mathematically), RL can be described by the following:

  • The set of possible states $\mathcal{S}$, where $S_t \in \mathcal{S}$ is the state of the environment at time $t$
  • The set of possible actions $\mathcal{A}$, where $A_t \in \mathcal{A}$ is the action taken by the agent at time $t$, given the environment state $S_t$
  • The transition function $\mathcal{P}$, where \(p(s'\vert s,a) : \mathcal{S} \times \mathcal{A} \times \mathcal{S} \rightarrow \left[0,1\right] = \mathbb{P}\left[S_{t+1} = s' \vert S_t = s, A_t = a \right]\) is the probability of transitioning to state $s’$ at time $t+1$ given that the agent is in state $s$ at time $t$ and takes action $a$ at time $t$
  • The reward function $\mathcal{R}$, where \(r(s,a) : \mathcal{S} \times \mathcal{A} \rightarrow \left[0,1\right] = \mathbb{E}\left[ R_{t+1} \vert S_t = s, A_t = a \right]\) is the expected reward received by the agent at time $t+1$ given that the agent is in state $s$ at time $t$ and takes action $a$ at time $t$
  • The policy $\pi$, where $\pi(a \vert s) : \mathcal{S} \times \mathcal{A} \rightarrow \left[0,1\right]$ is the probability of taking action $a$ at time $t$ given that the agent is in state $s$ at time $t$

The agent, interacting with the environment, generates a sequence of state-action-reward tuples, called trajectory:

\[\tau = s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2, r_3, ...\]

while its goal is to maximize the expected return, averaged over all possible trajectories, i.e. find that policy $\pi$ that maximizes the expected return $\eta(\pi)$:

\[\eta(\pi) = \mathbb{E}_{s_0, a_0, r_1, ...} \left[ \sum_{t=0}^{\infty} \gamma^{t} r(s_t, a_t) \right]\]

where:

  • $s_0$ is a random initial environment state
  • $a_t \sim \pi(a_t \vert s_t)$
  • $s_{t+1} \sim p(s_{t+1} \vert s_t,a_t)$
  • $r_{t+1} = r(s_t, a_t)$

RL at a practical level

In practice there are a lot of different algorithms that can be used to solve RL problems, some of them being already covered (Dreamer-V1, Dreamer-V2, Dreamer-V3, PPO, SAC and P2E). In this article we will focus on the Advantage Actor Critic (A2C) algorithm, which is a policy gradient method that uses the actor-critic architecture, composed by:

  • An Actor, which takes as input the environment state $s_t$ and outputs a probability distribution over the possible actions $a_t$
  • A Critic, which takes as input the environment state $s_t$ and outputs a value $V(s_t)$, which is an estimate of the expected return if the agent starts in state $s_t$ and follows the policy $\pi$ afterwards

Let’s dive in

Let’s start following what the SheepRL’s how-to guide suggests, by first installing the SheepRL package with pip install sheeprl[atari,box2d,dev,test] and then creating a new folder for our algorithm, called a2c, under which we create five new files called a2c.py, agent.py, loss.py, utils.py and the standard __init__.py:

./a2c
├── a2c.py
├── agent.py
├── __init__.py
├── loss.py
└── utils.py

1. a2c.py

The first thing we need to do is to import the necessary modules and prepare the skeleton of our algorithm. For simplicity our agent will work with only vector observations, so we need to check that no images will be returned by the environment:

# a2c.py

import os
import warnings
from typing import Any, Dict

import gymnasium as gym
import hydra
import numpy as np
import torch
from lightning.fabric import Fabric
from sheeprl.data import ReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.timer import timer
from sheeprl.utils.utils import gae, save_configs
from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler
from torchmetrics import SumMetric

from a2c.agent import build_agent
from a2c.loss import policy_loss, value_loss
from a2c.utils import test


def train(
    fabric: Fabric,
    agent: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    data: Dict[str, torch.Tensor],
    aggregator: MetricAggregator,
    cfg: Dict[str, Any],
):
    """Train the agent on the data collected from the environment."""

    # Prepare the sampler
    # If we are in the distributed setting, we need to use a DistributedSampler, which
    # will shuffle the data at each epoch and will ensure that each process will get
    # a different part of the data
    indexes = list(range(next(iter(data.values())).shape[0]))
    if cfg.buffer.share_data:
        sampler = DistributedSampler(
            indexes,
            num_replicas=fabric.world_size,
            rank=fabric.global_rank,
            shuffle=True,
            seed=cfg.seed,
        )
    else:
        sampler = RandomSampler(indexes)
    sampler = BatchSampler(sampler, batch_size=cfg.algo.per_rank_batch_size, drop_last=False)

    optimizer.zero_grad(set_to_none=True)
    if cfg.buffer.share_data:
        sampler.sampler.set_epoch(0)

    # Train the agent
    # Even though in the Spinning-Up A2C algorithm implementation
    # (https://spinningup.openai.com/en/latest/algorithms/vpg.html) the policy gradient is estimated
    # by taking the mean over all the sequences collected
    # of the sum of the actions log-probabilities gradients' multiplied by the advantages,
    # we do not do that, instead we take the overall sum (or mean, depending on the loss reduction).
    # This is achieved by accumulating the gradients and calling the backward method only at the end.
    for i, batch_idxes in enumerate(sampler):
        batch = {k: v[batch_idxes] for k, v in data.items()}
        obs = {k: batch[k] for k in cfg.algo.mlp_keys.encoder}

        # is_accumulating is True for every i except for the last one
        is_accumulating = i < len(sampler) - 1

        with fabric.no_backward_sync(agent, enabled=is_accumulating):
            _, logprobs, values = agent(obs, torch.split(batch["actions"], agent.actions_dim, dim=-1))

            # Policy loss
            pg_loss = policy_loss(
                logprobs,
                batch["advantages"],
                cfg.algo.loss_reduction,
            )

            # Value loss
            v_loss = value_loss(
                values,
                batch["returns"],
                cfg.algo.loss_reduction,
            )

            loss = pg_loss + v_loss
            fabric.backward(loss)

        if not is_accumulating:
            if cfg.algo.max_grad_norm > 0.0:
                fabric.clip_gradients(agent, optimizer, max_norm=cfg.algo.max_grad_norm)
            optimizer.step()

        # Update metrics
        if aggregator and not aggregator.disabled:
            aggregator.update("Loss/policy_loss", pg_loss.detach())
            aggregator.update("Loss/value_loss", v_loss.detach())


@register_algorithm(decoupled=False)
def main(fabric: Fabric, cfg: Dict[str, Any]):
    rank = fabric.global_rank
    world_size = fabric.world_size
    device = fabric.device
    fabric.seed_everything(cfg.seed)

    # Resume from checkpoint
    if cfg.checkpoint.resume_from:
        state = fabric.load(cfg.checkpoint.resume_from)

    # Create Logger. This will create the logger only on the
    # rank-0 process
    logger = get_logger(fabric, cfg)
    if logger and fabric.is_global_zero:
        fabric._loggers = [logger]
        fabric.logger.log_hyperparams(cfg)
    log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name)
    fabric.print(f"Log dir: {log_dir}")

    # Environment setup
    vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv
    envs = vectorized_env(
        [
            make_env(
                cfg,
                cfg.seed + rank * cfg.env.num_envs + i,
                rank * cfg.env.num_envs,
                log_dir if rank == 0 else None,
                "train",
                vector_env_idx=i,
            )
            for i in range(cfg.env.num_envs)
        ]
    )
    observation_space = envs.single_observation_space

    if not isinstance(observation_space, gym.spaces.Dict):
        raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}")
    if len(cfg.algo.mlp_keys.encoder) == 0:
        raise RuntimeError("You should specify at least one MLP key for the encoder: `algo.mlp_keys.encoder=[state]`")
    for k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder:
        if k in observation_space.keys() and len(observation_space[k].shape) > 1:
            raise ValueError(
                "Only environments with vector-only observations are supported by the A2C agent. "
                f"The observation with key '{k}' has shape {observation_space[k].shape}. "
                f"Provided environment: {cfg.env.id}"
            )
    if cfg.metric.log_level > 0:
        fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder)
        fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder)
    if len(cfg.algo.cnn_keys.encoder) > 0:
        warnings.warn("The A2C agent does not support pixel-based observations, the CNN keys will be ignored.")
    cfg.algo.cnn_keys.encoder = []
    obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder

    is_continuous = isinstance(envs.single_action_space, gym.spaces.Box)
    is_multidiscrete = isinstance(envs.single_action_space, gym.spaces.MultiDiscrete)
    actions_dim = tuple(
        envs.single_action_space.shape
        if is_continuous
        else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n])
    )
    # Create the actor and critic models
    agent = build_agent(
        fabric,
        actions_dim,
        is_continuous,
        cfg,
        observation_space,
        state["agent"] if cfg.checkpoint.resume_from else None,
    )

    # Define the optimizer
    optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters(), _convert_="all")

    if fabric.is_global_zero:
        save_configs(cfg, log_dir)

    # Load the state from the checkpoint
    if cfg.checkpoint.resume_from:
        optimizer.load_state_dict(state["optimizer"])

    # Setup agent and optimizer with Fabric
    optimizer = fabric.setup_optimizers(optimizer)

    # Create a metric aggregator to log the metrics
    aggregator = None
    if not MetricAggregator.disabled:
        aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device)

    # Local data
    if cfg.buffer.size < cfg.algo.rollout_steps:
        raise ValueError(
            f"The size of the buffer ({cfg.buffer.size}) cannot be lower "
            f"than the rollout steps ({cfg.algo.rollout_steps})"
        )
    rb = ReplayBuffer(
        cfg.buffer.size,
        cfg.env.num_envs,
        memmap=cfg.buffer.memmap,
        memmap_dir=os.path.join(log_dir, "memmap_buffer", f"rank_{fabric.global_rank}"),
        obs_keys=obs_keys,
    )

    # Global variables
    last_train = 0
    train_step = 0
    start_step = (
        # + 1 because the checkpoint is at the end of the update step
        # (when resuming from a checkpoint, the update at the checkpoint
        # is ended and you have to start with the next one)
        (state["update"] // fabric.world_size) + 1
        if cfg.checkpoint.resume_from
        else 1
    )
    policy_step = state["update"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0
    last_log = state["last_log"] if cfg.checkpoint.resume_from else 0
    last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0
    policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
    num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1
    if cfg.checkpoint.resume_from:
        cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size

    # Warning for log and checkpoint every
    if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0:
        warnings.warn(
            f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the "
            f"policy_steps_per_update value ({policy_steps_per_update}), so "
            "the metrics will be logged at the nearest greater multiple of the "
            "policy_steps_per_update value."
        )
    if cfg.checkpoint.every % policy_steps_per_update != 0:
        warnings.warn(
            f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the "
            f"policy_steps_per_update value ({policy_steps_per_update}), so "
            "the checkpoint will be saved at the nearest greater multiple of the "
            "policy_steps_per_update value."
        )

    # Get the first environment observation and start the optimization
    step_data = {}
    next_obs = envs.reset(seed=cfg.seed)[0]  # [N_envs, N_obs]
    for k in obs_keys:
        step_data[k] = next_obs[k][np.newaxis]

    for update in range(start_step, num_updates + 1):
        for _ in range(0, cfg.algo.rollout_steps):
            policy_step += cfg.env.num_envs * world_size

            # Measure environment interaction time: this considers both the model forward
            # to get the action given the observation and the time taken into the environment
            with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
                with torch.no_grad():
                    # Sample an action given the observation received by the environment
                    torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys}
                    actions, _, values = agent.module(torch_obs)
                    if is_continuous:
                        real_actions = torch.cat(actions, -1).cpu().numpy()
                    else:
                        real_actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1).cpu().numpy()
                    actions = torch.cat(actions, -1).cpu().numpy()

                # Single environment step
                obs, rewards, dones, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape))
                dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8)
                rewards = rewards.reshape(cfg.env.num_envs, -1)

            # Update the step data
            step_data["dones"] = dones[np.newaxis]
            step_data["values"] = values.cpu().numpy()[np.newaxis]
            step_data["actions"] = actions[np.newaxis]
            step_data["rewards"] = rewards[np.newaxis]
            if cfg.buffer.memmap:
                step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape))
                step_data["advantages"] = np.zeros_like(rewards, shape=(1, *rewards.shape))

            # Append data to buffer
            rb.add(step_data, validate_args=cfg.buffer.validate_args)

            # Update the observation and dones
            next_obs = obs
            for k in obs_keys:
                step_data[k] = obs[k][np.newaxis]

            if cfg.metric.log_level > 0 and "final_info" in info:
                for i, agent_ep_info in enumerate(info["final_info"]):
                    if agent_ep_info is not None:
                        ep_rew = agent_ep_info["episode"]["r"]
                        ep_len = agent_ep_info["episode"]["l"]
                        if aggregator and "Rewards/rew_avg" in aggregator:
                            aggregator.update("Rewards/rew_avg", ep_rew)
                        if aggregator and "Game/ep_len_avg" in aggregator:
                            aggregator.update("Game/ep_len_avg", ep_len)
                        fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}")

        # Transform the data into PyTorch Tensors
        local_data = rb.to_tensor(dtype=None, device=device, from_numpy=cfg.buffer.from_numpy)

        # Estimate returns with GAE (https://arxiv.org/abs/1506.02438)
        with torch.no_grad():
            torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys}
            next_values = agent.module.get_value(torch_obs)
            returns, advantages = gae(
                local_data["rewards"].to(torch.float64),
                local_data["values"],
                local_data["dones"],
                next_values,
                cfg.algo.rollout_steps,
                cfg.algo.gamma,
                cfg.algo.gae_lambda,
            )
            # Add returns and advantages to the buffer
            local_data["returns"] = returns.float()
            local_data["advantages"] = advantages.float()

        if cfg.buffer.share_data and fabric.world_size > 1:
            # Gather all the tensors from all the world and reshape them
            gathered_data: Dict[str, torch.Tensor] = fabric.all_gather(local_data)
            # Flatten the first three dimensions: [World_Size, Buffer_Size, Num_Envs]
            gathered_data = {k: v.flatten(start_dim=0, end_dim=2).float() for k, v in gathered_data.items()}
        else:
            # Flatten the first two dimensions: [Buffer_Size, Num_Envs]
            gathered_data = {k: v.flatten(start_dim=0, end_dim=1).float() for k, v in local_data.items()}

        with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
            train(fabric, agent, optimizer, gathered_data, aggregator, cfg)
        train_step += world_size

        if cfg.metric.log_level > 0:
            # Log metrics
            if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates):
                # Sync distributed metrics
                if aggregator and not aggregator.disabled:
                    metrics_dict = aggregator.compute()
                    fabric.log_dict(metrics_dict, policy_step)
                    aggregator.reset()

                # Sync distributed timers
                if not timer.disabled:
                    timer_metrics = timer.compute()
                    if "Time/train_time" in timer_metrics:
                        fabric.log(
                            "Time/sps_train",
                            (train_step - last_train) / timer_metrics["Time/train_time"],
                            policy_step,
                        )
                    if "Time/env_interaction_time" in timer_metrics:
                        fabric.log(
                            "Time/sps_env_interaction",
                            ((policy_step - last_log) / world_size * cfg.env.action_repeat)
                            / timer_metrics["Time/env_interaction_time"],
                            policy_step,
                        )
                    timer.reset()

                # Reset counters
                last_log = policy_step
                last_train = train_step

        # Checkpoint model
        if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or (
            update == num_updates and cfg.checkpoint.save_last
        ):
            last_checkpoint = policy_step
            state = {
                "agent": agent.state_dict(),
                "optimizer": optimizer.state_dict(),
                "update": update * world_size,
                "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size,
                "last_log": last_log,
                "last_checkpoint": last_checkpoint,
            }
            ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt")
            fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)

    envs.close()
    if fabric.is_global_zero:
        test(agent.module, fabric, cfg, log_dir)

2. agent.py

The second thing is the agent.py file, where we will define the build_agent function, which will be used to create the agent model. We can take inspiration from the PPO agent.py file and, to keep thing at a demonstration level, remove everything regarding the encoding/processing of pixel-based observations:

# agent.py

from typing import Any, Dict, List, Optional, Sequence, Tuple

import gymnasium
import torch
import torch.nn as nn
from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule
from torch import Tensor
from torch.distributions import Distribution, Independent, Normal

from sheeprl.models.models import MLP
from sheeprl.utils.distribution import OneHotCategoricalValidateArgs


class MLPEncoder(nn.Module):
    def __init__(
        self,
        input_dim: int,
        features_dim: int,
        keys: Sequence[str],
        dense_units: int = 64,
        mlp_layers: int = 2,
        dense_act: nn.Module = nn.ReLU,
    ) -> None:
        super().__init__()
        self.keys = keys
        self.input_dim = input_dim
        self.output_dim = features_dim
        self.model = MLP(
            input_dim,
            features_dim,
            [dense_units] * mlp_layers,
            activation=dense_act
        )

    def forward(self, obs: Dict[str, Tensor]) -> Tensor:
        x = torch.cat([obs[k] for k in self.keys], dim=-1)
        return self.model(x)


class A2CAgent(nn.Module):
    def __init__(
        self,
        actions_dim: Sequence[int],
        obs_space: gymnasium.spaces.Dict,
        encoder_cfg: Dict[str, Any],
        actor_cfg: Dict[str, Any],
        critic_cfg: Dict[str, Any],
        distribution_cfg: Dict[str, Any],
        mlp_keys: Sequence[str],
        is_continuous: bool = False,
    ):
        super().__init__()
        self.actions_dim = actions_dim
        self.obs_space = obs_space
        self.distribution_cfg = distribution_cfg
        self.mlp_keys = mlp_keys
        self.is_continuous = is_continuous

        # Feature extractor
        mlp_input_dim = sum([obs_space[k].shape[0] for k in mlp_keys])
        self.feature_extractor = (
            MLPEncoder(
                mlp_input_dim,
                encoder_cfg.mlp_features_dim,
                mlp_keys,
                encoder_cfg.dense_units,
                encoder_cfg.mlp_layers,
                eval(encoder_cfg.dense_act),
            )
            if mlp_keys is not None and len(mlp_keys) > 0
            else None
        )
        features_dim = self.feature_extractor.output_dim

        # Critic
        self.critic = MLP(
            input_dims=features_dim,
            output_dim=1,
            hidden_sizes=[critic_cfg.dense_units] * critic_cfg.mlp_layers,
            activation=eval(critic_cfg.dense_act)
        )

        # Actor
        self.actor_backbone = MLP(
            input_dims=features_dim,
            output_dim=None,
            hidden_sizes=[actor_cfg.dense_units] * actor_cfg.mlp_layers,
            activation=eval(actor_cfg.dense_act),
            flatten_dim=None
        )
        if is_continuous:
            # Output is a tuple of two elements: mean and log_std, one for every action
            self.actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, sum(actions_dim) * 2)])
        else:
            # Output is a tuple of one element: logits, one for every action
            self.actor_heads = nn.ModuleList(
                [nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim]
            )

    def forward(
        self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None
    ) -> Tuple[Sequence[Tensor], Tensor, Tensor]:
        feat = self.feature_extractor(obs)
        out: Tensor = self.actor_backbone(feat)
        pre_dist: List[Tensor] = [head(out) for head in self.actor_heads]
        values = self.critic(feat)
        if self.is_continuous:
            mean, log_std = torch.chunk(pre_dist[0], chunks=2, dim=-1)
            std = log_std.exp()
            normal = Independent(
                Normal(mean, std, validate_args=self.distribution_cfg.validate_args),
                1,
                validate_args=self.distribution_cfg.validate_args,
            )
            if actions is None:
                actions = normal.sample()
            else:
                # always composed by a tuple of one element containing all the
                # continuous actions
                actions = actions[0]
            log_prob = normal.log_prob(actions)
            return tuple([actions]), log_prob.unsqueeze(dim=-1), values
        else:
            should_append = False
            actions_logprobs: List[Tensor] = []
            actions_dist: List[Distribution] = []
            if actions is None:
                should_append = True
                actions: List[Tensor] = []
            for i, logits in enumerate(pre_dist):
                actions_dist.append(
                    OneHotCategoricalValidateArgs(
                        logits=logits, validate_args=self.distribution_cfg.validate_args
                    )
                )
                if should_append:
                    actions.append(actions_dist[-1].sample())
                actions_logprobs.append(actions_dist[-1].log_prob(actions[i]))
            return tuple(actions), torch.stack(actions_logprobs, dim=-1).sum(dim=-1, keepdim=True), values

    def get_value(self, obs: Dict[str, Tensor]) -> Tensor:
        feat = self.feature_extractor(obs)
        return self.critic(feat)

    def get_greedy_actions(self, obs: Dict[str, Tensor]) -> Sequence[Tensor]:
        feat = self.feature_extractor(obs)
        out = self.actor_backbone(feat)
        pre_dist: List[Tensor] = [head(out) for head in self.actor_heads]
        if self.is_continuous:
            # Just take the mean of the distribution
            return [torch.chunk(pre_dist[0], 2, -1)[0]]
        else:
            # Take the mode of the distribution
            return tuple(
                [
                    OneHotCategoricalValidateArgs(
                        logits=logits, validate_args=self.distribution_cfg.validate_args
                    ).mode
                    for logits in pre_dist
                ]
            )


def build_agent(
    fabric: Fabric,
    actions_dim: Sequence[int],
    is_continuous: bool,
    cfg: Dict[str, Any],
    obs_space: gymnasium.spaces.Dict,
    agent_state: Optional[Dict[str, Tensor]] = None,
) -> _FabricModule:
    agent = A2CAgent(
        actions_dim=actions_dim,
        obs_space=obs_space,
        encoder_cfg=cfg.algo.encoder,
        actor_cfg=cfg.algo.actor,
        critic_cfg=cfg.algo.critic,
        mlp_keys=cfg.algo.mlp_keys.encoder,
        distribution_cfg=cfg.distribution,
        is_continuous=is_continuous,
    )
    if agent_state:
        agent.load_state_dict(agent_state)
    agent = fabric.setup_module(agent)

    return agent

For simplicity we have only defined the MLPEncoder and A2CAgent classes, which are the ones we will use in this article. The MLPEncoder class is used to create a multi-layer perceptron (MLP) encoder, which is used to extract features from the observations. The A2CAgent class is used to create the agent model, which is composed by:

  • A feature extractor, which is an MLP encoder that takes as input the observations and outputs a vector of features
  • A critic, which is an MLP that takes as input the features extracted by the feature extractor and outputs a value $V(s_t)$, which is an estimate of the expected return if the agent starts in state $s_t$ and follows the policy $\pi$ afterwards
  • An actor, which is an MLP that takes as input the features extracted by the feature extractor and outputs a probability distribution over the possible actions $a_t$

3. loss.py

The third thing is the loss.py file, where we will define all the loss functions’, which will be used to compute the loss.

There are two loss functions we need to define:

  • The policy loss, which is used to train the actor and is defined as:

    \[L_{\pi} = -\sum_{t=1}^{T} \log \pi_{\theta}(a_t \vert s_t)[G_t - V_{\theta}^{\pi}(s_t)] = -\sum_{t=1}^{T} \log \pi_{\theta}(a_t \vert s_t) A_t\]

    where $A_t$ is the advantage function, which is defined as:

    \[A_t = G_t - V_{\theta}^{\pi}(s_t)\]

    where $G_t$ is the return at time $t$, which is defined as:

    \[G_t = \sum_{k=t}^{T} \gamma^{k-t} R_{k+1}\]

    The advantage indicates how much better an action is given a particular state over a random action selected according to the policy for that state.

  • The value loss, which is used to train the critic and is simply the Mean Squared Error (MSE) between the estimated value and the actual return:

    \[L_{V} = \sum_{t=1}^{T} (G_t - V_{\theta}^{\pi}(s_t))^2\]

# loss.py

import torch.nn.functional as F
from torch import Tensor


def policy_loss(
    logprobs: Tensor,
    advantages: Tensor,
    reduction: str = "mean",
) -> Tensor:
    """Compute the policy loss for a batch of data, as described in equation (7) of the paper.

        - Compute the difference between the new and old logprobs.
        - Exponentiate it to find the ratio.
        - Use the ratio and advantages to compute the loss as per equation (7).

    Args:
        logprobs (Tensor): the log-probs of the sampled actions from the environment.
        advantages (Tensor): the advantages.

    Returns:
        the policy loss
    """
    pg_loss = -(logprobs * advantages)
    reduction = reduction.lower()
    if reduction == "none":
        return pg_loss
    elif reduction == "mean":
        return pg_loss.mean()
    elif reduction == "sum":
        return pg_loss.sum()
    else:
        raise ValueError(f"Unrecognized reduction: {reduction}")


def value_loss(
    values: Tensor,
    returns: Tensor,
    reduction: str = "mean",
) -> Tensor:
    return F.mse_loss(values, returns, reduction=reduction)

4. utils.py

The fourth and almost last thing is the utils.py file, where we will define all the utility functions, which will be used to train the agent:

# utils.py

from typing import Any, Dict

import torch
from lightning import Fabric

from sheeprl.algos.a2c.agent import A2CAgent
from sheeprl.utils.env import make_env

AGGREGATOR_KEYS = {"Rewards/rew_avg", "Game/ep_len_avg", "Loss/value_loss", "Loss/policy_loss"}


@torch.no_grad()
def test(agent: A2CAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
    env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
    agent.eval()
    done = False
    cumulative_rew = 0
    o = env.reset(seed=cfg.seed)[0]
    obs = {}
    for k in o.keys():
        if k in cfg.algo.mlp_keys.encoder:
            torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
            torch_obs = torch_obs.float()
            obs[k] = torch_obs

    while not done:
        # Act greedly through the environment
        if agent.is_continuous:
            actions = torch.cat(agent.get_greedy_actions(obs), dim=-1)
        else:
            actions = torch.cat([act.argmax(dim=-1) for act in agent.get_greedy_actions(obs)], dim=-1)

        # Single environment step
        o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
        done = done or truncated
        cumulative_rew += reward
        obs = {}
        for k in o.keys():
            if k in cfg.algo.mlp_keys.encoder:
                torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
                torch_obs = torch_obs.float()
                obs[k] = torch_obs

        if cfg.dry_run:
            done = True
    fabric.print("Test - Reward:", cumulative_rew)
    if cfg.metric.log_level > 0:
        fabric.log_dict({"Test/cumulative_reward": cumulative_rew}, 0)
    env.close()

In particular we have defined the test function, which is used to test the agent in the environment.

5. Configs

One of the almost last thing we need to do is to define the configuration files, which will be used to configure the agent and the default experiment. We advise to have a look at our how-to regarding SheepRL configs.

We first create the config folder, then we define the following configuration files:

  • configs/algo/a2c.yaml: the algorithm’s configuration, which will be used to configure the agent
  • configs/exp/a2c.yaml: the default experiment configuration file

defaults:
  - default
  - /optim@optimizer: adam
  - _self_

# Training receipe
name: a2c
gamma: 0.99
gae_lambda: 1.0
loss_reduction: mean
rollout_steps: 16
dense_act: torch.nn.Tanh
max_grad_norm: 100

# Encoder
encoder:
  mlp_layers: 1
  dense_units: 64
  mlp_features_dim: 64
  dense_act: ${algo.dense_act}

# Actor
actor:
  mlp_layers: 2
  dense_units: 64
  dense_act: ${algo.dense_act}

# Critic
critic:
  mlp_layers: 2
  dense_units: 64
  dense_act: ${algo.dense_act}

# Single optimizer for both actor and critic
optimizer:
  lr: 1e-3
  eps: 1e-4

# @package _global_

defaults:
  - override /algo: a2c
  - override /env: gym
  - _self_

# Algorithm
algo:
  total_steps: 25000
  rollout_steps: 5
  per_rank_batch_size: ${algo.rollout_steps}
  mlp_keys:
    encoder: [state]

# Buffer
buffer:
  share_data: False
  size: ${algo.rollout_steps}

metric:
  aggregator:
    metrics:
      Loss/value_loss:
        _target_: torchmetrics.MeanMetric
        sync_on_compute: ${metric.sync_on_compute}
      Loss/policy_loss:
        _target_: torchmetrics.MeanMetric
        sync_on_compute: ${metric.sync_on_compute}

6. Algorithm and configs registration

To let SheepRL know about our new algorithm we need to register it. To do so we need to add a new file, which will act as our entrypoint, called (for example) a2c_main.py at the root our project:

# a2c_main.py

# This will trigger the algorithm registration of SheepRL
from a2c import a2c  # noqa: F401

if __name__ == "__main__":
    # This must be imported after the algorithm registration, otherwise SheepRL
    # will not be able to find the new algorithm given the name specified
    # in the `algo.name` field of the `./my_awesome_configs/algo/ext_sota.yaml` config file
    from sheeprl.cli import run

    run()

To let SheepRL know about our new configuration files we need to add a new file, called (for example) .env at the root our project with the following env variable in it:

SHEEPRL_SEARCH_PATH=file://configs;pkg://sheeprl.configs

7. Run the experiment

In the end you should have the following project structure:

.
├── .env
├── a2c
│   ├── a2c.py
│   ├── agent.py
│   ├── __init__.py
│   ├── loss.py
│   └── utils.py
├── a2c_main.py
└── configs
    ├── algo
    │   └── a2c.yaml
    └── exp
        └── a2c.yaml

To run the experiment we need to simply execute the following command:

python a2c_main.py exp=a2c env.num_envs=4

This will run the experiment using the default configuration file, which is configs/exp/a2c.yaml; you can follow the experiment progress with TensorBoard by running the following command:

tensorboard --logdir <log_dir>

where log_dir is the directory where the logs are saved, which is printed in the terminal when the experiment starts.

Conclusion

In this article we have seen how to implement the A2C algorithm in SheepRL. In particular, we have seen how to define the algorithm, starting from the register-new-algorithm how-to, the agent by taking inspiration from the PPO one, the loss functions, the utility functions, and the configuration files. We have also seen how to register the algorithm and how to run the experiment. We hope you have enjoyed this article and we hope to see you in the next one!

Share this