Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions graphgen/common/init_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def __init__(self, backend: str, config: Dict[str, Any]):
from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper

self.llm_instance = VLLMWrapper(**config)
elif backend == "ray_serve":
from graphgen.models.llm.api.ray_serve_client import RayServeClient

self.llm_instance = RayServeClient(**config)
else:
raise NotImplementedError(f"Backend {backend} is not implemented yet.")

Expand Down
88 changes: 88 additions & 0 deletions graphgen/models/llm/api/ray_serve_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Any, List, Optional

from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token


class RayServeClient(BaseLLMWrapper):
"""
A client to interact with a Ray Serve deployment.
"""

def __init__(
self,
*,
app_name: Optional[str] = None,
deployment_name: Optional[str] = None,
serve_backend: Optional[str] = None,
**kwargs: Any,
):
try:
from ray import serve
except ImportError as e:
raise ImportError(
"Ray is not installed. Please install it with `pip install ray[serve]`."
) from e

super().__init__(**kwargs)

# Try to get existing handle first
self.handle = None
if app_name:
try:
self.handle = serve.get_app_handle(app_name)
except Exception:
pass
elif deployment_name:
try:
self.handle = serve.get_deployment(deployment_name).get_handle()
except Exception:
pass
Comment on lines +32 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a generic Exception and silently passing can hide unexpected errors and make debugging difficult. It's better to catch a more specific exception. Both ray.serve.get_app_handle and ray.serve.get_deployment raise ray.serve.exceptions.RayServeException (or a subclass) when the app/deployment is not found. Consider catching this specific exception instead. You'll need to import it.

Suggested change
try:
self.handle = serve.get_app_handle(app_name)
except Exception:
pass
elif deployment_name:
try:
self.handle = serve.get_deployment(deployment_name).get_handle()
except Exception:
pass
try:
self.handle = serve.get_app_handle(app_name)
except serve.exceptions.RayServeException:
pass
elif deployment_name:
try:
self.handle = serve.get_deployment(deployment_name).get_handle()
except serve.exceptions.RayServeException:
pass


# If no handle found, try to deploy if serve_backend is provided
if self.handle is None:
if serve_backend:
if not app_name:
import uuid

app_name = f"llm_app_{serve_backend}_{uuid.uuid4().hex[:8]}"

print(
f"Deploying Ray Serve app '{app_name}' with backend '{serve_backend}'..."
)
Comment on lines +50 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using print() for logging in a library is generally discouraged as it offers no configuration for consumers of the library. Please use Python's logging module instead. This allows for log levels, different handlers, and formatting. You'll need to import logging at the top of the file.

Suggested change
print(
f"Deploying Ray Serve app '{app_name}' with backend '{serve_backend}'..."
)
logging.info(
f"Deploying Ray Serve app '{app_name}' with backend '{serve_backend}'..."
)

from graphgen.models.llm.local.ray_serve_deployment import LLMDeployment

# Filter kwargs to avoid passing unrelated args if necessary,
# but LLMDeployment config accepts everything for now.
# Note: We need to pass kwargs as the config dict.
deployment = LLMDeployment.bind(backend=serve_backend, config=kwargs)
serve.run(deployment, name=app_name, route_prefix=f"/{app_name}")
self.handle = serve.get_app_handle(app_name)
elif app_name or deployment_name:
raise ValueError(
f"Ray Serve app/deployment '{app_name or deployment_name}' "
"not found and 'serve_backend' not provided to deploy it."
)
else:
raise ValueError(
"Either 'app_name', 'deployment_name' or 'serve_backend' "
"must be provided for RayServeClient."
)

