Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,19 @@ def _build_parser():

parser.add_argument("-v", "--verbose", action="store_true")

parser.add_argument(
"--calibration_num_threads",
type=int,
default=0,
help="Thread count for calibration forward passes. 0 = auto-tune (default).",
)

return parser


def export_llama(args) -> None:
if args.calibration_num_threads < 0:
raise ValueError("--calibration_num_threads must be >= 0")
if args.compile_only and args.pre_gen_pte:
raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true")
if (TASKS_EVAL or SQNR_EVAL) in args.eval_methods and args.model_mode not in {
Expand Down
91 changes: 84 additions & 7 deletions examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import inspect
import json
import logging
import os
import time
import types

from functools import partial
Expand Down Expand Up @@ -412,6 +414,63 @@ def _tag_ios(self, node, fixed_point_type):

return quant_io_type

def _auto_tune_calibration_threads(self):
"""Find the optimal thread count for calibration via quick microbenchmark.

AR1 decode calibration is SGEMV-dominated (memory-bandwidth-bound).
The default thread count (os.cpu_count()) is typically far too high,
causing massive OpenMP sync overhead. This runs a few forward passes
at candidate thread counts and picks the fastest.
"""
# Use sched_getaffinity when available — it respects cgroup/taskset
# constraints (e.g. containers), unlike os.cpu_count() which returns
# the host total regardless of pinning.
available = (
len(os.sched_getaffinity(0))
if hasattr(os, "sched_getaffinity")
else (os.cpu_count() or 1)
)
baseline = min(torch.get_num_threads(), available)
# Sample fractions of the thread ceiling from low through the
# bandwidth-saturation knee up to the current default.
fractions = (1 / 8, 1 / 4, 3 / 8, 1 / 2, 2 / 3, 3 / 4, 1.0)
candidates = sorted(
{1, baseline} | {max(1, round(baseline * f)) for f in fractions}
)
original = torch.get_num_threads()
best_threads, best_time = original, float("inf")
try:
for n_threads in candidates:
torch.set_num_threads(n_threads)
try:
with torch.no_grad():
self.decoder(*self.export_input) # warmup
t0 = time.perf_counter()
for _ in range(3):
self.decoder(*self.export_input)
elapsed = time.perf_counter() - t0
if elapsed < best_time:
best_threads, best_time = n_threads, elapsed
except Exception:
logging.debug("Auto-tune: threads=%d failed, skipping", n_threads)
continue
finally:
torch.set_num_threads(original)
if best_time == float("inf"):
logging.warning(
"Auto-tune: all candidates %s failed, falling back to %d threads",
candidates,
baseline,
)
return baseline
logging.info(
"Auto-tune calibration threads: tested %s, best=%d (%.1fms/fwd)",
candidates,
best_threads,
best_time / 3 * 1000,
)
return best_threads

def _calibrate(
self,
model,
Expand Down Expand Up @@ -552,6 +611,14 @@ def quantize(self, request: Request): # noqa: C901
self.decoder, self.export_input, strict=True
).module()

# Auto-tune thread count BEFORE prepare_pt2e so the benchmark
# runs on the exported model without observers — no risk of
# polluting observer state with synthetic inputs.
if self.mode == Mode.DECODE or not self.model_args.use_kv_cache:
calib_threads = getattr(self.control_args, "calibration_num_threads", 0)
if calib_threads <= 0:
calib_threads = self._auto_tune_calibration_threads()

self.decoder = prepare_pt2e(self.decoder, quantizer)
if self.apply_embedding:
self.tok_embedding = prepare_pt2e(
Expand All @@ -560,14 +627,24 @@ def quantize(self, request: Request): # noqa: C901

# start calibration (only for kv mode or prefill mode without kv cache)
if self.mode == Mode.DECODE or not self.model_args.use_kv_cache:
self._calibrate(
model=self.decoder,
tokenizer=data.tokenizer,
event="prepare_pt2e",
user_calibration_data=data.calibration_data.datasets,
tok_embedding=self.tok_embedding,
intermediate_outputs=image_embedding,
original_threads = torch.get_num_threads()
torch.set_num_threads(calib_threads)
Comment on lines +630 to +631
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it actually do and mean? How is it different between cpu and gpu? Can we use gpu to calibrate still?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked a bit more and this is what claude said

PyTorch uses a heuristic that depends on the environment:

  • Locally / outside containers: It typically defaults to the number of logical CPU cores (os.cpu_count()), which counts hyperthreaded cores.
  • In containers / limited environments (like Docker with CPU limits, Kubernetes, or certain cloud VMs): PyTorch tries to respect CPU affinity and cgroup limits, so the thread count may be lower.
  • With OpenMP: If PyTorch is compiled with OpenMP (common on Linux), the thread count may be governed by OMP_NUM_THREADS, which, if unset, OpenMP often sets to the logical core count.

It seems like this is specific for PyTorch OpenMP built

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what Qualcomm folks set up is. @haowhsu-quic

Copy link
Contributor Author

@abhinaykukkadapu abhinaykukkadapu Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, in my experiments, the high per iteration time is due to threads waiting at the barrier (you can see the large pillar in the flamegraph from the GH linked issues, it is named mkl_blas_sgemv). This is matrix-vector multiply, specific to decode though as the workloads are smaller due to conv2d kernels, pytorch seems to default high thread counts assuming larger workloads.

@haowhsu-quic can you please pull this PR on top of main (i just merged my coarse + fine pr) and see if tuning works on other vms.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about GPU? Does it make a difference?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytorch seems to default high thread counts assuming larger workloads.

what is PyTorch logic here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is PyTorch logic here?

It sets the openmp and mkl to max threads available for the host: https://github.com/pytorch/pytorch/blob/cc57e0e7ca87ea3a9a2367a859112ea16b6afbee/aten/src/ATen/ParallelOpenMP.cpp#L38

How about GPU? Does it make a difference?

No, the thread tuning is relevant to CPU only hosts, the GPU path is untouched.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. It looks like the initial thread also use mkl_get_max_threads. Not quite sure how they're different...

Regarding GPU, I think the GPU logic is shared with CPU? Like we can also do model.to("cuda") and do the rest if needed and it goes through the same path. I ran this path a while ago, unsure if it is still works. Just trying to make fewer burden for us to use gpu to calibrate model here

logging.info(
"Calibration using %d threads (was %d)",
calib_threads,
original_threads,
)
try:
self._calibrate(
model=self.decoder,
tokenizer=data.tokenizer,
event="prepare_pt2e",
user_calibration_data=data.calibration_data.datasets,
tok_embedding=self.tok_embedding,
intermediate_outputs=image_embedding,
)
finally:
torch.set_num_threads(original_threads)
else:
# one dummy inference to remove affine observer
# error happened in convert_pt2e
Expand Down
Loading