diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 2e7ae6d57d4..c034618fc70 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -510,10 +510,19 @@ def _build_parser(): parser.add_argument("-v", "--verbose", action="store_true") + parser.add_argument( + "--calibration_num_threads", + type=int, + default=0, + help="Thread count for calibration forward passes. 0 = auto-tune (default).", + ) + return parser def export_llama(args) -> None: + if args.calibration_num_threads < 0: + raise ValueError("--calibration_num_threads must be >= 0") if args.compile_only and args.pre_gen_pte: raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true") if (TASKS_EVAL or SQNR_EVAL) in args.eval_methods and args.model_mode not in { diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py index ebb9bed8b69..1737eea51c4 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py @@ -7,6 +7,8 @@ import inspect import json import logging +import os +import time import types from functools import partial @@ -412,6 +414,63 @@ def _tag_ios(self, node, fixed_point_type): return quant_io_type + def _auto_tune_calibration_threads(self): + """Find the optimal thread count for calibration via quick microbenchmark. + + AR1 decode calibration is SGEMV-dominated (memory-bandwidth-bound). + The default thread count (os.cpu_count()) is typically far too high, + causing massive OpenMP sync overhead. This runs a few forward passes + at candidate thread counts and picks the fastest. + """ + # Use sched_getaffinity when available — it respects cgroup/taskset + # constraints (e.g. containers), unlike os.cpu_count() which returns + # the host total regardless of pinning. + available = ( + len(os.sched_getaffinity(0)) + if hasattr(os, "sched_getaffinity") + else (os.cpu_count() or 1) + ) + baseline = min(torch.get_num_threads(), available) + # Sample fractions of the thread ceiling from low through the + # bandwidth-saturation knee up to the current default. + fractions = (1 / 8, 1 / 4, 3 / 8, 1 / 2, 2 / 3, 3 / 4, 1.0) + candidates = sorted( + {1, baseline} | {max(1, round(baseline * f)) for f in fractions} + ) + original = torch.get_num_threads() + best_threads, best_time = original, float("inf") + try: + for n_threads in candidates: + torch.set_num_threads(n_threads) + try: + with torch.no_grad(): + self.decoder(*self.export_input) # warmup + t0 = time.perf_counter() + for _ in range(3): + self.decoder(*self.export_input) + elapsed = time.perf_counter() - t0 + if elapsed < best_time: + best_threads, best_time = n_threads, elapsed + except Exception: + logging.debug("Auto-tune: threads=%d failed, skipping", n_threads) + continue + finally: + torch.set_num_threads(original) + if best_time == float("inf"): + logging.warning( + "Auto-tune: all candidates %s failed, falling back to %d threads", + candidates, + baseline, + ) + return baseline + logging.info( + "Auto-tune calibration threads: tested %s, best=%d (%.1fms/fwd)", + candidates, + best_threads, + best_time / 3 * 1000, + ) + return best_threads + def _calibrate( self, model, @@ -552,6 +611,14 @@ def quantize(self, request: Request): # noqa: C901 self.decoder, self.export_input, strict=True ).module() + # Auto-tune thread count BEFORE prepare_pt2e so the benchmark + # runs on the exported model without observers — no risk of + # polluting observer state with synthetic inputs. + if self.mode == Mode.DECODE or not self.model_args.use_kv_cache: + calib_threads = getattr(self.control_args, "calibration_num_threads", 0) + if calib_threads <= 0: + calib_threads = self._auto_tune_calibration_threads() + self.decoder = prepare_pt2e(self.decoder, quantizer) if self.apply_embedding: self.tok_embedding = prepare_pt2e( @@ -560,14 +627,24 @@ def quantize(self, request: Request): # noqa: C901 # start calibration (only for kv mode or prefill mode without kv cache) if self.mode == Mode.DECODE or not self.model_args.use_kv_cache: - self._calibrate( - model=self.decoder, - tokenizer=data.tokenizer, - event="prepare_pt2e", - user_calibration_data=data.calibration_data.datasets, - tok_embedding=self.tok_embedding, - intermediate_outputs=image_embedding, + original_threads = torch.get_num_threads() + torch.set_num_threads(calib_threads) + logging.info( + "Calibration using %d threads (was %d)", + calib_threads, + original_threads, ) + try: + self._calibrate( + model=self.decoder, + tokenizer=data.tokenizer, + event="prepare_pt2e", + user_calibration_data=data.calibration_data.datasets, + tok_embedding=self.tok_embedding, + intermediate_outputs=image_embedding, + ) + finally: + torch.set_num_threads(original_threads) else: # one dummy inference to remove affine observer # error happened in convert_pt2e