Skip to content

Commit

Permalink
style: improve code style
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaiejj committed May 5, 2024
1 parent ef30616 commit 9fd5fd7
Show file tree
Hide file tree
Showing 12 changed files with 619 additions and 196 deletions.
12 changes: 12 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -486,3 +486,15 @@ UpdateDynamics
mathbb
meger
Jupyter
LazyFrames
SLAC
Leibler
Kullback
slac
Tal
Nils
Simão
Hogewind
Yannick
Kachman
Thiago
1 change: 1 addition & 0 deletions omnisafe/adapter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from omnisafe.adapter.modelbased_adapter import ModelBasedAdapter
from omnisafe.adapter.offline_adapter import OfflineAdapter
from omnisafe.adapter.offpolicy_adapter import OffPolicyAdapter
from omnisafe.adapter.offpolicy_latent_adapter import OffPolicyLatentAdapter
from omnisafe.adapter.online_adapter import OnlineAdapter
from omnisafe.adapter.onpolicy_adapter import OnPolicyAdapter
from omnisafe.adapter.saute_adapter import SauteAdapter
Expand Down
106 changes: 95 additions & 11 deletions omnisafe/adapter/offpolicy_latent_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@


class OffPolicyLatentAdapter(OnlineAdapter):
"""OffPolicy Adapter on Latent Space for OmniSafe.
:class:`OffPolicyLatentAdapter` is used to adapt the vision-based environment to the off-policy
training.
Args:
env_id (str): The environment id.
num_envs (int): The number of environments.
seed (int): The random seed.
cfgs (Config): The configuration.
"""

