Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 192 additions & 13 deletions app/api/routers/generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from app.domain import (
Tags,
TagsGenerative,
OpenAIChatRequest,
OpenAIChatResponse,
OpenAIChatCompletionsRequest,
OpenAIChatCompletionsResponse,
OpenAICompletionsRequest,
OpenAICompletionsResponse,
OpenAIEmbeddingsRequest,
OpenAIEmbeddingsResponse,
PromptMessage,
Expand All @@ -29,8 +31,10 @@

PATH_GENERATE = "/generate"
PATH_GENERATE_ASYNC = "/stream/generate"
PATH_OPENAI_COMPLETIONS = "/v1/chat/completions"
PATH_OPENAI_EMBEDDINGS = "/v1/embeddings"
PATH_GENERATE_SSE = "/events/generate"
PATH_CHAT_COMPLETIONS = "/v1/chat/completions"
PATH_COMPLETIONS = "/v1/completions"
PATH_EMBEDDINGS = "/v1/embeddings"

router = APIRouter()
config = get_settings()
Expand All @@ -54,6 +58,7 @@ def generate_text(
temperature: Annotated[float, Query(description="The temperature of the generated text", ge=0.0)] = 0.7,
top_p: Annotated[float, Query(description="The Top-P value for nucleus sampling", ge=0.0, le=1.0)] = 0.9,
stop_sequences: Annotated[List[str], Query(description="The list of sequences used to stop the generation")] = [],
ensure_full_sentences: Annotated[bool, Query(description="Whether to generate full sentences only")] = False,
tracking_id: Union[str, None] = Depends(validate_tracking_id),
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
) -> PlainTextResponse:
Expand All @@ -67,7 +72,7 @@ def generate_text(
temperature (float): The temperature of the generated text.
top_p (float): The Top-P value for nucleus sampling.
stop_sequences (List[str]): The list of sequences used to stop the generation.
tracking_id (Union[str, None]): An optional tracking ID of the requested task.
ensure_full_sentences (bool): Whether to generate full sentences only.
model_service (AbstractModelService): The model service dependency.

Returns:
Expand All @@ -84,6 +89,7 @@ def generate_text(
top_p=top_p,
stop_sequences=stop_sequences,
report_tokens=partial(_send_usage_metrics, handler=PATH_GENERATE),
ensure_full_sentences=ensure_full_sentences,
),
headers={"x-cms-tracking-id": tracking_id},
status_code=HTTP_200_OK,
Expand All @@ -110,6 +116,7 @@ async def generate_text_stream(
temperature: Annotated[float, Query(description="The temperature of the generated text", ge=0.0)] = 0.7,
top_p: Annotated[float, Query(description="The Top-P value for nucleus sampling", ge=0.0, le=1.0)] = 0.9,
stop_sequences: Annotated[List[str], Query(description="The list of sequences used to stop the generation")] = [],
ensure_full_sentences: Annotated[bool, Query(description="Whether to generate full sentences only")] = False,
tracking_id: Union[str, None] = Depends(validate_tracking_id),
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
) -> StreamingResponse:
Expand All @@ -123,6 +130,7 @@ async def generate_text_stream(
temperature (float): The temperature of the generated text.
top_p (float): The Top-P value for nucleus sampling.
stop_sequences (List[str]): The list of sequences used to stop the generation.
ensure_full_sentences (bool): Whether to generate full sentences only.
tracking_id (Union[str, None]): An optional tracking ID of the requested task.
model_service (AbstractModelService): The model service dependency.

Expand All @@ -140,6 +148,7 @@ async def generate_text_stream(
top_p=top_p,
stop_sequences=stop_sequences,
report_tokens=partial(_send_usage_metrics, handler=PATH_GENERATE_ASYNC),
ensure_full_sentences=ensure_full_sentences,
),
media_type="text/event-stream",
headers={"x-cms-tracking-id": tracking_id},
Expand All @@ -155,17 +164,18 @@ async def generate_text_stream(


@router.post(
PATH_OPENAI_COMPLETIONS,
PATH_CHAT_COMPLETIONS,
tags=[Tags.OpenAICompatible],
response_model=None,
dependencies=[Depends(cms_globals.props.current_active_user)],
description="Generate chat response based on messages, similar to OpenAI's /v1/chat/completions",
)
def generate_chat_completions(
request: Request,
request_data: Annotated[OpenAIChatRequest, Body(
request_data: Annotated[OpenAIChatCompletionsRequest, Body(
description="OpenAI-like completion request", media_type="application/json"
)],
ensure_full_sentences: Annotated[bool, Query(description="Whether to generate full sentences only")] = False,
tracking_id: Union[str, None] = Depends(validate_tracking_id),
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
) -> Union[StreamingResponse, JSONResponse]:
Expand All @@ -175,6 +185,7 @@ def generate_chat_completions(
Args:
request (Request): The request object.
request_data (OpenAIChatRequest): The request data containing model, messages, stream, temperature, top_p, and stop_sequences.
ensure_full_sentences (bool): Whether to generate full sentences only.
tracking_id (Union[str, None]): An optional tracking ID of the requested task.
model_service (AbstractModelService): The model service dependency.

Expand Down Expand Up @@ -207,7 +218,14 @@ def generate_chat_completions(
headers={"x-cms-tracking-id": tracking_id},
)

async def _stream(prompt: str, max_tokens: int, temperature: float, top_p: float, stop_sequences: List[str]) -> AsyncGenerator:
async def _stream(
prompt: str,
max_tokens: int,
temperature: float,
top_p: float,
stop_sequences: List[str],
ensure_full_sentences: bool,
) -> AsyncGenerator:
data = {
"id": tracking_id,
"object": "chat.completion.chunk",
Expand All @@ -220,7 +238,8 @@ async def _stream(prompt: str, max_tokens: int, temperature: float, top_p: float
temperature=temperature,
top_p=top_p,
stop_sequences=stop_sequences,
report_tokens=partial(_send_usage_metrics, handler=PATH_OPENAI_COMPLETIONS)
report_tokens=partial(_send_usage_metrics, handler=PATH_CHAT_COMPLETIONS),
ensure_full_sentences=ensure_full_sentences,
):
data = {
"choices": [
Expand All @@ -237,20 +256,31 @@ async def _stream(prompt: str, max_tokens: int, temperature: float, top_p: float
prompt = get_prompt_from_messages(model_service.tokenizer, messages)
if stream:
return StreamingResponse(
_stream(prompt, max_tokens, temperature, top_p, stop_sequences or []),
_stream(prompt, max_tokens, temperature, top_p, stop_sequences or [], ensure_full_sentences),
media_type="text/event-stream",
headers={"x-cms-tracking-id": tracking_id},
)
else:
usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
def _report_tokens(prompt_token_num: int, completion_token_num: int) -> None:
usage["prompt_tokens"] = prompt_token_num
usage["completion_tokens"] = completion_token_num
usage["total_tokens"] = prompt_token_num + completion_token_num
_send_usage_metrics(
handler=PATH_CHAT_COMPLETIONS,
prompt_token_num=prompt_token_num,
completion_token_num=completion_token_num,
)
generated_text = model_service.generate(
prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop_sequences=stop_sequences or [],
send_metrics=partial(_send_usage_metrics, handler=PATH_OPENAI_COMPLETIONS),
report_tokens=_report_tokens,
ensure_full_sentences=ensure_full_sentences,
)
completion = OpenAIChatResponse(
completion = OpenAIChatCompletionsResponse(
id=tracking_id,
object="chat.completion",
created=int(time.time()),
Expand All @@ -265,12 +295,161 @@ async def _stream(prompt: str, max_tokens: int, temperature: float, top_p: float
"finish_reason": "stop",
}
],
usage=usage,
)
return JSONResponse(content=jsonable_encoder(completion), headers={"x-cms-tracking-id": tracking_id})


@router.post(
PATH_OPENAI_EMBEDDINGS,
PATH_COMPLETIONS,
tags=[Tags.OpenAICompatible],
response_model=None,
dependencies=[Depends(cms_globals.props.current_active_user)],
description="Generate completion based on prompt, similar to OpenAI's /v1/completions",
)
def generate_text_completions(
request: Request,
request_data: Annotated[OpenAICompletionsRequest, Body(
description="OpenAI-like completion request", media_type="application/json"
)],
ensure_full_sentences: Annotated[bool, Query(description="Whether to generate full sentences only")] = False,
tracking_id: Union[str, None] = Depends(validate_tracking_id),
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
) -> Union[StreamingResponse, JSONResponse]:
"""
Generates completion response based on prompt, mimicking OpenAI's /v1/completions endpoint.

Args:
request (Request): The request object.
request_data (OpenAICompletionsRequest): The request data containing model, prompt, stream, temperature, top_p, and stop.
ensure_full_sentences (bool): Whether to generate full sentences only.
tracking_id (Union[str, None]): An optional tracking ID of the requested task.
model_service (AbstractModelService): The model service dependency.

Returns:
StreamingResponse: An OpenAI-like streaming response.
JSONResponse: A response containing the generated text or an error message.
"""

tracking_id = tracking_id or str(uuid.uuid4())
model = model_service.model_name if request_data.model != model_service.model_name else request_data.model
stream = request_data.stream
max_tokens = request_data.max_tokens
temperature = request_data.temperature
top_p = request_data.top_p
stop = request_data.stop

if isinstance(stop, str):
stop_sequences = [stop]
elif isinstance(stop, list):
stop_sequences = stop
else:
stop_sequences = []

if isinstance(request_data.prompt, str):
prompt = request_data.prompt
else:
prompt = "\n".join(request_data.prompt)

if not prompt:
error_response = {
"error": {
"message": "No prompt provided",
"type": "invalid_request_error",
"param": "prompt",
"code": "missing_field",
}
}
return JSONResponse(
content=error_response,
status_code=HTTP_400_BAD_REQUEST,
headers={"x-cms-tracking-id": tracking_id},
)

async def _stream(
prompt: str,
max_tokens: int,
temperature: float,
top_p: float,
stop_sequences: List[str],
ensure_full_sentences: bool,
) -> AsyncGenerator:
data = {
"id": tracking_id,
"object": "text_completion",
"choices": [{"text": "", "index": 0, "logprobs": None, "finish_reason": None}],
}
yield f"data: {json.dumps(data)}\n\n"
async for chunk in model_service.generate_async(
prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop_sequences=stop_sequences,
report_tokens=partial(_send_usage_metrics, handler=PATH_COMPLETIONS),
ensure_full_sentences=ensure_full_sentences,
):
data = {
"object": "text_completion",
"choices": [
{
"text": chunk,
"index": 0,
"logprobs": None,
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(data)}\n\n"
yield "data: [DONE]\n\n"

if stream:
return StreamingResponse(
_stream(prompt, max_tokens, temperature, top_p, stop_sequences, ensure_full_sentences),
media_type="text/event-stream",
headers={"x-cms-tracking-id": tracking_id},
)

usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
def _report_tokens(prompt_token_num: int, completion_token_num: int) -> None:
usage["prompt_tokens"] = prompt_token_num
usage["completion_tokens"] = completion_token_num
usage["total_tokens"] = prompt_token_num + completion_token_num
_send_usage_metrics(
handler=PATH_COMPLETIONS,
prompt_token_num=prompt_token_num,
completion_token_num=completion_token_num,
)
generated_text = model_service.generate(
prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop_sequences=stop_sequences,
report_tokens=_report_tokens,
ensure_full_sentences=ensure_full_sentences,
)

completion = OpenAICompletionsResponse(
id=tracking_id,
object="text_completion",
created=int(time.time()),
model=model,
choices=[
{
"index": 0,
"text": generated_text,
"logprobs": None,
"finish_reason": "stop",
}
],
usage=usage,
)
return JSONResponse(content=jsonable_encoder(completion), headers={"x-cms-tracking-id": tracking_id})


@router.post(
PATH_EMBEDDINGS,
tags=[Tags.OpenAICompatible],
response_model=None,
dependencies=[Depends(cms_globals.props.current_active_user)],
Expand Down
33 changes: 31 additions & 2 deletions app/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class PromptMessage(BaseModel):
content: str = Field(description="The actual text of the message")


class OpenAIChatRequest(BaseModel):
class OpenAIChatCompletionsRequest(BaseModel):
messages: List[PromptMessage] = Field(..., description="A list of messages to be sent to the model")
stream: bool = Field(..., description="Whether to stream the response")
max_tokens: int = Field(512, description="The maximum number of tokens to generate", gt=0)
Expand All @@ -221,12 +221,41 @@ class OpenAIChatRequest(BaseModel):
stop_sequences: Optional[List[str]] = Field(default=None, description="The list of sequences used to stop the generation")


class OpenAIChatResponse(BaseModel):
class OpenAIChatCompletionsResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the chat completion request")
object: str = Field(..., description="The type of the response")
created: int = Field(..., description="The timestamp when the completion was generated")
model: str = Field(..., description="The name of the model used for generating the completion")
choices: List = Field(..., description="The generated messages and their metadata")
usage: Optional[Dict[str, int]] = Field(
default=None,
description="Token usage information",
)


class OpenAICompletionsRequest(BaseModel):
prompt: Union[str, List[str]] = Field(..., description="Prompt text or list of prompts")
stream: bool = Field(False, description="Whether to stream the response")
max_tokens: int = Field(512, description="The maximum number of tokens to generate", gt=0)
model: str = Field(..., description="The name of the model used for generating the completion")
temperature: float = Field(0.7, description="The temperature of the generated text", ge=0.0, le=1.0)
top_p: float = Field(0.9, description="The top-p value for nucleus sampling", ge=0.0, le=1.0)
stop: Optional[Union[str, List[str]]] = Field(
default=None,
description="The list of sequences used to stop the generation",
)


class OpenAICompletionsResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the completion request")
object: str = Field(..., description="The type of the response")
created: int = Field(..., description="The timestamp when the completion was generated")
model: str = Field(..., description="The name of the model used for generating the completion")
choices: List = Field(..., description="The generated texts and their metadata")
usage: Optional[Dict[str, int]] = Field(
default=None,
description="Token usage information",
)


class OpenAIEmbeddingsRequest(BaseModel):
Expand Down
Loading