Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
57cce29
Add Miles Qwen3-8B GRPO training example on H100
xyuzh Feb 23, 2026
2c18184
Change num-rollout from 3000 to 5
xyuzh Feb 26, 2026
6ca34bc
Polish Qwen3-8B GRPO example
robertnishihara Feb 27, 2026
e7aa67c
Use declarative compute config instead of hardcoded instance types
robertnishihara Feb 27, 2026
4c57296
Fix declarative compute config resource requirements
robertnishihara Feb 28, 2026
7e690cd
Use Ray remote to run weight conversion on GPU worker
robertnishihara Feb 28, 2026
eba802a
Add Ray remote wrapper for training script
robertnishihara Feb 28, 2026
764d8fa
Rename miles_qwen3_8b_h100 to rl_with_miles
robertnishihara Mar 6, 2026
97ace47
Add explanation for Ray remote wrappers in README
robertnishihara Mar 6, 2026
5d6ff19
Fix GPU allocation in train_remote.py
robertnishihara Mar 6, 2026
0e270ae
Fix: Remove GPU reservation from train_remote wrapper
robertnishihara Mar 6, 2026
9260e65
Use accelerator type label to ensure GPU node placement
robertnishihara Mar 6, 2026
8e1f17d
Use Ray node affinity scheduling for GPU placement
robertnishihara Mar 6, 2026
1951ae1
Use label_selector for H100 node placement
robertnishihara Mar 6, 2026
8d0fdba
Add label_selector to convert_weights_remote.py
robertnishihara Mar 6, 2026
87d98fd
Simplify wrapper scripts
robertnishihara Mar 6, 2026
6cf53a3
Update README to explain label_selector usage
robertnishihara Mar 6, 2026
24b106f
Revert to sys.exit() instead of raise SystemExit()
robertnishihara Mar 6, 2026
5b1ca54
Remove GPU reservation from conversion wrapper
robertnishihara Mar 6, 2026
25d089d
Fix PYTHONBUFFERED typo in entrypoint.sh
robertnishihara Mar 6, 2026
b68e7e3
Fix PEP 8 spacing in wrapper scripts
robertnishihara Mar 6, 2026
a8f5d82
Use spot instances for H100 workers
robertnishihara Mar 6, 2026
65d7cdd
Fix GPU access in convert_weights_remote.py
robertnishihara Mar 6, 2026
922bebd
Set CUDA_VISIBLE_DEVICES explicitly for all GPUs
robertnishihara Mar 6, 2026
390fcbe
Remove num_gpus reservation from convert wrapper
robertnishihara Mar 6, 2026
395f24f
Fix TensorBoard configuration
robertnishihara Mar 6, 2026
f9cbc9c
Set TENSORBOARD_DIR in job env_vars
robertnishihara Mar 6, 2026
05f7b4b
Remove redundant TENSORBOARD_DIR from entrypoint
robertnishihara Mar 6, 2026
4d9c860
Use on-demand instances instead of spot
robertnishihara Mar 7, 2026
6355c4a
Clean up formatting
robertnishihara Mar 7, 2026
f8e4af1
Merge polish-qwen3-example into miles-qwen3-8b-h100
robertnishihara Mar 7, 2026
b499b61
Fix wrapper scripts with correct versions from polish-qwen3-example
robertnishihara Mar 7, 2026
f7f9d45
Remove specific memory size from H100 GPUs
robertnishihara Mar 7, 2026
551421c
Add explanation for CUDA_DEVICE_MAX_CONNECTIONS setting
robertnishihara Mar 7, 2026
b03c478
Update example to 2-node configuration with spot instances
robertnishihara Mar 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions rl_with_miles/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
FROM anyscale/ray:2.54.0-py312-cu129

ARG PATCH_VERSION=latest
ARG MEGATRON_COMMIT=3714d81d418c9f1bca4594fc35f9e8289f652862
ARG SGLANG_COMMIT=24c91001cf99ba642be791e099d358f4dfe955f5
ARG MILES_REF=main

# Anyscale base image runs as non-root; use sudo for system installs.
WORKDIR /home/ray