_current_obs: torch.Tensor
_ep_ret: torch.Tensor
_ep_cost: torch.Tensor
Expand All @@ -54,8 +66,9 @@ def __init__( # pylint: disable=too-many-arguments
seed: int,
cfgs: Config,
) -> None:
"""Initialize a instance of :class:`OffPolicyAdapter`."""
"""Initialize a instance of :class:`OffPolicyLatentAdapter`."""
super().__init__(env_id, num_envs, seed, cfgs)
assert self.action_space.shape
self._observation_concator: ObservationConcator = ObservationConcator(
self._cfgs.algo_cfgs.latent_dim_1 + self._cfgs.algo_cfgs.latent_dim_2,
self.action_space.shape,
Expand All @@ -65,8 +78,9 @@ def __init__( # pylint: disable=too-many-arguments
self._current_obs, _ = self.reset()
self._max_ep_len: int = 1000
self._reset_log()
self.z1 = None
self.z2 = None
self.z1: torch.Tensor = torch.zeros(1)
self.z2: torch.Tensor = torch.zeros(1)
self._initialized: bool = False
self._reset_sequence_queue = False

def _wrapper(
Expand Down Expand Up @@ -135,6 +149,13 @@ def eval_policy( # pylint: disable=too-many-locals
agent: ConstraintActorQCritic,
logger: Logger,
) -> None:
"""Rollout the environment with deterministic agent action.
Args:
episode (int): Number of episodes.
agent (ConstraintActorCritic): Agent.
logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``.
"""
for _ in range(episode):
ep_ret, ep_cost, ep_len = 0.0, 0.0, 0
obs, _ = self._eval_env.reset()
Expand All @@ -161,22 +182,37 @@ def eval_policy( # pylint: disable=too-many-locals
},
)

def pre_process(self, latent_model, concated_obs):
def pre_process(
self,
latent_model: CostLatentModel,
concated_obs: ObservationConcator,
) -> torch.Tensor:
"""Processes the concatenated observations to produce latent representation.
Args:
latent_model (CostLatentModel): The latent model containing the encoder and decoder.
concated_obs (ObservationConcator): An object that encapsulates the concatenated observations.
Returns:
A tensor combining the latent variables z1 and z2, representing the current state of
the system in the latent space.
"""
with torch.no_grad():
feature = latent_model.encoder(concated_obs.last_state)

if self.z2 is None:
if not self._initialized:
z1_mean, z1_std = latent_model.z1_posterior_init(feature)
self.z1 = z1_mean + torch.randn_like(z1_std) * z1_std
z2_mean, z2_std = latent_model.z2_posterior_init(self.z1)
self.z2 = z2_mean + torch.randn_like(z2_std) * z2_std
self._initialized = True
else:
z1_mean, z1_std = latent_model.z1_posterior(
torch.cat([feature.squeeze(), self.z2.squeeze(), concated_obs.last_action], dim=-1)
torch.cat([feature.squeeze(), self.z2.squeeze(), concated_obs.last_action], dim=-1),
)
self.z1 = z1_mean + torch.randn_like(z1_std) * z1_std
z2_mean, z2_std = latent_model.z2_posterior(
torch.cat([self.z1.squeeze(), self.z2.squeeze(), concated_obs.last_action], dim=-1)
torch.cat([self.z1.squeeze(), self.z2.squeeze(), concated_obs.last_action], dim=-1),
)
self.z2 = z2_mean + torch.randn_like(z2_std) * z2_std

Expand All @@ -191,17 +227,29 @@ def rollout( # pylint: disable=too-many-locals
logger: Logger,
use_rand_action: bool,
) -> None:
"""Rollout the environment and store the data in the buffer.
Args:
rollout_step (int): Number of rollout steps.
agent (ConstraintActorCritic): Constraint actor-critic, including actor, reward critic,
and cost critic.
latent_model (CostLatentModel): Latent model, including encoder and decoder.
buffer (VectorOnPolicyBuffer): Vector on-policy buffer.
logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``.
use_rand_action (bool): Whether to use random action.
"""
for step in range(rollout_step):
if not self._reset_sequence_queue:
buffer.reset_sequence_queue(self._current_obs)
self._observation_concator.reset_episode(self._current_obs)
self._reset_sequence_queue = True

if use_rand_action:
act = act = (torch.rand(self.action_space.shape) * 2 - 1).to(self._device) # type: ignore
act = (torch.rand(self.action_space.shape) * 2 - 1).to(self._device) # type: ignore
else:
act = agent.step(
self.pre_process(latent_model, self._observation_concator), deterministic=False
self.pre_process(latent_model, self._observation_concator),
deterministic=False,
)

next_obs, reward, cost, terminated, truncated, info = self.step(act)
Expand All @@ -217,8 +265,9 @@ def rollout( # pylint: disable=too-many-locals
if done:
self._log_metrics(logger, idx)
self._reset_log(idx)
self.z1 = None
self.z2 = None
self.z1 = torch.zeros(1)
self.z2 = torch.zeros(1)
self._initialized = False
self._reset_sequence_queue = False
if 'final_observation' in info:
real_next_obs[idx] = info['final_observation'][idx]
Expand All @@ -239,11 +288,30 @@ def _log_value(
cost: torch.Tensor,
info: dict[str, Any],
) -> None:
"""Log value.
.. note::
OmniSafe uses :class:`RewardNormalizer` wrapper, so the original reward and cost will
be stored in ``info['original_reward']`` and ``info['original_cost']``.
Args:
reward (torch.Tensor): The immediate step reward.
cost (torch.Tensor): The immediate step cost.
info (dict[str, Any]): Some information logged by the environment.
"""
self._ep_ret += info.get('original_reward', reward).cpu()
self._ep_cost += info.get('original_cost', cost).cpu()
self._ep_len += info.get('num_step', 1)

def _log_metrics(self, logger: Logger, idx: int) -> None:
"""Log metrics, including ``EpRet``, ``EpCost``, ``EpLen``.
Args:
logger (Logger): Logger, to log ``EpRet``, ``EpCost``, ``EpLen``.
idx (int): The index of the environment.
"""
if hasattr(self._env, 'spec_log'):
self._env.spec_log(logger)
logger.store(
{
'Metrics/EpRet': self._ep_ret[idx],
Expand All @@ -253,6 +321,12 @@ def _log_metrics(self, logger: Logger, idx: int) -> None:
)

def _reset_log(self, idx: int | None = None) -> None:
"""Reset the episode return, episode cost and episode length.
Args:
idx (int or None, optional): The index of the environment. Defaults to None
(single environment).
"""
if idx is None:
self._ep_ret = torch.zeros(self._env.num_envs)
self._ep_cost = torch.zeros(self._env.num_envs)
Expand All @@ -267,6 +341,16 @@ def reset(
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
"""Reset the environment and returns an initial observation.
Args:
seed (int, optional): The random seed. Defaults to None.
options (dict[str, Any], optional): The options for the environment. Defaults to None.
Returns:
observation: The initial observation of the space.
info: Some information logged by the environment.
"""
obs, info = self._env.reset(seed=seed, options=options)
self._observation_concator.reset_episode(obs)
return obs, info
34 changes: 23 additions & 11 deletions omnisafe/algorithms/off_policy/safe_slac.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 OmniSafe Team. All Rights Reserved.
# Copyright 2024 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,6 +36,16 @@
@registry.register
# pylint: disable-next=too-many-instance-attributes, too-few-public-methods
class SafeSLAC(SACLag):
"""Safe SLAC algorithms for vision-based safe RL tasks.
References:
- Title: Safe Reinforcement Learning From Pixels Using a Stochastic Latent Representation.
- Authors: Yannick Hogewind, Thiago D. Simão, Tal Kachman, Nils Jansen.
- URL: `Safe SLAC <https://openreview.net/pdf?id=b39dQt_uffW>`_
"""

_is_latent_model_init_learned: bool

def _init(self) -> None:
if self._cfgs.algo_cfgs.auto_alpha:
self._target_entropy = -torch.prod(torch.Tensor(self._env.action_space.shape)).item()
Expand All @@ -53,7 +63,7 @@ def _init(self) -> None:

self._lagrange: Lagrange = Lagrange(**self._cfgs.lagrange_cfgs)

self._buf: OffPolicySequenceBuffer = OffPolicySequenceBuffer(
self._buf: OffPolicySequenceBuffer = OffPolicySequenceBuffer( # type: ignore
obs_space=self._env.observation_space,
act_space=self._env.action_space,
size=self._cfgs.algo_cfgs.size,
Expand All @@ -64,7 +74,7 @@ def _init(self) -> None:
self._is_latent_model_init_learned = False

def _init_env(self) -> None:
self._env: OffPolicyLatentAdapter = OffPolicyLatentAdapter(
self._env: OffPolicyLatentAdapter = OffPolicyLatentAdapter( # type: ignore
self._env_id,
self._cfgs.train_cfgs.vector_env_nums,
self._seed,
Expand Down Expand Up @@ -96,6 +106,9 @@ def _init_env(self) -> None:
def _init_model(self) -> None:
self._cfgs.model_cfgs.critic['num_critics'] = 2

assert self._env.observation_space.shape
assert self._env.action_space.shape

self._latent_model = CostLatentModel(
obs_shape=self._env.observation_space.shape,
act_shape=self._env.action_space.shape,
Expand All @@ -114,9 +127,6 @@ def _init_model(self) -> None:
epochs=self._epochs,
).to(self._device)

self._actor_critic = torch.compile(self._actor_critic)
self._latent_model = torch.compile(self._latent_model)

self._latent_model_optimizer = optim.Adam(
self._latent_model.parameters(),
lr=1e-4,
Expand Down Expand Up @@ -218,7 +228,7 @@ def learn(self) -> tuple[float, float, float]:

return ep_ret, ep_cost, ep_len

def _prepare_batch(self, obs_, action_):
def _prepare_batch(self, obs_: torch.Tensor, action_: torch.Tensor) -> tuple[torch.Tensor, ...]:
with torch.no_grad():
feature_ = self._latent_model.encoder(obs_)
z_ = torch.cat(self._latent_model.sample_posterior(feature_, action_)[2:4], dim=-1)
Expand Down Expand Up @@ -266,9 +276,7 @@ def _update(self) -> None:
self._update_actor(obs)
self._actor_critic.polyak_update(self._cfgs.algo_cfgs.polyak)

def _update_latent_model(
self,
):
def _update_latent_model(self) -> None:
data = self._buf.sample_batch(32)
obs, act, reward, cost, done = (
data['obs'],
Expand All @@ -280,7 +288,11 @@ def _update_latent_model(

self._update_latent_count += 1
loss_kld, loss_image, loss_reward, loss_cost = self._latent_model.calculate_loss(
obs, act, reward, done, cost
obs,
act,
reward,
done,
cost,
)

self._latent_model_optimizer.zero_grad()
Expand Down
43 changes: 42 additions & 1 deletion omnisafe/common/buffer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,19 @@ def store(self, **data: torch.Tensor) -> None:
"""


class BaseSequenceBuffer(BaseBuffer):
class BaseSequenceBuffer(ABC):
r"""Abstract base class for sequence buffer.
Attributes:
sequence_queue (SequenceQueue): The queue for storing the data.
Args:
obs_space (OmnisafeSpace): The observation space.
act_space (OmnisafeSpace): The action space.
size (int): The size of the buffer.
device (torch.device): The device of the buffer. Defaults to ``torch.device('cpu')``.
"""

def __init__(
self,
obs_space: OmnisafeSpace,
Expand Down Expand Up @@ -178,8 +190,37 @@ def __init__(
self._observation_shape = obs_space.shape

def add_field(self, name: str, shape: tuple[int, ...], dtype: torch.dtype) -> None:
"""Add a field to the buffer.
Args:
name (str): The name of the field.
shape (tuple of int): The shape of the field.
dtype (torch.dtype): The dtype of the field.
"""
self.data[name] = torch.zeros(
(self._size, self._num_sequences, *shape),
dtype=dtype,
device=self._device,
)

@property
def device(self) -> torch.device:
"""The device of the buffer."""
return self._device

@property
def size(self) -> int:
"""The size of the buffer."""
return self._size

def __len__(self) -> int:
"""Return the length of the buffer."""
return self._size

@abstractmethod
def store(self, **data: torch.Tensor) -> None:
"""Store a transition in the buffer.
Args:
data (torch.Tensor): The data to store.
"""
2 changes: 2 additions & 0 deletions omnisafe/common/buffer/offpolicy_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def sample_batch(self) -> dict[str, torch.Tensor]:


class OffPolicySequenceBuffer(BaseSequenceBuffer):
"""Sequence-based Replay buffer for off-policy algorithms."""

def __init__( # pylint: disable=too-many-arguments
self,
obs_space: OmnisafeSpace,
Expand Down
Loading

0 comments on commit 9fd5fd7

Please sign in to comment.