Feature: add VLA policy and registry for RL#186
Conversation
There was a problem hiding this comment.
Pull request overview
Adds support for integrating a VLA (vision-language-action) model into the existing RL stack by introducing a new VLAPolicy, wiring raw (hierarchical) observations + chunked actions through collection/eval/training, and adding an entry-point based backend registry.
Changes:
- Introduce
VLAPolicyand register it in the RL policy registry. - Extend rollout collection/training (collector, buffer, GRPO, trainer eval) to support raw observations and action chunks (
action_chunk/chunk_step). - Add
vla_registryto discover VLA backend factories via Python entry points.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 10 comments.
Show a summary per file
| File | Description |
|---|---|
embodichain/agents/rl/vla_registry.py |
New entry-point based backend registry + factory creation. |
embodichain/agents/rl/models/vla_policy.py |
New VLAPolicy wrapper for VLA inference + GRPO-compatible evaluate_actions. |
embodichain/agents/rl/models/__init__.py |
Registers vla_policy; extends build_policy to optionally pass env/policy_cfg. |
embodichain/agents/rl/collector/sync_collector.py |
Adds raw-observation storage and action-chunk caching + chunk_step. |
embodichain/agents/rl/buffer/standard_buffer.py |
Adds use_raw_obs and attaches raw_obs list to shared rollout. |
embodichain/agents/rl/buffer/utils.py |
Propagates chunk_step into transition view; adds _indices in minibatches. |
embodichain/agents/rl/algo/grpo.py |
Passes rollout + num_envs into evaluate_actions; preserves raw fields across clone. |
embodichain/agents/rl/utils/trainer.py |
Adjusts buffer sizing for chunked actions; updates eval loop for raw obs/chunks. |
embodichain/agents/rl/models/actor_only.py |
Updates evaluate_actions signature to accept extra kwargs. |
embodichain/agents/rl/models/actor_critic.py |
Updates evaluate_actions signature to accept extra kwargs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Adds first-class support for VLA-backed policies in the RL stack by introducing a VLA policy wrapper, an entry-point-based backend registry, and rollout/collector plumbing for raw observations + chunked actions.
Changes:
- Introduces
VLAPolicyand registers it in the RL policy registry. - Adds
vla_registryto discover/load VLA backend factories via Python entry points. - Extends rollout collection/training/eval utilities to support
raw_obs,chunk_step, and action-chunk caching.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
embodichain/agents/rl/vla_registry.py |
Entry-point discovery + factory creation for pluggable VLA backends. |
embodichain/agents/rl/utils/trainer.py |
Trainer buffer allocation + eval loop updated for raw obs and action chunks. |
embodichain/agents/rl/models/vla_policy.py |
New VLA-backed policy wrapper implementing chunked action inference and proxy log-prob evaluation. |
embodichain/agents/rl/models/actor_only.py |
Broadens evaluate_actions signature to accept extra kwargs. |
embodichain/agents/rl/models/actor_critic.py |
Broadens evaluate_actions signature to accept extra kwargs. |
embodichain/agents/rl/models/__init__.py |
Registers vla_policy and adds env-dependent initialization path in build_policy. |
embodichain/agents/rl/collector/sync_collector.py |
Adds raw_obs storage + action chunk caching + chunk_step tracking. |
embodichain/agents/rl/buffer/utils.py |
Propagates chunk_step into transition view and adds minibatch _indices. |
embodichain/agents/rl/buffer/standard_buffer.py |
Adds use_raw_obs handling and allocates rollout.raw_obs. |
embodichain/agents/rl/algo/grpo.py |
Passes rollout context into evaluate_actions and preserves rollout attributes across clone. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| raw_obs = getattr(rollout, "raw_obs", None) | ||
| chunk_step = getattr(rollout, "chunk_step", None) | ||
| rollout = rollout.clone() | ||
| if raw_obs is not None: |
There was a problem hiding this comment.
This part seems useless?
There was a problem hiding this comment.
Small fault, forgot to remove it.
| if self._is_full: | ||
| raise RuntimeError("RolloutBuffer already contains a rollout.") | ||
| self._clear_dynamic_fields() | ||
| if self.use_raw_obs: |
There was a problem hiding this comment.
What is the purpose of adding raw_obs?
There was a problem hiding this comment.
VLA requires inputs of full images instead of flattened vector. Here I use raw_obs to separate VLA policy and other policy.
|
|
||
| if use_raw_obs: | ||
| if raw_obs_list is None: | ||
| raise ValueError( |
| self.action_chunk_size = self.action_horizon | ||
| self._env = None | ||
|
|
||
| def set_env(self, env) -> None: |
There was a problem hiding this comment.
Why adding env to policy?
There was a problem hiding this comment.
After getting the raw obs from env, VLA will package them as batches which fits the requirements of their inputs. This step needs information from env
| tensordict["value"] = torch.zeros(b, device=self.device, dtype=torch.float32) | ||
| return tensordict | ||
|
|
||
| def evaluate_actions( |
There was a problem hiding this comment.
num_envs seems also useless
There was a problem hiding this comment.
Pull request overview
Adds VLA (Vision-Language-Action) integration into the RL stack by introducing a VLAPolicy wrapper, extending rollout collection/training to support raw observations and chunked actions, and adding a registry for VLA backends via entry points.
Changes:
- Introduce
VLAPolicyandvla_registryto load and run VLA backends inside RL policies. - Extend rollout collection/evaluation to support
use_raw_obsand chunked actions (action_chunk+chunk_step). - Adjust minibatching/GRPO plumbing to pass rollout context (
raw_obs, indices) intoevaluate_actions.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| embodichain/agents/rl/vla_registry.py | Adds entry-point-based backend discovery + factory creation for VLA backends. |
| embodichain/agents/rl/models/vla_policy.py | New policy wrapper that runs a VLA backend and exposes RL Policy interface with action chunks + raw obs. |
| embodichain/agents/rl/utils/trainer.py | Updates buffer sizing and evaluation loop to handle raw obs + chunked actions. |
| embodichain/agents/rl/train.py | Passes env into build_policy for VLA policy initialization. |
| embodichain/agents/rl/models/init.py | Registers vla_policy and extends build_policy to support env/policy_cfg and VLA initialization. |
| embodichain/agents/rl/collector/sync_collector.py | Extends collector to populate raw_obs, generate/consume chunked actions, and track chunk_step. |
| embodichain/agents/rl/buffer/utils.py | Propagates chunk_step into transition view; adds _indices to minibatches for mapping back to rollout. |
| embodichain/agents/rl/buffer/standard_buffer.py | Allocates/clears raw_obs and chunk_step dynamic fields for VLA workflows. |
| embodichain/agents/rl/algo/grpo.py | Passes rollout into evaluate_actions to support VLA log-prob evaluation from raw obs. |
| embodichain/agents/rl/algo/ppo.py | Removes per-update rollout cloning (now relies on shared rollout lifecycle). |
| embodichain/agents/rl/models/actor_only.py | Allows evaluate_actions(..., **kwargs) to accept rollout context without breaking. |
| embodichain/agents/rl/models/actor_critic.py | Allows evaluate_actions(..., **kwargs) to accept rollout context without breaking. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| backend = create_vla_backend( | ||
| "dexforce_vla", |
There was a problem hiding this comment.
VLAPolicy._load_vla() hardcodes the backend name to "dexforce_vla" while also excluding a backend key from vla_cfg. If configuration is meant to choose among entry-point backends, this prevents it. Read backend from config (defaulting to dexforce_vla) and pass it into create_vla_backend.
| backend = create_vla_backend( | |
| "dexforce_vla", | |
| backend_name = str(self.vla_cfg.get("backend", "dexforce_vla")) | |
| backend = create_vla_backend( | |
| backend_name, |
| """Validate rollout layout expected by the collector.""" | ||
| obs_dim = rollout["obs"].shape[-1] | ||
| expected_shapes = { | ||
| "obs": (self.env.num_envs, num_steps + 1, self.policy.obs_dim), | ||
| "obs": (self.env.num_envs, num_steps + 1, obs_dim), | ||
| "action": (self.env.num_envs, num_steps + 1, self.policy.action_dim), |
There was a problem hiding this comment.
_validate_rollout() uses obs_dim = rollout["obs"].shape[-1] to form the expected shape, so the obs last-dimension check can never fail (making validation ineffective for this field). If the intent is to support policies with obs_dim == 0, consider validating against self.policy.obs_dim when it’s > 0 and skipping only that last-dim check otherwise.
| indices = torch.randperm(total) | ||
| for start in range(0, total, batch_size): | ||
| yield rollout[indices[start : start + batch_size]] | ||
| batch_indices = indices[start : start + batch_size] | ||
| batch = rollout[batch_indices].clone() | ||
| batch["_indices"] = batch_indices |
There was a problem hiding this comment.
iterate_minibatches() now builds indices on the default (CPU) device. If rollout is on CUDA, indexing rollout[batch_indices] will fail because the index tensor must be on the same device. Create indices on rollout.device (or the passed device) and keep _indices consistent with that choice.
| factory = ep.load() | ||
| name = str(ep.name).lower() | ||
| if name not in _VLA_BACKENDS: | ||
| _VLA_BACKENDS[name] = factory | ||
| except Exception: |
There was a problem hiding this comment.
Per-entry-point failures from ep.load() are silently swallowed, which makes plugin misconfiguration very hard to diagnose. Consider catching narrower exceptions and logging a warning/debug message with the entry point name/module when a backend fails to load.
| if use_action_chunk: | ||
| rollout.chunk_step[:, step_idx] = effective_step_in_chunk | ||
| # Invalidate cached_chunk on any env reset to avoid using old chunk for new episode | ||
| if (terminated | truncated).any(): | ||
| cached_chunk = None |
There was a problem hiding this comment.
When using action chunks, the collector invalidates the entire cached_chunk whenever any environment terminates/truncates. In a vectorized env, this forces all still-running envs to throw away their remaining chunk and recompute a new one, which can be very expensive for VLA backends and changes the intended per-env chunk semantics. Consider maintaining per-env cached chunks (and invalidating only the done env indices) to avoid unnecessary recomputation.
| {"obs": rollout["obs"][:, step_idx]}, | ||
| batch_size=[rollout.batch_size[0]], | ||
| use_raw_obs = getattr(self.policy, "use_raw_obs", False) | ||
| raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None |
There was a problem hiding this comment.
New behaviors (raw-observation rollouts via raw_obs and chunked-action support via chunk_step/action_chunk) are introduced here but aren’t covered by existing RL tests. Adding a focused unit test with a dummy policy exercising use_raw_obs=True and use_action_chunk=True would help prevent regressions (e.g., raw_obs population and chunk_step alignment).
| raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None | |
| raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None | |
| if use_raw_obs and raw_obs_list is None and isinstance(rollout, TensorDict): | |
| # Fallback to key-based access for raw observations when using a TensorDict. | |
| # This allows 'raw_obs' to be provided either as an attribute or a field key. | |
| if "raw_obs" in rollout.keys(): | |
| raw_obs_list = rollout.get("raw_obs") |
| except Exception: | ||
| pass |
There was a problem hiding this comment.
The outer except Exception: pass suppresses all discovery errors (including programming errors) and leaves the registry empty with no signal to the caller. At minimum, log the exception; ideally catch only expected errors (e.g., missing entry point group) and let unexpected ones surface.
| "so that set_env and _load_vla can be called before use." | ||
| ) | ||
| policy.set_env(env) | ||
| policy._load_vla() |
There was a problem hiding this comment.
Calling the private policy._load_vla() from build_policy() forces heavyweight backend/model loading at construction time and couples the factory to a private method. Prefer lazy-loading in VLAPolicy (or exposing a public init/load hook) so config parsing and policy construction stay lightweight and copying/cloning policies doesn’t accidentally duplicate backend state.
| "so that set_env and _load_vla can be called before use." | |
| ) | |
| policy.set_env(env) | |
| policy._load_vla() | |
| "so that set_env can be called before use." | |
| ) | |
| policy.set_env(env) |
| _ENTRY_POINTS_DISCOVERED = True | ||
| try: | ||
| eps = entry_points(group="embodichain.vla_backends") |
There was a problem hiding this comment.
_ENTRY_POINTS_DISCOVERED is set to True before attempting discovery. If entry_points() fails (or the first discovery attempt is partial), subsequent calls will never retry and the registry can remain empty. Consider only marking discovery complete after a successful pass, or allowing retries on failure.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 13 out of 13 changed files in this pull request and generated 10 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if pos_range: | ||
| random_value = sample_uniform( | ||
| lower=torch.tensor(pos_range[0]), | ||
| upper=torch.tensor(pos_range[1]), | ||
| lower=torch.tensor(pos_range[0], device=env.device), | ||
| upper=torch.tensor(pos_range[1], device=env.device), | ||
| size=(num_instance, 3), | ||
| device=env.device, | ||
| ) |
There was a problem hiding this comment.
The new torch.tensor(..., device=env.device) bounds rely on PyTorch’s default floating dtype, which can be float64 depending on global defaults. That can silently promote random_value/new_pose to float64 and increase compute/memory. Prefer constructing these bounds with dtype=torch.float32 (matching new_pose) for this and the other sample_uniform calls in this function.
| def _discover_entry_points() -> None: | ||
| """Discover and register VLA backends from entry_points.""" | ||
| global _ENTRY_POINTS_DISCOVERED | ||
| if _ENTRY_POINTS_DISCOVERED: | ||
| return | ||
| _ENTRY_POINTS_DISCOVERED = True | ||
| try: | ||
| eps = entry_points(group="embodichain.vla_backends") | ||
| for ep in eps: | ||
| try: | ||
| factory = ep.load() | ||
| name = str(ep.name).lower() | ||
| if name not in _VLA_BACKENDS: | ||
| _VLA_BACKENDS[name] = factory | ||
| except Exception: | ||
| pass | ||
| except Exception: | ||
| pass |
There was a problem hiding this comment.
_discover_entry_points() swallows all exceptions (both the outer discovery and individual ep.load() failures). This makes backend discovery failures silent and very hard to debug in production. Consider at least logging the exception (and which entry point failed) or re-raising in debug/dev mode; also prefer catching narrower exception types instead of Exception/pass.
| reward_sum = torch.zeros( | ||
| self.env.num_envs, dtype=torch.float32, device=self.device | ||
| ) | ||
| terminated = torch.zeros( | ||
| self.env.num_envs, dtype=torch.bool, device=self.device | ||
| ) | ||
| truncated = torch.zeros( | ||
| self.env.num_envs, dtype=torch.bool, device=self.device | ||
| ) | ||
| env_info = {} | ||
| next_obs_td = None | ||
|
|
||
| executed_substeps = 0 | ||
| # Execute the whole chunk sequentially | ||
| for sub_idx in range(action_chunk_size): | ||
| sub_action = chunk[:, sub_idx] | ||
| next_obs, reward, term_i, trunc_i, env_info = self.env.step( | ||
| self._to_action_dict(sub_action) | ||
| ) | ||
| executed_substeps += 1 | ||
| next_obs_td = dict_to_tensordict(next_obs, self.device) | ||
| reward_sum += reward.to(self.device).float() | ||
| terminated |= term_i.to(self.device) | ||
| truncated |= trunc_i.to(self.device) | ||
|
|
||
| # Stop chunk execution when any env reaches terminal/truncated. | ||
| if (term_i | trunc_i).any(): | ||
| break | ||
|
|
||
| if next_obs_td is None: | ||
| logger.log_error( | ||
| "Chunk execution produced no environment transition.", | ||
| RuntimeError, | ||
| ) | ||
|
|
||
| if use_action_chunk: | ||
| rollout.chunk_step[:, step_idx] = 0 | ||
| rollout["step_repeat"][:, step_idx] = float(executed_substeps) | ||
|
|
There was a problem hiding this comment.
When execute_full_chunk=True, the collector executes multiple env steps but collapses them into a single rollout transition (reward_sum, step_repeat). Downstream return/advantage code currently discounts with a single gamma per logical step (e.g., compute_gae / GRPO returns) and does not use step_repeat, so discounting becomes inconsistent with the actual number of environment steps executed. Consider adjusting return/GAE computation to use gamma ** step_repeat (and similarly for gae_lambda) or storing per-substep transitions instead of aggregating.
| step_td = self.policy.get_action(step_td) | ||
| cached_chunk = step_td["action_chunk"] | ||
| action = step_td["action"] | ||
| effective_step_in_chunk = 0 | ||
| chunk_cursor = 1 | ||
| elif use_action_chunk and cached_chunk is not None: | ||
| action = cached_chunk[:, chunk_cursor] | ||
| effective_step_in_chunk = chunk_cursor | ||
| step_td = TensorDict( | ||
| { | ||
| "action": action, | ||
| "sample_log_prob": torch.zeros( | ||
| action.shape[0], device=self.device, dtype=torch.float32 | ||
| ), | ||
| "value": torch.zeros( | ||
| action.shape[0], device=self.device, dtype=torch.float32 | ||
| ), | ||
| }, | ||
| batch_size=[rollout.batch_size[0]], | ||
| device=self.device, | ||
| ) | ||
| chunk_cursor += 1 |
There was a problem hiding this comment.
In the cached-chunk path (use_action_chunk=True, execute_full_chunk=False), subsequent steps build step_td without action_chunk, so rollout["action_chunk"][..., step_idx] stays at its default (zeros). But transition_view() always includes action_chunk, so VLAPolicy.evaluate_actions() will treat those zeros as the ground-truth chunk and compute incorrect log-probs. Ensure each logical step stores the relevant chunk (e.g., write the same cached_chunk into step_td["action_chunk"] for all steps in the chunk, or change evaluation logic to only use full-chunk likelihood when a valid chunk is present).
| def iterate_minibatches( | ||
| rollout: TensorDict, batch_size: int, device: torch.device | ||
| ) -> Iterator[TensorDict]: | ||
| """Yield shuffled minibatches from a flattened rollout.""" | ||
| total = rollout.batch_size[0] | ||
| indices = torch.randperm(total, device=device) | ||
| indices = torch.randperm(total) | ||
| for start in range(0, total, batch_size): | ||
| yield rollout[indices[start : start + batch_size]] | ||
| batch_indices = indices[start : start + batch_size] | ||
| batch = rollout[batch_indices].clone() | ||
| batch["_indices"] = batch_indices | ||
| yield batch |
There was a problem hiding this comment.
iterate_minibatches() no longer uses the provided device argument and now builds indices on CPU, then injects CPU _indices into (typically GPU) batches. This can introduce avoidable CPU↔GPU sync/transfer overhead and may break assumptions that all batch tensors are on the same device. Consider creating indices on device (or at least moving batch_indices to rollout.device) and either storing _indices on the same device or keeping it out-of-band.
| backend = create_vla_backend( | ||
| "dexforce_vla", | ||
| model_path=self.model_path, | ||
| device=self.device, | ||
| action_horizon=self.action_horizon, | ||
| **{ | ||
| k: v | ||
| for k, v in self.vla_cfg.items() | ||
| if k not in ("backend", "model_path", "action_horizon") | ||
| }, | ||
| ) |
There was a problem hiding this comment.
VLAPolicy._load_vla() always uses backend name "dexforce_vla" even though policy_cfg["vla"] appears to reserve a backend key (it’s explicitly filtered out of kwargs). This prevents selecting other backends discovered via vla_registry. Consider reading self.vla_cfg.get("backend", "dexforce_vla") (or similar) and passing that into create_vla_backend(...).
| if hasattr(obs, "batch_size") and len(obs.batch_size) > 0: | ||
| batch_size = int(obs.batch_size[0]) | ||
| elif isinstance(obs, dict) and "robot" in obs and "qpos" in obs["robot"]: | ||
| q = obs["robot"]["qpos"] | ||
| batch_size = q.shape[0] if hasattr(q, "shape") and len(q.shape) > 0 else 1 | ||
| else: | ||
| batch_size = 1 | ||
| if batch_size == 1: | ||
| batch = self._prepare_batch_fn(obs, env) | ||
| vla_chunk = self._vla_model.predict_action( | ||
| batch, | ||
| action_only=True, | ||
| inference_horizon=self.action_horizon, | ||
| allow_grad=False, | ||
| use_fix_aug=False, | ||
| ) | ||
| action_chunk_env = self._vla_chunk_to_env_chunk(vla_chunk, env=env) | ||
| else: | ||
| chunks_env = [] | ||
| for i in range(batch_size): | ||
| obs_i = obs[i] if hasattr(obs, "__getitem__") else obs | ||
| batch_i = self._prepare_batch_fn(obs_i, env) | ||
| vla_chunk = self._vla_model.predict_action( | ||
| batch_i, | ||
| action_only=True, | ||
| inference_horizon=self.action_horizon, | ||
| allow_grad=False, | ||
| use_fix_aug=False, | ||
| ) | ||
| chunk_i = self._vla_chunk_to_env_chunk(vla_chunk, env=env) | ||
| chunks_env.append(chunk_i) | ||
| action_chunk_env = torch.cat(chunks_env, dim=0) |
There was a problem hiding this comment.
If obs is a Python dict with a leading batch dimension (the code infers batch_size from obs["robot"]["qpos"]), the subsequent loop uses obs[i], which will try to index the dict by integer and fail. Either ensure obs is always converted to a TensorDict/array before batching logic, or add a proper per-env slicing implementation for mapping-based observations.
| use_raw_obs = getattr(self.policy, "use_raw_obs", False) | ||
| raw_obs_list = getattr(rollout, "raw_obs", None) if use_raw_obs else None | ||
|
|
||
| if use_raw_obs: | ||
| if raw_obs_list is None: | ||
| logger.log_error( | ||
| "Policy requires raw observations, " | ||
| "but the provided rollout TensorDict has no 'raw_obs' buffer. " | ||
| "Create the rollout via RolloutBuffer or " | ||
| "start_rollout so that 'raw_obs' is allocated.", | ||
| ValueError, | ||
| ) | ||
| try: | ||
| raw_obs_len = len(raw_obs_list) | ||
| except TypeError: | ||
| logger.log_error( | ||
| "Rollout field 'raw_obs' must be an indexable sequence of length " | ||
| f"{num_steps + 1} when policy.use_raw_obs=True.", | ||
| ValueError, | ||
| ) | ||
| expected_len = num_steps + 1 | ||
| if raw_obs_len != expected_len: | ||
| logger.log_error( | ||
| "Rollout 'raw_obs' length mismatch: " | ||
| f"expected {expected_len} (num_steps + 1), got {raw_obs_len}. " | ||
| "Ensure the rollout was created with use_raw_obs=True and " | ||
| "its time dimension matches the requested num_steps.", | ||
| ValueError, | ||
| ) | ||
|
|
||
| action_chunk_size = getattr(self.policy, "action_chunk_size", 0) | ||
| use_action_chunk = ( | ||
| getattr(self.policy, "use_action_chunk", False) and action_chunk_size > 0 | ||
| ) | ||
| # Execute a full predicted action chunk inside one logical rollout step. | ||
| execute_full_chunk = bool(getattr(self.policy, "execute_full_chunk", False)) | ||
| cached_chunk = None | ||
| chunk_cursor = 0 | ||
|
|
||
| if use_action_chunk: | ||
| rollout.chunk_step = torch.zeros( | ||
| self.env.num_envs, | ||
| num_steps, | ||
| dtype=torch.long, | ||
| device=self.device, | ||
| ) | ||
| step_td = self.policy.get_action(step_td) | ||
| rollout["step_repeat"] = torch.ones( | ||
| self.env.num_envs, | ||
| num_steps + 1, | ||
| dtype=torch.float32, | ||
| device=self.device, | ||
| ) | ||
| rollout["step_repeat"][:, -1] = 0.0 | ||
|
|
||
| if use_raw_obs and raw_obs_list is not None: | ||
| raw_obs_list[0] = self.obs_td | ||
| rollout["obs"][:, 0] = flatten_dict_observation(self.obs_td) | ||
| else: | ||
| rollout["obs"][:, 0] = flatten_dict_observation(self.obs_td) | ||
|
|
There was a problem hiding this comment.
The new raw-obs + chunked-action execution paths are not covered by existing RL tests (e.g., tests/agents/test_shared_rollout.py currently only exercises basic rollout collection). Adding tests for: (1) use_raw_obs populating rollout.raw_obs, (2) action_chunk being written per-step (including cached-chunk mode), and (3) execute_full_chunk producing consistent step_repeat/reward aggregation would help prevent regressions.
Description
vla_policywrapper to integrate VLA model into RL policies.vla_registryto discover VLA-related factories via entry points.Type of change
Checklist
black .command to format the code base.