RUN sudo apt-get update && \
sudo apt-get install -y --no-install-recommends git rsync dnsutils nvtop && \
sudo rm -rf /var/lib/apt/lists/*

# Keep pip tooling current and pin numpy to 1.x for Megatron compatibility.
RUN python -m pip install --upgrade pip setuptools wheel && \
python -m pip install "numpy<2" huggingface_hub

# Pin PyTorch 2.9.1 — matches sgl_kernel from PyPI (compiled for torch 2.9.x)
# and has a pre-built flash-attn 2.8.3 wheel available.
RUN python -m pip install torch==2.9.1 torchvision torchaudio \
--index-url https://download.pytorch.org/whl/cu128

# Pre-built flash-attn wheel for torch 2.9 + cu12 (source compilation
# exceeds Anyscale's ~60 min build timeout).
RUN python -m pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3%2Bcu12torch2.9cxx11abiTRUE-cp312-cp312-linux_x86_64.whl

# Apex: install Python-only (no CUDA extensions) to stay within Anyscale's
# ~60 min build timeout. Megatron falls back to PyTorch-native kernels.
RUN git clone --filter=blob:none https://github.com/NVIDIA/apex.git /tmp/apex && \
cd /tmp/apex && \
git checkout 10417aceddd7d5d05d7cbf7b0fc2daad1105f8b4 && \
python -m pip install --disable-pip-version-check --no-cache-dir \
--no-build-isolation . && \
rm -rf /tmp/apex

# Install SGLang from source. sgl_kernel comes from PyPI, pre-compiled
# for torch 2.9.x — no need to rebuild from source.
RUN git clone https://github.com/sgl-project/sglang.git /home/ray/sglang && \
cd /home/ray/sglang && \
git checkout ${SGLANG_COMMIT} && \
python -m pip install -e "python[all]"

# Install Megatron-LM from source.
RUN git clone --recursive https://github.com/NVIDIA/Megatron-LM.git /home/ray/Megatron-LM && \
cd /home/ray/Megatron-LM && \
git checkout ${MEGATRON_COMMIT} && \
python -m pip install -e .

# Pull Miles source for patches and dependency manifests.
RUN git clone https://github.com/radixark/miles.git /tmp/miles && \
cd /tmp/miles && \
git checkout ${MILES_REF}

# Apply SGLang patch.
RUN cd /home/ray/sglang && \
cp /tmp/miles/docker/patch/${PATCH_VERSION}/sglang.patch ./sglang.patch && \
git update-index --refresh && \
git apply sglang.patch --3way && \
if grep -R -n '^<<<<<<< ' .; then \
echo "SGLang patch failed to apply cleanly. Please resolve conflicts." && \
exit 1; \
fi && \
rm sglang.patch

# Apply Megatron-LM patch.
RUN cd /home/ray/Megatron-LM && \
cp /tmp/miles/docker/patch/${PATCH_VERSION}/megatron.patch ./megatron.patch && \
git update-index --refresh && \
git apply megatron.patch --3way && \
if grep -R -n '^<<<<<<< ' .; then \
echo "Megatron patch failed to apply cleanly. Please resolve conflicts." && \
exit 1; \
fi && \
rm megatron.patch

# Install Miles dependencies.
RUN python -m pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps && \
python -m pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@dc6876905830430b5054325fa4211ff302169c6b --no-cache-dir --force-reinstall && \
python -m pip install git+https://github.com/fzyzcjy/Megatron-Bridge.git@dev_rl --no-build-isolation && \
python -m pip install "nvidia-modelopt[torch]>=0.37.0" --no-build-isolation

# Make MXFP8 quantizer import conditional — mxfp8_group_quantize was added
# in a newer SGLang than our pinned commit. Not needed for Qwen3-8B training.
RUN python -c "\
import pathlib; \
p = pathlib.Path('/tmp/miles/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_mxfp8.py'); \
t = p.read_text(); \
t = t.replace( \
'from sglang.srt.layers.quantization.fp8_utils import mxfp8_group_quantize', \
'try:\\n from sglang.srt.layers.quantization.fp8_utils import mxfp8_group_quantize\\nexcept ImportError:\\n mxfp8_group_quantize = None' \
); \
p.write_text(t)"

# Install Miles itself.
RUN python -m pip install -r /tmp/miles/requirements.txt && \
python -m pip install -e /tmp/miles --no-deps && \
cd /tmp/miles/miles/backends/megatron_utils/kernels/int4_qat && \
python -m pip install . --no-build-isolation

# Re-pin PyTorch 2.9.1 and reinstall flash-attn + TE at the end.
# Earlier installs may have upgraded torch, breaking pre-built binary wheels.
RUN python -c "import torch; print(f'Before re-pin: PyTorch {torch.__version__}')"
RUN python -m pip install torch==2.9.1 torchvision torchaudio \
--index-url https://download.pytorch.org/whl/cu128
RUN python -m pip install --force-reinstall --no-deps \
https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3%2Bcu12torch2.9cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
RUN python -m pip install --no-build-isolation "transformer_engine[pytorch]==2.10.0"

# Verify torch + flash-attn ABI compatibility.
# sgl_kernel is skipped here — it requires libcuda.so.1 (GPU hardware) to import.
RUN python -c "\
import torch; print(f'PyTorch: {torch.__version__}, CUDA: {torch.version.cuda}'); \
assert torch.__version__.startswith('2.9'), f'Expected 2.9.x, got {torch.__version__}'; \
from flash_attn import flash_attn_func; print('flash-attn OK')"

ENV PYTHONPATH=/home/ray/Megatron-LM:$PYTHONPATH

WORKDIR /tmp/miles
40 changes: 40 additions & 0 deletions rl_with_miles/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# GRPO Training for Qwen3-8B with MILES

This example demonstrates reinforcement learning fine-tuning of Qwen3-8B using **Group Relative Policy Optimization (GRPO)** on the DAPO-Math-17k dataset. It uses the [MILES](https://github.com/radixark/miles) framework for distributed RL training with disaggregated rollouts on Anyscale.

The training runs on **2 nodes with 8x H100 GPUs each** (16 GPUs total), using:
- **8 GPUs for training** (TP=2, DP=8 with Megatron-LM across 2 nodes)
- **8 GPUs for rollout inference** (disaggregated SGLang engines, 8 total)

## Install the Anyscale CLI

```bash
pip install -U anyscale
anyscale login
```

## Submit the job

Clone the example from GitHub.

```bash
git clone https://github.com/anyscale/examples.git
cd examples/rl_with_miles
```

Submit the job.

```bash
anyscale job submit -f job.yaml
```

The entrypoint will automatically download the model and dataset, convert weights to Megatron format, and start training. Training progress can be monitored via TensorBoard logs in `/mnt/cluster_storage/tensorboard_logs`.

## Understanding the example

- **Algorithm**: This example uses GRPO with DAPO-style asymmetric clipping (ε_low=0.2, ε_high=0.28), which is particularly effective for math reasoning tasks.
- **Dataset**: [DAPO-Math-17k](https://huggingface.co/datasets/zhuzilin/dapo-math-17k) contains 17k integer math problems with deterministic reward signals based on answer correctness.
- **Disaggregated architecture**: Training and rollout happen on separate GPUs for maximum throughput. GPU placement is handled automatically by MILES using Ray placement groups, which uses node 1 for all training GPUs and node 2 for all rollout GPUs.
- **Weight conversion**: On the first run, HuggingFace weights are converted to Megatron-LM's `torch_dist` format. Converted weights are cached in `/mnt/cluster_storage/Qwen3-8B_torch_dist` for subsequent runs.
- **Async training**: The pipeline uses `train_async.py` which overlaps rollout generation and policy updates for better GPU utilization.
- **Ray remote wrappers**: The MILES scripts are wrapped in Ray remote functions (`convert_weights_remote.py` and `train_remote.py`) to ensure they execute on GPU worker nodes rather than the CPU-only head node. Both wrappers use `label_selector={"ray.io/accelerator-type": "H100"}` to match the accelerator type specified in `job.yaml`, ensuring placement on H100 GPU nodes. Both wrappers explicitly set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7` in the subprocess environment to provide access to all 8 GPUs. Neither wrapper reserves GPUs with `num_gpus` to allow the subprocesses to manage GPU allocation internally.
36 changes: 36 additions & 0 deletions rl_with_miles/convert_weights_remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Ray remote wrapper for weight conversion - ensures it runs on a GPU worker."""
import sys
import subprocess
import os
import ray


@ray.remote(label_selector={"ray.io/accelerator-type": "H100"})
def convert_weights(cmd_args):
"""Run weight conversion on a GPU worker.

Uses label selector to ensure placement on H100 GPU nodes.
The label must match the accelerator-type in job.yaml compute_config.

Explicitly sets CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 so the subprocess
can access all GPUs on the worker node. Does not reserve GPUs (num_gpus)
to allow flexible GPU allocation.
"""
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"

result = subprocess.run(
["python", "/tmp/miles/tools/convert_hf_to_torch_dist.py"] + cmd_args,
capture_output=True,
text=True,
env=env
)
return result.returncode, result.stdout, result.stderr


returncode, stdout, stderr = ray.get(convert_weights.remote(sys.argv[1:]))
if stdout:
print(stdout, end="")
if stderr:
print(stderr, end="", file=sys.stderr)
sys.exit(returncode)
145 changes: 145 additions & 0 deletions rl_with_miles/entrypoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!/bin/bash
# Anyscale entrypoint: Qwen3-8B GRPO training on 2 workers × 8x H100
# Downloads model/dataset, converts weights, and runs async RL training.
#
# Head node (m5.2xlarge): driver only, no GPUs
# GPU Placement (determined by MILES using Ray Placement Groups with PACK strategy):
# Node 1 (8x H100): Training GPUs 0-7 (TP=2, DP=8)
# Node 2 (8x H100): Rollout GPUs 0-7 (8 SGLang engines, 1 GPU each)

set -ex

export PYTHONUNBUFFERED=1
STORAGE=/mnt/cluster_storage

# Qwen3-8B model architecture args (from scripts/models/qwen3-8B.sh)
MODEL_ARGS=(
--swiglu
--num-layers 36
--hidden-size 4096
--ffn-hidden-size 12288
--num-attention-heads 32
--group-query-attention
--num-query-groups 8
--use-rotary-position-embeddings
--disable-bias-linear
--normalization "RMSNorm"
--norm-epsilon 1e-6
--rotary-base 1000000
--vocab-size 151936
--kv-channels 128
--qk-layernorm
--untie-embeddings-and-output-weights
)

# ======================== Step 1: Download model & dataset ========================

echo "=== Downloading model ==="
huggingface-cli download Qwen/Qwen3-8B --local-dir ${STORAGE}/Qwen3-8B

echo "=== Downloading dataset ==="
huggingface-cli download --repo-type dataset zhuzilin/dapo-math-17k --local-dir ${STORAGE}/dapo-math-17k

# ======================== Step 2: Convert HF weights to torch_dist ========================

if [ ! -d "${STORAGE}/Qwen3-8B_torch_dist/iter_0000000" ]; then
echo "=== Converting weights (HF -> torch_dist) on GPU worker ==="
python convert_weights_remote.py \
${MODEL_ARGS[@]} \
--no-gradient-accumulation-fusion \
--hf-checkpoint ${STORAGE}/Qwen3-8B \
--save ${STORAGE}/Qwen3-8B_torch_dist
else
echo "=== Converted weights already exist, skipping ==="
fi

# ======================== Step 3: Run training ========================

CKPT_ARGS=(
--hf-checkpoint ${STORAGE}/Qwen3-8B
--ref-load ${STORAGE}/Qwen3-8B_torch_dist
--load ${STORAGE}/Qwen3-8B_torch_dist
--save ${STORAGE}/Qwen3-8B_miles/
--save-interval 20
)

ROLLOUT_ARGS=(
--prompt-data ${STORAGE}/dapo-math-17k/dapo-math-17k.jsonl
--input-key prompt
--label-key label
--apply-chat-template
--rollout-shuffle
--balance-data
--rm-type dapo
--reward-key score
--num-rollout 5
--rollout-batch-size 32
--n-samples-per-prompt 8
--rollout-max-response-len 8192
--rollout-temperature 1
--global-batch-size 256
)

PERF_ARGS=(
--tensor-model-parallel-size 2
--sequence-parallel
--pipeline-model-parallel-size 1
--context-parallel-size 1
--expert-model-parallel-size 1
--expert-tensor-parallel-size 1

--recompute-granularity full
--recompute-method uniform
--recompute-num-layers 1

--use-dynamic-batch-size
--max-tokens-per-gpu 9216
)

GRPO_ARGS=(
--advantage-estimator grpo
--use-kl-loss
--kl-loss-coef 0.00
--kl-loss-type low_var_kl
--entropy-coef 0.00
--eps-clip 0.2
--eps-clip-high 0.28
)

OPTIMIZER_ARGS=(
--optimizer adam
--lr 1e-6
--lr-decay-style constant
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.98
)

SGLANG_ARGS=(
--rollout-num-gpus-per-engine 1
--sglang-mem-fraction-static 0.7
)

MISC_ARGS=(
--no-gradient-accumulation-fusion
--attention-dropout 0.0
--hidden-dropout 0.0
--accumulate-allreduce-grads-in-fp32
--attention-softmax-in-fp32
--attention-backend flash
--use-tensorboard
)

echo "=== Starting training ==="
python train_remote.py \
--actor-num-nodes 2 \
--actor-num-gpus-per-node 4 \
--rollout-num-gpus 8 \
${MODEL_ARGS[@]} \
${CKPT_ARGS[@]} \
${ROLLOUT_ARGS[@]} \
${OPTIMIZER_ARGS[@]} \
${GRPO_ARGS[@]} \
${PERF_ARGS[@]} \
${SGLANG_ARGS[@]} \
${MISC_ARGS[@]}
32 changes: 32 additions & 0 deletions rl_with_miles/job.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: miles-qwen3-8b-grpo-h100

containerfile: ./Dockerfile

compute_config:
head_node:
required_resources:
CPU: 8
memory: 32Gi
worker_nodes:
- required_resources:
CPU: 192
memory: 2048Gi
GPU: 8
required_labels:
ray.io/accelerator-type: H100
min_nodes: 2
max_nodes: 2

working_dir: .

entrypoint: bash entrypoint.sh

env_vars:
# Standard setting for Megatron-LM with tensor parallelism on H100 GPUs.
# Limits concurrent CUDA kernel launches to prevent deadlocks with NCCL
# collective operations during distributed training.
CUDA_DEVICE_MAX_CONNECTIONS: "1"
TENSORBOARD_DIR: "/mnt/cluster_storage/tensorboard_logs"

max_retries: 0
timeout_s: 7200
Loading