Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions graphgen/models/llm/local/vllm_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import math
import uuid
from typing import Any, List, Optional
import asyncio

from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token
Expand Down Expand Up @@ -43,6 +43,7 @@ def __init__(
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
self.timeout = float(timeout)
self.tokenizer = self.engine.engine.tokenizer.tokenizer
self.enable_thinking = kwargs.get("enable_thinking", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The base class BaseLLMWrapper's __init__ already sets any provided kwargs as attributes on the instance. This means self.enable_thinking is assigned twice if it's passed during initialization: once in super().__init__ and again on this line, which is redundant.

To avoid this, you can ensure the attribute is only set if it doesn't already exist after the superclass initialization.

Suggested change
self.enable_thinking = kwargs.get("enable_thinking", False)
if not hasattr(self, "enable_thinking"):
self.enable_thinking = False


def _build_inputs(self, prompt: str, history: Optional[List[dict]] = None) -> Any:
messages = history or []
Expand All @@ -51,7 +52,8 @@ def _build_inputs(self, prompt: str, history: Optional[List[dict]] = None) -> An
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
add_generation_prompt=True,
enable_thinking=self.enable_thinking,
)

async def _consume_generator(self, generator):
Expand All @@ -76,10 +78,11 @@ async def generate_answer(
)

try:
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
result_generator = self.engine.generate(
full_prompt, sp, request_id=request_id
)
final_output = await asyncio.wait_for(
self._consume_generator(result_generator),
timeout=self.timeout
self._consume_generator(result_generator), timeout=self.timeout
)

if not final_output or not final_output.outputs:
Expand All @@ -105,13 +108,13 @@ async def generate_topk_per_token(
)

try:
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
result_generator = self.engine.generate(
full_prompt, sp, request_id=request_id
)
final_output = await asyncio.wait_for(
self._consume_generator(result_generator),
timeout=self.timeout
self._consume_generator(result_generator), timeout=self.timeout
)


if (
not final_output
or not final_output.outputs
Expand All @@ -124,7 +127,9 @@ async def generate_topk_per_token(
candidate_tokens = []
for _, logprob_obj in top_logprobs.items():
tok_str = (
logprob_obj.decoded_token.strip() if logprob_obj.decoded_token else ""
logprob_obj.decoded_token.strip()
if logprob_obj.decoded_token
else ""
)
prob = float(math.exp(logprob_obj.logprob))
candidate_tokens.append(Token(tok_str, prob))
Expand Down