async def generate_answer(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> str:
"""Generate answer from the model."""
return await self.handle.generate_answer.remote(text, history, **extra)

async def generate_topk_per_token(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
"""Generate top-k tokens for the next token prediction."""
return await self.handle.generate_topk_per_token.remote(text, history, **extra)

async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
"""Generate probabilities for each token in the input."""
return await self.handle.generate_inputs_prob.remote(text, history, **extra)
84 changes: 84 additions & 0 deletions graphgen/models/llm/local/ray_serve_deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
from typing import Any, Dict, List, Optional

from ray import serve
from starlette.requests import Request

from graphgen.bases.datatypes import Token
from graphgen.models.tokenizer import Tokenizer


@serve.deployment
class LLMDeployment:
def __init__(self, backend: str, config: Dict[str, Any]):
self.backend = backend
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The config dictionary is modified in-place on line 20, which can cause unexpected side effects for the caller. It's safer to work on a copy. Please create a copy of the dictionary before it's potentially modified.

        self.backend = backend
        config = config.copy()


# Initialize tokenizer if needed
tokenizer_model = os.environ.get("TOKENIZER_MODEL", "cl100k_base")
if "tokenizer" not in config:
tokenizer = Tokenizer(model_name=tokenizer_model)
config["tokenizer"] = tokenizer

if backend == "vllm":
from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper

self.llm_instance = VLLMWrapper(**config)
elif backend == "huggingface":
from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper

self.llm_instance = HuggingFaceWrapper(**config)
elif backend == "sglang":
from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper

self.llm_instance = SGLangWrapper(**config)
else:
raise NotImplementedError(
f"Backend {backend} is not implemented for Ray Serve yet."
)

async def generate_answer(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> str:
return await self.llm_instance.generate_answer(text, history, **extra)

async def generate_topk_per_token(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
return await self.llm_instance.generate_topk_per_token(text, history, **extra)

async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
return await self.llm_instance.generate_inputs_prob(text, history, **extra)

async def __call__(self, request: Request) -> Dict:
try:
data = await request.json()
text = data.get("text")
history = data.get("history")
method = data.get("method", "generate_answer")
kwargs = data.get("kwargs", {})

if method == "generate_answer":
result = await self.generate_answer(text, history, **kwargs)
Comment on lines +58 to +63
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

This section is vulnerable to a high-severity prompt injection attack because the history parameter is used without validation, allowing attackers to manipulate the LLM's behavior. Additionally, the current if/elif/else chain for dispatching methods is cumbersome and could be refactored for better maintainability and extensibility. Implement strict validation for history and consider refactoring the method dispatch using getattr with an allowlist of methods.

            if method not in {"generate_answer", "generate_topk_per_token", "generate_inputs_prob"}:
                return {"error": f"Method {method} not supported"}

            method_to_call = getattr(self, method)
            result = await method_to_call(text, history, **kwargs)

elif method == "generate_topk_per_token":
result = await self.generate_topk_per_token(text, history, **kwargs)
elif method == "generate_inputs_prob":
result = await self.generate_inputs_prob(text, history, **kwargs)
else:
return {"error": f"Method {method} not supported"}

return {"result": result}
Comment on lines +54 to +71
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The __call__ method, which serves as the HTTP entry point for the Ray Serve deployment, lacks any authentication or authorization checks. When deployed, this endpoint may be exposed to the network, allowing unauthorized users to access the LLM service. This can lead to unauthorized resource consumption, potential abuse of the model, and associated costs.

except Exception as e:
return {"error": str(e)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

Returning raw exception messages (str(e)) directly to the user can lead to information leakage, potentially exposing sensitive internal details. While catching Exception is fine, ensure error messages are sanitized before being returned. Additionally, it's crucial to log the full exception traceback internally for debugging purposes.

        except Exception as e:
            import logging
            logging.exception("Error processing request")
            return {"error": "An internal server error occurred."}



def app_builder(args: Dict[str, str]) -> Any:
"""
Builder function for 'serve run'.
Usage: serve run graphgen.models.llm.local.ray_serve_deployment:app_builder backend=vllm model=...
"""
# args comes from the command line key=value pairs
backend = args.pop("backend", "vllm")
# remaining args are treated as config
return LLMDeployment.bind(backend=backend, config=args)
7 changes: 4 additions & 3 deletions graphgen/models/llm/local/vllm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def _build_inputs(self, prompt: str, history: Optional[List[dict]] = None) -> An
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
add_generation_prompt=True,
enable_thinking=False
)

async def _consume_generator(self, generator):
Expand All @@ -72,7 +73,7 @@ async def generate_answer(
temperature=self.temperature if self.temperature >= 0 else 1.0,
top_p=self.top_p if self.top_p >= 0 else 1.0,
max_tokens=extra.get("max_new_tokens", 2048),
repetition_penalty=extra.get("repetition_penalty", 1.05),
repetition_penalty=extra.get("repetition_penalty", 1.05)
)

try:
Expand Down Expand Up @@ -101,7 +102,7 @@ async def generate_topk_per_token(
sp = self.SamplingParams(
temperature=0,
max_tokens=1,
logprobs=self.top_k,
logprobs=self.top_k
)

try:
Expand Down