diff --git a/docs/source/overview/rl/trainer.md b/docs/source/overview/rl/trainer.md index 5ef4ee99..b90dd991 100644 --- a/docs/source/overview/rl/trainer.md +++ b/docs/source/overview/rl/trainer.md @@ -51,4 +51,9 @@ trainer.save_checkpoint() - The event mechanism can be used for automated experiments, data collection, and environment reset. - Logging and monitoring help analyze training progress and tune hyperparameters. +## API References +- VLA backend lookup: `embodichain.agents.rl.vla_registry.get_vla_backend()` +- VLA backend listing: `embodichain.agents.rl.vla_registry.get_registered_vla_backend_names()` +- VLA backend creation: `embodichain.agents.rl.vla_registry.create_vla_backend()` + --- diff --git a/embodichain/agents/rl/algo/grpo.py b/embodichain/agents/rl/algo/grpo.py index ea26c84e..376502b4 100644 --- a/embodichain/agents/rl/algo/grpo.py +++ b/embodichain/agents/rl/algo/grpo.py @@ -113,7 +113,12 @@ def _compute_step_group_advantages( return advantages.view(n_envs, t_steps) * seq_mask def update(self, rollout: TensorDict) -> Dict[str, float]: - rollout = rollout.clone() + raw_obs = getattr(rollout, "raw_obs", None) + chunk_step = getattr(rollout, "chunk_step", None) + if raw_obs is not None: + rollout.raw_obs = raw_obs + if chunk_step is not None: + rollout.chunk_step = chunk_step num_envs = rollout.batch_size[0] if num_envs % self.cfg.group_size != 0: raise ValueError( @@ -149,7 +154,7 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: seq_mask_batch = batch["seq_mask"].float() policy_module = getattr(self.policy, "module", self.policy) - eval_batch = policy_module.evaluate_actions(batch) + eval_batch = policy_module.evaluate_actions(batch, rollout=rollout) logprobs = eval_batch["sample_log_prob"] entropy = eval_batch["entropy"] ratio = (logprobs - old_logprobs).exp() @@ -168,7 +173,9 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: if self.ref_policy is not None: with torch.no_grad(): - ref_batch = self.ref_policy.evaluate_actions(batch) + ref_batch = self.ref_policy.evaluate_actions( + batch, rollout=rollout + ) ref_logprobs = ref_batch["sample_log_prob"] log_ref_over_pi = ref_logprobs - logprobs kl_per = torch.exp(log_ref_over_pi) - log_ref_over_pi - 1.0 diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index 4a4dfb3c..1b326876 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -48,7 +48,6 @@ def __init__(self, cfg: PPOCfg, policy): def update(self, rollout: TensorDict) -> Dict[str, float]: """Update the policy using a collected rollout.""" - rollout = rollout.clone() compute_gae(rollout, gamma=self.cfg.gamma, gae_lambda=self.cfg.gae_lambda) flat_rollout = transition_view(rollout, flatten=True) diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py index 57a4899b..1664e596 100644 --- a/embodichain/agents/rl/buffer/standard_buffer.py +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -40,12 +40,24 @@ def __init__( obs_dim: int, action_dim: int, device: torch.device, + use_raw_obs: bool = False, + action_chunk_size: int = 0, + store_flat_obs: bool = True, ) -> None: + if use_raw_obs and store_flat_obs: + raise ValueError( + "RolloutBuffer does not support storing flat observations when " + "use_raw_obs=True. Set store_flat_obs=False for raw-observation " + "policies." + ) self.num_envs = num_envs self.rollout_len = rollout_len self.obs_dim = obs_dim self.action_dim = action_dim self.device = device + self.use_raw_obs = use_raw_obs + self.action_chunk_size = action_chunk_size + self.store_flat_obs = store_flat_obs self._rollout = self._allocate_rollout() self._is_full = False @@ -58,6 +70,8 @@ def start_rollout(self) -> TensorDict: if self._is_full: raise RuntimeError("RolloutBuffer already contains a rollout.") self._clear_dynamic_fields() + if self.use_raw_obs: + self._rollout.raw_obs = [None] * (self.rollout_len + 1) return self._rollout def add(self, rollout: TensorDict) -> None: @@ -97,15 +111,18 @@ def is_full(self) -> bool: def _allocate_rollout(self) -> TensorDict: """Preallocate rollout storage with uniform `[num_envs, time + 1]` shape.""" - return TensorDict( + rollout_tensors = {} + if self.store_flat_obs: + rollout_tensors["obs"] = torch.empty( + self.num_envs, + self.rollout_len + 1, + self.obs_dim, + dtype=torch.float32, + device=self.device, + ) + td = TensorDict( { - "obs": torch.empty( - self.num_envs, - self.rollout_len + 1, - self.obs_dim, - dtype=torch.float32, - device=self.device, - ), + **rollout_tensors, "action": torch.empty( self.num_envs, self.rollout_len + 1, @@ -153,12 +170,34 @@ def _allocate_rollout(self) -> TensorDict: batch_size=[self.num_envs, self.rollout_len + 1], device=self.device, ) + if self.action_chunk_size > 0: + td["action_chunk"] = torch.zeros( + self.num_envs, + self.rollout_len + 1, + self.action_chunk_size, + self.action_dim, + dtype=torch.float32, + device=self.device, + ) + return td def _clear_dynamic_fields(self) -> None: """Drop algorithm-added fields before reusing the shared rollout.""" - for key in ("advantage", "return", "seq_mask", "seq_return", "entropy"): + for key in ( + "advantage", + "return", + "seq_mask", + "seq_return", + "entropy", + "step_repeat", + "execute_full_chunk", + ): if key in self._rollout.keys(): del self._rollout[key] + if self.use_raw_obs and hasattr(self._rollout, "raw_obs"): + delattr(self._rollout, "raw_obs") + if hasattr(self._rollout, "chunk_step"): + delattr(self._rollout, "chunk_step") self._reset_padding_slot() def _reset_padding_slot(self) -> None: @@ -170,11 +209,12 @@ def _reset_padding_slot(self) -> None: self._rollout["done"][:, last_idx].fill_(False) self._rollout["terminated"][:, last_idx].fill_(False) self._rollout["truncated"][:, last_idx].fill_(False) + if "action_chunk" in self._rollout.keys(): + self._rollout["action_chunk"][:, last_idx].zero_() def _validate_rollout_layout(self, rollout: TensorDict) -> None: """Validate the expected tensor shapes for the shared rollout.""" expected_shapes = { - "obs": (self.num_envs, self.rollout_len + 1, self.obs_dim), "action": (self.num_envs, self.rollout_len + 1, self.action_dim), "sample_log_prob": (self.num_envs, self.rollout_len + 1), "value": (self.num_envs, self.rollout_len + 1), @@ -183,6 +223,12 @@ def _validate_rollout_layout(self, rollout: TensorDict) -> None: "terminated": (self.num_envs, self.rollout_len + 1), "truncated": (self.num_envs, self.rollout_len + 1), } + if self.store_flat_obs: + expected_shapes["obs"] = ( + self.num_envs, + self.rollout_len + 1, + self.obs_dim, + ) for key, expected_shape in expected_shapes.items(): actual_shape = tuple(rollout[key].shape) if actual_shape != expected_shape: @@ -190,3 +236,8 @@ def _validate_rollout_layout(self, rollout: TensorDict) -> None: f"Rollout field '{key}' shape mismatch: expected {expected_shape}, " f"got {actual_shape}." ) + if not self.store_flat_obs and "obs" in rollout.keys(): + raise ValueError( + "RolloutBuffer configured with store_flat_obs=False must not contain " + "a preallocated 'obs' field." + ) diff --git a/embodichain/agents/rl/buffer/utils.py b/embodichain/agents/rl/buffer/utils.py index 7c0d265b..eebc842f 100644 --- a/embodichain/agents/rl/buffer/utils.py +++ b/embodichain/agents/rl/buffer/utils.py @@ -42,26 +42,42 @@ def transition_view(rollout: TensorDict, flatten: bool = False) -> TensorDict: """ action = rollout["action"][:, :-1] num_envs, time_dim = action.shape[:2] + transition_fields = { + "action": action, + "sample_log_prob": rollout["sample_log_prob"][:, :-1], + "value": rollout["value"][:, :-1], + "next_value": rollout["value"][:, 1:], + "reward": rollout["reward"][:, :-1], + "done": rollout["done"][:, :-1], + "terminated": rollout["terminated"][:, :-1], + "truncated": rollout["truncated"][:, :-1], + } + if "obs" in rollout.keys(): + transition_fields["obs"] = rollout["obs"][:, :-1] td = TensorDict( - { - "obs": rollout["obs"][:, :-1], - "action": action, - "sample_log_prob": rollout["sample_log_prob"][:, :-1], - "value": rollout["value"][:, :-1], - "next_value": rollout["value"][:, 1:], - "reward": rollout["reward"][:, :-1], - "done": rollout["done"][:, :-1], - "terminated": rollout["terminated"][:, :-1], - "truncated": rollout["truncated"][:, :-1], - }, + transition_fields, batch_size=[num_envs, time_dim], device=rollout.device, ) - for key in ("advantage", "return", "seq_mask", "seq_return", "entropy"): + for key in ( + "advantage", + "return", + "seq_mask", + "seq_return", + "entropy", + "step_repeat", + "execute_full_chunk", + ): if key in rollout.keys(): td[key] = rollout[key][:, :-1] + if hasattr(rollout, "chunk_step") and rollout.chunk_step is not None: + td["chunk_step"] = rollout.chunk_step + + if "action_chunk" in rollout.keys(): + td["action_chunk"] = rollout["action_chunk"][:, :-1] + if flatten: return td.reshape(num_envs * time_dim) return td @@ -72,6 +88,10 @@ def iterate_minibatches( ) -> Iterator[TensorDict]: """Yield shuffled minibatches from a flattened rollout.""" total = rollout.batch_size[0] - indices = torch.randperm(total, device=device) + idx_device = rollout.device if rollout.device is not None else device + indices = torch.randperm(total, device=idx_device) for start in range(0, total, batch_size): - yield rollout[indices[start : start + batch_size]] + batch_indices = indices[start : start + batch_size] + batch = rollout[batch_indices].clone() + batch["_indices"] = batch_indices + yield batch diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index 15160e61..8f6b9236 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -22,6 +22,7 @@ from tensordict import TensorDict from embodichain.agents.rl.utils import dict_to_tensordict, flatten_dict_observation +from embodichain.utils import logger from .base import BaseCollector __all__ = ["SyncCollector"] @@ -41,7 +42,10 @@ def __init__( self.policy = policy self.device = device self.reset_every_rollout = reset_every_rollout - self._supports_shared_rollout = hasattr(self.env, "set_rollout_buffer") + execute_full_chunk = bool(getattr(self.policy, "execute_full_chunk", False)) + self._supports_shared_rollout = ( + hasattr(self.env, "set_rollout_buffer") and not execute_full_chunk + ) self.obs_td = self._reset_env() @torch.no_grad() @@ -56,32 +60,237 @@ def collect( self.obs_td = self._reset_env() if rollout is None: - raise ValueError( - "SyncCollector.collect() requires a preallocated rollout TensorDict." + logger.log_error( + "SyncCollector.collect() requires a preallocated rollout TensorDict.", + ValueError, ) if tuple(rollout.batch_size) != (self.env.num_envs, num_steps + 1): - raise ValueError( + logger.log_error( "Preallocated rollout batch size mismatch: " - f"expected ({self.env.num_envs}, {num_steps + 1}), got {tuple(rollout.batch_size)}." + f"expected ({self.env.num_envs}, {num_steps + 1}), got {tuple(rollout.batch_size)}.", + ValueError, ) self._validate_rollout(rollout, num_steps) if self._supports_shared_rollout: self.env.set_rollout_buffer(rollout) - initial_obs = flatten_dict_observation(self.obs_td) - rollout["obs"][:, 0] = initial_obs - for step_idx in range(num_steps): - step_td = TensorDict( - {"obs": rollout["obs"][:, step_idx]}, - batch_size=[rollout.batch_size[0]], + use_raw_obs = getattr(self.policy, "use_raw_obs", False) + raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None + + if use_raw_obs: + if raw_obs_list is None: + logger.log_error( + "Policy requires raw observations, " + "but the provided rollout TensorDict has no 'raw_obs' buffer. " + "Create the rollout via RolloutBuffer or " + "start_rollout so that 'raw_obs' is allocated.", + ValueError, + ) + try: + raw_obs_len = len(raw_obs_list) + except TypeError: + logger.log_error( + "Rollout field 'raw_obs' must be an indexable sequence of length " + f"{num_steps + 1} when policy.use_raw_obs=True.", + ValueError, + ) + expected_len = num_steps + 1 + if raw_obs_len != expected_len: + logger.log_error( + "Rollout 'raw_obs' length mismatch: " + f"expected {expected_len} (num_steps + 1), got {raw_obs_len}. " + "Ensure the rollout was created with use_raw_obs=True and " + "its time dimension matches the requested num_steps.", + ValueError, + ) + + action_chunk_size = getattr(self.policy, "action_chunk_size", 0) + use_action_chunk = ( + getattr(self.policy, "use_action_chunk", False) and action_chunk_size > 0 + ) + # Execute a full predicted action chunk inside one logical rollout step. + execute_full_chunk = bool(getattr(self.policy, "execute_full_chunk", False)) + cached_chunk = None + cached_chunk_log_prob = None + cached_chunk_value = None + cached_chunk_log_prob_scalar = None + cached_chunk_value_scalar = None + chunk_cursor = 0 + + if use_action_chunk: + rollout.chunk_step = torch.zeros( + self.env.num_envs, + num_steps, + dtype=torch.long, device=self.device, ) - step_td = self.policy.get_action(step_td) + rollout["step_repeat"] = torch.ones( + self.env.num_envs, + num_steps + 1, + dtype=torch.float32, + device=self.device, + ) + rollout["step_repeat"][:, -1] = 0.0 + rollout["execute_full_chunk"] = torch.full( + (self.env.num_envs, num_steps + 1), + fill_value=execute_full_chunk, + dtype=torch.bool, + device=self.device, + ) + + self._store_observation( + rollout=rollout, + raw_obs_list=raw_obs_list, + step_idx=0, + obs_td=self.obs_td, + ) + + for step_idx in range(num_steps): + if execute_full_chunk and use_action_chunk: + step_td = self._policy_input_tensordict( + rollout=rollout, + raw_obs_list=raw_obs_list, + step_idx=step_idx, + ) + step_td = self.policy.get_action(step_td) + chunk = self._require_action_chunk(step_td, action_chunk_size) + + reward_sum = torch.zeros( + self.env.num_envs, dtype=torch.float32, device=self.device + ) + terminated = torch.zeros( + self.env.num_envs, dtype=torch.bool, device=self.device + ) + truncated = torch.zeros( + self.env.num_envs, dtype=torch.bool, device=self.device + ) + env_info = {} + next_obs_td = None + + executed_substeps = 0 + # Execute the whole chunk sequentially + for sub_idx in range(action_chunk_size): + sub_action = chunk[:, sub_idx] + next_obs, reward, term_i, trunc_i, env_info = self.env.step( + self._to_action_dict(sub_action) + ) + executed_substeps += 1 + next_obs_td = dict_to_tensordict(next_obs, self.device) + reward_sum += reward.to(self.device).float() + terminated |= term_i.to(self.device) + truncated |= trunc_i.to(self.device) + + # Stop chunk execution when any env reaches terminal/truncated. + if (term_i | trunc_i).any(): + break + + if next_obs_td is None: + logger.log_error( + "Chunk execution produced no environment transition.", + RuntimeError, + ) + + if use_action_chunk: + rollout.chunk_step[:, step_idx] = 0 + rollout["step_repeat"][:, step_idx] = float(executed_substeps) + self._finalize_execute_full_chunk_step(step_td, executed_substeps) + + self._write_step( + rollout=rollout, + step_idx=step_idx, + step_td=step_td, + ) + if not self._supports_shared_rollout: + self._write_env_step( + rollout=rollout, + step_idx=step_idx, + reward=reward_sum, + terminated=terminated, + truncated=truncated, + ) + self._store_observation( + rollout=rollout, + raw_obs_list=raw_obs_list, + step_idx=step_idx + 1, + obs_td=next_obs_td, + ) + + if on_step_callback is not None: + on_step_callback(rollout[:, step_idx], env_info) + + self.obs_td = next_obs_td + continue + + # Execute a predicted chunk sequentially + need_new_chunk = use_action_chunk and ( + cached_chunk is None or chunk_cursor >= action_chunk_size + ) + + if need_new_chunk: + step_td = self._policy_input_tensordict( + rollout=rollout, + raw_obs_list=raw_obs_list, + step_idx=step_idx, + ) + step_td = self.policy.get_action(step_td) + cached_chunk = self._require_action_chunk(step_td, action_chunk_size) + cached_chunk_log_prob = self._get_chunk_stat( + step_td, "action_chunk_log_prob" + ) + cached_chunk_value = self._get_chunk_stat(step_td, "action_chunk_value") + cached_chunk_log_prob_scalar = step_td["sample_log_prob"] + cached_chunk_value_scalar = step_td["value"] + step_td["sample_log_prob"] = self._resolve_chunk_stat( + chunk_stat=cached_chunk_log_prob, + fallback=cached_chunk_log_prob_scalar, + step_idx=0, + ) + step_td["value"] = self._resolve_chunk_stat( + chunk_stat=cached_chunk_value, + fallback=cached_chunk_value_scalar, + step_idx=0, + ) + action = step_td["action"] + effective_step_in_chunk = 0 + chunk_cursor = 1 + elif use_action_chunk and cached_chunk is not None: + action = cached_chunk[:, chunk_cursor] + effective_step_in_chunk = chunk_cursor + step_td = TensorDict( + { + "action": action, + "action_chunk": cached_chunk, + "sample_log_prob": self._resolve_chunk_stat( + chunk_stat=cached_chunk_log_prob, + fallback=cached_chunk_log_prob_scalar, + step_idx=chunk_cursor, + ), + "value": self._resolve_chunk_stat( + chunk_stat=cached_chunk_value, + fallback=cached_chunk_value_scalar, + step_idx=chunk_cursor, + ), + }, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + chunk_cursor += 1 + else: + step_td = self._policy_input_tensordict( + rollout=rollout, + raw_obs_list=raw_obs_list, + step_idx=step_idx, + ) + step_td = self.policy.get_action(step_td) + action = step_td["action"] next_obs, reward, terminated, truncated, env_info = self.env.step( - self._to_action_dict(step_td["action"]) + self._to_action_dict(action) ) next_obs_td = dict_to_tensordict(next_obs, self.device) + if use_action_chunk: + rollout.chunk_step[:, step_idx] = effective_step_in_chunk + rollout["step_repeat"][:, step_idx] = 1.0 self._write_step( rollout=rollout, step_idx=step_idx, @@ -95,7 +304,12 @@ def collect( terminated=terminated, truncated=truncated, ) - rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td) + self._store_observation( + rollout=rollout, + raw_obs_list=raw_obs_list, + step_idx=step_idx + 1, + obs_td=next_obs_td, + ) if on_step_callback is not None: on_step_callback(rollout[:, step_idx], env_info) @@ -107,9 +321,16 @@ def collect( def _attach_final_value(self, rollout: TensorDict) -> None: """Populate the bootstrap value for the final observed state.""" - final_obs = rollout["obs"][:, -1] + use_raw_obs = getattr(self.policy, "use_raw_obs", False) + raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None last_next_td = TensorDict( - {"obs": final_obs}, + { + "obs": self._policy_obs_at( + rollout=rollout, + raw_obs_list=raw_obs_list, + step_idx=rollout.batch_size[1] - 1, + ) + }, batch_size=[rollout.batch_size[0]], device=self.device, ) @@ -137,6 +358,8 @@ def _write_step( rollout["action"][:, step_idx] = step_td["action"] rollout["sample_log_prob"][:, step_idx] = step_td["sample_log_prob"] rollout["value"][:, step_idx] = step_td["value"] + if "action_chunk" in rollout.keys() and "action_chunk" in step_td.keys(): + rollout["action_chunk"][:, step_idx] = step_td["action_chunk"] def _write_env_step( self, @@ -153,22 +376,230 @@ def _write_env_step( rollout["terminated"][:, step_idx] = terminated.to(self.device) rollout["truncated"][:, step_idx] = truncated.to(self.device) + def _policy_input_tensordict( + self, + rollout: TensorDict, + raw_obs_list: list[TensorDict | None] | None, + step_idx: int, + ) -> TensorDict: + """Build the policy input TensorDict for a rollout step.""" + obs = self._policy_obs_at( + rollout=rollout, + raw_obs_list=raw_obs_list, + step_idx=step_idx, + ) + return TensorDict( + {"obs": obs}, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + + def _policy_obs_at( + self, + rollout: TensorDict, + raw_obs_list: list[TensorDict | None] | None, + step_idx: int, + ) -> TensorDict | torch.Tensor: + """Read the observation representation expected by the current policy.""" + if raw_obs_list is not None: + obs = raw_obs_list[step_idx] + if obs is None: + logger.log_error( + f"Missing raw observation at rollout step {step_idx}.", + RuntimeError, + ) + return obs + if "obs" not in rollout.keys(): + logger.log_error( + "Collector requires rollout['obs'] for policies that do not use raw " + "observations.", + ValueError, + ) + return rollout["obs"][:, step_idx] + + def _store_observation( + self, + rollout: TensorDict, + raw_obs_list: list[TensorDict | None] | None, + step_idx: int, + obs_td: TensorDict, + ) -> None: + """Write the current observation into whichever rollout views are enabled.""" + if raw_obs_list is not None: + raw_obs_list[step_idx] = obs_td + if "obs" in rollout.keys(): + if raw_obs_list is not None: + logger.log_error( + "Rollout should not allocate flat observations when raw observation " + "storage is enabled.", + ValueError, + ) + rollout["obs"][:, step_idx] = flatten_dict_observation(obs_td) + + def _require_action_chunk( + self, step_td: TensorDict, action_chunk_size: int + ) -> torch.Tensor: + """Return a validated action chunk from the policy output.""" + chunk = step_td.get("action_chunk") + if chunk is None: + logger.log_error( + "Action-chunk policy did not return 'action_chunk'. " + f"policy={type(self.policy).__name__}, expected chunk length={action_chunk_size}.", + ValueError, + ) + expected_shape = ( + step_td.batch_size[0], + action_chunk_size, + self.policy.action_dim, + ) + if tuple(chunk.shape) != expected_shape: + logger.log_error( + "Policy-produced 'action_chunk' shape mismatch: " + f"expected {expected_shape}, got {tuple(chunk.shape)}.", + ValueError, + ) + return chunk + + def _get_chunk_stat(self, step_td: TensorDict, key: str) -> torch.Tensor | None: + """Read an optional per-substep chunk statistic from the policy output.""" + stat = step_td.get(key) + if stat is None: + return None + if stat.dim() != 2 or stat.shape[0] != step_td.batch_size[0]: + logger.log_error( + f"Policy-produced '{key}' must have shape [batch, chunk].", + ValueError, + ) + return stat + + def _resolve_chunk_stat( + self, + chunk_stat: torch.Tensor | None, + fallback: torch.Tensor, + step_idx: int, + ) -> torch.Tensor: + """Resolve the scalar statistic to store for one chunk substep.""" + if chunk_stat is None: + return fallback + if step_idx >= chunk_stat.shape[1]: + logger.log_error( + f"Chunk statistic lookup out of range: step_idx={step_idx}, " + f"chunk_len={chunk_stat.shape[1]}.", + ValueError, + ) + return chunk_stat[:, step_idx] + + def _finalize_execute_full_chunk_step( + self, step_td: TensorDict, executed_substeps: int + ) -> None: + """Mask unexecuted chunk suffixes and align scalar stats to the executed prefix.""" + chunk = step_td.get("action_chunk") + if chunk is None: + return + if executed_substeps < chunk.shape[1]: + masked_chunk = chunk.clone() + masked_chunk[:, executed_substeps:].zero_() + step_td["action_chunk"] = masked_chunk + if "action_chunk_log_prob" in step_td.keys(): + step_td["sample_log_prob"] = step_td["action_chunk_log_prob"][ + :, :executed_substeps + ].mean(dim=1) + if "action_chunk_value" in step_td.keys(): + step_td["value"] = step_td["action_chunk_value"][ + :, :executed_substeps + ].mean(dim=1) + def _validate_rollout(self, rollout: TensorDict, num_steps: int) -> None: """Validate rollout layout expected by the collector.""" + num_envs = self.env.num_envs + time_plus_one = num_steps + 1 + policy_obs_dim = int(getattr(self.policy, "obs_dim", 0) or 0) + has_flat_obs = "obs" in rollout.keys() + if has_flat_obs: + obs_shape = tuple(rollout["obs"].shape) + if policy_obs_dim > 0: + expected_obs = (num_envs, time_plus_one, policy_obs_dim) + if obs_shape != expected_obs: + logger.log_error( + f"Preallocated rollout field 'obs' shape mismatch: " + f"expected {expected_obs}, got {obs_shape}.", + ValueError, + ) + else: + if ( + len(obs_shape) != 3 + or obs_shape[0] != num_envs + or obs_shape[1] != time_plus_one + ): + logger.log_error( + f"Preallocated rollout field 'obs' shape mismatch: " + f"expected ({num_envs}, {time_plus_one}, *), got {obs_shape}.", + ValueError, + ) + elif not getattr(self.policy, "use_raw_obs", False): + logger.log_error( + "Preallocated rollout TensorDict must contain 'obs' when " + "policy.use_raw_obs=False.", + ValueError, + ) + expected_shapes = { - "obs": (self.env.num_envs, num_steps + 1, self.policy.obs_dim), - "action": (self.env.num_envs, num_steps + 1, self.policy.action_dim), - "sample_log_prob": (self.env.num_envs, num_steps + 1), - "value": (self.env.num_envs, num_steps + 1), - "reward": (self.env.num_envs, num_steps + 1), - "done": (self.env.num_envs, num_steps + 1), - "terminated": (self.env.num_envs, num_steps + 1), - "truncated": (self.env.num_envs, num_steps + 1), + "action": (num_envs, time_plus_one, self.policy.action_dim), + "sample_log_prob": (num_envs, time_plus_one), + "value": (num_envs, time_plus_one), + "reward": (num_envs, time_plus_one), + "done": (num_envs, time_plus_one), + "terminated": (num_envs, time_plus_one), + "truncated": (num_envs, time_plus_one), } for key, expected_shape in expected_shapes.items(): actual_shape = tuple(rollout[key].shape) if actual_shape != expected_shape: - raise ValueError( + logger.log_error( f"Preallocated rollout field '{key}' shape mismatch: " - f"expected {expected_shape}, got {actual_shape}." + f"expected {expected_shape}, got {actual_shape}.", + ValueError, + ) + action_chunk_size = int(getattr(self.policy, "action_chunk_size", 0) or 0) + use_action_chunk = bool( + getattr(self.policy, "use_action_chunk", False) and action_chunk_size > 0 + ) + if "action_chunk" in rollout.keys(): + if not use_action_chunk: + logger.log_error( + "Preallocated rollout field 'action_chunk' is present, but the " + "current policy is not configured to use action chunks.", + ValueError, + ) + expected_action_chunk_shape = ( + num_envs, + time_plus_one, + action_chunk_size, + self.policy.action_dim, + ) + actual_action_chunk_shape = tuple(rollout["action_chunk"].shape) + if actual_action_chunk_shape != expected_action_chunk_shape: + logger.log_error( + "Preallocated rollout field 'action_chunk' shape mismatch: " + f"expected {expected_action_chunk_shape}, got {actual_action_chunk_shape}.", + ValueError, + ) + if hasattr(rollout, "chunk_step"): + chunk_step_shape = tuple(rollout.chunk_step.shape) + expected_chunk_step_shape = (num_envs, num_steps) + if chunk_step_shape != expected_chunk_step_shape: + logger.log_error( + "Preallocated rollout attribute 'chunk_step' shape mismatch: " + f"expected {expected_chunk_step_shape}, got {chunk_step_shape}.", + ValueError, ) + for key in ("step_repeat", "execute_full_chunk"): + if key in rollout.keys(): + actual_shape = tuple(rollout[key].shape) + expected_shape = (num_envs, time_plus_one) + if actual_shape != expected_shape: + logger.log_error( + f"Preallocated rollout field '{key}' shape mismatch: " + f"expected {expected_shape}, got {actual_shape}.", + ValueError, + ) diff --git a/embodichain/agents/rl/models/__init__.py b/embodichain/agents/rl/models/__init__.py index 51cf7653..1eee5983 100644 --- a/embodichain/agents/rl/models/__init__.py +++ b/embodichain/agents/rl/models/__init__.py @@ -17,7 +17,7 @@ from __future__ import annotations import inspect -from typing import Dict, Type +from typing import Any, Dict, Optional, Type from gymnasium import spaces import torch @@ -26,6 +26,7 @@ from .actor_only import ActorOnly from .policy import Policy from .mlp import MLP +from .vla_policy import VLAPolicy # In-module policy registry _POLICY_REGISTRY: Dict[str, Type[Policy]] = {} @@ -63,13 +64,16 @@ def build_policy( device: torch.device, actor: torch.nn.Module | None = None, critic: torch.nn.Module | None = None, + env: Optional[Any] = None, ) -> Policy: """Build a policy from config using spaces for extensibility. Built-in MLP policies still resolve flattened `obs_dim` / `action_dim`, while custom policies may accept richer `obs_space` / `action_space` inputs. + For vla_policy, pass env so set_env can run; VLA weights load lazily on first use. """ name = policy_block["name"].lower() + if name not in _POLICY_REGISTRY: available = ", ".join(get_registered_policy_names()) raise ValueError( @@ -119,7 +123,17 @@ def build_policy( build_kwargs["actor"] = actor if "critic" in init_params and critic is not None: build_kwargs["critic"] = critic - return policy_cls(**build_kwargs) + if "policy_cfg" in init_params: + build_kwargs["policy_cfg"] = policy_block + policy = policy_cls(**build_kwargs) + if name == "vla_policy": + if env is None: + raise ValueError( + "VLAPolicy requires an 'env' argument to be passed to build_policy " + "so that set_env can be called before use." + ) + policy.set_env(env) + return policy def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: @@ -143,10 +157,12 @@ def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: # default registrations register_policy("actor_critic", ActorCritic) register_policy("actor_only", ActorOnly) +register_policy("vla_policy", VLAPolicy) __all__ = [ "ActorCritic", "ActorOnly", + "VLAPolicy", "register_policy", "get_registered_policy_names", "build_policy", diff --git a/embodichain/agents/rl/models/actor_critic.py b/embodichain/agents/rl/models/actor_critic.py index 32caf0e3..8016ddcd 100644 --- a/embodichain/agents/rl/models/actor_critic.py +++ b/embodichain/agents/rl/models/actor_critic.py @@ -86,7 +86,7 @@ def get_value(self, tensordict: TensorDict) -> TensorDict: tensordict["value"] = self.critic(tensordict["obs"]).squeeze(-1) return tensordict - def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + def evaluate_actions(self, tensordict: TensorDict, **kwargs) -> TensorDict: obs = tensordict["obs"] action = tensordict["action"] dist = self._distribution(obs) diff --git a/embodichain/agents/rl/models/actor_only.py b/embodichain/agents/rl/models/actor_only.py index 3d6d1f78..0f93ce8f 100644 --- a/embodichain/agents/rl/models/actor_only.py +++ b/embodichain/agents/rl/models/actor_only.py @@ -77,7 +77,7 @@ def get_value(self, tensordict: TensorDict) -> TensorDict: ) return tensordict - def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + def evaluate_actions(self, tensordict: TensorDict, **kwargs) -> TensorDict: obs = tensordict["obs"] action = tensordict["action"] dist = self._distribution(obs) diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py new file mode 100644 index 00000000..550b0264 --- /dev/null +++ b/embodichain/agents/rl/models/vla_policy.py @@ -0,0 +1,342 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from tensordict import TensorDict +from embodichain.agents.rl.vla_registry import create_vla_backend +from .policy import Policy + +__all__ = ["VLAPolicy"] + + +class VLAPolicy(Policy): + """Wraps DexForceVLA as Policy for GRPO fine-tuning.""" + + def __init__( + self, + device: torch.device, + policy_cfg: dict[str, Any], + obs_space=None, + action_space=None, + ) -> None: + super().__init__() + self.device = device + self.policy_cfg = dict(policy_cfg) + self.vla_cfg = dict(self.policy_cfg.get("vla", {})) + self.model_path = str(self.vla_cfg.get("model_path", "")) + self.action_horizon = int(self.vla_cfg.get("action_horizon", 32)) + self.gaussian_sigma = float(self.vla_cfg.get("gaussian_sigma", 0.1)) + + if not self.model_path: + raise ValueError("VLAPolicy requires 'policy.vla.model_path'.") + + self._vla_model: nn.Module | None = None + self._action_indices: list[int] | None = None + + if action_space is None: + self.action_dim = 14 + elif isinstance(action_space, int): + self.action_dim = action_space + elif hasattr(action_space, "shape") and len(action_space.shape) > 0: + self.action_dim = int(action_space.shape[-1]) + else: + self.action_dim = 14 + self.obs_dim = 0 # VLA uses raw observations + self.use_raw_obs = True # Tell collector to pass raw observations + + self.use_action_chunk = True + self.action_chunk_size = self.action_horizon + self.execute_full_chunk = True + self._env = None + + def set_env(self, env) -> None: + """Set env reference in forward.""" + self._env = env + + def _load_vla(self) -> None: + if self._vla_model is not None: + return + backend_name = str(self.vla_cfg.get("backend", "dexforce_vla")).lower() + backend = create_vla_backend( + backend_name, + model_path=self.model_path, + device=self.device, + action_horizon=self.action_horizon, + **{ + k: v + for k, v in self.vla_cfg.items() + if k not in ("backend", "model_path", "action_horizon") + }, + ) + self._vla_model, self._action_indices, self._prepare_batch_fn = backend + self._freeze_encoders() + + def parameters(self, recurse: bool = True): + """Expose trainable parameters, lazily loading VLA backend if needed.""" + self._load_vla() + return super().parameters(recurse=recurse) + + def _freeze_encoders(self) -> None: + """Freeze vision encoders to avoid catastrophic forgetting""" + if self._vla_model is None: + return + encoders = getattr(self._vla_model, "encoders", None) + if encoders is not None: + for param in encoders.parameters(): + param.requires_grad_(False) + privilege_estimators = getattr(self._vla_model, "privilege_estimators", None) + if privilege_estimators is not None: + for param in privilege_estimators.parameters(): + param.requires_grad_(False) + + def _vla_chunk_to_env_chunk( + self, action_chunk: torch.Tensor, env=None + ) -> torch.Tensor: + """Convert VLA output (N, T, va_dim) chunk to env format (N, T, env_dim).""" + if self._action_indices is not None: + step = action_chunk[:, :, self._action_indices] + else: + step = action_chunk + + if env is not None: + env_dim = getattr(env.action_space, "shape", (None,)) + if len(env_dim) > 0 and env_dim[-1] is not None: + env_dim = int(env_dim[-1]) + if step.shape[-1] > env_dim: + step = step[..., :env_dim] + elif step.shape[-1] < env_dim: + pad = torch.zeros( + step.shape[0], + step.shape[1], + env_dim - step.shape[-1], + device=step.device, + dtype=step.dtype, + ) + step = torch.cat([step, pad], dim=-1) + return step + + def _infer_batch_size(self, obs: Any) -> int: + """Infer leading batch size from TensorDict / mapping / tensor observations.""" + if hasattr(obs, "batch_size") and len(obs.batch_size) > 0: + return int(obs.batch_size[0]) + if isinstance(obs, Mapping): + robot = obs.get("robot") + if isinstance(robot, Mapping): + q = robot.get("qpos") + if hasattr(q, "shape") and len(q.shape) > 0: + return int(q.shape[0]) + return 1 + + def _slice_obs_item(self, obs: Any, index: int) -> Any: + """Slice one environment sample from a batched observation structure.""" + if hasattr(obs, "batch_size") and len(obs.batch_size) > 0: + return obs[index] + if isinstance(obs, Mapping): + return { + key: self._slice_obs_item(value, index) for key, value in obs.items() + } + if torch.is_tensor(obs): + if obs.dim() == 0: + return obs + return obs[index] + return obs + + def forward( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + obs = tensordict["obs"] + env = getattr(tensordict, "env", None) + if env is None: + env = getattr(self, "_env", None) + if env is None: + raise ValueError( + "VLAPolicy needs env. Set policy._env or pass env in tensordict." + ) + + self._load_vla() + self._vla_model.eval() + batch_size = self._infer_batch_size(obs) + if batch_size == 1: + batch = self._prepare_batch_fn(obs, env) + vla_chunk = self._vla_model.predict_action( + batch, + action_only=True, + inference_horizon=self.action_horizon, + allow_grad=False, + use_fix_aug=False, + ) + action_chunk_env = self._vla_chunk_to_env_chunk(vla_chunk, env=env) + else: + chunks_env = [] + for i in range(batch_size): + obs_i = self._slice_obs_item(obs, i) + batch_i = self._prepare_batch_fn(obs_i, env) + vla_chunk = self._vla_model.predict_action( + batch_i, + action_only=True, + inference_horizon=self.action_horizon, + allow_grad=False, + use_fix_aug=False, + ) + chunk_i = self._vla_chunk_to_env_chunk(vla_chunk, env=env) + chunks_env.append(chunk_i) + action_chunk_env = torch.cat(chunks_env, dim=0) + + mean_chunk = action_chunk_env.to(self.device, dtype=torch.float32) + + if deterministic: + noisy_chunk = mean_chunk + per_step_log_prob = torch.zeros( + mean_chunk.shape[0], + mean_chunk.shape[1], + device=self.device, + dtype=torch.float32, + ) + else: + sigma = self.gaussian_sigma + noise = torch.randn_like(mean_chunk) * sigma + noisy_chunk = mean_chunk + noise + per_step_log_prob = -0.5 * noise.pow(2).sum(-1) / (sigma * sigma + 1e-8) + log_prob = per_step_log_prob.mean(dim=-1) + + action = noisy_chunk[:, 0] + tensordict["action"] = action + tensordict["sample_log_prob"] = log_prob + tensordict["value"] = torch.zeros( + action.shape[0], device=self.device, dtype=torch.float32 + ) + if self.use_action_chunk: + tensordict["action_chunk"] = noisy_chunk + tensordict["action_chunk_log_prob"] = per_step_log_prob + tensordict["action_chunk_value"] = torch.zeros_like(per_step_log_prob) + return tensordict + + def get_value(self, tensordict: TensorDict) -> TensorDict: + b = tensordict.batch_size[0] + tensordict["value"] = torch.zeros(b, device=self.device, dtype=torch.float32) + return tensordict + + def evaluate_actions( + self, tensordict: TensorDict, rollout=None, **kwargs + ) -> TensorDict: + """Compute log_prob via Gaussian proxy for GRPO policy gradient.""" + b = tensordict.batch_size[0] + env = getattr(self, "_env", None) + if env is None: + raise ValueError( + "VLAPolicy.evaluate_actions requires env. Call policy.set_env(env)." + ) + + raw_obs = getattr(rollout, "raw_obs", None) + chunk_step = tensordict.get("chunk_step", None) + indices = tensordict.get("_indices", None) + execute_full_chunk = tensordict.get("execute_full_chunk", None) + step_repeat = tensordict.get("step_repeat", None) + if raw_obs is None or chunk_step is None or indices is None: + raise ValueError( + "VLAPolicy.evaluate_actions requires rollout.raw_obs, chunk_step, and _indices. " + "Ensure collector uses use_raw_obs and use_action_chunk, and GRPO passes rollout." + ) + + time_dim = len(raw_obs) - 1 + sigma = self.gaussian_sigma + log_probs = [] + self._load_vla() + + stored_chunks = tensordict.get("action_chunk", None) + use_stored_chunks = stored_chunks is not None + execute_full_chunk_mask = ( + execute_full_chunk.bool() + if execute_full_chunk is not None + else torch.zeros((b,), device=self.device, dtype=torch.bool) + ) + + for i in range(b): + idx = int(indices[i].item()) + env_idx = idx // time_dim + step_idx = idx % time_dim + step_in_chunk = int(chunk_step[i].item()) + chunk_start_idx = max(0, step_idx - step_in_chunk) + obs_i = raw_obs[chunk_start_idx][env_idx] + + batch_i = self._prepare_batch_fn(obs_i, env) + vla_chunk = self._vla_model.predict_action( + batch_i, + action_only=True, + inference_horizon=self.action_horizon, + allow_grad=True, + use_fix_aug=False, + ) + pred_chunk_env = self._vla_chunk_to_env_chunk(vla_chunk, env=env) + + if use_stored_chunks and bool(execute_full_chunk_mask[i].item()): + gt_chunk = stored_chunks[i] + pred_chunk = pred_chunk_env[0] + executed_len = int( + step_repeat[i].item() + if step_repeat is not None + else gt_chunk.shape[0] + ) + executed_len = max( + 1, min(executed_len, gt_chunk.shape[0], pred_chunk.shape[0]) + ) + mse = ( + ((gt_chunk[:executed_len] - pred_chunk[:executed_len]).pow(2)) + .sum(-1) + .mean(-1) + ) + else: + action_gt = tensordict["action"][i] + pred = pred_chunk_env[0, step_in_chunk] + if pred.shape[-1] != action_gt.shape[-1]: + pred = pred[: action_gt.shape[-1]] + mse = ((action_gt - pred).pow(2)).sum(-1) + + log_prob = -0.5 * mse / (sigma * sigma + 1e-8) + log_probs.append(log_prob) + + log_probs = torch.stack(log_probs) + if step_repeat is not None: + chunk_lengths = step_repeat.to( + device=self.device, dtype=torch.float32 + ).clamp_min(1.0) + else: + chunk_lengths = torch.ones((b,), device=self.device, dtype=torch.float32) + effective_dim = torch.where( + execute_full_chunk_mask, + chunk_lengths * float(self.action_dim), + torch.full((b,), float(self.action_dim), device=self.device), + ) + entropy = ( + 0.5 * effective_dim * (1 + np.log(2 * np.pi) + 2 * np.log(sigma + 1e-8)) + ) + + return TensorDict( + { + "sample_log_prob": log_probs, + "entropy": entropy, + "value": torch.zeros(b, device=self.device, dtype=torch.float32), + }, + batch_size=tensordict.batch_size, + device=self.device, + ) diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py index fa1f5948..65a6ef0a 100644 --- a/embodichain/agents/rl/train.py +++ b/embodichain/agents/rl/train.py @@ -214,10 +214,8 @@ def train_from_config(config_path: str, distributed: bool | None = None): ) env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg) - sample_obs, _ = env.reset() - sample_obs_td = dict_to_tensordict(sample_obs, device) - obs_dim = flatten_dict_observation(sample_obs_td).shape[-1] - flat_obs_space = env.flattened_observation_space + obs_dim = None + flat_obs_space = None # Create evaluation environment only if enabled eval_env = None @@ -246,6 +244,10 @@ def train_from_config(config_path: str, distributed: bool | None = None): ) # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic/actor_only) if policy_name.lower() == "actor_critic": + sample_obs, _ = env.reset() + sample_obs_td = dict_to_tensordict(sample_obs, device) + obs_dim = flatten_dict_observation(sample_obs_td).shape[-1] + flat_obs_space = env.flattened_observation_space actor_cfg = policy_block.get("actor") critic_cfg = policy_block.get("critic") if actor_cfg is None or critic_cfg is None: @@ -265,6 +267,10 @@ def train_from_config(config_path: str, distributed: bool | None = None): critic=critic, ) elif policy_name.lower() == "actor_only": + sample_obs, _ = env.reset() + sample_obs_td = dict_to_tensordict(sample_obs, device) + obs_dim = flatten_dict_observation(sample_obs_td).shape[-1] + flat_obs_space = env.flattened_observation_space actor_cfg = policy_block.get("actor") if actor_cfg is None: raise ValueError( @@ -282,7 +288,11 @@ def train_from_config(config_path: str, distributed: bool | None = None): ) else: policy = build_policy( - policy_block, env.observation_space, env.action_space, device + policy_block, + env.observation_space, + env.action_space, + device, + env=env, ) # Build Algorithm via factory diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 56ea0db2..fa2058f9 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -90,15 +90,34 @@ def __init__( raise RuntimeError("Env must expose num_envs for trainer statistics.") obs_dim = getattr(self.policy, "obs_dim", None) action_dim = getattr(self.policy, "action_dim", None) - if obs_dim is None or action_dim is None: - raise RuntimeError("Policy must expose obs_dim and action_dim.") + use_raw_obs = getattr(self.policy, "use_raw_obs", False) + store_flat_obs = not use_raw_obs + if action_dim is None or (store_flat_obs and obs_dim is None): + raise RuntimeError( + "Policy must expose action_dim and flat-observation metadata." + ) + action_chunk_size = getattr(self.policy, "action_chunk_size", 0) + use_action_chunk = getattr(self.policy, "use_action_chunk", False) + execute_full_chunk = bool(getattr(self.policy, "execute_full_chunk", False)) + if use_action_chunk and action_chunk_size > 0 and not execute_full_chunk: + self.buffer_size = ( + (self.buffer_size + action_chunk_size - 1) + // action_chunk_size + * action_chunk_size + ) + + if not store_flat_obs: + obs_dim = int(obs_dim or 0) self.buffer = RolloutBuffer( num_envs=num_envs, rollout_len=self.buffer_size, - obs_dim=obs_dim, + obs_dim=int(obs_dim), action_dim=action_dim, device=self.device, + use_raw_obs=use_raw_obs, + action_chunk_size=action_chunk_size if use_action_chunk else 0, + store_flat_obs=store_flat_obs, ) self.collector = SyncCollector( env=self.env, @@ -116,6 +135,17 @@ def __init__( self.curr_len = torch.zeros(num_envs, dtype=torch.int32, device=self.device) # ---- lightweight helpers for dense logging ---- + def _obs_to_tensordict(self, obs) -> TensorDict: + """Normalize observation to TensorDict on trainer device.""" + if isinstance(obs, TensorDict): + return obs.to(self.device) + if isinstance(obs, dict): + return TensorDict.from_dict(obs, device=self.device) + raise TypeError( + f"Unsupported raw observation type: {type(obs)!r}. " + "Expected TensorDict or dict." + ) + @staticmethod def _mean_scalar(x) -> float: if hasattr(x, "detach"): @@ -171,9 +201,14 @@ def on_step(tensordict: TensorDict, info: dict): """Callback called at each step during rollout collection.""" reward = tensordict["reward"] done = tensordict["done"] + step_repeat = tensordict.get( + "step_repeat", + torch.ones_like(reward, dtype=torch.float32, device=self.device), + ) + step_repeat_int = step_repeat.to(dtype=torch.int32) # Episode stats self.curr_ret += reward - self.curr_len += 1 + self.curr_len += step_repeat_int done_idx = torch.nonzero(done, as_tuple=False).squeeze(-1) if done_idx.numel() > 0: finished_ret = self.curr_ret[done_idx].detach().cpu().tolist() @@ -184,7 +219,7 @@ def on_step(tensordict: TensorDict, info: dict): self.curr_len[done_idx] = 0 if not self.distributed: - self.global_step += tensordict.batch_size[0] + self.global_step += int(step_repeat_int.sum().item()) if self.rank == 0 and isinstance(info, dict): rewards_dict = info.get("rewards") @@ -220,7 +255,10 @@ def on_step(tensordict: TensorDict, info: dict): "Call torch.distributed.init_process_group(...) before creating " "or using Trainer(distributed=True, ...)." ) - local_delta = self.env.num_envs * self.buffer_size + if "step_repeat" in rollout.keys(): + local_delta = int(rollout["step_repeat"][:, :-1].sum().item()) + else: + local_delta = self.env.num_envs * self.buffer_size delta_tensor = torch.tensor( [local_delta], dtype=torch.int64, device=self.device ) @@ -330,42 +368,147 @@ def _eval_once(self, num_episodes: int = 5): episode_returns = [] episode_lengths = [] - self.eval_env.set_rollout_buffer(self.buffer.buffer) + use_raw_obs = getattr(self.policy, "use_raw_obs", False) + use_action_chunk = getattr(self.policy, "use_action_chunk", False) + action_chunk_size = getattr(self.policy, "action_chunk_size", 1) + effective_use_action_chunk = use_action_chunk and action_chunk_size > 0 + execute_full_chunk = bool(getattr(self.policy, "execute_full_chunk", False)) + + if hasattr(self.eval_env, "set_rollout_buffer"): + self.eval_env.set_rollout_buffer(self.buffer.buffer) for _ in range(num_episodes): - # Reset and initialize episode tracking obs, _ = self.eval_env.reset() - obs = flatten_dict_observation(obs) - num_envs = obs.shape[0] if obs.ndim == 2 else 1 + if use_raw_obs: + obs_td = self._obs_to_tensordict(obs) + else: + obs_td = flatten_dict_observation(obs) + num_envs = ( + obs_td.batch_size[0] + if hasattr(obs_td, "batch_size") + else (obs_td.shape[0] if hasattr(obs_td, "shape") else 1) + ) done_mask = torch.zeros(num_envs, dtype=torch.bool, device=self.device) cumulative_reward = torch.zeros( num_envs, dtype=torch.float32, device=self.device ) step_count = torch.zeros(num_envs, dtype=torch.int32, device=self.device) + cached_chunk = None + step_in_chunk = 0 - # Run episode until all environments complete while not done_mask.all(): - # Get deterministic actions from policy - action_td = TensorDict( - {"obs": obs}, - batch_size=[num_envs], - device=self.device, + if execute_full_chunk and effective_use_action_chunk: + action_td = TensorDict( + {"obs": obs_td}, + batch_size=[num_envs], + device=self.device, + ) + action_td = self.policy.get_action(action_td, deterministic=True) + chunk = action_td.get("action_chunk") + if chunk is None: + raise ValueError( + "execute_full_chunk=True requires policy to provide 'action_chunk'." + ) + + reward_sum = torch.zeros( + num_envs, dtype=torch.float32, device=self.device + ) + terminated = torch.zeros( + num_envs, dtype=torch.bool, device=self.device + ) + truncated = torch.zeros( + num_envs, dtype=torch.bool, device=self.device + ) + executed_substeps = 0 + info = {} + + for sub_idx in range(action_chunk_size): + sub_actions = chunk[:, sub_idx] + am: ActionManager | None = getattr( + self.eval_env, "action_manager", None + ) + if am is None: + action_in = sub_actions + else: + action_in = am.convert_policy_action_to_env_action( + sub_actions + ) + + obs, reward, term_i, trunc_i, info = self.eval_env.step( + action_in + ) + if use_raw_obs: + obs_td = self._obs_to_tensordict(obs) + else: + obs_td = ( + flatten_dict_observation(obs) + if isinstance(obs, TensorDict) + else obs + ) + + still_running = ~done_mask + reward_sum[still_running] += reward[still_running].float() + step_count[still_running] += 1 + terminated |= term_i + truncated |= trunc_i + done_mask |= term_i | trunc_i + executed_substeps += 1 + + if hasattr(self, "eval_event_manager"): + if "interval" in self.eval_event_manager.available_modes: + self.eval_event_manager.apply(mode="interval") + + if done_mask.all(): + break + + cumulative_reward += reward_sum + continue + + if effective_use_action_chunk and ( + cached_chunk is None or step_in_chunk == 0 + ): + action_td = TensorDict( + {"obs": obs_td}, + batch_size=[num_envs], + device=self.device, + ) + action_td = self.policy.get_action(action_td, deterministic=True) + cached_chunk = action_td.get("action_chunk") + actions = action_td["action"] + step_in_chunk = 0 + elif effective_use_action_chunk and cached_chunk is not None: + actions = cached_chunk[:, step_in_chunk] + else: + action_td = TensorDict( + {"obs": obs_td}, + batch_size=[num_envs], + device=self.device, + ) + action_td = self.policy.get_action(action_td, deterministic=True) + actions = action_td["action"] + + step_in_chunk = ( + (step_in_chunk + 1) % action_chunk_size + if effective_use_action_chunk + else 0 + ) + am: ActionManager | None = getattr( + self.eval_env, "action_manager", None ) - action_td = self.policy.get_action(action_td, deterministic=True) - actions = action_td["action"] - am: ActionManager = getattr(self.eval_env, "action_manager", None) if am is None: action_in = actions else: action_in = am.convert_policy_action_to_env_action(actions) - # Environment step obs, reward, terminated, truncated, info = self.eval_env.step(action_in) - obs = ( - flatten_dict_observation(obs) - if isinstance(obs, TensorDict) - else obs - ) + if use_raw_obs: + obs_td = self._obs_to_tensordict(obs) + else: + obs_td = ( + flatten_dict_observation(obs) + if isinstance(obs, TensorDict) + else obs + ) # Update statistics only for still-running environments done = terminated | truncated @@ -374,6 +517,10 @@ def _eval_once(self, num_episodes: int = 5): step_count[still_running] += 1 done_mask |= done + # Invalidate cached_chunk on any env reset + if effective_use_action_chunk and done.any(): + cached_chunk = None + # Trigger evaluation events (e.g., video recording) if hasattr(self, "eval_event_manager"): if "interval" in self.eval_event_manager.available_modes: diff --git a/embodichain/agents/rl/vla_registry.py b/embodichain/agents/rl/vla_registry.py new file mode 100644 index 00000000..f594c30b --- /dev/null +++ b/embodichain/agents/rl/vla_registry.py @@ -0,0 +1,124 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from importlib.metadata import entry_points +from typing import Any, Callable + +from embodichain.utils.logger import log_warning + +__all__ = [ + "get_vla_backend", + "get_registered_vla_backend_names", + "create_vla_backend", +] + + +_VLA_BACKENDS: dict[str, Callable[..., Any]] = {} +_ENTRY_POINTS_DISCOVERED = False +_ENTRY_POINTS_ENUM_LOGGED = False + + +def _discover_entry_points() -> None: + """Discover and register VLA backends from entry_points.""" + global _ENTRY_POINTS_DISCOVERED, _ENTRY_POINTS_ENUM_LOGGED + if _ENTRY_POINTS_DISCOVERED: + return + try: + eps = entry_points(group="embodichain.vla_backends") + except (OSError, ValueError, TypeError) as exc: + if not _ENTRY_POINTS_ENUM_LOGGED: + log_warning( + "Could not enumerate 'embodichain.vla_backends' entry points: " + f"{type(exc).__name__}: {exc}" + ) + _ENTRY_POINTS_ENUM_LOGGED = True + return + + for ep in eps: + try: + factory = ep.load() + except (ImportError, AttributeError, TypeError, ValueError) as exc: + log_warning( + f"Failed to load VLA backend entry point name={ep.name!r} " + f"value={ep.value!r}: {type(exc).__name__}: {exc}" + ) + continue + except Exception as exc: + log_warning( + f"Unexpected error loading VLA backend entry point name={ep.name!r} " + f"value={ep.value!r}: {type(exc).__name__}: {exc}" + ) + continue + name = str(ep.name).lower() + if name not in _VLA_BACKENDS: + _VLA_BACKENDS[name] = factory + + _ENTRY_POINTS_DISCOVERED = True + + +def get_vla_backend(name: str) -> Callable[..., Any] | None: + """Get a registered backend factory by name. + + This checks the in-memory registry first, and then lazily triggers + entry-point discovery if needed. + + Args: + name: Backend identifier (case-insensitive). + + Returns: + The backend factory callable if found, otherwise ``None``. + """ + name = str(name).lower() + if name in _VLA_BACKENDS: + return _VLA_BACKENDS[name] + _discover_entry_points() + return _VLA_BACKENDS.get(name) + + +def get_registered_vla_backend_names() -> list[str]: + """List all currently discoverable VLA backend names. + + Returns: + A list of backend names after lazy entry-point discovery. + """ + _discover_entry_points() + return list(_VLA_BACKENDS.keys()) + + +def create_vla_backend(name: str, **kwargs) -> Any: + """Instantiate a VLA backend by name. + + Args: + name: Backend identifier (case-insensitive). + **kwargs: Keyword arguments forwarded to the backend factory. + + Returns: + The instantiated backend object (factory-defined type). + + Raises: + ValueError: If the backend name is unknown. + """ + factory = get_vla_backend(name) + if factory is None: + available = get_registered_vla_backend_names() + raise ValueError( + f"Unknown VLA backend '{name}'. Available: {available}. " + "Ensure a package providing the 'embodichain.vla_backends' entry point " + "group is installed." + ) + return factory(**kwargs) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index da3ae9b7..090bca25 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -547,7 +547,6 @@ def _initialize_episode( def _infer_rollout_buffer_mode(self, rollout_buffer: TensorDict) -> str: """Infer whether the rollout buffer is expert recording or RL training data.""" if { - "obs", "action", "reward", "done", diff --git a/embodichain/lab/gym/envs/managers/randomization/visual.py b/embodichain/lab/gym/envs/managers/randomization/visual.py index 66d3d6fb..0522db8d 100644 --- a/embodichain/lab/gym/envs/managers/randomization/visual.py +++ b/embodichain/lab/gym/envs/managers/randomization/visual.py @@ -183,9 +183,14 @@ def randomize_camera_extrinsics( ).repeat(num_instance, 1) if pos_range: random_value = sample_uniform( - lower=torch.tensor(pos_range[0]), - upper=torch.tensor(pos_range[1]), + lower=torch.tensor( + pos_range[0], dtype=torch.float32, device=env.device + ), + upper=torch.tensor( + pos_range[1], dtype=torch.float32, device=env.device + ), size=(num_instance, 3), + device=env.device, ) new_pose[:, :3] += random_value if euler_range: @@ -198,9 +203,14 @@ def randomize_camera_extrinsics( init_euler = torch.stack(euler_xyz_from_quat(init_quat_np), dim=1) # 2. Sample perturbation for euler angles random_value = sample_uniform( - lower=torch.tensor(euler_range[0]), - upper=torch.tensor(euler_range[1]), + lower=torch.tensor( + euler_range[0], dtype=torch.float32, device=env.device + ), + upper=torch.tensor( + euler_range[1], dtype=torch.float32, device=env.device + ), size=(num_instance, 3), + device=env.device, ) # 3. Add perturbation to each environment and convert back to quaternion roll, pitch, yaw = (init_euler + random_value).unbind(dim=1) @@ -229,9 +239,14 @@ def randomize_camera_extrinsics( if eye_range: eye_delta = sample_uniform( - lower=torch.tensor(eye_range[0]), - upper=torch.tensor(eye_range[1]), + lower=torch.tensor( + eye_range[0], dtype=torch.float32, device=env.device + ), + upper=torch.tensor( + eye_range[1], dtype=torch.float32, device=env.device + ), size=(num_instance, 3), + device=env.device, ) new_eye = init_eye + eye_delta else: @@ -239,9 +254,14 @@ def randomize_camera_extrinsics( if target_range: target_delta = sample_uniform( - lower=torch.tensor(target_range[0]), - upper=torch.tensor(target_range[1]), + lower=torch.tensor( + target_range[0], dtype=torch.float32, device=env.device + ), + upper=torch.tensor( + target_range[1], dtype=torch.float32, device=env.device + ), size=(num_instance, 3), + device=env.device, ) new_target = init_target + target_delta else: @@ -249,9 +269,10 @@ def randomize_camera_extrinsics( if up_range: up_delta = sample_uniform( - lower=torch.tensor(up_range[0]), - upper=torch.tensor(up_range[1]), + lower=torch.tensor(up_range[0], dtype=torch.float32, device=env.device), + upper=torch.tensor(up_range[1], dtype=torch.float32, device=env.device), size=(num_instance, 3), + device=env.device, ) new_up = init_up + up_delta else: @@ -311,8 +332,8 @@ def randomize_light( .repeat(num_instance, 1) ) random_value = sample_uniform( - lower=torch.tensor(position_range[0]), - upper=torch.tensor(position_range[1]), + lower=torch.tensor(position_range[0], dtype=torch.float32), + upper=torch.tensor(position_range[1], dtype=torch.float32), size=new_pos.shape, ) new_pos += random_value @@ -321,8 +342,8 @@ def randomize_light( if color_range: color = torch.zeros((num_instance, 3), dtype=torch.float32) random_value = sample_uniform( - lower=torch.tensor(color_range[0]), - upper=torch.tensor(color_range[1]), + lower=torch.tensor(color_range[0], dtype=torch.float32), + upper=torch.tensor(color_range[1], dtype=torch.float32), size=color.shape, ) color += random_value @@ -336,8 +357,8 @@ def randomize_light( .repeat(num_instance, 1) ) random_value = sample_uniform( - lower=torch.tensor(intensity_range[0]), - upper=torch.tensor(intensity_range[1]), + lower=torch.tensor(intensity_range[0], dtype=torch.float32), + upper=torch.tensor(intensity_range[1], dtype=torch.float32), size=new_intensity.shape, ) new_intensity += random_value @@ -372,8 +393,8 @@ def randomize_emission_light( if color_range: color = torch.zeros((1, 3), dtype=torch.float32) random_value = sample_uniform( - lower=torch.tensor(color_range[0]), - upper=torch.tensor(color_range[1]), + lower=torch.tensor(color_range[0], dtype=torch.float32), + upper=torch.tensor(color_range[1], dtype=torch.float32), size=color.shape, ) color += random_value @@ -445,8 +466,8 @@ def randomize_camera_intrinsics( # Randomize focal length x (fx) if focal_x_range: random_value = sample_uniform( - lower=torch.tensor(focal_x_range[0]), - upper=torch.tensor(focal_x_range[1]), + lower=torch.tensor(focal_x_range[0], dtype=torch.float32), + upper=torch.tensor(focal_x_range[1], dtype=torch.float32), size=(num_instance,), ) new_intrinsics[:, 0] += random_value @@ -454,8 +475,8 @@ def randomize_camera_intrinsics( # Randomize focal length y (fy) if focal_y_range: random_value = sample_uniform( - lower=torch.tensor(focal_y_range[0]), - upper=torch.tensor(focal_y_range[1]), + lower=torch.tensor(focal_y_range[0], dtype=torch.float32), + upper=torch.tensor(focal_y_range[1], dtype=torch.float32), size=(num_instance,), ) new_intrinsics[:, 1] += random_value @@ -463,8 +484,8 @@ def randomize_camera_intrinsics( # Randomize principal point x (cx) if cx_range: random_value = sample_uniform( - lower=torch.tensor(cx_range[0]), - upper=torch.tensor(cx_range[1]), + lower=torch.tensor(cx_range[0], dtype=torch.float32), + upper=torch.tensor(cx_range[1], dtype=torch.float32), size=(num_instance,), ) new_intrinsics[:, 2] += random_value @@ -472,8 +493,8 @@ def randomize_camera_intrinsics( # Randomize principal point y (cy) if cy_range: random_value = sample_uniform( - lower=torch.tensor(cy_range[0]), - upper=torch.tensor(cy_range[1]), + lower=torch.tensor(cy_range[0], dtype=torch.float32), + upper=torch.tensor(cy_range[1], dtype=torch.float32), size=(num_instance,), ) new_intrinsics[:, 3] += random_value diff --git a/tests/agents/test_shared_rollout.py b/tests/agents/test_shared_rollout.py index 37dd34fa..1a29bc9b 100644 --- a/tests/agents/test_shared_rollout.py +++ b/tests/agents/test_shared_rollout.py @@ -223,3 +223,338 @@ def test_embodied_env_writes_next_fields_into_external_rollout(): env.close() if SimulationManager.is_instantiated(): SimulationManager.get_instance().destroy() + + +class _FakePolicyRawObs: + """Policy that reads nested env observations (collector passes raw_obs slices).""" + + use_raw_obs = True + + def __init__(self, obs_dim: int, action_dim: int, device: torch.device) -> None: + self.obs_dim = 0 + self.action_dim = action_dim + self.device = device + + def train(self) -> None: + pass + + def get_action( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + obs = tensordict["obs"] + flat = obs["agent"]["state"] + tensordict["action"] = flat[:, : self.action_dim] * 0.25 + tensordict["sample_log_prob"] = flat.sum(dim=-1) * 0.1 + tensordict["value"] = flat.mean(dim=-1) + return tensordict + + def get_value(self, tensordict: TensorDict) -> TensorDict: + obs = tensordict["obs"] + flat = obs["agent"]["state"] + tensordict["value"] = flat.mean(dim=-1) + return tensordict + + +class _FakePolicyActionChunk: + """Policy that emits a fixed-length action chunk (sequential env steps per chunk).""" + + use_action_chunk = True + action_chunk_size = 2 + execute_full_chunk = False + + def __init__(self, obs_dim: int, action_dim: int, device: torch.device) -> None: + self.obs_dim = obs_dim + self.action_dim = action_dim + self.device = device + + def train(self) -> None: + pass + + def get_action( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + obs = tensordict["obs"] + row0 = obs[:, : self.action_dim] * 0.1 + row1 = obs[:, : self.action_dim] * 0.2 + chunk = torch.stack([row0, row1], dim=1) + chunk_log_prob = torch.stack( + [ + torch.full( + (obs.shape[0],), 0.25, device=obs.device, dtype=torch.float32 + ), + torch.full( + (obs.shape[0],), 0.75, device=obs.device, dtype=torch.float32 + ), + ], + dim=1, + ) + chunk_value = torch.stack( + [ + torch.full( + (obs.shape[0],), 1.0, device=obs.device, dtype=torch.float32 + ), + torch.full( + (obs.shape[0],), 2.0, device=obs.device, dtype=torch.float32 + ), + ], + dim=1, + ) + tensordict["action_chunk"] = chunk + tensordict["action"] = chunk[:, 0] + tensordict["sample_log_prob"] = chunk_log_prob[:, 0] + tensordict["value"] = chunk_value[:, 0] + tensordict["action_chunk_log_prob"] = chunk_log_prob + tensordict["action_chunk_value"] = chunk_value + return tensordict + + def get_value(self, tensordict: TensorDict) -> TensorDict: + tensordict["value"] = tensordict["obs"].mean(dim=-1) + return tensordict + + +class _FakePolicyExecuteFullChunk: + """Policy that runs an entire chunk inside one logical rollout step.""" + + use_action_chunk = True + action_chunk_size = 3 + execute_full_chunk = True + + def __init__(self, obs_dim: int, action_dim: int, device: torch.device) -> None: + self.obs_dim = obs_dim + self.action_dim = action_dim + self.device = device + + def train(self) -> None: + pass + + def get_action( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + obs = tensordict["obs"] + n = obs.shape[0] + t = self.action_chunk_size + base = obs[:, : self.action_dim] * 0.1 + chunk = base.unsqueeze(1).expand(n, t, -1).clone() + chunk_log_prob = torch.tensor( + [0.2, 0.4, 0.6], device=obs.device, dtype=torch.float32 + ).expand(n, -1) + tensordict["action_chunk"] = chunk + tensordict["action"] = chunk[:, 0] + tensordict["sample_log_prob"] = chunk_log_prob.mean(dim=1) + tensordict["value"] = torch.zeros(n, device=obs.device, dtype=torch.float32) + tensordict["action_chunk_log_prob"] = chunk_log_prob + return tensordict + + def get_value(self, tensordict: TensorDict) -> TensorDict: + tensordict["value"] = tensordict["obs"].mean(dim=-1) + return tensordict + + +class _FakeEnvTerminateOnFirstChunkStep(_FakeEnv): + """Environment that terminates after the first executed chunk substep.""" + + def step(self, action): + next_obs, reward, terminated, truncated, info = super().step(action) + terminated = torch.ones(self.num_envs, dtype=torch.bool, device=self.device) + return next_obs, reward, terminated, truncated, info + + +def test_collector_populates_raw_obs_buffer(): + device = torch.device("cpu") + num_envs = 2 + rollout_len = 3 + obs_dim = 5 + action_dim = 2 + + env = _FakeEnv( + num_envs=num_envs, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + ) + policy = _FakePolicyRawObs(obs_dim=obs_dim, action_dim=action_dim, device=device) + collector = SyncCollector(env=env, policy=policy, device=device) + buffer = RolloutBuffer( + num_envs=num_envs, + rollout_len=rollout_len, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + use_raw_obs=True, + store_flat_obs=False, + ) + + rollout = collector.collect( + num_steps=rollout_len, + rollout=buffer.start_rollout(), + ) + buffer.add(rollout) + stored = buffer.get(flatten=False) + + assert "obs" not in stored.keys() + assert hasattr(stored, "raw_obs") + assert len(stored.raw_obs) == rollout_len + 1 + for t in range(rollout_len + 1): + assert stored.raw_obs[t] is not None + assert stored.raw_obs[t].batch_size == torch.Size([num_envs]) + assert torch.allclose( + stored.raw_obs[t]["agent"]["state"], + torch.full((num_envs, obs_dim), float(t), dtype=torch.float32), + ) + assert torch.allclose( + stored["value"][:, -1], + torch.full((num_envs,), float(rollout_len), dtype=torch.float32), + ) + + +def test_collector_chunk_step_alternates_for_sequential_action_chunk(): + device = torch.device("cpu") + num_envs = 2 + rollout_len = 4 + obs_dim = 5 + action_dim = 2 + + env = _FakeEnv( + num_envs=num_envs, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + ) + policy = _FakePolicyActionChunk( + obs_dim=obs_dim, action_dim=action_dim, device=device + ) + collector = SyncCollector(env=env, policy=policy, device=device) + buffer = RolloutBuffer( + num_envs=num_envs, + rollout_len=rollout_len, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + action_chunk_size=policy.action_chunk_size, + ) + + rollout = collector.collect( + num_steps=rollout_len, + rollout=buffer.start_rollout(), + ) + + assert hasattr(rollout, "chunk_step") + expected = torch.tensor([0, 1, 0, 1], dtype=torch.long, device=device).expand( + num_envs, -1 + ) + assert torch.equal(rollout.chunk_step, expected) + assert torch.all(rollout["step_repeat"][:, :-1] == 1.0) + assert torch.allclose(rollout["action_chunk"][:, 1], rollout["action_chunk"][:, 0]) + assert torch.allclose(rollout["action_chunk"][:, 3], rollout["action_chunk"][:, 2]) + expected_log_prob = torch.tensor([0.25, 0.75, 0.25, 0.75], device=device).expand( + num_envs, -1 + ) + expected_value = torch.tensor([1.0, 2.0, 1.0, 2.0], device=device).expand( + num_envs, -1 + ) + assert torch.allclose( + rollout["sample_log_prob"][:, :rollout_len], expected_log_prob + ) + assert torch.allclose(rollout["value"][:, :rollout_len], expected_value) + + +def test_collector_execute_full_chunk_sets_step_repeat_and_action_chunk(): + device = torch.device("cpu") + num_envs = 2 + rollout_len = 2 + obs_dim = 5 + action_dim = 2 + chunk_t = 3 + + env = _FakeEnv( + num_envs=num_envs, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + ) + policy = _FakePolicyExecuteFullChunk( + obs_dim=obs_dim, action_dim=action_dim, device=device + ) + assert policy.action_chunk_size == chunk_t + collector = SyncCollector(env=env, policy=policy, device=device) + buffer = RolloutBuffer( + num_envs=num_envs, + rollout_len=rollout_len, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + action_chunk_size=chunk_t, + ) + + rollout = collector.collect( + num_steps=rollout_len, + rollout=buffer.start_rollout(), + ) + + assert "action_chunk" in rollout.keys() + assert rollout["action_chunk"].shape == ( + num_envs, + rollout_len + 1, + chunk_t, + action_dim, + ) + assert torch.all(rollout.chunk_step[:, :rollout_len] == 0) + assert torch.all(rollout["step_repeat"][:, :rollout_len] == float(chunk_t)) + assert torch.all(rollout["step_repeat"][:, -1] == 0.0) + + +def test_execute_full_chunk_masks_unexecuted_suffix_on_early_done(): + device = torch.device("cpu") + num_envs = 2 + rollout_len = 1 + obs_dim = 5 + action_dim = 2 + + env = _FakeEnvTerminateOnFirstChunkStep( + num_envs=num_envs, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + ) + policy = _FakePolicyExecuteFullChunk( + obs_dim=obs_dim, action_dim=action_dim, device=device + ) + collector = SyncCollector(env=env, policy=policy, device=device) + buffer = RolloutBuffer( + num_envs=num_envs, + rollout_len=rollout_len, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + action_chunk_size=policy.action_chunk_size, + ) + + rollout = collector.collect( + num_steps=rollout_len, + rollout=buffer.start_rollout(), + ) + + assert torch.all(rollout["step_repeat"][:, 0] == 1.0) + assert torch.all(rollout["action_chunk"][:, 0, 1:] == 0.0) + assert torch.allclose( + rollout["sample_log_prob"][:, 0], + torch.full((num_envs,), 0.2, dtype=torch.float32, device=device), + ) + + +def test_rollout_buffer_rejects_raw_and_flat_obs_storage_together(): + device = torch.device("cpu") + try: + RolloutBuffer( + num_envs=1, + rollout_len=1, + obs_dim=4, + action_dim=2, + device=device, + use_raw_obs=True, + store_flat_obs=True, + ) + except ValueError as exc: + assert "store_flat_obs=False" in str(exc) + else: + raise AssertionError("Expected raw+flat rollout buffer configuration to fail.") diff --git a/tests/agents/test_vla_policy.py b/tests/agents/test_vla_policy.py new file mode 100644 index 00000000..3cf20dc7 --- /dev/null +++ b/tests/agents/test_vla_policy.py @@ -0,0 +1,42 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch + +from embodichain.agents.rl.models.vla_policy import VLAPolicy + + +def test_vla_policy_slice_obs_item_with_mapping_batch(): + policy = VLAPolicy( + device=torch.device("cpu"), + policy_cfg={"vla": {"model_path": "dummy_model_path"}}, + action_space=2, + ) + obs = { + "robot": { + "qpos": torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32), + "qvel": torch.tensor([[5.0, 6.0], [7.0, 8.0]], dtype=torch.float32), + }, + "image": torch.tensor([[9.0], [10.0]], dtype=torch.float32), + } + + assert policy._infer_batch_size(obs) == 2 + obs_1 = policy._slice_obs_item(obs, 1) + assert torch.allclose(obs_1["robot"]["qpos"], torch.tensor([3.0, 4.0])) + assert torch.allclose(obs_1["robot"]["qvel"], torch.tensor([7.0, 8.0])) + assert torch.allclose(obs_1["image"], torch.tensor([10.0]))