From 9691b81d36f5fc8c001543ac38e8f239db5612e0 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 10 Feb 2026 16:13:17 +0800 Subject: [PATCH] fix: enable and disable thinking for vllmwrapper --- graphgen/models/llm/local/vllm_wrapper.py | 25 ++++++++++++++--------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py index cafe6529..f9778183 100644 --- a/graphgen/models/llm/local/vllm_wrapper.py +++ b/graphgen/models/llm/local/vllm_wrapper.py @@ -1,7 +1,7 @@ +import asyncio import math import uuid from typing import Any, List, Optional -import asyncio from graphgen.bases.base_llm_wrapper import BaseLLMWrapper from graphgen.bases.datatypes import Token @@ -43,6 +43,7 @@ def __init__( self.engine = AsyncLLMEngine.from_engine_args(engine_args) self.timeout = float(timeout) self.tokenizer = self.engine.engine.tokenizer.tokenizer + self.enable_thinking = kwargs.get("enable_thinking", False) def _build_inputs(self, prompt: str, history: Optional[List[dict]] = None) -> Any: messages = history or [] @@ -51,7 +52,8 @@ def _build_inputs(self, prompt: str, history: Optional[List[dict]] = None) -> An return self.tokenizer.apply_chat_template( messages, tokenize=False, - add_generation_prompt=True + add_generation_prompt=True, + enable_thinking=self.enable_thinking, ) async def _consume_generator(self, generator): @@ -76,10 +78,11 @@ async def generate_answer( ) try: - result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) + result_generator = self.engine.generate( + full_prompt, sp, request_id=request_id + ) final_output = await asyncio.wait_for( - self._consume_generator(result_generator), - timeout=self.timeout + self._consume_generator(result_generator), timeout=self.timeout ) if not final_output or not final_output.outputs: @@ -105,13 +108,13 @@ async def generate_topk_per_token( ) try: - result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) + result_generator = self.engine.generate( + full_prompt, sp, request_id=request_id + ) final_output = await asyncio.wait_for( - self._consume_generator(result_generator), - timeout=self.timeout + self._consume_generator(result_generator), timeout=self.timeout ) - if ( not final_output or not final_output.outputs @@ -124,7 +127,9 @@ async def generate_topk_per_token( candidate_tokens = [] for _, logprob_obj in top_logprobs.items(): tok_str = ( - logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else "" + logprob_obj.decoded_token.strip() + if logprob_obj.decoded_token + else "" ) prob = float(math.exp(logprob_obj.logprob)) candidate_tokens.append(Token(tok_str, prob))