From 8abab90b17fd8141e2779b9f29b3b7ed741ff779 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Mon, 16 Mar 2026 03:35:58 +0000 Subject: [PATCH 01/17] Support raw obs and chunk action for VLA --- embodichain/agents/rl/algo/grpo.py | 14 ++- .../agents/rl/buffer/standard_buffer.py | 9 +- embodichain/agents/rl/buffer/utils.py | 8 +- .../agents/rl/collector/sync_collector.py | 104 ++++++++++++++++-- embodichain/agents/rl/utils/trainer.py | 76 +++++++++---- 5 files changed, 177 insertions(+), 34 deletions(-) diff --git a/embodichain/agents/rl/algo/grpo.py b/embodichain/agents/rl/algo/grpo.py index 12f7c32f..8f96aa42 100644 --- a/embodichain/agents/rl/algo/grpo.py +++ b/embodichain/agents/rl/algo/grpo.py @@ -112,7 +112,13 @@ def _compute_step_group_advantages( return advantages.view(n_envs, t_steps) * seq_mask def update(self, rollout: TensorDict) -> Dict[str, float]: + raw_obs = getattr(rollout, "raw_obs", None) + chunk_step = getattr(rollout, "chunk_step", None) rollout = rollout.clone() + if raw_obs is not None: + rollout.raw_obs = raw_obs + if chunk_step is not None: + rollout.chunk_step = chunk_step num_envs = rollout.batch_size[0] if num_envs % self.cfg.group_size != 0: raise ValueError( @@ -147,7 +153,9 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: advantages = batch["advantage"].detach() seq_mask_batch = batch["seq_mask"].float() - eval_batch = self.policy.evaluate_actions(batch) + eval_batch = self.policy.evaluate_actions( + batch, rollout=rollout, num_envs=num_envs + ) logprobs = eval_batch["sample_log_prob"] entropy = eval_batch["entropy"] ratio = (logprobs - old_logprobs).exp() @@ -166,7 +174,9 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: if self.ref_policy is not None: with torch.no_grad(): - ref_batch = self.ref_policy.evaluate_actions(batch) + ref_batch = self.ref_policy.evaluate_actions( + batch, rollout=rollout, num_envs=num_envs + ) ref_logprobs = ref_batch["sample_log_prob"] log_ref_over_pi = ref_logprobs - logprobs kl_per = torch.exp(log_ref_over_pi) - log_ref_over_pi - 1.0 diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py index 2df69f86..f14eedbb 100644 --- a/embodichain/agents/rl/buffer/standard_buffer.py +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -40,12 +40,14 @@ def __init__( obs_dim: int, action_dim: int, device: torch.device, + use_raw_obs: bool = False, ) -> None: self.num_envs = num_envs self.rollout_len = rollout_len self.obs_dim = obs_dim self.action_dim = action_dim self.device = device + self.use_raw_obs = use_raw_obs self._rollout = self._allocate_rollout() self._is_full = False @@ -54,6 +56,8 @@ def start_rollout(self) -> TensorDict: if self._is_full: raise RuntimeError("RolloutBuffer already contains a rollout.") self._clear_dynamic_fields() + if self.use_raw_obs: + self._rollout.raw_obs = [None] * (self.rollout_len + 1) return self._rollout def add(self, rollout: TensorDict) -> None: @@ -93,7 +97,7 @@ def is_full(self) -> bool: def _allocate_rollout(self) -> TensorDict: """Preallocate rollout storage with uniform `[num_envs, time + 1]` shape.""" - return TensorDict( + td = TensorDict( { "obs": torch.empty( self.num_envs, @@ -149,12 +153,15 @@ def _allocate_rollout(self) -> TensorDict: batch_size=[self.num_envs, self.rollout_len + 1], device=self.device, ) + return td def _clear_dynamic_fields(self) -> None: """Drop algorithm-added fields before reusing the shared rollout.""" for key in ("advantage", "return", "seq_mask", "seq_return", "entropy"): if key in self._rollout.keys(): del self._rollout[key] + if self.use_raw_obs and hasattr(self._rollout, "raw_obs"): + delattr(self._rollout, "raw_obs") self._reset_padding_slot() def _reset_padding_slot(self) -> None: diff --git a/embodichain/agents/rl/buffer/utils.py b/embodichain/agents/rl/buffer/utils.py index 7c0d265b..655c1e7c 100644 --- a/embodichain/agents/rl/buffer/utils.py +++ b/embodichain/agents/rl/buffer/utils.py @@ -62,6 +62,9 @@ def transition_view(rollout: TensorDict, flatten: bool = False) -> TensorDict: if key in rollout.keys(): td[key] = rollout[key][:, :-1] + if hasattr(rollout, "chunk_step") and rollout.chunk_step is not None: + td["chunk_step"] = rollout.chunk_step + if flatten: return td.reshape(num_envs * time_dim) return td @@ -74,4 +77,7 @@ def iterate_minibatches( total = rollout.batch_size[0] indices = torch.randperm(total, device=device) for start in range(0, total, batch_size): - yield rollout[indices[start : start + batch_size]] + batch_indices = indices[start : start + batch_size] + batch = rollout[batch_indices].clone() + batch["_indices"] = batch_indices + yield batch diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index 16c5b584..f7904788 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -68,20 +68,92 @@ def collect( if self._supports_shared_rollout: self.env.set_rollout_buffer(rollout) - initial_obs = flatten_dict_observation(self.obs_td) - rollout["obs"][:, 0] = initial_obs + use_raw_obs = getattr(self.policy, "use_raw_obs", False) + raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None + + action_chunk_size = getattr(self.policy, "action_chunk_size", 0) + use_action_chunk = ( + getattr(self.policy, "use_action_chunk", False) and action_chunk_size > 0 + ) + cached_chunk = None + + if use_action_chunk: + rollout.chunk_step = torch.zeros( + self.env.num_envs, num_steps, + dtype=torch.long, device=self.device, + ) + + if use_raw_obs and raw_obs_list is not None: + raw_obs_list[0] = self.obs_td + else: + rollout["obs"][:, 0] = flatten_dict_observation(self.obs_td) + for step_idx in range(num_steps): - step_td = TensorDict( - {"obs": rollout["obs"][:, step_idx]}, - batch_size=[rollout.batch_size[0]], - device=self.device, + step_in_chunk = step_idx % action_chunk_size if use_action_chunk else 0 + + # At chunk boundary, or cached invalidated by env reset, we need a new chunk + need_new_chunk = use_action_chunk and ( + step_in_chunk == 0 or cached_chunk is None ) - step_td = self.policy.get_action(step_td) + + if need_new_chunk: + if use_raw_obs and raw_obs_list is not None: + step_td = TensorDict( + {"obs": raw_obs_list[step_idx]}, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + else: + step_td = TensorDict( + {"obs": rollout["obs"][:, step_idx]}, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + step_td = self.policy.get_action(step_td) + cached_chunk = step_td["action_chunk"] + action = step_td["action"] + effective_step_in_chunk = 0 + elif use_action_chunk and cached_chunk is not None: + action = cached_chunk[:, step_in_chunk] + effective_step_in_chunk = step_in_chunk + step_td = TensorDict( + { + "action": action, + "sample_log_prob": torch.zeros( + action.shape[0], device=self.device, dtype=torch.float32 + ), + "value": torch.zeros( + action.shape[0], device=self.device, dtype=torch.float32 + ), + }, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + else: + if use_raw_obs and raw_obs_list is not None: + step_td = TensorDict( + {"obs": raw_obs_list[step_idx]}, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + else: + step_td = TensorDict( + {"obs": rollout["obs"][:, step_idx]}, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + step_td = self.policy.get_action(step_td) + action = step_td["action"] next_obs, reward, terminated, truncated, env_info = self.env.step( - self._to_action_dict(step_td["action"]) + self._to_action_dict(action) ) next_obs_td = dict_to_tensordict(next_obs, self.device) + if use_action_chunk: + rollout.chunk_step[:, step_idx] = effective_step_in_chunk + # Invalidate cached_chunk on any env reset to avoid using old chunk for new episode + if (terminated | truncated).any(): + cached_chunk = None self._write_step( rollout=rollout, step_idx=step_idx, @@ -95,7 +167,10 @@ def collect( terminated=terminated, truncated=truncated, ) - rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td) + if use_raw_obs and raw_obs_list is not None: + raw_obs_list[step_idx + 1] = next_obs_td + else: + rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td) if on_step_callback is not None: on_step_callback(rollout[:, step_idx], env_info) @@ -107,7 +182,12 @@ def collect( def _attach_final_value(self, rollout: TensorDict) -> None: """Populate the bootstrap value for the final observed state.""" - final_obs = rollout["obs"][:, -1] + use_raw_obs = getattr(self.policy, "use_raw_obs", False) + raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None + if use_raw_obs and raw_obs_list is not None: + final_obs = raw_obs_list[-1] + else: + final_obs = rollout["obs"][:, -1] last_next_td = TensorDict( {"obs": final_obs}, batch_size=[rollout.batch_size[0]], @@ -155,8 +235,10 @@ def _write_env_step( def _validate_rollout(self, rollout: TensorDict, num_steps: int) -> None: """Validate rollout layout expected by the collector.""" + use_raw_obs = getattr(self.policy, "use_raw_obs", False) + obs_dim = 1 if use_raw_obs else self.policy.obs_dim expected_shapes = { - "obs": (self.env.num_envs, num_steps + 1, self.policy.obs_dim), + "obs": (self.env.num_envs, num_steps + 1, obs_dim), "action": (self.env.num_envs, num_steps + 1, self.policy.action_dim), "sample_log_prob": (self.env.num_envs, num_steps + 1), "value": (self.env.num_envs, num_steps + 1), diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 4f660232..0e3b3a3a 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -27,6 +27,7 @@ from embodichain.agents.rl.buffer import RolloutBuffer from embodichain.agents.rl.collector import SyncCollector +from embodichain.agents.rl.utils import dict_to_tensordict from embodichain.lab.gym.envs.managers.event_manager import EventManager from .helper import flatten_dict_observation @@ -84,13 +85,23 @@ def __init__( action_dim = getattr(self.policy, "action_dim", None) if obs_dim is None or action_dim is None: raise RuntimeError("Policy must expose obs_dim and action_dim.") + use_raw_obs = getattr(self.policy, "use_raw_obs", False) + action_chunk_size = getattr(self.policy, "action_chunk_size", 0) + use_action_chunk = getattr(self.policy, "use_action_chunk", False) + if use_action_chunk and action_chunk_size > 0: + self.buffer_size = ( + (self.buffer_size + action_chunk_size - 1) + // action_chunk_size + * action_chunk_size + ) self.buffer = RolloutBuffer( num_envs=num_envs, rollout_len=self.buffer_size, - obs_dim=obs_dim, + obs_dim=max(1, obs_dim) if use_raw_obs else obs_dim, action_dim=action_dim, device=self.device, + use_raw_obs=use_raw_obs, ) self.collector = SyncCollector( env=self.env, @@ -245,28 +256,49 @@ def _eval_once(self, num_episodes: int = 5): episode_returns = [] episode_lengths = [] + use_raw_obs = getattr(self.policy, "use_raw_obs", False) + use_action_chunk = getattr(self.policy, "use_action_chunk", False) + action_chunk_size = getattr(self.policy, "action_chunk_size", 1) + for _ in range(num_episodes): - # Reset and initialize episode tracking obs, _ = self.eval_env.reset() - obs = flatten_dict_observation(obs) - num_envs = obs.shape[0] if obs.ndim == 2 else 1 + if use_raw_obs: + obs_td = dict_to_tensordict(obs, self.device) + else: + obs_td = flatten_dict_observation(obs) + num_envs = obs_td.batch_size[0] if hasattr(obs_td, "batch_size") else (obs_td.shape[0] if hasattr(obs_td, "shape") else 1) done_mask = torch.zeros(num_envs, dtype=torch.bool, device=self.device) cumulative_reward = torch.zeros( num_envs, dtype=torch.float32, device=self.device ) step_count = torch.zeros(num_envs, dtype=torch.int32, device=self.device) + cached_chunk = None + step_in_chunk = 0 - # Run episode until all environments complete while not done_mask.all(): - # Get deterministic actions from policy - action_td = TensorDict( - {"obs": obs}, - batch_size=[num_envs], - device=self.device, - ) - action_td = self.policy.get_action(action_td, deterministic=True) - actions = action_td["action"] + if use_action_chunk and (cached_chunk is None or step_in_chunk == 0): + action_td = TensorDict( + {"obs": obs_td}, + batch_size=[num_envs], + device=self.device, + ) + action_td = self.policy.get_action(action_td, deterministic=True) + cached_chunk = action_td.get("action_chunk") + actions = action_td["action"] + step_in_chunk = 0 + elif use_action_chunk and cached_chunk is not None: + actions = cached_chunk[:, step_in_chunk] + else: + action_td = TensorDict( + {"obs": obs_td}, + batch_size=[num_envs], + device=self.device, + ) + action_td = self.policy.get_action(action_td, deterministic=True) + actions = action_td["action"] + + step_in_chunk = (step_in_chunk + 1) % action_chunk_size if use_action_chunk else 0 am = getattr(self.eval_env, "action_manager", None) action_type = ( am.action_type @@ -275,15 +307,17 @@ def _eval_once(self, num_episodes: int = 5): ) action_dict = {action_type: actions} - # Environment step obs, reward, terminated, truncated, info = self.eval_env.step( action_dict ) - obs = ( - flatten_dict_observation(obs) - if isinstance(obs, TensorDict) - else obs - ) + if use_raw_obs: + obs_td = dict_to_tensordict(obs, self.device) + else: + obs_td = ( + flatten_dict_observation(obs) + if isinstance(obs, TensorDict) + else obs + ) # Update statistics only for still-running environments done = terminated | truncated @@ -292,6 +326,10 @@ def _eval_once(self, num_episodes: int = 5): step_count[still_running] += 1 done_mask |= done + # Invalidate cached_chunk on any env reset + if use_action_chunk and done.any(): + cached_chunk = None + # Trigger evaluation events (e.g., video recording) if hasattr(self, "eval_event_manager"): if "interval" in self.eval_event_manager.available_modes: From d5a0684362469b821215d894e4e75da521553611 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Mon, 16 Mar 2026 08:16:00 +0000 Subject: [PATCH 02/17] Extend policy for vla --- embodichain/agents/rl/models/actor_critic.py | 2 +- embodichain/agents/rl/models/actor_only.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/embodichain/agents/rl/models/actor_critic.py b/embodichain/agents/rl/models/actor_critic.py index 32caf0e3..8016ddcd 100644 --- a/embodichain/agents/rl/models/actor_critic.py +++ b/embodichain/agents/rl/models/actor_critic.py @@ -86,7 +86,7 @@ def get_value(self, tensordict: TensorDict) -> TensorDict: tensordict["value"] = self.critic(tensordict["obs"]).squeeze(-1) return tensordict - def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + def evaluate_actions(self, tensordict: TensorDict, **kwargs) -> TensorDict: obs = tensordict["obs"] action = tensordict["action"] dist = self._distribution(obs) diff --git a/embodichain/agents/rl/models/actor_only.py b/embodichain/agents/rl/models/actor_only.py index 3d6d1f78..0f93ce8f 100644 --- a/embodichain/agents/rl/models/actor_only.py +++ b/embodichain/agents/rl/models/actor_only.py @@ -77,7 +77,7 @@ def get_value(self, tensordict: TensorDict) -> TensorDict: ) return tensordict - def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + def evaluate_actions(self, tensordict: TensorDict, **kwargs) -> TensorDict: obs = tensordict["obs"] action = tensordict["action"] dist = self._distribution(obs) From 662a53d7fa2f7328152eeaa65b3449dc26a4b87c Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Mon, 16 Mar 2026 08:17:44 +0000 Subject: [PATCH 03/17] Implement a registry to find factory in dexechain from entry point --- embodichain/agents/rl/vla_registry.py | 70 +++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 embodichain/agents/rl/vla_registry.py diff --git a/embodichain/agents/rl/vla_registry.py b/embodichain/agents/rl/vla_registry.py new file mode 100644 index 00000000..544afb49 --- /dev/null +++ b/embodichain/agents/rl/vla_registry.py @@ -0,0 +1,70 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from importlib.metadata import entry_points +from typing import Any, Callable + +__all__ = [ + "get_vla_backend", + "get_registered_vla_backend_names", + "create_vla_backend", +] + + +_VLA_BACKENDS: dict[str, Callable[..., Any]] = {} +_ENTRY_POINTS_DISCOVERED = False + + +def _discover_entry_points() -> None: + """Discover and register VLA backends from entry_points.""" + global _ENTRY_POINTS_DISCOVERED + if _ENTRY_POINTS_DISCOVERED: + return + _ENTRY_POINTS_DISCOVERED = True + try: + eps = entry_points(group="embodichain.vla_backends") + for ep in eps: + try: + factory = ep.load() + name = str(ep.name).lower() + if name not in _VLA_BACKENDS: + _VLA_BACKENDS[name] = factory + except Exception: + pass + except Exception: + pass + + +def get_vla_backend(name: str) -> Callable[..., Any] | None: + """Get registered backend factory (discovered from entry_points).""" + name = str(name).lower() + if name in _VLA_BACKENDS: + return _VLA_BACKENDS[name] + _discover_entry_points() + return _VLA_BACKENDS.get(name) + + +def get_registered_vla_backend_names() -> list[str]: + """Return registered VLA backend names.""" + _discover_entry_points() + return list(_VLA_BACKENDS.keys()) + + +def create_vla_backend(name: str, **kwargs) -> Any: + """Create VLA backend: returns (model, action_indices, prepare_batch_fn).""" + factory = get_vla_backend(name) + if factory is None: + available = get_registered_vla_backend_names() + raise ValueError( + f"Unknown VLA backend '{name}'. Available: {available}. " + "Ensure dexechain is installed (pip install dexechain)." + ) + return factory(**kwargs) From 7fe96d0853dc74128f01fb9fb240fb4cfba0f3bc Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Mon, 16 Mar 2026 08:18:51 +0000 Subject: [PATCH 04/17] Update registry --- embodichain/agents/rl/vla_registry.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/embodichain/agents/rl/vla_registry.py b/embodichain/agents/rl/vla_registry.py index 544afb49..8d8aec74 100644 --- a/embodichain/agents/rl/vla_registry.py +++ b/embodichain/agents/rl/vla_registry.py @@ -44,7 +44,6 @@ def _discover_entry_points() -> None: def get_vla_backend(name: str) -> Callable[..., Any] | None: - """Get registered backend factory (discovered from entry_points).""" name = str(name).lower() if name in _VLA_BACKENDS: return _VLA_BACKENDS[name] @@ -53,13 +52,11 @@ def get_vla_backend(name: str) -> Callable[..., Any] | None: def get_registered_vla_backend_names() -> list[str]: - """Return registered VLA backend names.""" _discover_entry_points() return list(_VLA_BACKENDS.keys()) def create_vla_backend(name: str, **kwargs) -> Any: - """Create VLA backend: returns (model, action_indices, prepare_batch_fn).""" factory = get_vla_backend(name) if factory is None: available = get_registered_vla_backend_names() From 7b2287f86304278c7c18888625380a88b1d88415 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Mon, 16 Mar 2026 08:19:03 +0000 Subject: [PATCH 05/17] Update registry --- embodichain/agents/rl/vla_registry.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/embodichain/agents/rl/vla_registry.py b/embodichain/agents/rl/vla_registry.py index 8d8aec74..a8d197d3 100644 --- a/embodichain/agents/rl/vla_registry.py +++ b/embodichain/agents/rl/vla_registry.py @@ -3,6 +3,13 @@ # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ---------------------------------------------------------------------------- From af373f22d3f294b13479fc40dda0049dd746f18e Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Mon, 16 Mar 2026 09:15:55 +0000 Subject: [PATCH 06/17] Add vla_policy to wrap vla model --- .../agents/rl/collector/sync_collector.py | 6 +- embodichain/agents/rl/models/__init__.py | 16 +- embodichain/agents/rl/models/vla_policy.py | 249 ++++++++++++++++++ embodichain/agents/rl/utils/trainer.py | 10 +- 4 files changed, 275 insertions(+), 6 deletions(-) create mode 100644 embodichain/agents/rl/models/vla_policy.py diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index f7904788..851f89f8 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -79,8 +79,10 @@ def collect( if use_action_chunk: rollout.chunk_step = torch.zeros( - self.env.num_envs, num_steps, - dtype=torch.long, device=self.device, + self.env.num_envs, + num_steps, + dtype=torch.long, + device=self.device, ) if use_raw_obs and raw_obs_list is not None: diff --git a/embodichain/agents/rl/models/__init__.py b/embodichain/agents/rl/models/__init__.py index 51cf7653..101a5e75 100644 --- a/embodichain/agents/rl/models/__init__.py +++ b/embodichain/agents/rl/models/__init__.py @@ -17,7 +17,7 @@ from __future__ import annotations import inspect -from typing import Dict, Type +from typing import Any, Dict, Optional, Type from gymnasium import spaces import torch @@ -26,6 +26,7 @@ from .actor_only import ActorOnly from .policy import Policy from .mlp import MLP +from .vla_policy import VLAPolicy # In-module policy registry _POLICY_REGISTRY: Dict[str, Type[Policy]] = {} @@ -63,13 +64,16 @@ def build_policy( device: torch.device, actor: torch.nn.Module | None = None, critic: torch.nn.Module | None = None, + env: Optional[Any] = None, ) -> Policy: """Build a policy from config using spaces for extensibility. Built-in MLP policies still resolve flattened `obs_dim` / `action_dim`, while custom policies may accept richer `obs_space` / `action_space` inputs. + For vla_policy, pass env to enable set_env and _load_vla initialization. """ name = policy_block["name"].lower() + if name not in _POLICY_REGISTRY: available = ", ".join(get_registered_policy_names()) raise ValueError( @@ -119,7 +123,13 @@ def build_policy( build_kwargs["actor"] = actor if "critic" in init_params and critic is not None: build_kwargs["critic"] = critic - return policy_cls(**build_kwargs) + if "policy_cfg" in init_params: + build_kwargs["policy_cfg"] = policy_block + policy = policy_cls(**build_kwargs) + if name == "vla_policy" and env is not None: + policy.set_env(env) + policy._load_vla() + return policy def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: @@ -143,10 +153,12 @@ def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: # default registrations register_policy("actor_critic", ActorCritic) register_policy("actor_only", ActorOnly) +register_policy("vla_policy", VLAPolicy) __all__ = [ "ActorCritic", "ActorOnly", + "VLAPolicy", "register_policy", "get_registered_policy_names", "build_policy", diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py new file mode 100644 index 00000000..5fc7aef0 --- /dev/null +++ b/embodichain/agents/rl/models/vla_policy.py @@ -0,0 +1,249 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from tensordict import TensorDict +from embodichain.agents.rl.vla_registry import create_vla_backend +from .policy import Policy + +__all__ = ["VLAPolicy"] + + +class VLAPolicy(Policy): + """Wraps DexForceVLA as Policy for GRPO fine-tuning.""" + + def __init__( + self, + device: torch.device, + policy_cfg: dict[str, Any], + obs_space=None, + action_space=None, + ) -> None: + super().__init__() + self.device = device + self.policy_cfg = dict(policy_cfg) + self.vla_cfg = dict(self.policy_cfg.get("vla", {})) + self.model_path = str(self.vla_cfg.get("model_path", "")) + self.action_horizon = int(self.vla_cfg.get("action_horizon", 32)) + self.gaussian_sigma = float(self.vla_cfg.get("gaussian_sigma", 0.1)) + + if not self.model_path: + raise ValueError("VLAPolicy requires 'policy.vla.model_path'.") + + self._vla_model: nn.Module | None = None + self._action_indices: list[int] | None = None + + if action_space is None: + self.action_dim = 14 + elif isinstance(action_space, int): + self.action_dim = action_space + elif hasattr(action_space, "shape") and len(action_space.shape) > 0: + self.action_dim = int(action_space.shape[-1]) + else: + self.action_dim = 14 + self.obs_dim = 0 # VLA uses raw ob + self.use_raw_obs = True # Tell collector to pass raw ob + + self.use_action_chunk = True + self.action_chunk_size = self.action_horizon + self._env = None + + def set_env(self, env) -> None: + """Set env reference in forward.""" + self._env = env + + def _load_vla(self) -> None: + if self._vla_model is not None: + return + backend = create_vla_backend( + "dexforce_vla", + model_path=self.model_path, + device=self.device, + action_horizon=self.action_horizon, + **{ + k: v + for k, v in self.vla_cfg.items() + if k not in ("backend", "model_path", "action_horizon") + }, + ) + self._vla_model, self._action_indices, self._prepare_batch_fn = backend + + def _vla_chunk_to_env_chunk( + self, action_chunk: torch.Tensor, env=None + ) -> torch.Tensor: + """Convert VLA output (N, T, va_dim) chunk to env format (N, T, env_dim).""" + if self._action_indices is not None: + step = action_chunk[:, :, self._action_indices] + else: + step = action_chunk + + if env is not None: + env_dim = getattr(env.action_space, "shape", (None,)) + if len(env_dim) > 0 and env_dim[-1] is not None: + env_dim = int(env_dim[-1]) + if step.shape[-1] > env_dim: + step = step[..., :env_dim] + elif step.shape[-1] < env_dim: + pad = torch.zeros( + step.shape[0], + step.shape[1], + env_dim - step.shape[-1], + device=step.device, + dtype=step.dtype, + ) + step = torch.cat([step, pad], dim=-1) + return step + + def forward( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + obs = tensordict["obs"] + env = getattr(tensordict, "env", None) + if env is None: + env = getattr(self, "_env", None) + if env is None: + raise ValueError( + "VLAPolicy needs env. Set policy._env or pass env in tensordict." + ) + + self._load_vla() + self._vla_model.eval() + if hasattr(obs, "batch_size") and len(obs.batch_size) > 0: + batch_size = int(obs.batch_size[0]) + elif isinstance(obs, dict) and "robot" in obs and "qpos" in obs["robot"]: + q = obs["robot"]["qpos"] + batch_size = q.shape[0] if hasattr(q, "shape") and len(q.shape) > 0 else 1 + else: + batch_size = 1 + if batch_size == 1: + batch = self._prepare_batch_fn(obs, env) + vla_chunk = self._vla_model.predict_action( + batch, + action_only=True, + inference_horizon=self.action_horizon, + allow_grad=False, + use_fix_aug=False, + ) + action_chunk_env = self._vla_chunk_to_env_chunk(vla_chunk, env=env) + else: + chunks_env = [] + for i in range(batch_size): + obs_i = obs[i] if hasattr(obs, "__getitem__") else obs + batch_i = self._prepare_batch_fn(obs_i, env) + vla_chunk = self._vla_model.predict_action( + batch_i, + action_only=True, + inference_horizon=self.action_horizon, + allow_grad=False, + use_fix_aug=False, + ) + chunk_i = self._vla_chunk_to_env_chunk(vla_chunk, env=env) + chunks_env.append(chunk_i) + action_chunk_env = torch.cat(chunks_env, dim=0) + + action_chunk_env = action_chunk_env.to(self.device, dtype=torch.float32) + action = action_chunk_env[:, 0] + + tensordict["action"] = action + tensordict["sample_log_prob"] = torch.zeros( + action.shape[0], device=self.device, dtype=torch.float32 + ) + tensordict["value"] = torch.zeros( + action.shape[0], device=self.device, dtype=torch.float32 + ) + if self.use_action_chunk: + tensordict["action_chunk"] = action_chunk_env + return tensordict + + def get_value(self, tensordict: TensorDict) -> TensorDict: + b = tensordict.batch_size[0] + tensordict["value"] = torch.zeros(b, device=self.device, dtype=torch.float32) + return tensordict + + def evaluate_actions( + self, tensordict: TensorDict, rollout=None, num_envs=None, **kwargs + ) -> TensorDict: + """Compute log_prob via Gaussian proxy""" + b = tensordict.batch_size[0] + env = getattr(self, "_env", None) + if env is None: + raise ValueError( + "VLAPolicy.evaluate_actions requires env. Call policy.set_env(env)." + ) + + raw_obs = getattr(rollout, "raw_obs", None) + chunk_step = tensordict.get("chunk_step", None) + indices = tensordict.get("_indices", None) + if raw_obs is None or chunk_step is None or indices is None or num_envs is None: + raise ValueError( + "VLAPolicy.evaluate_actions requires rollout.raw_obs, chunk_step, _indices, num_envs. " + "Ensure collector uses use_raw_obs and use_action_chunk, and GRPO passes rollout and num_envs." + ) + + time_dim = len(raw_obs) - 1 + sigma = self.gaussian_sigma + log_probs = [] + self._load_vla() + self._vla_model.eval() + + for i in range(b): + idx = int(indices[i].item()) + env_idx = idx // time_dim + step_idx = idx % time_dim + step_in_chunk = int(chunk_step[i].item()) + # Action came from chunk predicted at chunk start + chunk_start_idx = max(0, step_idx - step_in_chunk) + obs_i = raw_obs[chunk_start_idx][env_idx] + action_gt = tensordict["action"][i] + + batch_i = self._prepare_batch_fn(obs_i, env) + vla_chunk = self._vla_model.predict_action( + batch_i, + action_only=True, + inference_horizon=self.action_horizon, + allow_grad=True, + use_fix_aug=False, + ) + pred_chunk_env = self._vla_chunk_to_env_chunk(vla_chunk, env=env) + pred = pred_chunk_env[0, step_in_chunk] + if pred.shape[-1] != action_gt.shape[-1]: + pred = pred[: action_gt.shape[-1]] + mse = ((action_gt - pred).pow(2)).sum(-1) + log_prob = -0.5 * mse / (sigma * sigma + 1e-8) + log_probs.append(log_prob) + + log_probs = torch.stack(log_probs) + entropy = ( + 0.5 * self.action_dim * (1 + np.log(2 * np.pi) + 2 * np.log(sigma + 1e-8)) + ) + entropy = torch.full((b,), entropy, device=self.device, dtype=torch.float32) + + return TensorDict( + { + "sample_log_prob": log_probs, + "entropy": entropy, + "value": torch.zeros(b, device=self.device, dtype=torch.float32), + }, + batch_size=tensordict.batch_size, + device=self.device, + ) diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 0e3b3a3a..6876868d 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -266,7 +266,11 @@ def _eval_once(self, num_episodes: int = 5): obs_td = dict_to_tensordict(obs, self.device) else: obs_td = flatten_dict_observation(obs) - num_envs = obs_td.batch_size[0] if hasattr(obs_td, "batch_size") else (obs_td.shape[0] if hasattr(obs_td, "shape") else 1) + num_envs = ( + obs_td.batch_size[0] + if hasattr(obs_td, "batch_size") + else (obs_td.shape[0] if hasattr(obs_td, "shape") else 1) + ) done_mask = torch.zeros(num_envs, dtype=torch.bool, device=self.device) cumulative_reward = torch.zeros( @@ -298,7 +302,9 @@ def _eval_once(self, num_episodes: int = 5): action_td = self.policy.get_action(action_td, deterministic=True) actions = action_td["action"] - step_in_chunk = (step_in_chunk + 1) % action_chunk_size if use_action_chunk else 0 + step_in_chunk = ( + (step_in_chunk + 1) % action_chunk_size if use_action_chunk else 0 + ) am = getattr(self.eval_env, "action_manager", None) action_type = ( am.action_type From 1b2810546bd0ecf1320b915599514e417f9ba12c Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Mon, 16 Mar 2026 09:47:12 +0000 Subject: [PATCH 07/17] Update --- embodichain/agents/rl/buffer/utils.py | 2 +- .../agents/rl/collector/sync_collector.py | 4 ++++ embodichain/agents/rl/models/__init__.py | 13 +++++++++-- embodichain/agents/rl/models/vla_policy.py | 22 +++++++++++++++++-- embodichain/agents/rl/utils/trainer.py | 11 +++++----- embodichain/agents/rl/vla_registry.py | 10 +++++---- 6 files changed, 48 insertions(+), 14 deletions(-) diff --git a/embodichain/agents/rl/buffer/utils.py b/embodichain/agents/rl/buffer/utils.py index 655c1e7c..acb61ff8 100644 --- a/embodichain/agents/rl/buffer/utils.py +++ b/embodichain/agents/rl/buffer/utils.py @@ -78,6 +78,6 @@ def iterate_minibatches( indices = torch.randperm(total, device=device) for start in range(0, total, batch_size): batch_indices = indices[start : start + batch_size] - batch = rollout[batch_indices].clone() + batch = rollout[batch_indices] batch["_indices"] = batch_indices yield batch diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index 851f89f8..8f471261 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -87,6 +87,8 @@ def collect( if use_raw_obs and raw_obs_list is not None: raw_obs_list[0] = self.obs_td + # Keep flattened obs populated even when using raw observations. + rollout["obs"][:, 0] = flatten_dict_observation(self.obs_td) else: rollout["obs"][:, 0] = flatten_dict_observation(self.obs_td) @@ -171,6 +173,8 @@ def collect( ) if use_raw_obs and raw_obs_list is not None: raw_obs_list[step_idx + 1] = next_obs_td + # Also keep flattened obs buffer up to date + rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td) else: rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td) diff --git a/embodichain/agents/rl/models/__init__.py b/embodichain/agents/rl/models/__init__.py index 101a5e75..4a77b52c 100644 --- a/embodichain/agents/rl/models/__init__.py +++ b/embodichain/agents/rl/models/__init__.py @@ -70,7 +70,8 @@ def build_policy( Built-in MLP policies still resolve flattened `obs_dim` / `action_dim`, while custom policies may accept richer `obs_space` / `action_space` inputs. - For vla_policy, pass env to enable set_env and _load_vla initialization. + For vla_policy, ``env`` must be provided so that :meth:`set_env` and + ``_load_vla`` can initialize the policy correctly before use. """ name = policy_block["name"].lower() @@ -125,10 +126,18 @@ def build_policy( build_kwargs["critic"] = critic if "policy_cfg" in init_params: build_kwargs["policy_cfg"] = policy_block + policy = policy_cls(**build_kwargs) - if name == "vla_policy" and env is not None: + + if name == "vla_policy": + if env is None: + raise ValueError( + "VLAPolicy requires an 'env' argument to be passed to build_policy " + "so that set_env and _load_vla can be called before use." + ) policy.set_env(env) policy._load_vla() + return policy diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py index 5fc7aef0..a1fc7c1a 100644 --- a/embodichain/agents/rl/models/vla_policy.py +++ b/embodichain/agents/rl/models/vla_policy.py @@ -75,8 +75,9 @@ def set_env(self, env) -> None: def _load_vla(self) -> None: if self._vla_model is not None: return + backend_name = self.vla_cfg.get("backend", "dexforce_vla") backend = create_vla_backend( - "dexforce_vla", + backend_name, model_path=self.model_path, device=self.device, action_horizon=self.action_horizon, @@ -147,8 +148,25 @@ def forward( action_chunk_env = self._vla_chunk_to_env_chunk(vla_chunk, env=env) else: chunks_env = [] + + def _index_obs_dict(obs_dict: dict, idx: int) -> dict: + indexed: dict[str, Any] = {} + for key, value in obs_dict.items(): + if isinstance(value, dict): + indexed[key] = _index_obs_dict(value, idx) + elif hasattr(value, "__getitem__"): + indexed[key] = value[idx] + else: + indexed[key] = value + return indexed + for i in range(batch_size): - obs_i = obs[i] if hasattr(obs, "__getitem__") else obs + if isinstance(obs, dict): + obs_i = _index_obs_dict(obs, i) + elif hasattr(obs, "__getitem__"): + obs_i = obs[i] + else: + obs_i = obs batch_i = self._prepare_batch_fn(obs_i, env) vla_chunk = self._vla_model.predict_action( batch_i, diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 6876868d..5102acad 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -89,11 +89,12 @@ def __init__( action_chunk_size = getattr(self.policy, "action_chunk_size", 0) use_action_chunk = getattr(self.policy, "use_action_chunk", False) if use_action_chunk and action_chunk_size > 0: - self.buffer_size = ( - (self.buffer_size + action_chunk_size - 1) - // action_chunk_size - * action_chunk_size - ) + if self.buffer_size % action_chunk_size != 0: + raise ValueError( + "Trainer buffer_size must be a multiple of policy.action_chunk_size " + f"when use_action_chunk is True (buffer_size={self.buffer_size}, " + f"action_chunk_size={action_chunk_size})." + ) self.buffer = RolloutBuffer( num_envs=num_envs, diff --git a/embodichain/agents/rl/vla_registry.py b/embodichain/agents/rl/vla_registry.py index a8d197d3..f8c442d6 100644 --- a/embodichain/agents/rl/vla_registry.py +++ b/embodichain/agents/rl/vla_registry.py @@ -18,6 +18,7 @@ from importlib.metadata import entry_points from typing import Any, Callable +import logging __all__ = [ "get_vla_backend", @@ -28,6 +29,7 @@ _VLA_BACKENDS: dict[str, Callable[..., Any]] = {} _ENTRY_POINTS_DISCOVERED = False +_LOGGER = logging.getLogger(__name__) def _discover_entry_points() -> None: @@ -35,7 +37,6 @@ def _discover_entry_points() -> None: global _ENTRY_POINTS_DISCOVERED if _ENTRY_POINTS_DISCOVERED: return - _ENTRY_POINTS_DISCOVERED = True try: eps = entry_points(group="embodichain.vla_backends") for ep in eps: @@ -45,9 +46,10 @@ def _discover_entry_points() -> None: if name not in _VLA_BACKENDS: _VLA_BACKENDS[name] = factory except Exception: - pass + _LOGGER.exception("Failed to load VLA backend entry point %r", ep) + _ENTRY_POINTS_DISCOVERED = True except Exception: - pass + _LOGGER.exception("Failed to discover VLA backend entry points") def get_vla_backend(name: str) -> Callable[..., Any] | None: @@ -69,6 +71,6 @@ def create_vla_backend(name: str, **kwargs) -> Any: available = get_registered_vla_backend_names() raise ValueError( f"Unknown VLA backend '{name}'. Available: {available}. " - "Ensure dexechain is installed (pip install dexechain)." + "Ensure dexechain is installed." ) return factory(**kwargs) From 79e584064f1ff0421d78ee28cbb2ef9bd87bda20 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Mon, 16 Mar 2026 10:28:23 +0000 Subject: [PATCH 08/17] WIP --- .../agents/rl/buffer/standard_buffer.py | 2 + embodichain/agents/rl/buffer/utils.py | 4 +- .../agents/rl/collector/sync_collector.py | 29 ++++++++++++-- embodichain/agents/rl/models/__init__.py | 6 +-- embodichain/agents/rl/models/vla_policy.py | 23 +---------- embodichain/agents/rl/train.py | 6 ++- embodichain/agents/rl/utils/trainer.py | 38 ++++++++++++++----- embodichain/agents/rl/vla_registry.py | 11 +++--- 8 files changed, 70 insertions(+), 49 deletions(-) diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py index f14eedbb..a0dad10a 100644 --- a/embodichain/agents/rl/buffer/standard_buffer.py +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -162,6 +162,8 @@ def _clear_dynamic_fields(self) -> None: del self._rollout[key] if self.use_raw_obs and hasattr(self._rollout, "raw_obs"): delattr(self._rollout, "raw_obs") + if hasattr(self._rollout, "chunk_step"): + delattr(self._rollout, "chunk_step") self._reset_padding_slot() def _reset_padding_slot(self) -> None: diff --git a/embodichain/agents/rl/buffer/utils.py b/embodichain/agents/rl/buffer/utils.py index acb61ff8..1e82c7e4 100644 --- a/embodichain/agents/rl/buffer/utils.py +++ b/embodichain/agents/rl/buffer/utils.py @@ -75,9 +75,9 @@ def iterate_minibatches( ) -> Iterator[TensorDict]: """Yield shuffled minibatches from a flattened rollout.""" total = rollout.batch_size[0] - indices = torch.randperm(total, device=device) + indices = torch.randperm(total) for start in range(0, total, batch_size): batch_indices = indices[start : start + batch_size] - batch = rollout[batch_indices] + batch = rollout[batch_indices].clone() batch["_indices"] = batch_indices yield batch diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index 8f471261..6e81c0f9 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -71,6 +71,30 @@ def collect( use_raw_obs = getattr(self.policy, "use_raw_obs", False) raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None + if use_raw_obs: + if raw_obs_list is None: + raise ValueError( + "Policy requires raw observations, " + "but the provided rollout TensorDict has no 'raw_obs' buffer. " + "Create the rollout via RolloutBuffer or " + "start_rollout so that 'raw_obs' is allocated." + ) + try: + raw_obs_len = len(raw_obs_list) + except TypeError: + raise ValueError( + "Rollout field 'raw_obs' must be an indexable sequence of length " + f"{num_steps + 1} when policy.use_raw_obs=True." + ) + expected_len = num_steps + 1 + if raw_obs_len != expected_len: + raise ValueError( + "Rollout 'raw_obs' length mismatch: " + f"expected {expected_len} (num_steps + 1), got {raw_obs_len}. " + "Ensure the rollout was created with use_raw_obs=True and " + "its time dimension matches the requested num_steps." + ) + action_chunk_size = getattr(self.policy, "action_chunk_size", 0) use_action_chunk = ( getattr(self.policy, "use_action_chunk", False) and action_chunk_size > 0 @@ -87,7 +111,6 @@ def collect( if use_raw_obs and raw_obs_list is not None: raw_obs_list[0] = self.obs_td - # Keep flattened obs populated even when using raw observations. rollout["obs"][:, 0] = flatten_dict_observation(self.obs_td) else: rollout["obs"][:, 0] = flatten_dict_observation(self.obs_td) @@ -173,7 +196,6 @@ def collect( ) if use_raw_obs and raw_obs_list is not None: raw_obs_list[step_idx + 1] = next_obs_td - # Also keep flattened obs buffer up to date rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td) else: rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td) @@ -241,8 +263,7 @@ def _write_env_step( def _validate_rollout(self, rollout: TensorDict, num_steps: int) -> None: """Validate rollout layout expected by the collector.""" - use_raw_obs = getattr(self.policy, "use_raw_obs", False) - obs_dim = 1 if use_raw_obs else self.policy.obs_dim + obs_dim = rollout["obs"].shape[-1] expected_shapes = { "obs": (self.env.num_envs, num_steps + 1, obs_dim), "action": (self.env.num_envs, num_steps + 1, self.policy.action_dim), diff --git a/embodichain/agents/rl/models/__init__.py b/embodichain/agents/rl/models/__init__.py index 4a77b52c..46231005 100644 --- a/embodichain/agents/rl/models/__init__.py +++ b/embodichain/agents/rl/models/__init__.py @@ -70,8 +70,7 @@ def build_policy( Built-in MLP policies still resolve flattened `obs_dim` / `action_dim`, while custom policies may accept richer `obs_space` / `action_space` inputs. - For vla_policy, ``env`` must be provided so that :meth:`set_env` and - ``_load_vla`` can initialize the policy correctly before use. + For vla_policy, pass env to enable set_env and _load_vla initialization. """ name = policy_block["name"].lower() @@ -126,9 +125,7 @@ def build_policy( build_kwargs["critic"] = critic if "policy_cfg" in init_params: build_kwargs["policy_cfg"] = policy_block - policy = policy_cls(**build_kwargs) - if name == "vla_policy": if env is None: raise ValueError( @@ -137,7 +134,6 @@ def build_policy( ) policy.set_env(env) policy._load_vla() - return policy diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py index a1fc7c1a..a3c1648f 100644 --- a/embodichain/agents/rl/models/vla_policy.py +++ b/embodichain/agents/rl/models/vla_policy.py @@ -16,7 +16,6 @@ from __future__ import annotations -from pathlib import Path from typing import Any import numpy as np @@ -75,9 +74,8 @@ def set_env(self, env) -> None: def _load_vla(self) -> None: if self._vla_model is not None: return - backend_name = self.vla_cfg.get("backend", "dexforce_vla") backend = create_vla_backend( - backend_name, + "dexforce_vla", model_path=self.model_path, device=self.device, action_horizon=self.action_horizon, @@ -148,25 +146,8 @@ def forward( action_chunk_env = self._vla_chunk_to_env_chunk(vla_chunk, env=env) else: chunks_env = [] - - def _index_obs_dict(obs_dict: dict, idx: int) -> dict: - indexed: dict[str, Any] = {} - for key, value in obs_dict.items(): - if isinstance(value, dict): - indexed[key] = _index_obs_dict(value, idx) - elif hasattr(value, "__getitem__"): - indexed[key] = value[idx] - else: - indexed[key] = value - return indexed - for i in range(batch_size): - if isinstance(obs, dict): - obs_i = _index_obs_dict(obs, i) - elif hasattr(obs, "__getitem__"): - obs_i = obs[i] - else: - obs_i = obs + obs_i = obs[i] if hasattr(obs, "__getitem__") else obs batch_i = self._prepare_batch_fn(obs_i, env) vla_chunk = self._vla_model.predict_action( batch_i, diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py index bf08c746..b4678d26 100644 --- a/embodichain/agents/rl/train.py +++ b/embodichain/agents/rl/train.py @@ -230,7 +230,11 @@ def train_from_config(config_path: str): ) else: policy = build_policy( - policy_block, env.observation_space, env.action_space, device + policy_block, + env.observation_space, + env.action_space, + device, + env=env, ) # Build Algorithm via factory diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 5102acad..1ec97d82 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -89,17 +89,30 @@ def __init__( action_chunk_size = getattr(self.policy, "action_chunk_size", 0) use_action_chunk = getattr(self.policy, "use_action_chunk", False) if use_action_chunk and action_chunk_size > 0: - if self.buffer_size % action_chunk_size != 0: - raise ValueError( - "Trainer buffer_size must be a multiple of policy.action_chunk_size " - f"when use_action_chunk is True (buffer_size={self.buffer_size}, " - f"action_chunk_size={action_chunk_size})." + self.buffer_size = ( + (self.buffer_size + action_chunk_size - 1) + // action_chunk_size + * action_chunk_size + ) + + if use_raw_obs: + try: + reset_out = self.env.reset() + sample_obs = reset_out[0] if isinstance(reset_out, tuple) else reset_out + obs_td = dict_to_tensordict(sample_obs, self.device) + flat_obs = flatten_dict_observation(obs_td) + obs_dim = int( + flat_obs.shape[-1] + if isinstance(flat_obs, torch.Tensor) + else np.asarray(flat_obs).shape[-1] ) + except Exception: + obs_dim = max(1, obs_dim) self.buffer = RolloutBuffer( num_envs=num_envs, rollout_len=self.buffer_size, - obs_dim=max(1, obs_dim) if use_raw_obs else obs_dim, + obs_dim=obs_dim, action_dim=action_dim, device=self.device, use_raw_obs=use_raw_obs, @@ -260,6 +273,7 @@ def _eval_once(self, num_episodes: int = 5): use_raw_obs = getattr(self.policy, "use_raw_obs", False) use_action_chunk = getattr(self.policy, "use_action_chunk", False) action_chunk_size = getattr(self.policy, "action_chunk_size", 1) + effective_use_action_chunk = use_action_chunk and action_chunk_size > 0 for _ in range(num_episodes): obs, _ = self.eval_env.reset() @@ -282,7 +296,9 @@ def _eval_once(self, num_episodes: int = 5): step_in_chunk = 0 while not done_mask.all(): - if use_action_chunk and (cached_chunk is None or step_in_chunk == 0): + if effective_use_action_chunk and ( + cached_chunk is None or step_in_chunk == 0 + ): action_td = TensorDict( {"obs": obs_td}, batch_size=[num_envs], @@ -292,7 +308,7 @@ def _eval_once(self, num_episodes: int = 5): cached_chunk = action_td.get("action_chunk") actions = action_td["action"] step_in_chunk = 0 - elif use_action_chunk and cached_chunk is not None: + elif effective_use_action_chunk and cached_chunk is not None: actions = cached_chunk[:, step_in_chunk] else: action_td = TensorDict( @@ -304,7 +320,9 @@ def _eval_once(self, num_episodes: int = 5): actions = action_td["action"] step_in_chunk = ( - (step_in_chunk + 1) % action_chunk_size if use_action_chunk else 0 + (step_in_chunk + 1) % action_chunk_size + if effective_use_action_chunk + else 0 ) am = getattr(self.eval_env, "action_manager", None) action_type = ( @@ -334,7 +352,7 @@ def _eval_once(self, num_episodes: int = 5): done_mask |= done # Invalidate cached_chunk on any env reset - if use_action_chunk and done.any(): + if effective_use_action_chunk and done.any(): cached_chunk = None # Trigger evaluation events (e.g., video recording) diff --git a/embodichain/agents/rl/vla_registry.py b/embodichain/agents/rl/vla_registry.py index f8c442d6..e9f13546 100644 --- a/embodichain/agents/rl/vla_registry.py +++ b/embodichain/agents/rl/vla_registry.py @@ -18,7 +18,6 @@ from importlib.metadata import entry_points from typing import Any, Callable -import logging __all__ = [ "get_vla_backend", @@ -29,7 +28,6 @@ _VLA_BACKENDS: dict[str, Callable[..., Any]] = {} _ENTRY_POINTS_DISCOVERED = False -_LOGGER = logging.getLogger(__name__) def _discover_entry_points() -> None: @@ -37,6 +35,7 @@ def _discover_entry_points() -> None: global _ENTRY_POINTS_DISCOVERED if _ENTRY_POINTS_DISCOVERED: return + _ENTRY_POINTS_DISCOVERED = True try: eps = entry_points(group="embodichain.vla_backends") for ep in eps: @@ -46,10 +45,9 @@ def _discover_entry_points() -> None: if name not in _VLA_BACKENDS: _VLA_BACKENDS[name] = factory except Exception: - _LOGGER.exception("Failed to load VLA backend entry point %r", ep) - _ENTRY_POINTS_DISCOVERED = True + pass except Exception: - _LOGGER.exception("Failed to discover VLA backend entry points") + pass def get_vla_backend(name: str) -> Callable[..., Any] | None: @@ -71,6 +69,7 @@ def create_vla_backend(name: str, **kwargs) -> Any: available = get_registered_vla_backend_names() raise ValueError( f"Unknown VLA backend '{name}'. Available: {available}. " - "Ensure dexechain is installed." + "Ensure a package providing the 'embodichain.vla_backends' entry point " + "group is installed." ) return factory(**kwargs) From 10b2b7b68da9906eb9e9279f3fa9d940fb0f9068 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Mon, 30 Mar 2026 03:26:07 +0000 Subject: [PATCH 09/17] wip --- embodichain/agents/rl/algo/grpo.py | 7 ++--- embodichain/agents/rl/algo/ppo.py | 1 - .../agents/rl/collector/sync_collector.py | 31 ++++++++++++------- embodichain/agents/rl/models/vla_policy.py | 8 ++--- 4 files changed, 25 insertions(+), 22 deletions(-) diff --git a/embodichain/agents/rl/algo/grpo.py b/embodichain/agents/rl/algo/grpo.py index 8f96aa42..18f6c0a2 100644 --- a/embodichain/agents/rl/algo/grpo.py +++ b/embodichain/agents/rl/algo/grpo.py @@ -114,7 +114,6 @@ def _compute_step_group_advantages( def update(self, rollout: TensorDict) -> Dict[str, float]: raw_obs = getattr(rollout, "raw_obs", None) chunk_step = getattr(rollout, "chunk_step", None) - rollout = rollout.clone() if raw_obs is not None: rollout.raw_obs = raw_obs if chunk_step is not None: @@ -153,9 +152,7 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: advantages = batch["advantage"].detach() seq_mask_batch = batch["seq_mask"].float() - eval_batch = self.policy.evaluate_actions( - batch, rollout=rollout, num_envs=num_envs - ) + eval_batch = self.policy.evaluate_actions(batch, rollout=rollout) logprobs = eval_batch["sample_log_prob"] entropy = eval_batch["entropy"] ratio = (logprobs - old_logprobs).exp() @@ -175,7 +172,7 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: if self.ref_policy is not None: with torch.no_grad(): ref_batch = self.ref_policy.evaluate_actions( - batch, rollout=rollout, num_envs=num_envs + batch, rollout=rollout ) ref_logprobs = ref_batch["sample_log_prob"] log_ref_over_pi = ref_logprobs - logprobs diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index e33ee5b3..08bb176c 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -48,7 +48,6 @@ def __init__(self, cfg: PPOCfg, policy): def update(self, rollout: TensorDict) -> Dict[str, float]: """Update the policy using a collected rollout.""" - rollout = rollout.clone() compute_gae(rollout, gamma=self.cfg.gamma, gae_lambda=self.cfg.gae_lambda) flat_rollout = transition_view(rollout, flatten=True) diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index 5c63ed43..2cca25ac 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -22,6 +22,7 @@ from tensordict import TensorDict from embodichain.agents.rl.utils import dict_to_tensordict, flatten_dict_observation +from embodichain.utils import logger from .base import BaseCollector __all__ = ["SyncCollector"] @@ -56,13 +57,15 @@ def collect( self.obs_td = self._reset_env() if rollout is None: - raise ValueError( - "SyncCollector.collect() requires a preallocated rollout TensorDict." + logger.log_error( + "SyncCollector.collect() requires a preallocated rollout TensorDict.", + ValueError, ) if tuple(rollout.batch_size) != (self.env.num_envs, num_steps + 1): - raise ValueError( + logger.log_error( "Preallocated rollout batch size mismatch: " - f"expected ({self.env.num_envs}, {num_steps + 1}), got {tuple(rollout.batch_size)}." + f"expected ({self.env.num_envs}, {num_steps + 1}), got {tuple(rollout.batch_size)}.", + ValueError, ) self._validate_rollout(rollout, num_steps) if self._supports_shared_rollout: @@ -73,26 +76,29 @@ def collect( if use_raw_obs: if raw_obs_list is None: - raise ValueError( + logger.log_error( "Policy requires raw observations, " "but the provided rollout TensorDict has no 'raw_obs' buffer. " "Create the rollout via RolloutBuffer or " - "start_rollout so that 'raw_obs' is allocated." + "start_rollout so that 'raw_obs' is allocated.", + ValueError, ) try: raw_obs_len = len(raw_obs_list) except TypeError: - raise ValueError( + logger.log_error( "Rollout field 'raw_obs' must be an indexable sequence of length " - f"{num_steps + 1} when policy.use_raw_obs=True." + f"{num_steps + 1} when policy.use_raw_obs=True.", + ValueError, ) expected_len = num_steps + 1 if raw_obs_len != expected_len: - raise ValueError( + logger.log_error( "Rollout 'raw_obs' length mismatch: " f"expected {expected_len} (num_steps + 1), got {raw_obs_len}. " "Ensure the rollout was created with use_raw_obs=True and " - "its time dimension matches the requested num_steps." + "its time dimension matches the requested num_steps.", + ValueError, ) action_chunk_size = getattr(self.policy, "action_chunk_size", 0) @@ -277,7 +283,8 @@ def _validate_rollout(self, rollout: TensorDict, num_steps: int) -> None: for key, expected_shape in expected_shapes.items(): actual_shape = tuple(rollout[key].shape) if actual_shape != expected_shape: - raise ValueError( + logger.log_error( f"Preallocated rollout field '{key}' shape mismatch: " - f"expected {expected_shape}, got {actual_shape}." + f"expected {expected_shape}, got {actual_shape}.", + ValueError, ) diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py index a3c1648f..ccac90f3 100644 --- a/embodichain/agents/rl/models/vla_policy.py +++ b/embodichain/agents/rl/models/vla_policy.py @@ -180,7 +180,7 @@ def get_value(self, tensordict: TensorDict) -> TensorDict: return tensordict def evaluate_actions( - self, tensordict: TensorDict, rollout=None, num_envs=None, **kwargs + self, tensordict: TensorDict, rollout=None, **kwargs ) -> TensorDict: """Compute log_prob via Gaussian proxy""" b = tensordict.batch_size[0] @@ -193,10 +193,10 @@ def evaluate_actions( raw_obs = getattr(rollout, "raw_obs", None) chunk_step = tensordict.get("chunk_step", None) indices = tensordict.get("_indices", None) - if raw_obs is None or chunk_step is None or indices is None or num_envs is None: + if raw_obs is None or chunk_step is None or indices is None: raise ValueError( - "VLAPolicy.evaluate_actions requires rollout.raw_obs, chunk_step, _indices, num_envs. " - "Ensure collector uses use_raw_obs and use_action_chunk, and GRPO passes rollout and num_envs." + "VLAPolicy.evaluate_actions requires rollout.raw_obs, chunk_step, and _indices. " + "Ensure collector uses use_raw_obs and use_action_chunk, and GRPO passes rollout." ) time_dim = len(raw_obs) - 1 From 712e730853b5cda52cf570ff3d8d6663486a1c94 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Mon, 30 Mar 2026 10:28:38 +0000 Subject: [PATCH 10/17] fix: device conflict when randomize --- .../gym/envs/managers/randomization/visual.py | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/embodichain/lab/gym/envs/managers/randomization/visual.py b/embodichain/lab/gym/envs/managers/randomization/visual.py index 66d3d6fb..62dcfcd5 100644 --- a/embodichain/lab/gym/envs/managers/randomization/visual.py +++ b/embodichain/lab/gym/envs/managers/randomization/visual.py @@ -183,9 +183,10 @@ def randomize_camera_extrinsics( ).repeat(num_instance, 1) if pos_range: random_value = sample_uniform( - lower=torch.tensor(pos_range[0]), - upper=torch.tensor(pos_range[1]), + lower=torch.tensor(pos_range[0], device=env.device), + upper=torch.tensor(pos_range[1], device=env.device), size=(num_instance, 3), + device=env.device, ) new_pose[:, :3] += random_value if euler_range: @@ -198,9 +199,10 @@ def randomize_camera_extrinsics( init_euler = torch.stack(euler_xyz_from_quat(init_quat_np), dim=1) # 2. Sample perturbation for euler angles random_value = sample_uniform( - lower=torch.tensor(euler_range[0]), - upper=torch.tensor(euler_range[1]), + lower=torch.tensor(euler_range[0], device=env.device), + upper=torch.tensor(euler_range[1], device=env.device), size=(num_instance, 3), + device=env.device, ) # 3. Add perturbation to each environment and convert back to quaternion roll, pitch, yaw = (init_euler + random_value).unbind(dim=1) @@ -229,9 +231,10 @@ def randomize_camera_extrinsics( if eye_range: eye_delta = sample_uniform( - lower=torch.tensor(eye_range[0]), - upper=torch.tensor(eye_range[1]), + lower=torch.tensor(eye_range[0], device=env.device), + upper=torch.tensor(eye_range[1], device=env.device), size=(num_instance, 3), + device=env.device, ) new_eye = init_eye + eye_delta else: @@ -239,9 +242,10 @@ def randomize_camera_extrinsics( if target_range: target_delta = sample_uniform( - lower=torch.tensor(target_range[0]), - upper=torch.tensor(target_range[1]), + lower=torch.tensor(target_range[0], device=env.device), + upper=torch.tensor(target_range[1], device=env.device), size=(num_instance, 3), + device=env.device, ) new_target = init_target + target_delta else: @@ -249,9 +253,10 @@ def randomize_camera_extrinsics( if up_range: up_delta = sample_uniform( - lower=torch.tensor(up_range[0]), - upper=torch.tensor(up_range[1]), + lower=torch.tensor(up_range[0], device=env.device), + upper=torch.tensor(up_range[1], device=env.device), size=(num_instance, 3), + device=env.device, ) new_up = init_up + up_delta else: From 6aad3f8398d222bcf045ee0e32b2adc4ac91c6ae Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Tue, 31 Mar 2026 06:09:18 +0000 Subject: [PATCH 11/17] Support action chunk --- .../agents/rl/buffer/standard_buffer.py | 9 +- .../agents/rl/collector/sync_collector.py | 113 ++++++++++++++++-- embodichain/agents/rl/models/vla_policy.py | 1 + embodichain/agents/rl/utils/trainer.py | 80 ++++++++++++- 4 files changed, 191 insertions(+), 12 deletions(-) diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py index f0f6f5a5..fef94d51 100644 --- a/embodichain/agents/rl/buffer/standard_buffer.py +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -161,7 +161,14 @@ def _allocate_rollout(self) -> TensorDict: def _clear_dynamic_fields(self) -> None: """Drop algorithm-added fields before reusing the shared rollout.""" - for key in ("advantage", "return", "seq_mask", "seq_return", "entropy"): + for key in ( + "advantage", + "return", + "seq_mask", + "seq_return", + "entropy", + "step_repeat", + ): if key in self._rollout.keys(): del self._rollout[key] if self.use_raw_obs and hasattr(self._rollout, "raw_obs"): diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index 2cca25ac..ba3acc24 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -105,7 +105,10 @@ def collect( use_action_chunk = ( getattr(self.policy, "use_action_chunk", False) and action_chunk_size > 0 ) + # Execute a full predicted action chunk inside one logical rollout step. + execute_full_chunk = bool(getattr(self.policy, "execute_full_chunk", False)) cached_chunk = None + chunk_cursor = 0 if use_action_chunk: rollout.chunk_step = torch.zeros( @@ -114,6 +117,13 @@ def collect( dtype=torch.long, device=self.device, ) + rollout["step_repeat"] = torch.ones( + self.env.num_envs, + num_steps + 1, + dtype=torch.float32, + device=self.device, + ) + rollout["step_repeat"][:, -1] = 0.0 if use_raw_obs and raw_obs_list is not None: raw_obs_list[0] = self.obs_td @@ -122,11 +132,98 @@ def collect( rollout["obs"][:, 0] = flatten_dict_observation(self.obs_td) for step_idx in range(num_steps): - step_in_chunk = step_idx % action_chunk_size if use_action_chunk else 0 + if execute_full_chunk and use_action_chunk: + if use_raw_obs and raw_obs_list is not None: + step_td = TensorDict( + {"obs": raw_obs_list[step_idx]}, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + else: + step_td = TensorDict( + {"obs": rollout["obs"][:, step_idx]}, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + step_td = self.policy.get_action(step_td) + chunk = step_td.get("action_chunk") + if chunk is None: + logger.log_error( + "execute_full_chunk=True requires policy to provide 'action_chunk'.", + ValueError, + ) + + reward_sum = torch.zeros( + self.env.num_envs, dtype=torch.float32, device=self.device + ) + terminated = torch.zeros( + self.env.num_envs, dtype=torch.bool, device=self.device + ) + truncated = torch.zeros( + self.env.num_envs, dtype=torch.bool, device=self.device + ) + env_info = {} + next_obs_td = None + + executed_substeps = 0 + # Execute the whole chunk sequentially + for sub_idx in range(action_chunk_size): + sub_action = chunk[:, sub_idx] + next_obs, reward, term_i, trunc_i, env_info = self.env.step( + self._to_action_dict(sub_action) + ) + executed_substeps += 1 + next_obs_td = dict_to_tensordict(next_obs, self.device) + reward_sum += reward.to(self.device).float() + terminated |= term_i.to(self.device) + truncated |= trunc_i.to(self.device) + + # Stop chunk execution when any env reaches terminal/truncated. + if (term_i | trunc_i).any(): + break + + if next_obs_td is None: + logger.log_error( + "Chunk execution produced no environment transition.", + RuntimeError, + ) + + if use_action_chunk: + rollout.chunk_step[:, step_idx] = 0 + rollout["step_repeat"][:, step_idx] = float(executed_substeps) + + self._write_step( + rollout=rollout, + step_idx=step_idx, + step_td=step_td, + ) + if not self._supports_shared_rollout: + self._write_env_step( + rollout=rollout, + step_idx=step_idx, + reward=reward_sum, + terminated=terminated, + truncated=truncated, + ) + if use_raw_obs and raw_obs_list is not None: + raw_obs_list[step_idx + 1] = next_obs_td + rollout["obs"][:, step_idx + 1] = flatten_dict_observation( + next_obs_td + ) + else: + rollout["obs"][:, step_idx + 1] = flatten_dict_observation( + next_obs_td + ) + + if on_step_callback is not None: + on_step_callback(rollout[:, step_idx], env_info) + + self.obs_td = next_obs_td + continue - # At chunk boundary, or cached invalidated by env reset, we need a new chunk + # Execute a predicted chunk sequentially need_new_chunk = use_action_chunk and ( - step_in_chunk == 0 or cached_chunk is None + cached_chunk is None or chunk_cursor >= action_chunk_size ) if need_new_chunk: @@ -146,9 +243,10 @@ def collect( cached_chunk = step_td["action_chunk"] action = step_td["action"] effective_step_in_chunk = 0 + chunk_cursor = 1 elif use_action_chunk and cached_chunk is not None: - action = cached_chunk[:, step_in_chunk] - effective_step_in_chunk = step_in_chunk + action = cached_chunk[:, chunk_cursor] + effective_step_in_chunk = chunk_cursor step_td = TensorDict( { "action": action, @@ -162,6 +260,7 @@ def collect( batch_size=[rollout.batch_size[0]], device=self.device, ) + chunk_cursor += 1 else: if use_raw_obs and raw_obs_list is not None: step_td = TensorDict( @@ -184,9 +283,7 @@ def collect( next_obs_td = dict_to_tensordict(next_obs, self.device) if use_action_chunk: rollout.chunk_step[:, step_idx] = effective_step_in_chunk - # Invalidate cached_chunk on any env reset to avoid using old chunk for new episode - if (terminated | truncated).any(): - cached_chunk = None + rollout["step_repeat"][:, step_idx] = 1.0 self._write_step( rollout=rollout, step_idx=step_idx, diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py index ccac90f3..ad228f99 100644 --- a/embodichain/agents/rl/models/vla_policy.py +++ b/embodichain/agents/rl/models/vla_policy.py @@ -65,6 +65,7 @@ def __init__( self.use_action_chunk = True self.action_chunk_size = self.action_horizon + self.execute_full_chunk = bool(self.vla_cfg.get("execute_full_chunk", True)) self._env = None def set_env(self, env) -> None: diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 7a7ce619..2c2c97ef 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -196,9 +196,14 @@ def on_step(tensordict: TensorDict, info: dict): """Callback called at each step during rollout collection.""" reward = tensordict["reward"] done = tensordict["done"] + step_repeat = tensordict.get( + "step_repeat", + torch.ones_like(reward, dtype=torch.float32, device=self.device), + ) + step_repeat_int = step_repeat.to(dtype=torch.int32) # Episode stats self.curr_ret += reward - self.curr_len += 1 + self.curr_len += step_repeat_int done_idx = torch.nonzero(done, as_tuple=False).squeeze(-1) if done_idx.numel() > 0: finished_ret = self.curr_ret[done_idx].detach().cpu().tolist() @@ -209,7 +214,7 @@ def on_step(tensordict: TensorDict, info: dict): self.curr_len[done_idx] = 0 if not self.distributed: - self.global_step += tensordict.batch_size[0] + self.global_step += int(step_repeat_int.sum().item()) if self.rank == 0 and isinstance(info, dict): rewards_dict = info.get("rewards") @@ -245,7 +250,10 @@ def on_step(tensordict: TensorDict, info: dict): "Call torch.distributed.init_process_group(...) before creating " "or using Trainer(distributed=True, ...)." ) - local_delta = self.env.num_envs * self.buffer_size + if "step_repeat" in rollout.keys(): + local_delta = int(rollout["step_repeat"][:, :-1].sum().item()) + else: + local_delta = self.env.num_envs * self.buffer_size delta_tensor = torch.tensor( [local_delta], dtype=torch.int64, device=self.device ) @@ -359,6 +367,7 @@ def _eval_once(self, num_episodes: int = 5): use_action_chunk = getattr(self.policy, "use_action_chunk", False) action_chunk_size = getattr(self.policy, "action_chunk_size", 1) effective_use_action_chunk = use_action_chunk and action_chunk_size > 0 + execute_full_chunk = bool(getattr(self.policy, "execute_full_chunk", False)) if hasattr(self.eval_env, "set_rollout_buffer"): self.eval_env.set_rollout_buffer(self.buffer.buffer) @@ -383,6 +392,71 @@ def _eval_once(self, num_episodes: int = 5): step_in_chunk = 0 while not done_mask.all(): + if execute_full_chunk and effective_use_action_chunk: + action_td = TensorDict( + {"obs": obs_td}, + batch_size=[num_envs], + device=self.device, + ) + action_td = self.policy.get_action(action_td, deterministic=True) + chunk = action_td.get("action_chunk") + if chunk is None: + raise ValueError( + "execute_full_chunk=True requires policy to provide 'action_chunk'." + ) + + reward_sum = torch.zeros( + num_envs, dtype=torch.float32, device=self.device + ) + terminated = torch.zeros( + num_envs, dtype=torch.bool, device=self.device + ) + truncated = torch.zeros( + num_envs, dtype=torch.bool, device=self.device + ) + executed_substeps = 0 + info = {} + + for sub_idx in range(action_chunk_size): + sub_actions = chunk[:, sub_idx] + am: ActionManager | None = getattr( + self.eval_env, "action_manager", None + ) + if am is None: + action_in = sub_actions + else: + action_in = am.convert_policy_action_to_env_action( + sub_actions + ) + + obs, reward, term_i, trunc_i, info = self.eval_env.step(action_in) + if use_raw_obs: + obs_td = dict_to_tensordict(obs, self.device) + else: + obs_td = ( + flatten_dict_observation(obs) + if isinstance(obs, TensorDict) + else obs + ) + + still_running = ~done_mask + reward_sum[still_running] += reward[still_running].float() + step_count[still_running] += 1 + terminated |= term_i + truncated |= trunc_i + done_mask |= term_i | trunc_i + executed_substeps += 1 + + if hasattr(self, "eval_event_manager"): + if "interval" in self.eval_event_manager.available_modes: + self.eval_event_manager.apply(mode="interval") + + if done_mask.all(): + break + + cumulative_reward += reward_sum + continue + if effective_use_action_chunk and ( cached_chunk is None or step_in_chunk == 0 ): From 081c728dab40df8d79c94badf7b25841c9496418 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Tue, 31 Mar 2026 07:29:59 +0000 Subject: [PATCH 12/17] fix --- .../agents/rl/buffer/standard_buffer.py | 13 ++++++ embodichain/agents/rl/buffer/utils.py | 3 ++ .../agents/rl/collector/sync_collector.py | 7 ++- embodichain/agents/rl/models/vla_policy.py | 44 ++++++++++++++----- embodichain/agents/rl/utils/trainer.py | 1 + 5 files changed, 57 insertions(+), 11 deletions(-) diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py index fef94d51..deb6abed 100644 --- a/embodichain/agents/rl/buffer/standard_buffer.py +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -41,6 +41,7 @@ def __init__( action_dim: int, device: torch.device, use_raw_obs: bool = False, + action_chunk_size: int = 0, ) -> None: self.num_envs = num_envs self.rollout_len = rollout_len @@ -48,6 +49,7 @@ def __init__( self.action_dim = action_dim self.device = device self.use_raw_obs = use_raw_obs + self.action_chunk_size = action_chunk_size self._rollout = self._allocate_rollout() self._is_full = False @@ -157,6 +159,15 @@ def _allocate_rollout(self) -> TensorDict: batch_size=[self.num_envs, self.rollout_len + 1], device=self.device, ) + if self.action_chunk_size > 0: + td["action_chunk"] = torch.zeros( + self.num_envs, + self.rollout_len + 1, + self.action_chunk_size, + self.action_dim, + dtype=torch.float32, + device=self.device, + ) return td def _clear_dynamic_fields(self) -> None: @@ -186,6 +197,8 @@ def _reset_padding_slot(self) -> None: self._rollout["done"][:, last_idx].fill_(False) self._rollout["terminated"][:, last_idx].fill_(False) self._rollout["truncated"][:, last_idx].fill_(False) + if "action_chunk" in self._rollout.keys(): + self._rollout["action_chunk"][:, last_idx].zero_() def _validate_rollout_layout(self, rollout: TensorDict) -> None: """Validate the expected tensor shapes for the shared rollout.""" diff --git a/embodichain/agents/rl/buffer/utils.py b/embodichain/agents/rl/buffer/utils.py index 1e82c7e4..740d9d69 100644 --- a/embodichain/agents/rl/buffer/utils.py +++ b/embodichain/agents/rl/buffer/utils.py @@ -65,6 +65,9 @@ def transition_view(rollout: TensorDict, flatten: bool = False) -> TensorDict: if hasattr(rollout, "chunk_step") and rollout.chunk_step is not None: td["chunk_step"] = rollout.chunk_step + if "action_chunk" in rollout.keys(): + td["action_chunk"] = rollout["action_chunk"][:, :-1] + if flatten: return td.reshape(num_envs * time_dim) return td diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index ba3acc24..d91d7f56 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -42,7 +42,10 @@ def __init__( self.policy = policy self.device = device self.reset_every_rollout = reset_every_rollout - self._supports_shared_rollout = hasattr(self.env, "set_rollout_buffer") + execute_full_chunk = bool(getattr(self.policy, "execute_full_chunk", False)) + self._supports_shared_rollout = ( + hasattr(self.env, "set_rollout_buffer") and not execute_full_chunk + ) self.obs_td = self._reset_env() @torch.no_grad() @@ -348,6 +351,8 @@ def _write_step( rollout["action"][:, step_idx] = step_td["action"] rollout["sample_log_prob"][:, step_idx] = step_td["sample_log_prob"] rollout["value"][:, step_idx] = step_td["value"] + if "action_chunk" in rollout.keys() and "action_chunk" in step_td.keys(): + rollout["action_chunk"][:, step_idx] = step_td["action_chunk"] def _write_env_step( self, diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py index ad228f99..add54ad1 100644 --- a/embodichain/agents/rl/models/vla_policy.py +++ b/embodichain/agents/rl/models/vla_policy.py @@ -65,7 +65,7 @@ def __init__( self.use_action_chunk = True self.action_chunk_size = self.action_horizon - self.execute_full_chunk = bool(self.vla_cfg.get("execute_full_chunk", True)) + self.execute_full_chunk = True self._env = None def set_env(self, env) -> None: @@ -87,6 +87,20 @@ def _load_vla(self) -> None: }, ) self._vla_model, self._action_indices, self._prepare_batch_fn = backend + self._freeze_encoders() + + def _freeze_encoders(self) -> None: + """Freeze vision encoders to avoid catastrophic forgetting""" + if self._vla_model is None: + return + encoders = getattr(self._vla_model, "encoders", None) + if encoders is not None: + for param in encoders.parameters(): + param.requires_grad_(False) + privilege_estimators = getattr(self._vla_model, "privilege_estimators", None) + if privilege_estimators is not None: + for param in privilege_estimators.parameters(): + param.requires_grad_(False) def _vla_chunk_to_env_chunk( self, action_chunk: torch.Tensor, env=None @@ -183,7 +197,7 @@ def get_value(self, tensordict: TensorDict) -> TensorDict: def evaluate_actions( self, tensordict: TensorDict, rollout=None, **kwargs ) -> TensorDict: - """Compute log_prob via Gaussian proxy""" + """Compute log_prob via Gaussian proxy for GRPO policy gradient.""" b = tensordict.batch_size[0] env = getattr(self, "_env", None) if env is None: @@ -204,17 +218,17 @@ def evaluate_actions( sigma = self.gaussian_sigma log_probs = [] self._load_vla() - self._vla_model.eval() + + stored_chunks = tensordict.get("action_chunk", None) + use_full_chunk = stored_chunks is not None for i in range(b): idx = int(indices[i].item()) env_idx = idx // time_dim step_idx = idx % time_dim step_in_chunk = int(chunk_step[i].item()) - # Action came from chunk predicted at chunk start chunk_start_idx = max(0, step_idx - step_in_chunk) obs_i = raw_obs[chunk_start_idx][env_idx] - action_gt = tensordict["action"][i] batch_i = self._prepare_batch_fn(obs_i, env) vla_chunk = self._vla_model.predict_action( @@ -225,16 +239,26 @@ def evaluate_actions( use_fix_aug=False, ) pred_chunk_env = self._vla_chunk_to_env_chunk(vla_chunk, env=env) - pred = pred_chunk_env[0, step_in_chunk] - if pred.shape[-1] != action_gt.shape[-1]: - pred = pred[: action_gt.shape[-1]] - mse = ((action_gt - pred).pow(2)).sum(-1) + + if use_full_chunk: + gt_chunk = stored_chunks[i] + pred_chunk = pred_chunk_env[0] + min_len = min(gt_chunk.shape[-1], pred_chunk.shape[-1]) + mse = ((gt_chunk[..., :min_len] - pred_chunk[..., :min_len]).pow(2)).sum(-1).mean(-1) + else: + action_gt = tensordict["action"][i] + pred = pred_chunk_env[0, step_in_chunk] + if pred.shape[-1] != action_gt.shape[-1]: + pred = pred[: action_gt.shape[-1]] + mse = ((action_gt - pred).pow(2)).sum(-1) + log_prob = -0.5 * mse / (sigma * sigma + 1e-8) log_probs.append(log_prob) log_probs = torch.stack(log_probs) + effective_dim = self.action_dim * (self.action_horizon if use_full_chunk else 1) entropy = ( - 0.5 * self.action_dim * (1 + np.log(2 * np.pi) + 2 * np.log(sigma + 1e-8)) + 0.5 * effective_dim * (1 + np.log(2 * np.pi) + 2 * np.log(sigma + 1e-8)) ) entropy = torch.full((b,), entropy, device=self.device, dtype=torch.float32) diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 2c2c97ef..92122223 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -124,6 +124,7 @@ def __init__( action_dim=action_dim, device=self.device, use_raw_obs=use_raw_obs, + action_chunk_size=action_chunk_size if use_action_chunk else 0, ) self.collector = SyncCollector( env=self.env, From 62c0de8c863739a9489c6cb214ade76b535a3bb5 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Tue, 31 Mar 2026 07:54:46 +0000 Subject: [PATCH 13/17] update --- embodichain/agents/rl/models/vla_policy.py | 23 ++++++++++++++++------ embodichain/agents/rl/utils/trainer.py | 3 ++- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py index add54ad1..72f039a7 100644 --- a/embodichain/agents/rl/models/vla_policy.py +++ b/embodichain/agents/rl/models/vla_policy.py @@ -175,18 +175,29 @@ def forward( chunks_env.append(chunk_i) action_chunk_env = torch.cat(chunks_env, dim=0) - action_chunk_env = action_chunk_env.to(self.device, dtype=torch.float32) - action = action_chunk_env[:, 0] + mean_chunk = action_chunk_env.to(self.device, dtype=torch.float32) + if deterministic: + noisy_chunk = mean_chunk + log_prob = torch.zeros( + mean_chunk.shape[0], device=self.device, dtype=torch.float32 + ) + else: + sigma = self.gaussian_sigma + noise = torch.randn_like(mean_chunk) * sigma + noisy_chunk = mean_chunk + noise + log_prob = ( + -0.5 * noise.pow(2).sum(-1).mean(-1) / (sigma * sigma + 1e-8) + ) + + action = noisy_chunk[:, 0] tensordict["action"] = action - tensordict["sample_log_prob"] = torch.zeros( - action.shape[0], device=self.device, dtype=torch.float32 - ) + tensordict["sample_log_prob"] = log_prob tensordict["value"] = torch.zeros( action.shape[0], device=self.device, dtype=torch.float32 ) if self.use_action_chunk: - tensordict["action_chunk"] = action_chunk_env + tensordict["action_chunk"] = noisy_chunk return tensordict def get_value(self, tensordict: TensorDict) -> TensorDict: diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 92122223..02f69e9b 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -96,7 +96,8 @@ def __init__( use_raw_obs = getattr(self.policy, "use_raw_obs", False) action_chunk_size = getattr(self.policy, "action_chunk_size", 0) use_action_chunk = getattr(self.policy, "use_action_chunk", False) - if use_action_chunk and action_chunk_size > 0: + execute_full_chunk = bool(getattr(self.policy, "execute_full_chunk", False)) + if use_action_chunk and action_chunk_size > 0 and not execute_full_chunk: self.buffer_size = ( (self.buffer_size + action_chunk_size - 1) // action_chunk_size From d5ce609d9c20b5e6fcfc1a722ac720abb28974c9 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Tue, 31 Mar 2026 08:33:37 +0000 Subject: [PATCH 14/17] reformat files --- embodichain/agents/rl/models/vla_policy.py | 10 ++++++---- embodichain/agents/rl/utils/trainer.py | 4 +++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py index 72f039a7..3d46ee9b 100644 --- a/embodichain/agents/rl/models/vla_policy.py +++ b/embodichain/agents/rl/models/vla_policy.py @@ -186,9 +186,7 @@ def forward( sigma = self.gaussian_sigma noise = torch.randn_like(mean_chunk) * sigma noisy_chunk = mean_chunk + noise - log_prob = ( - -0.5 * noise.pow(2).sum(-1).mean(-1) / (sigma * sigma + 1e-8) - ) + log_prob = -0.5 * noise.pow(2).sum(-1).mean(-1) / (sigma * sigma + 1e-8) action = noisy_chunk[:, 0] tensordict["action"] = action @@ -255,7 +253,11 @@ def evaluate_actions( gt_chunk = stored_chunks[i] pred_chunk = pred_chunk_env[0] min_len = min(gt_chunk.shape[-1], pred_chunk.shape[-1]) - mse = ((gt_chunk[..., :min_len] - pred_chunk[..., :min_len]).pow(2)).sum(-1).mean(-1) + mse = ( + ((gt_chunk[..., :min_len] - pred_chunk[..., :min_len]).pow(2)) + .sum(-1) + .mean(-1) + ) else: action_gt = tensordict["action"][i] pred = pred_chunk_env[0, step_in_chunk] diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 02f69e9b..2b03876f 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -431,7 +431,9 @@ def _eval_once(self, num_episodes: int = 5): sub_actions ) - obs, reward, term_i, trunc_i, info = self.eval_env.step(action_in) + obs, reward, term_i, trunc_i, info = self.eval_env.step( + action_in + ) if use_raw_obs: obs_td = dict_to_tensordict(obs, self.device) else: From 5076ebe4a86efd0b176549a09ef1b07501df005c Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Wed, 1 Apr 2026 10:26:53 +0000 Subject: [PATCH 15/17] wip --- docs/source/overview/rl/trainer.md | 5 + embodichain/agents/rl/buffer/utils.py | 3 +- .../agents/rl/collector/sync_collector.py | 41 +++- embodichain/agents/rl/models/__init__.py | 5 +- embodichain/agents/rl/models/vla_policy.py | 49 +++- embodichain/agents/rl/utils/trainer.py | 23 +- embodichain/agents/rl/vla_registry.py | 73 +++++- .../gym/envs/managers/randomization/visual.py | 68 +++-- tests/agents/test_shared_rollout.py | 232 ++++++++++++++++++ tests/agents/test_vla_policy.py | 42 ++++ 10 files changed, 472 insertions(+), 69 deletions(-) create mode 100644 tests/agents/test_vla_policy.py diff --git a/docs/source/overview/rl/trainer.md b/docs/source/overview/rl/trainer.md index 5ef4ee99..b90dd991 100644 --- a/docs/source/overview/rl/trainer.md +++ b/docs/source/overview/rl/trainer.md @@ -51,4 +51,9 @@ trainer.save_checkpoint() - The event mechanism can be used for automated experiments, data collection, and environment reset. - Logging and monitoring help analyze training progress and tune hyperparameters. +## API References +- VLA backend lookup: `embodichain.agents.rl.vla_registry.get_vla_backend()` +- VLA backend listing: `embodichain.agents.rl.vla_registry.get_registered_vla_backend_names()` +- VLA backend creation: `embodichain.agents.rl.vla_registry.create_vla_backend()` + --- diff --git a/embodichain/agents/rl/buffer/utils.py b/embodichain/agents/rl/buffer/utils.py index 740d9d69..0440c42e 100644 --- a/embodichain/agents/rl/buffer/utils.py +++ b/embodichain/agents/rl/buffer/utils.py @@ -78,7 +78,8 @@ def iterate_minibatches( ) -> Iterator[TensorDict]: """Yield shuffled minibatches from a flattened rollout.""" total = rollout.batch_size[0] - indices = torch.randperm(total) + idx_device = rollout.device if rollout.device is not None else device + indices = torch.randperm(total, device=idx_device) for start in range(0, total, batch_size): batch_indices = indices[start : start + batch_size] batch = rollout[batch_indices].clone() diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index d91d7f56..be525d57 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -253,6 +253,7 @@ def collect( step_td = TensorDict( { "action": action, + "action_chunk": cached_chunk, "sample_log_prob": torch.zeros( action.shape[0], device=self.device, dtype=torch.float32 ), @@ -371,16 +372,38 @@ def _write_env_step( def _validate_rollout(self, rollout: TensorDict, num_steps: int) -> None: """Validate rollout layout expected by the collector.""" - obs_dim = rollout["obs"].shape[-1] + num_envs = self.env.num_envs + time_plus_one = num_steps + 1 + policy_obs_dim = int(getattr(self.policy, "obs_dim", 0) or 0) + obs_shape = tuple(rollout["obs"].shape) + if policy_obs_dim > 0: + expected_obs = (num_envs, time_plus_one, policy_obs_dim) + if obs_shape != expected_obs: + logger.log_error( + f"Preallocated rollout field 'obs' shape mismatch: " + f"expected {expected_obs}, got {obs_shape}.", + ValueError, + ) + else: + if ( + len(obs_shape) != 3 + or obs_shape[0] != num_envs + or obs_shape[1] != time_plus_one + ): + logger.log_error( + f"Preallocated rollout field 'obs' shape mismatch: " + f"expected ({num_envs}, {time_plus_one}, *), got {obs_shape}.", + ValueError, + ) + expected_shapes = { - "obs": (self.env.num_envs, num_steps + 1, obs_dim), - "action": (self.env.num_envs, num_steps + 1, self.policy.action_dim), - "sample_log_prob": (self.env.num_envs, num_steps + 1), - "value": (self.env.num_envs, num_steps + 1), - "reward": (self.env.num_envs, num_steps + 1), - "done": (self.env.num_envs, num_steps + 1), - "terminated": (self.env.num_envs, num_steps + 1), - "truncated": (self.env.num_envs, num_steps + 1), + "action": (num_envs, time_plus_one, self.policy.action_dim), + "sample_log_prob": (num_envs, time_plus_one), + "value": (num_envs, time_plus_one), + "reward": (num_envs, time_plus_one), + "done": (num_envs, time_plus_one), + "terminated": (num_envs, time_plus_one), + "truncated": (num_envs, time_plus_one), } for key, expected_shape in expected_shapes.items(): actual_shape = tuple(rollout[key].shape) diff --git a/embodichain/agents/rl/models/__init__.py b/embodichain/agents/rl/models/__init__.py index 46231005..1eee5983 100644 --- a/embodichain/agents/rl/models/__init__.py +++ b/embodichain/agents/rl/models/__init__.py @@ -70,7 +70,7 @@ def build_policy( Built-in MLP policies still resolve flattened `obs_dim` / `action_dim`, while custom policies may accept richer `obs_space` / `action_space` inputs. - For vla_policy, pass env to enable set_env and _load_vla initialization. + For vla_policy, pass env so set_env can run; VLA weights load lazily on first use. """ name = policy_block["name"].lower() @@ -130,10 +130,9 @@ def build_policy( if env is None: raise ValueError( "VLAPolicy requires an 'env' argument to be passed to build_policy " - "so that set_env and _load_vla can be called before use." + "so that set_env can be called before use." ) policy.set_env(env) - policy._load_vla() return policy diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py index 3d46ee9b..1d70c6e0 100644 --- a/embodichain/agents/rl/models/vla_policy.py +++ b/embodichain/agents/rl/models/vla_policy.py @@ -16,6 +16,7 @@ from __future__ import annotations +from collections.abc import Mapping from typing import Any import numpy as np @@ -60,8 +61,8 @@ def __init__( self.action_dim = int(action_space.shape[-1]) else: self.action_dim = 14 - self.obs_dim = 0 # VLA uses raw ob - self.use_raw_obs = True # Tell collector to pass raw ob + self.obs_dim = 0 # VLA uses raw observations + self.use_raw_obs = True # Tell collector to pass raw observations self.use_action_chunk = True self.action_chunk_size = self.action_horizon @@ -75,8 +76,9 @@ def set_env(self, env) -> None: def _load_vla(self) -> None: if self._vla_model is not None: return + backend_name = str(self.vla_cfg.get("backend", "dexforce_vla")).lower() backend = create_vla_backend( - "dexforce_vla", + backend_name, model_path=self.model_path, device=self.device, action_horizon=self.action_horizon, @@ -89,6 +91,11 @@ def _load_vla(self) -> None: self._vla_model, self._action_indices, self._prepare_batch_fn = backend self._freeze_encoders() + def parameters(self, recurse: bool = True): + """Expose trainable parameters, lazily loading VLA backend if needed.""" + self._load_vla() + return super().parameters(recurse=recurse) + def _freeze_encoders(self) -> None: """Freeze vision encoders to avoid catastrophic forgetting""" if self._vla_model is None: @@ -128,6 +135,32 @@ def _vla_chunk_to_env_chunk( step = torch.cat([step, pad], dim=-1) return step + def _infer_batch_size(self, obs: Any) -> int: + """Infer leading batch size from TensorDict / mapping / tensor observations.""" + if hasattr(obs, "batch_size") and len(obs.batch_size) > 0: + return int(obs.batch_size[0]) + if isinstance(obs, Mapping): + robot = obs.get("robot") + if isinstance(robot, Mapping): + q = robot.get("qpos") + if hasattr(q, "shape") and len(q.shape) > 0: + return int(q.shape[0]) + return 1 + + def _slice_obs_item(self, obs: Any, index: int) -> Any: + """Slice one environment sample from a batched observation structure.""" + if hasattr(obs, "batch_size") and len(obs.batch_size) > 0: + return obs[index] + if isinstance(obs, Mapping): + return { + key: self._slice_obs_item(value, index) for key, value in obs.items() + } + if torch.is_tensor(obs): + if obs.dim() == 0: + return obs + return obs[index] + return obs + def forward( self, tensordict: TensorDict, deterministic: bool = False ) -> TensorDict: @@ -142,13 +175,7 @@ def forward( self._load_vla() self._vla_model.eval() - if hasattr(obs, "batch_size") and len(obs.batch_size) > 0: - batch_size = int(obs.batch_size[0]) - elif isinstance(obs, dict) and "robot" in obs and "qpos" in obs["robot"]: - q = obs["robot"]["qpos"] - batch_size = q.shape[0] if hasattr(q, "shape") and len(q.shape) > 0 else 1 - else: - batch_size = 1 + batch_size = self._infer_batch_size(obs) if batch_size == 1: batch = self._prepare_batch_fn(obs, env) vla_chunk = self._vla_model.predict_action( @@ -162,7 +189,7 @@ def forward( else: chunks_env = [] for i in range(batch_size): - obs_i = obs[i] if hasattr(obs, "__getitem__") else obs + obs_i = self._slice_obs_item(obs, i) batch_i = self._prepare_batch_fn(obs_i, env) vla_chunk = self._vla_model.predict_action( batch_i, diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 2b03876f..8b5a1cec 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -29,7 +29,6 @@ from embodichain.lab.gym.envs.managers.action_manager import ActionManager from embodichain.agents.rl.buffer import RolloutBuffer from embodichain.agents.rl.collector import SyncCollector -from embodichain.agents.rl.utils import dict_to_tensordict from embodichain.lab.gym.envs.managers.event_manager import EventManager from .helper import flatten_dict_observation @@ -106,9 +105,8 @@ def __init__( if use_raw_obs: try: - reset_out = self.env.reset() - sample_obs = reset_out[0] if isinstance(reset_out, tuple) else reset_out - obs_td = dict_to_tensordict(sample_obs, self.device) + sample_obs = self.env.observation_space.sample() + obs_td = self._obs_to_tensordict(sample_obs) flat_obs = flatten_dict_observation(obs_td) obs_dim = int( flat_obs.shape[-1] @@ -143,6 +141,17 @@ def __init__( self.curr_len = torch.zeros(num_envs, dtype=torch.int32, device=self.device) # ---- lightweight helpers for dense logging ---- + def _obs_to_tensordict(self, obs) -> TensorDict: + """Normalize observation to TensorDict on trainer device.""" + if isinstance(obs, TensorDict): + return obs.to(self.device) + if isinstance(obs, dict): + return TensorDict.from_dict(obs, device=self.device) + raise TypeError( + f"Unsupported raw observation type: {type(obs)!r}. " + "Expected TensorDict or dict." + ) + @staticmethod def _mean_scalar(x) -> float: if hasattr(x, "detach"): @@ -376,7 +385,7 @@ def _eval_once(self, num_episodes: int = 5): for _ in range(num_episodes): obs, _ = self.eval_env.reset() if use_raw_obs: - obs_td = dict_to_tensordict(obs, self.device) + obs_td = self._obs_to_tensordict(obs) else: obs_td = flatten_dict_observation(obs) num_envs = ( @@ -435,7 +444,7 @@ def _eval_once(self, num_episodes: int = 5): action_in ) if use_raw_obs: - obs_td = dict_to_tensordict(obs, self.device) + obs_td = self._obs_to_tensordict(obs) else: obs_td = ( flatten_dict_observation(obs) @@ -499,7 +508,7 @@ def _eval_once(self, num_episodes: int = 5): obs, reward, terminated, truncated, info = self.eval_env.step(action_in) if use_raw_obs: - obs_td = dict_to_tensordict(obs, self.device) + obs_td = self._obs_to_tensordict(obs) else: obs_td = ( flatten_dict_observation(obs) diff --git a/embodichain/agents/rl/vla_registry.py b/embodichain/agents/rl/vla_registry.py index e9f13546..f594c30b 100644 --- a/embodichain/agents/rl/vla_registry.py +++ b/embodichain/agents/rl/vla_registry.py @@ -19,6 +19,8 @@ from importlib.metadata import entry_points from typing import Any, Callable +from embodichain.utils.logger import log_warning + __all__ = [ "get_vla_backend", "get_registered_vla_backend_names", @@ -28,29 +30,59 @@ _VLA_BACKENDS: dict[str, Callable[..., Any]] = {} _ENTRY_POINTS_DISCOVERED = False +_ENTRY_POINTS_ENUM_LOGGED = False def _discover_entry_points() -> None: """Discover and register VLA backends from entry_points.""" - global _ENTRY_POINTS_DISCOVERED + global _ENTRY_POINTS_DISCOVERED, _ENTRY_POINTS_ENUM_LOGGED if _ENTRY_POINTS_DISCOVERED: return - _ENTRY_POINTS_DISCOVERED = True try: eps = entry_points(group="embodichain.vla_backends") - for ep in eps: - try: - factory = ep.load() - name = str(ep.name).lower() - if name not in _VLA_BACKENDS: - _VLA_BACKENDS[name] = factory - except Exception: - pass - except Exception: - pass + except (OSError, ValueError, TypeError) as exc: + if not _ENTRY_POINTS_ENUM_LOGGED: + log_warning( + "Could not enumerate 'embodichain.vla_backends' entry points: " + f"{type(exc).__name__}: {exc}" + ) + _ENTRY_POINTS_ENUM_LOGGED = True + return + + for ep in eps: + try: + factory = ep.load() + except (ImportError, AttributeError, TypeError, ValueError) as exc: + log_warning( + f"Failed to load VLA backend entry point name={ep.name!r} " + f"value={ep.value!r}: {type(exc).__name__}: {exc}" + ) + continue + except Exception as exc: + log_warning( + f"Unexpected error loading VLA backend entry point name={ep.name!r} " + f"value={ep.value!r}: {type(exc).__name__}: {exc}" + ) + continue + name = str(ep.name).lower() + if name not in _VLA_BACKENDS: + _VLA_BACKENDS[name] = factory + + _ENTRY_POINTS_DISCOVERED = True def get_vla_backend(name: str) -> Callable[..., Any] | None: + """Get a registered backend factory by name. + + This checks the in-memory registry first, and then lazily triggers + entry-point discovery if needed. + + Args: + name: Backend identifier (case-insensitive). + + Returns: + The backend factory callable if found, otherwise ``None``. + """ name = str(name).lower() if name in _VLA_BACKENDS: return _VLA_BACKENDS[name] @@ -59,11 +91,28 @@ def get_vla_backend(name: str) -> Callable[..., Any] | None: def get_registered_vla_backend_names() -> list[str]: + """List all currently discoverable VLA backend names. + + Returns: + A list of backend names after lazy entry-point discovery. + """ _discover_entry_points() return list(_VLA_BACKENDS.keys()) def create_vla_backend(name: str, **kwargs) -> Any: + """Instantiate a VLA backend by name. + + Args: + name: Backend identifier (case-insensitive). + **kwargs: Keyword arguments forwarded to the backend factory. + + Returns: + The instantiated backend object (factory-defined type). + + Raises: + ValueError: If the backend name is unknown. + """ factory = get_vla_backend(name) if factory is None: available = get_registered_vla_backend_names() diff --git a/embodichain/lab/gym/envs/managers/randomization/visual.py b/embodichain/lab/gym/envs/managers/randomization/visual.py index 62dcfcd5..0522db8d 100644 --- a/embodichain/lab/gym/envs/managers/randomization/visual.py +++ b/embodichain/lab/gym/envs/managers/randomization/visual.py @@ -183,8 +183,12 @@ def randomize_camera_extrinsics( ).repeat(num_instance, 1) if pos_range: random_value = sample_uniform( - lower=torch.tensor(pos_range[0], device=env.device), - upper=torch.tensor(pos_range[1], device=env.device), + lower=torch.tensor( + pos_range[0], dtype=torch.float32, device=env.device + ), + upper=torch.tensor( + pos_range[1], dtype=torch.float32, device=env.device + ), size=(num_instance, 3), device=env.device, ) @@ -199,8 +203,12 @@ def randomize_camera_extrinsics( init_euler = torch.stack(euler_xyz_from_quat(init_quat_np), dim=1) # 2. Sample perturbation for euler angles random_value = sample_uniform( - lower=torch.tensor(euler_range[0], device=env.device), - upper=torch.tensor(euler_range[1], device=env.device), + lower=torch.tensor( + euler_range[0], dtype=torch.float32, device=env.device + ), + upper=torch.tensor( + euler_range[1], dtype=torch.float32, device=env.device + ), size=(num_instance, 3), device=env.device, ) @@ -231,8 +239,12 @@ def randomize_camera_extrinsics( if eye_range: eye_delta = sample_uniform( - lower=torch.tensor(eye_range[0], device=env.device), - upper=torch.tensor(eye_range[1], device=env.device), + lower=torch.tensor( + eye_range[0], dtype=torch.float32, device=env.device + ), + upper=torch.tensor( + eye_range[1], dtype=torch.float32, device=env.device + ), size=(num_instance, 3), device=env.device, ) @@ -242,8 +254,12 @@ def randomize_camera_extrinsics( if target_range: target_delta = sample_uniform( - lower=torch.tensor(target_range[0], device=env.device), - upper=torch.tensor(target_range[1], device=env.device), + lower=torch.tensor( + target_range[0], dtype=torch.float32, device=env.device + ), + upper=torch.tensor( + target_range[1], dtype=torch.float32, device=env.device + ), size=(num_instance, 3), device=env.device, ) @@ -253,8 +269,8 @@ def randomize_camera_extrinsics( if up_range: up_delta = sample_uniform( - lower=torch.tensor(up_range[0], device=env.device), - upper=torch.tensor(up_range[1], device=env.device), + lower=torch.tensor(up_range[0], dtype=torch.float32, device=env.device), + upper=torch.tensor(up_range[1], dtype=torch.float32, device=env.device), size=(num_instance, 3), device=env.device, ) @@ -316,8 +332,8 @@ def randomize_light( .repeat(num_instance, 1) ) random_value = sample_uniform( - lower=torch.tensor(position_range[0]), - upper=torch.tensor(position_range[1]), + lower=torch.tensor(position_range[0], dtype=torch.float32), + upper=torch.tensor(position_range[1], dtype=torch.float32), size=new_pos.shape, ) new_pos += random_value @@ -326,8 +342,8 @@ def randomize_light( if color_range: color = torch.zeros((num_instance, 3), dtype=torch.float32) random_value = sample_uniform( - lower=torch.tensor(color_range[0]), - upper=torch.tensor(color_range[1]), + lower=torch.tensor(color_range[0], dtype=torch.float32), + upper=torch.tensor(color_range[1], dtype=torch.float32), size=color.shape, ) color += random_value @@ -341,8 +357,8 @@ def randomize_light( .repeat(num_instance, 1) ) random_value = sample_uniform( - lower=torch.tensor(intensity_range[0]), - upper=torch.tensor(intensity_range[1]), + lower=torch.tensor(intensity_range[0], dtype=torch.float32), + upper=torch.tensor(intensity_range[1], dtype=torch.float32), size=new_intensity.shape, ) new_intensity += random_value @@ -377,8 +393,8 @@ def randomize_emission_light( if color_range: color = torch.zeros((1, 3), dtype=torch.float32) random_value = sample_uniform( - lower=torch.tensor(color_range[0]), - upper=torch.tensor(color_range[1]), + lower=torch.tensor(color_range[0], dtype=torch.float32), + upper=torch.tensor(color_range[1], dtype=torch.float32), size=color.shape, ) color += random_value @@ -450,8 +466,8 @@ def randomize_camera_intrinsics( # Randomize focal length x (fx) if focal_x_range: random_value = sample_uniform( - lower=torch.tensor(focal_x_range[0]), - upper=torch.tensor(focal_x_range[1]), + lower=torch.tensor(focal_x_range[0], dtype=torch.float32), + upper=torch.tensor(focal_x_range[1], dtype=torch.float32), size=(num_instance,), ) new_intrinsics[:, 0] += random_value @@ -459,8 +475,8 @@ def randomize_camera_intrinsics( # Randomize focal length y (fy) if focal_y_range: random_value = sample_uniform( - lower=torch.tensor(focal_y_range[0]), - upper=torch.tensor(focal_y_range[1]), + lower=torch.tensor(focal_y_range[0], dtype=torch.float32), + upper=torch.tensor(focal_y_range[1], dtype=torch.float32), size=(num_instance,), ) new_intrinsics[:, 1] += random_value @@ -468,8 +484,8 @@ def randomize_camera_intrinsics( # Randomize principal point x (cx) if cx_range: random_value = sample_uniform( - lower=torch.tensor(cx_range[0]), - upper=torch.tensor(cx_range[1]), + lower=torch.tensor(cx_range[0], dtype=torch.float32), + upper=torch.tensor(cx_range[1], dtype=torch.float32), size=(num_instance,), ) new_intrinsics[:, 2] += random_value @@ -477,8 +493,8 @@ def randomize_camera_intrinsics( # Randomize principal point y (cy) if cy_range: random_value = sample_uniform( - lower=torch.tensor(cy_range[0]), - upper=torch.tensor(cy_range[1]), + lower=torch.tensor(cy_range[0], dtype=torch.float32), + upper=torch.tensor(cy_range[1], dtype=torch.float32), size=(num_instance,), ) new_intrinsics[:, 3] += random_value diff --git a/tests/agents/test_shared_rollout.py b/tests/agents/test_shared_rollout.py index 37dd34fa..5439a014 100644 --- a/tests/agents/test_shared_rollout.py +++ b/tests/agents/test_shared_rollout.py @@ -223,3 +223,235 @@ def test_embodied_env_writes_next_fields_into_external_rollout(): env.close() if SimulationManager.is_instantiated(): SimulationManager.get_instance().destroy() + + +class _FakePolicyRawObs: + """Policy that reads nested env observations (collector passes raw_obs slices).""" + + use_raw_obs = True + + def __init__(self, obs_dim: int, action_dim: int, device: torch.device) -> None: + self.obs_dim = obs_dim + self.action_dim = action_dim + self.device = device + + def train(self) -> None: + pass + + def get_action( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + obs = tensordict["obs"] + flat = obs["agent"]["state"] + tensordict["action"] = flat[:, : self.action_dim] * 0.25 + tensordict["sample_log_prob"] = flat.sum(dim=-1) * 0.1 + tensordict["value"] = flat.mean(dim=-1) + return tensordict + + def get_value(self, tensordict: TensorDict) -> TensorDict: + obs = tensordict["obs"] + flat = obs["agent"]["state"] + tensordict["value"] = flat.mean(dim=-1) + return tensordict + + +class _FakePolicyActionChunk: + """Policy that emits a fixed-length action chunk (sequential env steps per chunk).""" + + use_action_chunk = True + action_chunk_size = 2 + execute_full_chunk = False + + def __init__(self, obs_dim: int, action_dim: int, device: torch.device) -> None: + self.obs_dim = obs_dim + self.action_dim = action_dim + self.device = device + + def train(self) -> None: + pass + + def get_action( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + obs = tensordict["obs"] + row0 = obs[:, : self.action_dim] * 0.1 + row1 = obs[:, : self.action_dim] * 0.2 + chunk = torch.stack([row0, row1], dim=1) + tensordict["action_chunk"] = chunk + tensordict["action"] = chunk[:, 0] + tensordict["sample_log_prob"] = torch.zeros( + obs.shape[0], device=obs.device, dtype=torch.float32 + ) + tensordict["value"] = torch.zeros( + obs.shape[0], device=obs.device, dtype=torch.float32 + ) + return tensordict + + def get_value(self, tensordict: TensorDict) -> TensorDict: + tensordict["value"] = tensordict["obs"].mean(dim=-1) + return tensordict + + +class _FakePolicyExecuteFullChunk: + """Policy that runs an entire chunk inside one logical rollout step.""" + + use_action_chunk = True + action_chunk_size = 3 + execute_full_chunk = True + + def __init__(self, obs_dim: int, action_dim: int, device: torch.device) -> None: + self.obs_dim = obs_dim + self.action_dim = action_dim + self.device = device + + def train(self) -> None: + pass + + def get_action( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + obs = tensordict["obs"] + n = obs.shape[0] + t = self.action_chunk_size + base = obs[:, : self.action_dim] * 0.1 + chunk = base.unsqueeze(1).expand(n, t, -1).clone() + tensordict["action_chunk"] = chunk + tensordict["action"] = chunk[:, 0] + tensordict["sample_log_prob"] = torch.zeros( + n, device=obs.device, dtype=torch.float32 + ) + tensordict["value"] = torch.zeros(n, device=obs.device, dtype=torch.float32) + return tensordict + + def get_value(self, tensordict: TensorDict) -> TensorDict: + tensordict["value"] = tensordict["obs"].mean(dim=-1) + return tensordict + + +def test_collector_populates_raw_obs_buffer(): + device = torch.device("cpu") + num_envs = 2 + rollout_len = 3 + obs_dim = 5 + action_dim = 2 + + env = _FakeEnv( + num_envs=num_envs, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + ) + policy = _FakePolicyRawObs(obs_dim=obs_dim, action_dim=action_dim, device=device) + collector = SyncCollector(env=env, policy=policy, device=device) + buffer = RolloutBuffer( + num_envs=num_envs, + rollout_len=rollout_len, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + use_raw_obs=True, + ) + + rollout = collector.collect( + num_steps=rollout_len, + rollout=buffer.start_rollout(), + ) + + assert hasattr(rollout, "raw_obs") + assert len(rollout.raw_obs) == rollout_len + 1 + for t in range(rollout_len + 1): + assert rollout.raw_obs[t] is not None + assert rollout.raw_obs[t].batch_size == torch.Size([num_envs]) + assert torch.allclose( + rollout["obs"][:, 0], torch.zeros(num_envs, obs_dim, dtype=torch.float32) + ) + assert torch.allclose( + rollout["obs"][:, -1], + torch.full((num_envs, obs_dim), float(rollout_len), dtype=torch.float32), + ) + + +def test_collector_chunk_step_alternates_for_sequential_action_chunk(): + device = torch.device("cpu") + num_envs = 2 + rollout_len = 4 + obs_dim = 5 + action_dim = 2 + + env = _FakeEnv( + num_envs=num_envs, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + ) + policy = _FakePolicyActionChunk( + obs_dim=obs_dim, action_dim=action_dim, device=device + ) + collector = SyncCollector(env=env, policy=policy, device=device) + buffer = RolloutBuffer( + num_envs=num_envs, + rollout_len=rollout_len, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + action_chunk_size=policy.action_chunk_size, + ) + + rollout = collector.collect( + num_steps=rollout_len, + rollout=buffer.start_rollout(), + ) + + assert hasattr(rollout, "chunk_step") + expected = torch.tensor([0, 1, 0, 1], dtype=torch.long, device=device).expand( + num_envs, -1 + ) + assert torch.equal(rollout.chunk_step, expected) + assert torch.all(rollout["step_repeat"][:, :-1] == 1.0) + assert torch.allclose(rollout["action_chunk"][:, 1], rollout["action_chunk"][:, 0]) + assert torch.allclose(rollout["action_chunk"][:, 3], rollout["action_chunk"][:, 2]) + + +def test_collector_execute_full_chunk_sets_step_repeat_and_action_chunk(): + device = torch.device("cpu") + num_envs = 2 + rollout_len = 2 + obs_dim = 5 + action_dim = 2 + chunk_t = 3 + + env = _FakeEnv( + num_envs=num_envs, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + ) + policy = _FakePolicyExecuteFullChunk( + obs_dim=obs_dim, action_dim=action_dim, device=device + ) + assert policy.action_chunk_size == chunk_t + collector = SyncCollector(env=env, policy=policy, device=device) + buffer = RolloutBuffer( + num_envs=num_envs, + rollout_len=rollout_len, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + action_chunk_size=chunk_t, + ) + + rollout = collector.collect( + num_steps=rollout_len, + rollout=buffer.start_rollout(), + ) + + assert "action_chunk" in rollout.keys() + assert rollout["action_chunk"].shape == ( + num_envs, + rollout_len + 1, + chunk_t, + action_dim, + ) + assert torch.all(rollout.chunk_step[:, :rollout_len] == 0) + assert torch.all(rollout["step_repeat"][:, :rollout_len] == float(chunk_t)) + assert torch.all(rollout["step_repeat"][:, -1] == 0.0) diff --git a/tests/agents/test_vla_policy.py b/tests/agents/test_vla_policy.py new file mode 100644 index 00000000..3cf20dc7 --- /dev/null +++ b/tests/agents/test_vla_policy.py @@ -0,0 +1,42 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch + +from embodichain.agents.rl.models.vla_policy import VLAPolicy + + +def test_vla_policy_slice_obs_item_with_mapping_batch(): + policy = VLAPolicy( + device=torch.device("cpu"), + policy_cfg={"vla": {"model_path": "dummy_model_path"}}, + action_space=2, + ) + obs = { + "robot": { + "qpos": torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32), + "qvel": torch.tensor([[5.0, 6.0], [7.0, 8.0]], dtype=torch.float32), + }, + "image": torch.tensor([[9.0], [10.0]], dtype=torch.float32), + } + + assert policy._infer_batch_size(obs) == 2 + obs_1 = policy._slice_obs_item(obs, 1) + assert torch.allclose(obs_1["robot"]["qpos"], torch.tensor([3.0, 4.0])) + assert torch.allclose(obs_1["robot"]["qvel"], torch.tensor([7.0, 8.0])) + assert torch.allclose(obs_1["image"], torch.tensor([10.0])) From 42b431b76d113152f4770268761922ffa82023cc Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Wed, 1 Apr 2026 12:55:44 +0000 Subject: [PATCH 16/17] refactor raw observation rollout handling --- .../agents/rl/buffer/standard_buffer.py | 31 ++- embodichain/agents/rl/buffer/utils.py | 24 +- .../agents/rl/collector/sync_collector.py | 208 +++++++++++------- embodichain/agents/rl/train.py | 14 +- embodichain/agents/rl/utils/trainer.py | 24 +- embodichain/lab/gym/envs/embodied_env.py | 1 - tests/agents/test_shared_rollout.py | 25 ++- 7 files changed, 198 insertions(+), 129 deletions(-) diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py index deb6abed..f897755d 100644 --- a/embodichain/agents/rl/buffer/standard_buffer.py +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -42,6 +42,7 @@ def __init__( device: torch.device, use_raw_obs: bool = False, action_chunk_size: int = 0, + store_flat_obs: bool = True, ) -> None: self.num_envs = num_envs self.rollout_len = rollout_len @@ -50,6 +51,7 @@ def __init__( self.device = device self.use_raw_obs = use_raw_obs self.action_chunk_size = action_chunk_size + self.store_flat_obs = store_flat_obs self._rollout = self._allocate_rollout() self._is_full = False @@ -103,15 +105,18 @@ def is_full(self) -> bool: def _allocate_rollout(self) -> TensorDict: """Preallocate rollout storage with uniform `[num_envs, time + 1]` shape.""" + rollout_tensors = {} + if self.store_flat_obs: + rollout_tensors["obs"] = torch.empty( + self.num_envs, + self.rollout_len + 1, + self.obs_dim, + dtype=torch.float32, + device=self.device, + ) td = TensorDict( { - "obs": torch.empty( - self.num_envs, - self.rollout_len + 1, - self.obs_dim, - dtype=torch.float32, - device=self.device, - ), + **rollout_tensors, "action": torch.empty( self.num_envs, self.rollout_len + 1, @@ -203,7 +208,6 @@ def _reset_padding_slot(self) -> None: def _validate_rollout_layout(self, rollout: TensorDict) -> None: """Validate the expected tensor shapes for the shared rollout.""" expected_shapes = { - "obs": (self.num_envs, self.rollout_len + 1, self.obs_dim), "action": (self.num_envs, self.rollout_len + 1, self.action_dim), "sample_log_prob": (self.num_envs, self.rollout_len + 1), "value": (self.num_envs, self.rollout_len + 1), @@ -212,6 +216,12 @@ def _validate_rollout_layout(self, rollout: TensorDict) -> None: "terminated": (self.num_envs, self.rollout_len + 1), "truncated": (self.num_envs, self.rollout_len + 1), } + if self.store_flat_obs: + expected_shapes["obs"] = ( + self.num_envs, + self.rollout_len + 1, + self.obs_dim, + ) for key, expected_shape in expected_shapes.items(): actual_shape = tuple(rollout[key].shape) if actual_shape != expected_shape: @@ -219,3 +229,8 @@ def _validate_rollout_layout(self, rollout: TensorDict) -> None: f"Rollout field '{key}' shape mismatch: expected {expected_shape}, " f"got {actual_shape}." ) + if not self.store_flat_obs and "obs" in rollout.keys(): + raise ValueError( + "RolloutBuffer configured with store_flat_obs=False must not contain " + "a preallocated 'obs' field." + ) diff --git a/embodichain/agents/rl/buffer/utils.py b/embodichain/agents/rl/buffer/utils.py index 0440c42e..e65e268e 100644 --- a/embodichain/agents/rl/buffer/utils.py +++ b/embodichain/agents/rl/buffer/utils.py @@ -42,18 +42,20 @@ def transition_view(rollout: TensorDict, flatten: bool = False) -> TensorDict: """ action = rollout["action"][:, :-1] num_envs, time_dim = action.shape[:2] + transition_fields = { + "action": action, + "sample_log_prob": rollout["sample_log_prob"][:, :-1], + "value": rollout["value"][:, :-1], + "next_value": rollout["value"][:, 1:], + "reward": rollout["reward"][:, :-1], + "done": rollout["done"][:, :-1], + "terminated": rollout["terminated"][:, :-1], + "truncated": rollout["truncated"][:, :-1], + } + if "obs" in rollout.keys(): + transition_fields["obs"] = rollout["obs"][:, :-1] td = TensorDict( - { - "obs": rollout["obs"][:, :-1], - "action": action, - "sample_log_prob": rollout["sample_log_prob"][:, :-1], - "value": rollout["value"][:, :-1], - "next_value": rollout["value"][:, 1:], - "reward": rollout["reward"][:, :-1], - "done": rollout["done"][:, :-1], - "terminated": rollout["terminated"][:, :-1], - "truncated": rollout["truncated"][:, :-1], - }, + transition_fields, batch_size=[num_envs, time_dim], device=rollout.device, ) diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index be525d57..5680e7e0 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -128,26 +128,20 @@ def collect( ) rollout["step_repeat"][:, -1] = 0.0 - if use_raw_obs and raw_obs_list is not None: - raw_obs_list[0] = self.obs_td - rollout["obs"][:, 0] = flatten_dict_observation(self.obs_td) - else: - rollout["obs"][:, 0] = flatten_dict_observation(self.obs_td) + self._store_observation( + rollout=rollout, + raw_obs_list=raw_obs_list, + step_idx=0, + obs_td=self.obs_td, + ) for step_idx in range(num_steps): if execute_full_chunk and use_action_chunk: - if use_raw_obs and raw_obs_list is not None: - step_td = TensorDict( - {"obs": raw_obs_list[step_idx]}, - batch_size=[rollout.batch_size[0]], - device=self.device, - ) - else: - step_td = TensorDict( - {"obs": rollout["obs"][:, step_idx]}, - batch_size=[rollout.batch_size[0]], - device=self.device, - ) + step_td = self._policy_input_tensordict( + rollout=rollout, + raw_obs_list=raw_obs_list, + step_idx=step_idx, + ) step_td = self.policy.get_action(step_td) chunk = step_td.get("action_chunk") if chunk is None: @@ -208,15 +202,12 @@ def collect( terminated=terminated, truncated=truncated, ) - if use_raw_obs and raw_obs_list is not None: - raw_obs_list[step_idx + 1] = next_obs_td - rollout["obs"][:, step_idx + 1] = flatten_dict_observation( - next_obs_td - ) - else: - rollout["obs"][:, step_idx + 1] = flatten_dict_observation( - next_obs_td - ) + self._store_observation( + rollout=rollout, + raw_obs_list=raw_obs_list, + step_idx=step_idx + 1, + obs_td=next_obs_td, + ) if on_step_callback is not None: on_step_callback(rollout[:, step_idx], env_info) @@ -230,18 +221,11 @@ def collect( ) if need_new_chunk: - if use_raw_obs and raw_obs_list is not None: - step_td = TensorDict( - {"obs": raw_obs_list[step_idx]}, - batch_size=[rollout.batch_size[0]], - device=self.device, - ) - else: - step_td = TensorDict( - {"obs": rollout["obs"][:, step_idx]}, - batch_size=[rollout.batch_size[0]], - device=self.device, - ) + step_td = self._policy_input_tensordict( + rollout=rollout, + raw_obs_list=raw_obs_list, + step_idx=step_idx, + ) step_td = self.policy.get_action(step_td) cached_chunk = step_td["action_chunk"] action = step_td["action"] @@ -266,18 +250,11 @@ def collect( ) chunk_cursor += 1 else: - if use_raw_obs and raw_obs_list is not None: - step_td = TensorDict( - {"obs": raw_obs_list[step_idx]}, - batch_size=[rollout.batch_size[0]], - device=self.device, - ) - else: - step_td = TensorDict( - {"obs": rollout["obs"][:, step_idx]}, - batch_size=[rollout.batch_size[0]], - device=self.device, - ) + step_td = self._policy_input_tensordict( + rollout=rollout, + raw_obs_list=raw_obs_list, + step_idx=step_idx, + ) step_td = self.policy.get_action(step_td) action = step_td["action"] @@ -301,11 +278,12 @@ def collect( terminated=terminated, truncated=truncated, ) - if use_raw_obs and raw_obs_list is not None: - raw_obs_list[step_idx + 1] = next_obs_td - rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td) - else: - rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td) + self._store_observation( + rollout=rollout, + raw_obs_list=raw_obs_list, + step_idx=step_idx + 1, + obs_td=next_obs_td, + ) if on_step_callback is not None: on_step_callback(rollout[:, step_idx], env_info) @@ -319,12 +297,14 @@ def _attach_final_value(self, rollout: TensorDict) -> None: """Populate the bootstrap value for the final observed state.""" use_raw_obs = getattr(self.policy, "use_raw_obs", False) raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None - if use_raw_obs and raw_obs_list is not None: - final_obs = raw_obs_list[-1] - else: - final_obs = rollout["obs"][:, -1] last_next_td = TensorDict( - {"obs": final_obs}, + { + "obs": self._policy_obs_at( + rollout=rollout, + raw_obs_list=raw_obs_list, + step_idx=rollout.batch_size[1] - 1, + ) + }, batch_size=[rollout.batch_size[0]], device=self.device, ) @@ -370,31 +350,99 @@ def _write_env_step( rollout["terminated"][:, step_idx] = terminated.to(self.device) rollout["truncated"][:, step_idx] = truncated.to(self.device) - def _validate_rollout(self, rollout: TensorDict, num_steps: int) -> None: - """Validate rollout layout expected by the collector.""" - num_envs = self.env.num_envs - time_plus_one = num_steps + 1 - policy_obs_dim = int(getattr(self.policy, "obs_dim", 0) or 0) - obs_shape = tuple(rollout["obs"].shape) - if policy_obs_dim > 0: - expected_obs = (num_envs, time_plus_one, policy_obs_dim) - if obs_shape != expected_obs: + def _policy_input_tensordict( + self, + rollout: TensorDict, + raw_obs_list: list[TensorDict | None] | None, + step_idx: int, + ) -> TensorDict: + """Build the policy input TensorDict for a rollout step.""" + obs = self._policy_obs_at( + rollout=rollout, + raw_obs_list=raw_obs_list, + step_idx=step_idx, + ) + return TensorDict( + {"obs": obs}, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + + def _policy_obs_at( + self, + rollout: TensorDict, + raw_obs_list: list[TensorDict | None] | None, + step_idx: int, + ) -> TensorDict | torch.Tensor: + """Read the observation representation expected by the current policy.""" + if raw_obs_list is not None: + obs = raw_obs_list[step_idx] + if obs is None: logger.log_error( - f"Preallocated rollout field 'obs' shape mismatch: " - f"expected {expected_obs}, got {obs_shape}.", - ValueError, + f"Missing raw observation at rollout step {step_idx}.", + RuntimeError, ) - else: - if ( - len(obs_shape) != 3 - or obs_shape[0] != num_envs - or obs_shape[1] != time_plus_one - ): + return obs + if "obs" not in rollout.keys(): + logger.log_error( + "Collector requires rollout['obs'] for policies that do not use raw " + "observations.", + ValueError, + ) + return rollout["obs"][:, step_idx] + + def _store_observation( + self, + rollout: TensorDict, + raw_obs_list: list[TensorDict | None] | None, + step_idx: int, + obs_td: TensorDict, + ) -> None: + """Write the current observation into whichever rollout views are enabled.""" + if raw_obs_list is not None: + raw_obs_list[step_idx] = obs_td + if "obs" in rollout.keys(): + if raw_obs_list is not None: logger.log_error( - f"Preallocated rollout field 'obs' shape mismatch: " - f"expected ({num_envs}, {time_plus_one}, *), got {obs_shape}.", + "Rollout should not allocate flat observations when raw observation " + "storage is enabled.", ValueError, ) + rollout["obs"][:, step_idx] = flatten_dict_observation(obs_td) + + def _validate_rollout(self, rollout: TensorDict, num_steps: int) -> None: + """Validate rollout layout expected by the collector.""" + num_envs = self.env.num_envs + time_plus_one = num_steps + 1 + policy_obs_dim = int(getattr(self.policy, "obs_dim", 0) or 0) + has_flat_obs = "obs" in rollout.keys() + if has_flat_obs: + obs_shape = tuple(rollout["obs"].shape) + if policy_obs_dim > 0: + expected_obs = (num_envs, time_plus_one, policy_obs_dim) + if obs_shape != expected_obs: + logger.log_error( + f"Preallocated rollout field 'obs' shape mismatch: " + f"expected {expected_obs}, got {obs_shape}.", + ValueError, + ) + else: + if ( + len(obs_shape) != 3 + or obs_shape[0] != num_envs + or obs_shape[1] != time_plus_one + ): + logger.log_error( + f"Preallocated rollout field 'obs' shape mismatch: " + f"expected ({num_envs}, {time_plus_one}, *), got {obs_shape}.", + ValueError, + ) + elif not getattr(self.policy, "use_raw_obs", False): + logger.log_error( + "Preallocated rollout TensorDict must contain 'obs' when " + "policy.use_raw_obs=False.", + ValueError, + ) expected_shapes = { "action": (num_envs, time_plus_one, self.policy.action_dim), diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py index c122bc1f..65a6ef0a 100644 --- a/embodichain/agents/rl/train.py +++ b/embodichain/agents/rl/train.py @@ -214,10 +214,8 @@ def train_from_config(config_path: str, distributed: bool | None = None): ) env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg) - sample_obs, _ = env.reset() - sample_obs_td = dict_to_tensordict(sample_obs, device) - obs_dim = flatten_dict_observation(sample_obs_td).shape[-1] - flat_obs_space = env.flattened_observation_space + obs_dim = None + flat_obs_space = None # Create evaluation environment only if enabled eval_env = None @@ -246,6 +244,10 @@ def train_from_config(config_path: str, distributed: bool | None = None): ) # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic/actor_only) if policy_name.lower() == "actor_critic": + sample_obs, _ = env.reset() + sample_obs_td = dict_to_tensordict(sample_obs, device) + obs_dim = flatten_dict_observation(sample_obs_td).shape[-1] + flat_obs_space = env.flattened_observation_space actor_cfg = policy_block.get("actor") critic_cfg = policy_block.get("critic") if actor_cfg is None or critic_cfg is None: @@ -265,6 +267,10 @@ def train_from_config(config_path: str, distributed: bool | None = None): critic=critic, ) elif policy_name.lower() == "actor_only": + sample_obs, _ = env.reset() + sample_obs_td = dict_to_tensordict(sample_obs, device) + obs_dim = flatten_dict_observation(sample_obs_td).shape[-1] + flat_obs_space = env.flattened_observation_space actor_cfg = policy_block.get("actor") if actor_cfg is None: raise ValueError( diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 8b5a1cec..fa2058f9 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -90,9 +90,12 @@ def __init__( raise RuntimeError("Env must expose num_envs for trainer statistics.") obs_dim = getattr(self.policy, "obs_dim", None) action_dim = getattr(self.policy, "action_dim", None) - if obs_dim is None or action_dim is None: - raise RuntimeError("Policy must expose obs_dim and action_dim.") use_raw_obs = getattr(self.policy, "use_raw_obs", False) + store_flat_obs = not use_raw_obs + if action_dim is None or (store_flat_obs and obs_dim is None): + raise RuntimeError( + "Policy must expose action_dim and flat-observation metadata." + ) action_chunk_size = getattr(self.policy, "action_chunk_size", 0) use_action_chunk = getattr(self.policy, "use_action_chunk", False) execute_full_chunk = bool(getattr(self.policy, "execute_full_chunk", False)) @@ -103,27 +106,18 @@ def __init__( * action_chunk_size ) - if use_raw_obs: - try: - sample_obs = self.env.observation_space.sample() - obs_td = self._obs_to_tensordict(sample_obs) - flat_obs = flatten_dict_observation(obs_td) - obs_dim = int( - flat_obs.shape[-1] - if isinstance(flat_obs, torch.Tensor) - else np.asarray(flat_obs).shape[-1] - ) - except Exception: - obs_dim = max(1, obs_dim) + if not store_flat_obs: + obs_dim = int(obs_dim or 0) self.buffer = RolloutBuffer( num_envs=num_envs, rollout_len=self.buffer_size, - obs_dim=obs_dim, + obs_dim=int(obs_dim), action_dim=action_dim, device=self.device, use_raw_obs=use_raw_obs, action_chunk_size=action_chunk_size if use_action_chunk else 0, + store_flat_obs=store_flat_obs, ) self.collector = SyncCollector( env=self.env, diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index da3ae9b7..090bca25 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -547,7 +547,6 @@ def _initialize_episode( def _infer_rollout_buffer_mode(self, rollout_buffer: TensorDict) -> str: """Infer whether the rollout buffer is expert recording or RL training data.""" if { - "obs", "action", "reward", "done", diff --git a/tests/agents/test_shared_rollout.py b/tests/agents/test_shared_rollout.py index 5439a014..d56b358e 100644 --- a/tests/agents/test_shared_rollout.py +++ b/tests/agents/test_shared_rollout.py @@ -231,7 +231,7 @@ class _FakePolicyRawObs: use_raw_obs = True def __init__(self, obs_dim: int, action_dim: int, device: torch.device) -> None: - self.obs_dim = obs_dim + self.obs_dim = 0 self.action_dim = action_dim self.device = device @@ -350,24 +350,29 @@ def test_collector_populates_raw_obs_buffer(): action_dim=action_dim, device=device, use_raw_obs=True, + store_flat_obs=False, ) rollout = collector.collect( num_steps=rollout_len, rollout=buffer.start_rollout(), ) + buffer.add(rollout) + stored = buffer.get(flatten=False) - assert hasattr(rollout, "raw_obs") - assert len(rollout.raw_obs) == rollout_len + 1 + assert "obs" not in stored.keys() + assert hasattr(stored, "raw_obs") + assert len(stored.raw_obs) == rollout_len + 1 for t in range(rollout_len + 1): - assert rollout.raw_obs[t] is not None - assert rollout.raw_obs[t].batch_size == torch.Size([num_envs]) - assert torch.allclose( - rollout["obs"][:, 0], torch.zeros(num_envs, obs_dim, dtype=torch.float32) - ) + assert stored.raw_obs[t] is not None + assert stored.raw_obs[t].batch_size == torch.Size([num_envs]) + assert torch.allclose( + stored.raw_obs[t]["agent"]["state"], + torch.full((num_envs, obs_dim), float(t), dtype=torch.float32), + ) assert torch.allclose( - rollout["obs"][:, -1], - torch.full((num_envs, obs_dim), float(rollout_len), dtype=torch.float32), + stored["value"][:, -1], + torch.full((num_envs,), float(rollout_len), dtype=torch.float32), ) From 1c8e1367254dd5d82e8565dd32c2c5d04c1e5776 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Wed, 1 Apr 2026 13:24:04 +0000 Subject: [PATCH 17/17] wip --- .../agents/rl/buffer/standard_buffer.py | 7 + embodichain/agents/rl/buffer/utils.py | 10 +- .../agents/rl/collector/sync_collector.py | 164 ++++++++++++++++-- embodichain/agents/rl/models/vla_policy.py | 47 ++++- tests/agents/test_shared_rollout.py | 116 ++++++++++++- 5 files changed, 314 insertions(+), 30 deletions(-) diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py index f897755d..1664e596 100644 --- a/embodichain/agents/rl/buffer/standard_buffer.py +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -44,6 +44,12 @@ def __init__( action_chunk_size: int = 0, store_flat_obs: bool = True, ) -> None: + if use_raw_obs and store_flat_obs: + raise ValueError( + "RolloutBuffer does not support storing flat observations when " + "use_raw_obs=True. Set store_flat_obs=False for raw-observation " + "policies." + ) self.num_envs = num_envs self.rollout_len = rollout_len self.obs_dim = obs_dim @@ -184,6 +190,7 @@ def _clear_dynamic_fields(self) -> None: "seq_return", "entropy", "step_repeat", + "execute_full_chunk", ): if key in self._rollout.keys(): del self._rollout[key] diff --git a/embodichain/agents/rl/buffer/utils.py b/embodichain/agents/rl/buffer/utils.py index e65e268e..eebc842f 100644 --- a/embodichain/agents/rl/buffer/utils.py +++ b/embodichain/agents/rl/buffer/utils.py @@ -60,7 +60,15 @@ def transition_view(rollout: TensorDict, flatten: bool = False) -> TensorDict: device=rollout.device, ) - for key in ("advantage", "return", "seq_mask", "seq_return", "entropy"): + for key in ( + "advantage", + "return", + "seq_mask", + "seq_return", + "entropy", + "step_repeat", + "execute_full_chunk", + ): if key in rollout.keys(): td[key] = rollout[key][:, :-1] diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index 5680e7e0..8f6b9236 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -111,6 +111,10 @@ def collect( # Execute a full predicted action chunk inside one logical rollout step. execute_full_chunk = bool(getattr(self.policy, "execute_full_chunk", False)) cached_chunk = None + cached_chunk_log_prob = None + cached_chunk_value = None + cached_chunk_log_prob_scalar = None + cached_chunk_value_scalar = None chunk_cursor = 0 if use_action_chunk: @@ -127,6 +131,12 @@ def collect( device=self.device, ) rollout["step_repeat"][:, -1] = 0.0 + rollout["execute_full_chunk"] = torch.full( + (self.env.num_envs, num_steps + 1), + fill_value=execute_full_chunk, + dtype=torch.bool, + device=self.device, + ) self._store_observation( rollout=rollout, @@ -143,12 +153,7 @@ def collect( step_idx=step_idx, ) step_td = self.policy.get_action(step_td) - chunk = step_td.get("action_chunk") - if chunk is None: - logger.log_error( - "execute_full_chunk=True requires policy to provide 'action_chunk'.", - ValueError, - ) + chunk = self._require_action_chunk(step_td, action_chunk_size) reward_sum = torch.zeros( self.env.num_envs, dtype=torch.float32, device=self.device @@ -188,6 +193,7 @@ def collect( if use_action_chunk: rollout.chunk_step[:, step_idx] = 0 rollout["step_repeat"][:, step_idx] = float(executed_substeps) + self._finalize_execute_full_chunk_step(step_td, executed_substeps) self._write_step( rollout=rollout, @@ -227,7 +233,23 @@ def collect( step_idx=step_idx, ) step_td = self.policy.get_action(step_td) - cached_chunk = step_td["action_chunk"] + cached_chunk = self._require_action_chunk(step_td, action_chunk_size) + cached_chunk_log_prob = self._get_chunk_stat( + step_td, "action_chunk_log_prob" + ) + cached_chunk_value = self._get_chunk_stat(step_td, "action_chunk_value") + cached_chunk_log_prob_scalar = step_td["sample_log_prob"] + cached_chunk_value_scalar = step_td["value"] + step_td["sample_log_prob"] = self._resolve_chunk_stat( + chunk_stat=cached_chunk_log_prob, + fallback=cached_chunk_log_prob_scalar, + step_idx=0, + ) + step_td["value"] = self._resolve_chunk_stat( + chunk_stat=cached_chunk_value, + fallback=cached_chunk_value_scalar, + step_idx=0, + ) action = step_td["action"] effective_step_in_chunk = 0 chunk_cursor = 1 @@ -238,11 +260,15 @@ def collect( { "action": action, "action_chunk": cached_chunk, - "sample_log_prob": torch.zeros( - action.shape[0], device=self.device, dtype=torch.float32 + "sample_log_prob": self._resolve_chunk_stat( + chunk_stat=cached_chunk_log_prob, + fallback=cached_chunk_log_prob_scalar, + step_idx=chunk_cursor, ), - "value": torch.zeros( - action.shape[0], device=self.device, dtype=torch.float32 + "value": self._resolve_chunk_stat( + chunk_stat=cached_chunk_value, + fallback=cached_chunk_value_scalar, + step_idx=chunk_cursor, ), }, batch_size=[rollout.batch_size[0]], @@ -410,6 +436,79 @@ def _store_observation( ) rollout["obs"][:, step_idx] = flatten_dict_observation(obs_td) + def _require_action_chunk( + self, step_td: TensorDict, action_chunk_size: int + ) -> torch.Tensor: + """Return a validated action chunk from the policy output.""" + chunk = step_td.get("action_chunk") + if chunk is None: + logger.log_error( + "Action-chunk policy did not return 'action_chunk'. " + f"policy={type(self.policy).__name__}, expected chunk length={action_chunk_size}.", + ValueError, + ) + expected_shape = ( + step_td.batch_size[0], + action_chunk_size, + self.policy.action_dim, + ) + if tuple(chunk.shape) != expected_shape: + logger.log_error( + "Policy-produced 'action_chunk' shape mismatch: " + f"expected {expected_shape}, got {tuple(chunk.shape)}.", + ValueError, + ) + return chunk + + def _get_chunk_stat(self, step_td: TensorDict, key: str) -> torch.Tensor | None: + """Read an optional per-substep chunk statistic from the policy output.""" + stat = step_td.get(key) + if stat is None: + return None + if stat.dim() != 2 or stat.shape[0] != step_td.batch_size[0]: + logger.log_error( + f"Policy-produced '{key}' must have shape [batch, chunk].", + ValueError, + ) + return stat + + def _resolve_chunk_stat( + self, + chunk_stat: torch.Tensor | None, + fallback: torch.Tensor, + step_idx: int, + ) -> torch.Tensor: + """Resolve the scalar statistic to store for one chunk substep.""" + if chunk_stat is None: + return fallback + if step_idx >= chunk_stat.shape[1]: + logger.log_error( + f"Chunk statistic lookup out of range: step_idx={step_idx}, " + f"chunk_len={chunk_stat.shape[1]}.", + ValueError, + ) + return chunk_stat[:, step_idx] + + def _finalize_execute_full_chunk_step( + self, step_td: TensorDict, executed_substeps: int + ) -> None: + """Mask unexecuted chunk suffixes and align scalar stats to the executed prefix.""" + chunk = step_td.get("action_chunk") + if chunk is None: + return + if executed_substeps < chunk.shape[1]: + masked_chunk = chunk.clone() + masked_chunk[:, executed_substeps:].zero_() + step_td["action_chunk"] = masked_chunk + if "action_chunk_log_prob" in step_td.keys(): + step_td["sample_log_prob"] = step_td["action_chunk_log_prob"][ + :, :executed_substeps + ].mean(dim=1) + if "action_chunk_value" in step_td.keys(): + step_td["value"] = step_td["action_chunk_value"][ + :, :executed_substeps + ].mean(dim=1) + def _validate_rollout(self, rollout: TensorDict, num_steps: int) -> None: """Validate rollout layout expected by the collector.""" num_envs = self.env.num_envs @@ -461,3 +560,46 @@ def _validate_rollout(self, rollout: TensorDict, num_steps: int) -> None: f"expected {expected_shape}, got {actual_shape}.", ValueError, ) + action_chunk_size = int(getattr(self.policy, "action_chunk_size", 0) or 0) + use_action_chunk = bool( + getattr(self.policy, "use_action_chunk", False) and action_chunk_size > 0 + ) + if "action_chunk" in rollout.keys(): + if not use_action_chunk: + logger.log_error( + "Preallocated rollout field 'action_chunk' is present, but the " + "current policy is not configured to use action chunks.", + ValueError, + ) + expected_action_chunk_shape = ( + num_envs, + time_plus_one, + action_chunk_size, + self.policy.action_dim, + ) + actual_action_chunk_shape = tuple(rollout["action_chunk"].shape) + if actual_action_chunk_shape != expected_action_chunk_shape: + logger.log_error( + "Preallocated rollout field 'action_chunk' shape mismatch: " + f"expected {expected_action_chunk_shape}, got {actual_action_chunk_shape}.", + ValueError, + ) + if hasattr(rollout, "chunk_step"): + chunk_step_shape = tuple(rollout.chunk_step.shape) + expected_chunk_step_shape = (num_envs, num_steps) + if chunk_step_shape != expected_chunk_step_shape: + logger.log_error( + "Preallocated rollout attribute 'chunk_step' shape mismatch: " + f"expected {expected_chunk_step_shape}, got {chunk_step_shape}.", + ValueError, + ) + for key in ("step_repeat", "execute_full_chunk"): + if key in rollout.keys(): + actual_shape = tuple(rollout[key].shape) + expected_shape = (num_envs, time_plus_one) + if actual_shape != expected_shape: + logger.log_error( + f"Preallocated rollout field '{key}' shape mismatch: " + f"expected {expected_shape}, got {actual_shape}.", + ValueError, + ) diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py index 1d70c6e0..550b0264 100644 --- a/embodichain/agents/rl/models/vla_policy.py +++ b/embodichain/agents/rl/models/vla_policy.py @@ -206,14 +206,18 @@ def forward( if deterministic: noisy_chunk = mean_chunk - log_prob = torch.zeros( - mean_chunk.shape[0], device=self.device, dtype=torch.float32 + per_step_log_prob = torch.zeros( + mean_chunk.shape[0], + mean_chunk.shape[1], + device=self.device, + dtype=torch.float32, ) else: sigma = self.gaussian_sigma noise = torch.randn_like(mean_chunk) * sigma noisy_chunk = mean_chunk + noise - log_prob = -0.5 * noise.pow(2).sum(-1).mean(-1) / (sigma * sigma + 1e-8) + per_step_log_prob = -0.5 * noise.pow(2).sum(-1) / (sigma * sigma + 1e-8) + log_prob = per_step_log_prob.mean(dim=-1) action = noisy_chunk[:, 0] tensordict["action"] = action @@ -223,6 +227,8 @@ def forward( ) if self.use_action_chunk: tensordict["action_chunk"] = noisy_chunk + tensordict["action_chunk_log_prob"] = per_step_log_prob + tensordict["action_chunk_value"] = torch.zeros_like(per_step_log_prob) return tensordict def get_value(self, tensordict: TensorDict) -> TensorDict: @@ -244,6 +250,8 @@ def evaluate_actions( raw_obs = getattr(rollout, "raw_obs", None) chunk_step = tensordict.get("chunk_step", None) indices = tensordict.get("_indices", None) + execute_full_chunk = tensordict.get("execute_full_chunk", None) + step_repeat = tensordict.get("step_repeat", None) if raw_obs is None or chunk_step is None or indices is None: raise ValueError( "VLAPolicy.evaluate_actions requires rollout.raw_obs, chunk_step, and _indices. " @@ -256,7 +264,12 @@ def evaluate_actions( self._load_vla() stored_chunks = tensordict.get("action_chunk", None) - use_full_chunk = stored_chunks is not None + use_stored_chunks = stored_chunks is not None + execute_full_chunk_mask = ( + execute_full_chunk.bool() + if execute_full_chunk is not None + else torch.zeros((b,), device=self.device, dtype=torch.bool) + ) for i in range(b): idx = int(indices[i].item()) @@ -276,12 +289,19 @@ def evaluate_actions( ) pred_chunk_env = self._vla_chunk_to_env_chunk(vla_chunk, env=env) - if use_full_chunk: + if use_stored_chunks and bool(execute_full_chunk_mask[i].item()): gt_chunk = stored_chunks[i] pred_chunk = pred_chunk_env[0] - min_len = min(gt_chunk.shape[-1], pred_chunk.shape[-1]) + executed_len = int( + step_repeat[i].item() + if step_repeat is not None + else gt_chunk.shape[0] + ) + executed_len = max( + 1, min(executed_len, gt_chunk.shape[0], pred_chunk.shape[0]) + ) mse = ( - ((gt_chunk[..., :min_len] - pred_chunk[..., :min_len]).pow(2)) + ((gt_chunk[:executed_len] - pred_chunk[:executed_len]).pow(2)) .sum(-1) .mean(-1) ) @@ -296,11 +316,20 @@ def evaluate_actions( log_probs.append(log_prob) log_probs = torch.stack(log_probs) - effective_dim = self.action_dim * (self.action_horizon if use_full_chunk else 1) + if step_repeat is not None: + chunk_lengths = step_repeat.to( + device=self.device, dtype=torch.float32 + ).clamp_min(1.0) + else: + chunk_lengths = torch.ones((b,), device=self.device, dtype=torch.float32) + effective_dim = torch.where( + execute_full_chunk_mask, + chunk_lengths * float(self.action_dim), + torch.full((b,), float(self.action_dim), device=self.device), + ) entropy = ( 0.5 * effective_dim * (1 + np.log(2 * np.pi) + 2 * np.log(sigma + 1e-8)) ) - entropy = torch.full((b,), entropy, device=self.device, dtype=torch.float32) return TensorDict( { diff --git a/tests/agents/test_shared_rollout.py b/tests/agents/test_shared_rollout.py index d56b358e..1a29bc9b 100644 --- a/tests/agents/test_shared_rollout.py +++ b/tests/agents/test_shared_rollout.py @@ -277,14 +277,34 @@ def get_action( row0 = obs[:, : self.action_dim] * 0.1 row1 = obs[:, : self.action_dim] * 0.2 chunk = torch.stack([row0, row1], dim=1) - tensordict["action_chunk"] = chunk - tensordict["action"] = chunk[:, 0] - tensordict["sample_log_prob"] = torch.zeros( - obs.shape[0], device=obs.device, dtype=torch.float32 + chunk_log_prob = torch.stack( + [ + torch.full( + (obs.shape[0],), 0.25, device=obs.device, dtype=torch.float32 + ), + torch.full( + (obs.shape[0],), 0.75, device=obs.device, dtype=torch.float32 + ), + ], + dim=1, ) - tensordict["value"] = torch.zeros( - obs.shape[0], device=obs.device, dtype=torch.float32 + chunk_value = torch.stack( + [ + torch.full( + (obs.shape[0],), 1.0, device=obs.device, dtype=torch.float32 + ), + torch.full( + (obs.shape[0],), 2.0, device=obs.device, dtype=torch.float32 + ), + ], + dim=1, ) + tensordict["action_chunk"] = chunk + tensordict["action"] = chunk[:, 0] + tensordict["sample_log_prob"] = chunk_log_prob[:, 0] + tensordict["value"] = chunk_value[:, 0] + tensordict["action_chunk_log_prob"] = chunk_log_prob + tensordict["action_chunk_value"] = chunk_value return tensordict def get_value(self, tensordict: TensorDict) -> TensorDict: @@ -315,12 +335,14 @@ def get_action( t = self.action_chunk_size base = obs[:, : self.action_dim] * 0.1 chunk = base.unsqueeze(1).expand(n, t, -1).clone() + chunk_log_prob = torch.tensor( + [0.2, 0.4, 0.6], device=obs.device, dtype=torch.float32 + ).expand(n, -1) tensordict["action_chunk"] = chunk tensordict["action"] = chunk[:, 0] - tensordict["sample_log_prob"] = torch.zeros( - n, device=obs.device, dtype=torch.float32 - ) + tensordict["sample_log_prob"] = chunk_log_prob.mean(dim=1) tensordict["value"] = torch.zeros(n, device=obs.device, dtype=torch.float32) + tensordict["action_chunk_log_prob"] = chunk_log_prob return tensordict def get_value(self, tensordict: TensorDict) -> TensorDict: @@ -328,6 +350,15 @@ def get_value(self, tensordict: TensorDict) -> TensorDict: return tensordict +class _FakeEnvTerminateOnFirstChunkStep(_FakeEnv): + """Environment that terminates after the first executed chunk substep.""" + + def step(self, action): + next_obs, reward, terminated, truncated, info = super().step(action) + terminated = torch.ones(self.num_envs, dtype=torch.bool, device=self.device) + return next_obs, reward, terminated, truncated, info + + def test_collector_populates_raw_obs_buffer(): device = torch.device("cpu") num_envs = 2 @@ -415,6 +446,16 @@ def test_collector_chunk_step_alternates_for_sequential_action_chunk(): assert torch.all(rollout["step_repeat"][:, :-1] == 1.0) assert torch.allclose(rollout["action_chunk"][:, 1], rollout["action_chunk"][:, 0]) assert torch.allclose(rollout["action_chunk"][:, 3], rollout["action_chunk"][:, 2]) + expected_log_prob = torch.tensor([0.25, 0.75, 0.25, 0.75], device=device).expand( + num_envs, -1 + ) + expected_value = torch.tensor([1.0, 2.0, 1.0, 2.0], device=device).expand( + num_envs, -1 + ) + assert torch.allclose( + rollout["sample_log_prob"][:, :rollout_len], expected_log_prob + ) + assert torch.allclose(rollout["value"][:, :rollout_len], expected_value) def test_collector_execute_full_chunk_sets_step_repeat_and_action_chunk(): @@ -460,3 +501,60 @@ def test_collector_execute_full_chunk_sets_step_repeat_and_action_chunk(): assert torch.all(rollout.chunk_step[:, :rollout_len] == 0) assert torch.all(rollout["step_repeat"][:, :rollout_len] == float(chunk_t)) assert torch.all(rollout["step_repeat"][:, -1] == 0.0) + + +def test_execute_full_chunk_masks_unexecuted_suffix_on_early_done(): + device = torch.device("cpu") + num_envs = 2 + rollout_len = 1 + obs_dim = 5 + action_dim = 2 + + env = _FakeEnvTerminateOnFirstChunkStep( + num_envs=num_envs, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + ) + policy = _FakePolicyExecuteFullChunk( + obs_dim=obs_dim, action_dim=action_dim, device=device + ) + collector = SyncCollector(env=env, policy=policy, device=device) + buffer = RolloutBuffer( + num_envs=num_envs, + rollout_len=rollout_len, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + action_chunk_size=policy.action_chunk_size, + ) + + rollout = collector.collect( + num_steps=rollout_len, + rollout=buffer.start_rollout(), + ) + + assert torch.all(rollout["step_repeat"][:, 0] == 1.0) + assert torch.all(rollout["action_chunk"][:, 0, 1:] == 0.0) + assert torch.allclose( + rollout["sample_log_prob"][:, 0], + torch.full((num_envs,), 0.2, dtype=torch.float32, device=device), + ) + + +def test_rollout_buffer_rejects_raw_and_flat_obs_storage_together(): + device = torch.device("cpu") + try: + RolloutBuffer( + num_envs=1, + rollout_len=1, + obs_dim=4, + action_dim=2, + device=device, + use_raw_obs=True, + store_flat_obs=True, + ) + except ValueError as exc: + assert "store_flat_obs=False" in str(exc) + else: + raise AssertionError("Expected raw+flat rollout buffer configuration to fail.")