From b0c845e1ec69605c50a4361b9be89d6a2a8e900e Mon Sep 17 00:00:00 2001 From: Xi Bai Date: Wed, 11 Feb 2026 17:31:07 +0000 Subject: [PATCH] feat: add the endpoint for legacy v1 completion feat: auto-utilise the local chat template if detected feat: add the option to generate full sentences feat: add the option for local 8bit quantisation feat: add the gpt oss chat template fix: skip quantisation if the model being loaded is already quantised --- app/api/routers/generative.py | 205 +++++++- app/domain.py | 33 +- app/model_services/huggingface_llm_model.py | 232 +++++++-- app/processors/prompt_factory.py | 481 +++++++++++++++++- app/trainers/huggingface_llm_trainer.py | 13 +- app/utils.py | 20 +- pyproject.toml | 2 +- tests/app/api/test_serving_hf_llm.py | 94 ++++ .../test_huggingface_llm_model.py | 197 +++++-- tests/app/processors/test_prompt_factory.py | 1 + tests/app/test_utils.py | 26 +- tests/app/trainers/test_hf_llm_trainer.py | 22 +- 12 files changed, 1185 insertions(+), 141 deletions(-) diff --git a/app/api/routers/generative.py b/app/api/routers/generative.py index 761fb51..fa3db63 100644 --- a/app/api/routers/generative.py +++ b/app/api/routers/generative.py @@ -14,8 +14,10 @@ from app.domain import ( Tags, TagsGenerative, - OpenAIChatRequest, - OpenAIChatResponse, + OpenAIChatCompletionsRequest, + OpenAIChatCompletionsResponse, + OpenAICompletionsRequest, + OpenAICompletionsResponse, OpenAIEmbeddingsRequest, OpenAIEmbeddingsResponse, PromptMessage, @@ -29,8 +31,10 @@ PATH_GENERATE = "/generate" PATH_GENERATE_ASYNC = "/stream/generate" -PATH_OPENAI_COMPLETIONS = "/v1/chat/completions" -PATH_OPENAI_EMBEDDINGS = "/v1/embeddings" +PATH_GENERATE_SSE = "/events/generate" +PATH_CHAT_COMPLETIONS = "/v1/chat/completions" +PATH_COMPLETIONS = "/v1/completions" +PATH_EMBEDDINGS = "/v1/embeddings" router = APIRouter() config = get_settings() @@ -54,6 +58,7 @@ def generate_text( temperature: Annotated[float, Query(description="The temperature of the generated text", ge=0.0)] = 0.7, top_p: Annotated[float, Query(description="The Top-P value for nucleus sampling", ge=0.0, le=1.0)] = 0.9, stop_sequences: Annotated[List[str], Query(description="The list of sequences used to stop the generation")] = [], + ensure_full_sentences: Annotated[bool, Query(description="Whether to generate full sentences only")] = False, tracking_id: Union[str, None] = Depends(validate_tracking_id), model_service: AbstractModelService = Depends(cms_globals.model_service_dep) ) -> PlainTextResponse: @@ -67,7 +72,7 @@ def generate_text( temperature (float): The temperature of the generated text. top_p (float): The Top-P value for nucleus sampling. stop_sequences (List[str]): The list of sequences used to stop the generation. - tracking_id (Union[str, None]): An optional tracking ID of the requested task. + ensure_full_sentences (bool): Whether to generate full sentences only. model_service (AbstractModelService): The model service dependency. Returns: @@ -84,6 +89,7 @@ def generate_text( top_p=top_p, stop_sequences=stop_sequences, report_tokens=partial(_send_usage_metrics, handler=PATH_GENERATE), + ensure_full_sentences=ensure_full_sentences, ), headers={"x-cms-tracking-id": tracking_id}, status_code=HTTP_200_OK, @@ -110,6 +116,7 @@ async def generate_text_stream( temperature: Annotated[float, Query(description="The temperature of the generated text", ge=0.0)] = 0.7, top_p: Annotated[float, Query(description="The Top-P value for nucleus sampling", ge=0.0, le=1.0)] = 0.9, stop_sequences: Annotated[List[str], Query(description="The list of sequences used to stop the generation")] = [], + ensure_full_sentences: Annotated[bool, Query(description="Whether to generate full sentences only")] = False, tracking_id: Union[str, None] = Depends(validate_tracking_id), model_service: AbstractModelService = Depends(cms_globals.model_service_dep) ) -> StreamingResponse: @@ -123,6 +130,7 @@ async def generate_text_stream( temperature (float): The temperature of the generated text. top_p (float): The Top-P value for nucleus sampling. stop_sequences (List[str]): The list of sequences used to stop the generation. + ensure_full_sentences (bool): Whether to generate full sentences only. tracking_id (Union[str, None]): An optional tracking ID of the requested task. model_service (AbstractModelService): The model service dependency. @@ -140,6 +148,7 @@ async def generate_text_stream( top_p=top_p, stop_sequences=stop_sequences, report_tokens=partial(_send_usage_metrics, handler=PATH_GENERATE_ASYNC), + ensure_full_sentences=ensure_full_sentences, ), media_type="text/event-stream", headers={"x-cms-tracking-id": tracking_id}, @@ -155,7 +164,7 @@ async def generate_text_stream( @router.post( - PATH_OPENAI_COMPLETIONS, + PATH_CHAT_COMPLETIONS, tags=[Tags.OpenAICompatible], response_model=None, dependencies=[Depends(cms_globals.props.current_active_user)], @@ -163,9 +172,10 @@ async def generate_text_stream( ) def generate_chat_completions( request: Request, - request_data: Annotated[OpenAIChatRequest, Body( + request_data: Annotated[OpenAIChatCompletionsRequest, Body( description="OpenAI-like completion request", media_type="application/json" )], + ensure_full_sentences: Annotated[bool, Query(description="Whether to generate full sentences only")] = False, tracking_id: Union[str, None] = Depends(validate_tracking_id), model_service: AbstractModelService = Depends(cms_globals.model_service_dep) ) -> Union[StreamingResponse, JSONResponse]: @@ -175,6 +185,7 @@ def generate_chat_completions( Args: request (Request): The request object. request_data (OpenAIChatRequest): The request data containing model, messages, stream, temperature, top_p, and stop_sequences. + ensure_full_sentences (bool): Whether to generate full sentences only. tracking_id (Union[str, None]): An optional tracking ID of the requested task. model_service (AbstractModelService): The model service dependency. @@ -207,7 +218,14 @@ def generate_chat_completions( headers={"x-cms-tracking-id": tracking_id}, ) - async def _stream(prompt: str, max_tokens: int, temperature: float, top_p: float, stop_sequences: List[str]) -> AsyncGenerator: + async def _stream( + prompt: str, + max_tokens: int, + temperature: float, + top_p: float, + stop_sequences: List[str], + ensure_full_sentences: bool, + ) -> AsyncGenerator: data = { "id": tracking_id, "object": "chat.completion.chunk", @@ -220,7 +238,8 @@ async def _stream(prompt: str, max_tokens: int, temperature: float, top_p: float temperature=temperature, top_p=top_p, stop_sequences=stop_sequences, - report_tokens=partial(_send_usage_metrics, handler=PATH_OPENAI_COMPLETIONS) + report_tokens=partial(_send_usage_metrics, handler=PATH_CHAT_COMPLETIONS), + ensure_full_sentences=ensure_full_sentences, ): data = { "choices": [ @@ -237,20 +256,31 @@ async def _stream(prompt: str, max_tokens: int, temperature: float, top_p: float prompt = get_prompt_from_messages(model_service.tokenizer, messages) if stream: return StreamingResponse( - _stream(prompt, max_tokens, temperature, top_p, stop_sequences or []), + _stream(prompt, max_tokens, temperature, top_p, stop_sequences or [], ensure_full_sentences), media_type="text/event-stream", headers={"x-cms-tracking-id": tracking_id}, ) else: + usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + def _report_tokens(prompt_token_num: int, completion_token_num: int) -> None: + usage["prompt_tokens"] = prompt_token_num + usage["completion_tokens"] = completion_token_num + usage["total_tokens"] = prompt_token_num + completion_token_num + _send_usage_metrics( + handler=PATH_CHAT_COMPLETIONS, + prompt_token_num=prompt_token_num, + completion_token_num=completion_token_num, + ) generated_text = model_service.generate( prompt, max_tokens=max_tokens, temperature=temperature, top_p=top_p, stop_sequences=stop_sequences or [], - send_metrics=partial(_send_usage_metrics, handler=PATH_OPENAI_COMPLETIONS), + report_tokens=_report_tokens, + ensure_full_sentences=ensure_full_sentences, ) - completion = OpenAIChatResponse( + completion = OpenAIChatCompletionsResponse( id=tracking_id, object="chat.completion", created=int(time.time()), @@ -265,12 +295,161 @@ async def _stream(prompt: str, max_tokens: int, temperature: float, top_p: float "finish_reason": "stop", } ], + usage=usage, ) return JSONResponse(content=jsonable_encoder(completion), headers={"x-cms-tracking-id": tracking_id}) @router.post( - PATH_OPENAI_EMBEDDINGS, + PATH_COMPLETIONS, + tags=[Tags.OpenAICompatible], + response_model=None, + dependencies=[Depends(cms_globals.props.current_active_user)], + description="Generate completion based on prompt, similar to OpenAI's /v1/completions", +) +def generate_text_completions( + request: Request, + request_data: Annotated[OpenAICompletionsRequest, Body( + description="OpenAI-like completion request", media_type="application/json" + )], + ensure_full_sentences: Annotated[bool, Query(description="Whether to generate full sentences only")] = False, + tracking_id: Union[str, None] = Depends(validate_tracking_id), + model_service: AbstractModelService = Depends(cms_globals.model_service_dep) +) -> Union[StreamingResponse, JSONResponse]: + """ + Generates completion response based on prompt, mimicking OpenAI's /v1/completions endpoint. + + Args: + request (Request): The request object. + request_data (OpenAICompletionsRequest): The request data containing model, prompt, stream, temperature, top_p, and stop. + ensure_full_sentences (bool): Whether to generate full sentences only. + tracking_id (Union[str, None]): An optional tracking ID of the requested task. + model_service (AbstractModelService): The model service dependency. + + Returns: + StreamingResponse: An OpenAI-like streaming response. + JSONResponse: A response containing the generated text or an error message. + """ + + tracking_id = tracking_id or str(uuid.uuid4()) + model = model_service.model_name if request_data.model != model_service.model_name else request_data.model + stream = request_data.stream + max_tokens = request_data.max_tokens + temperature = request_data.temperature + top_p = request_data.top_p + stop = request_data.stop + + if isinstance(stop, str): + stop_sequences = [stop] + elif isinstance(stop, list): + stop_sequences = stop + else: + stop_sequences = [] + + if isinstance(request_data.prompt, str): + prompt = request_data.prompt + else: + prompt = "\n".join(request_data.prompt) + + if not prompt: + error_response = { + "error": { + "message": "No prompt provided", + "type": "invalid_request_error", + "param": "prompt", + "code": "missing_field", + } + } + return JSONResponse( + content=error_response, + status_code=HTTP_400_BAD_REQUEST, + headers={"x-cms-tracking-id": tracking_id}, + ) + + async def _stream( + prompt: str, + max_tokens: int, + temperature: float, + top_p: float, + stop_sequences: List[str], + ensure_full_sentences: bool, + ) -> AsyncGenerator: + data = { + "id": tracking_id, + "object": "text_completion", + "choices": [{"text": "", "index": 0, "logprobs": None, "finish_reason": None}], + } + yield f"data: {json.dumps(data)}\n\n" + async for chunk in model_service.generate_async( + prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop_sequences=stop_sequences, + report_tokens=partial(_send_usage_metrics, handler=PATH_COMPLETIONS), + ensure_full_sentences=ensure_full_sentences, + ): + data = { + "object": "text_completion", + "choices": [ + { + "text": chunk, + "index": 0, + "logprobs": None, + "finish_reason": None, + } + ], + } + yield f"data: {json.dumps(data)}\n\n" + yield "data: [DONE]\n\n" + + if stream: + return StreamingResponse( + _stream(prompt, max_tokens, temperature, top_p, stop_sequences, ensure_full_sentences), + media_type="text/event-stream", + headers={"x-cms-tracking-id": tracking_id}, + ) + + usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + def _report_tokens(prompt_token_num: int, completion_token_num: int) -> None: + usage["prompt_tokens"] = prompt_token_num + usage["completion_tokens"] = completion_token_num + usage["total_tokens"] = prompt_token_num + completion_token_num + _send_usage_metrics( + handler=PATH_COMPLETIONS, + prompt_token_num=prompt_token_num, + completion_token_num=completion_token_num, + ) + generated_text = model_service.generate( + prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop_sequences=stop_sequences, + report_tokens=_report_tokens, + ensure_full_sentences=ensure_full_sentences, + ) + + completion = OpenAICompletionsResponse( + id=tracking_id, + object="text_completion", + created=int(time.time()), + model=model, + choices=[ + { + "index": 0, + "text": generated_text, + "logprobs": None, + "finish_reason": "stop", + } + ], + usage=usage, + ) + return JSONResponse(content=jsonable_encoder(completion), headers={"x-cms-tracking-id": tracking_id}) + + +@router.post( + PATH_EMBEDDINGS, tags=[Tags.OpenAICompatible], response_model=None, dependencies=[Depends(cms_globals.props.current_active_user)], diff --git a/app/domain.py b/app/domain.py index 655eafe..21a565f 100644 --- a/app/domain.py +++ b/app/domain.py @@ -211,7 +211,7 @@ class PromptMessage(BaseModel): content: str = Field(description="The actual text of the message") -class OpenAIChatRequest(BaseModel): +class OpenAIChatCompletionsRequest(BaseModel): messages: List[PromptMessage] = Field(..., description="A list of messages to be sent to the model") stream: bool = Field(..., description="Whether to stream the response") max_tokens: int = Field(512, description="The maximum number of tokens to generate", gt=0) @@ -221,12 +221,41 @@ class OpenAIChatRequest(BaseModel): stop_sequences: Optional[List[str]] = Field(default=None, description="The list of sequences used to stop the generation") -class OpenAIChatResponse(BaseModel): +class OpenAIChatCompletionsResponse(BaseModel): id: str = Field(..., description="The unique identifier for the chat completion request") object: str = Field(..., description="The type of the response") created: int = Field(..., description="The timestamp when the completion was generated") model: str = Field(..., description="The name of the model used for generating the completion") choices: List = Field(..., description="The generated messages and their metadata") + usage: Optional[Dict[str, int]] = Field( + default=None, + description="Token usage information", + ) + + +class OpenAICompletionsRequest(BaseModel): + prompt: Union[str, List[str]] = Field(..., description="Prompt text or list of prompts") + stream: bool = Field(False, description="Whether to stream the response") + max_tokens: int = Field(512, description="The maximum number of tokens to generate", gt=0) + model: str = Field(..., description="The name of the model used for generating the completion") + temperature: float = Field(0.7, description="The temperature of the generated text", ge=0.0, le=1.0) + top_p: float = Field(0.9, description="The top-p value for nucleus sampling", ge=0.0, le=1.0) + stop: Optional[Union[str, List[str]]] = Field( + default=None, + description="The list of sequences used to stop the generation", + ) + + +class OpenAICompletionsResponse(BaseModel): + id: str = Field(..., description="The unique identifier for the completion request") + object: str = Field(..., description="The type of the response") + created: int = Field(..., description="The timestamp when the completion was generated") + model: str = Field(..., description="The name of the model used for generating the completion") + choices: List = Field(..., description="The generated texts and their metadata") + usage: Optional[Dict[str, int]] = Field( + default=None, + description="Token usage information", + ) class OpenAIEmbeddingsRequest(BaseModel): diff --git a/app/model_services/huggingface_llm_model.py b/app/model_services/huggingface_llm_model.py index 4eafe8b..bb08d2b 100644 --- a/app/model_services/huggingface_llm_model.py +++ b/app/model_services/huggingface_llm_model.py @@ -7,6 +7,7 @@ from transformers import ( AutoModelForCausalLM, AutoTokenizer, + AutoConfig, PreTrainedModel, PreTrainedTokenizerBase, TextIteratorStreamer, @@ -24,6 +25,8 @@ unpack_model_data_package, ensure_tensor_contiguity, get_model_data_package_base_name, + get_default_chat_template, + utilise_local_chat_template, ) logger = logging.getLogger("cms") @@ -61,8 +64,9 @@ def __init__( self._whitelisted_tuis = set([tui.strip() for tui in config.TYPE_UNIQUE_ID_WHITELIST.split(",")]) self._multi_label_threshold = 0.5 self._text_generator = ThreadPoolExecutor(max_workers=50) + self._sentence_endings = ".。!!??::;;\n" self.model_name = model_name or "HuggingFace LLM model" - self.is_4bit_quantised = False + self.is_quantised = False @property def model(self) -> PreTrainedModel: @@ -130,6 +134,7 @@ def load_model( model_file_path: str, *args: Tuple, load_in_4bit: bool = False, + load_in_8bit: bool = False, **kwargs: Dict[str, Any] ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: """ @@ -139,6 +144,7 @@ def load_model( model_file_path (str): The path to the model package file. *args (Tuple): Additional positional arguments. load_in_4bit (bool): Whether to load the model in 4-bit precision. Defaults to False. + load_in_8bit (bool): Whether to load the model in 8-bit precision. Defaults to False. **kwargs (Dict[str, Any]): Additional keyword arguments. Returns: @@ -151,26 +157,49 @@ def load_model( model_path = os.path.join(os.path.dirname(model_file_path), get_model_data_package_base_name(model_file_path)) if unpack_model_data_package(model_file_path, model_path): try: - if load_in_4bit: - bnb_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=torch.bfloat16, - bnb_4bit_use_double_quant=True, - ) - if get_settings().DEVICE == Device.DEFAULT.value: - model = AutoModelForCausalLM.from_pretrained( - model_path, - quantization_config=bnb_config, - device_map="auto", - ) - else: - model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=bnb_config) - else: + config = AutoConfig.from_pretrained(model_path) + + if "quantization_config" in config.to_dict(): + logger.info("Model already quantised, loading by ignoring 'load_in_4bit' or 'load_in_8bit' flag") if get_settings().DEVICE == Device.DEFAULT.value: model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") else: model = AutoModelForCausalLM.from_pretrained(model_path) + else: + if load_in_4bit: + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + ) + if get_settings().DEVICE == Device.DEFAULT.value: + model = AutoModelForCausalLM.from_pretrained( + model_path, + quantization_config=bnb_config, + device_map="auto", + ) + else: + model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=bnb_config) + elif load_in_8bit: + bnb_config = BitsAndBytesConfig( + load_in_8bit=True, + llm_int8_threshold=6.0, + llm_int8_enable_fp32_cpu_offload=False + ) + if get_settings().DEVICE == Device.DEFAULT.value: + model = AutoModelForCausalLM.from_pretrained( + model_path, + quantization_config=bnb_config, + device_map="auto", + ) + else: + model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=bnb_config) + else: + if get_settings().DEVICE == Device.DEFAULT.value: + model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto") + else: + model = AutoModelForCausalLM.from_pretrained(model_path) ensure_tensor_contiguity(model) tokenizer = AutoTokenizer.from_pretrained( model_path, @@ -185,11 +214,17 @@ def load_model( else: raise ConfigurationException(f"Model package archive format is not supported: {model_file_path}") - def init_model(self, load_in_4bit: bool = False, *args: Any, **kwargs: Any) -> None: + def init_model(self, + load_in_4bit: bool = False, + load_in_8bit: bool = False, + *args: Any, + **kwargs: Any, + ) -> None: """Initialises the HuggingFace model and its tokenizer based on the configuration. Args: load_in_4bit (bool): Whether to load the model in 4-bit precision. Defaults to False. + load_in_8bit (bool): Whether to load the model in 8-bit precision. Defaults to False. *args (Any): Additional positional arguments to be passed to this method. **kwargs (Any): Additional keyword arguments to be passed to this method. """ @@ -202,13 +237,20 @@ def init_model(self, load_in_4bit: bool = False, *args: Any, **kwargs: Any) -> N ]): logger.warning("Model service is already initialised and can be initialised only once") else: - self._model, self._tokenizer = self.load_model(self._model_pack_path, load_in_4bit=load_in_4bit) - if non_default_device_is_available(get_settings().DEVICE): + self._model, self._tokenizer = self.load_model( + self._model_pack_path, load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit + ) + + if (non_default_device_is_available(get_settings().DEVICE) and + not ( + getattr(self._model, "is_loaded_in_8bit", False) or + getattr(self._model, "is_loaded_in_4bit", False) + ) + ): self._model.to(get_settings().DEVICE) if self._enable_trainer: self._supervised_trainer = HuggingFaceLlmSupervisedTrainer(self) self._unsupervised_trainer = HuggingFaceLlmUnsupervisedTrainer(self) - self.is_4bit_quantised = load_in_4bit def info(self) -> ModelCard: """ @@ -233,28 +275,28 @@ def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]: def generate( self, prompt: str, - min_tokens: int = 100, + min_tokens: int = 64, max_tokens: int = 512, - num_beams: int = 5, + num_beams: int = 1, temperature: float = 0.7, top_p: float = 0.9, stop_sequences: Optional[List[str]] = None, report_tokens: Optional[Callable[[str], None]] = None, - **kwargs: Any + ensure_full_sentences: bool = False, ) -> str: """ Generates text based on the prompt. Args: prompt (str): The prompt for the text generation - min_tokens (int): The minimum number of tokens to generate. Defaults to 100. + min_tokens (int): The minimum number of tokens to generate. Defaults to 64. max_tokens (int): The maximum number of tokens to generate. Defaults to 512. - num_beams (int): The number of beams for beam search. Defaults to 5. + num_beams (int): The number of beams for beam search. Defaults to 1. temperature (float): The temperature for the text generation. Defaults to 0.7. top_p (float): The Top-P value for nucleus sampling. Defaults to 0.9. stop_sequences (Optional[List[str]]): List of strings that will stop generation when encountered. Defaults to None. report_tokens (Optional[Callable[[str], None]]): The callback function to send metrics. Defaults to None. - **kwargs (Any): Additional keyword arguments to be passed to this method. + ensure_full_sentences (bool): Whether to generate full sentences only. Defaults to False. Returns: Any: The string containing the generated text. @@ -262,16 +304,32 @@ def generate( self.model.eval() - inputs = self.tokenizer(prompt, add_special_tokens=False, return_tensors="pt") + if hasattr(self.tokenizer, "chat_template") and self.tokenizer.chat_template is None: + logger.warning("The tokenizer does not have a chat template. Using the default one.") + self.tokenizer.chat_template = get_default_chat_template() + else: + if utilise_local_chat_template(self.model.config.model_type, self.tokenizer): + logger.debug("Chat template overwritten by the prompt factory for %s", self.model.config.model_type) + else: + logger.debug(f"Found a chat template in the tokenizer:\n {self.tokenizer.chat_template}") + + prompt_text = self.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + ) + inputs = self.tokenizer(prompt_text, add_special_tokens=False, return_tensors="pt") inputs.to(self.model.device) + max_tokens = max(min_tokens, max_tokens) generation_kwargs = dict( inputs=inputs.input_ids, attention_mask=inputs.attention_mask, min_new_tokens=min_tokens, max_new_tokens=max_tokens, + use_cache=True, num_beams=num_beams, - do_sample=True, + do_sample=(num_beams == 1), temperature=temperature, top_p=top_p, repetition_penalty=1.2, @@ -279,7 +337,9 @@ def generate( ) outputs = self.model.generate(**generation_kwargs) - generated_text = self.tokenizer.decode(outputs[0], skip_prompt=True, skip_special_tokens=True) + prompt_len = inputs.input_ids.shape[-1] + completion_ids = outputs[0][prompt_len:] + generated_text = self.tokenizer.decode(completion_ids, skip_special_tokens=True) if stop_sequences: for stop_seq in stop_sequences: @@ -287,12 +347,25 @@ def generate( generated_text = generated_text.split(stop_seq)[0] break + if ensure_full_sentences and generated_text and generated_text[-1] not in self._sentence_endings: + last_pos = -1 + for ending in self._sentence_endings: + pos = generated_text.rfind(ending) + if pos > last_pos: + last_pos = pos + if last_pos != -1: + generated_text = generated_text[:last_pos + 1] + logger.debug("Response generation completed") if report_tokens: report_tokens( - prompt_token_num=inputs.input_ids.shape[-1], # type: ignore - completion_token_num=outputs[0].shape[-1], # type: ignore + prompt_token_num=prompt_len, # type: ignore + completion_token_num=self.tokenizer( # type: ignore + generated_text, + add_special_tokens=False, + return_tensors="pt" + ).input_ids.shape[-1], ) return generated_text @@ -300,12 +373,14 @@ def generate( async def generate_async( self, prompt: str, + min_tokens: int = 64, max_tokens: int = 512, + num_beams: int = 1, temperature: float = 0.7, top_p: float = 0.9, stop_sequences: Optional[List[str]] = None, report_tokens: Optional[Callable[[str], None]] = None, - **kwargs: Any + ensure_full_sentences: bool = False, ) -> AsyncIterable: """ Asynchronously generates text stream based on the prompt. @@ -313,11 +388,13 @@ async def generate_async( Args: prompt (str): The prompt for the text generation. max_tokens (int): The maximum number of tokens to generate. Defaults to 512. + min_tokens (int): The minimum number of tokens to generate. Defaults to 64. + num_beams (int): The number of beams for beam search. Defaults to 1. temperature (float): The temperature for the text generation. Defaults to 0.7. top_p (float): The Top-P value for nucleus sampling. Defaults to 0.9. stop_sequences (Optional[List[str]]): List of strings that will stop generation when encountered. Defaults to None. report_tokens (Optional[Callable[[str], None]]): The callback function to send metrics. Defaults to None. - **kwargs (Any): Additional keyword arguments to be passed to the model loader. + ensure_full_sentences (bool): Whether to generate full sentences only. Defaults to False. Returns: AsyncIterable: The stream containing the generated text. @@ -325,7 +402,21 @@ async def generate_async( self.model.eval() - inputs = self.tokenizer(prompt, add_special_tokens=False, return_tensors="pt") + if hasattr(self.tokenizer, "chat_template") and self.tokenizer.chat_template is None: + logger.warning("The tokenizer does not have a chat template. Using the default one.") + self.tokenizer.chat_template = get_default_chat_template() + else: + if utilise_local_chat_template(self.model.config.model_type, self.tokenizer): + logger.debug("Chat template overwritten by the prompt factory for %s", self.model.config.model_type) + else: + logger.debug(f"Found a chat template in the tokenizer:\n {self.tokenizer.chat_template}") + + prompt_text = self.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + ) + inputs = self.tokenizer(prompt_text, add_special_tokens=False, return_tensors="pt") inputs.to(self.model.device) streamer = TextIteratorStreamer( @@ -333,12 +424,16 @@ async def generate_async( skip_prompt=True, skip_special_tokens=True ) + max_tokens = max(min_tokens, max_tokens) generation_kwargs = dict( inputs=inputs.input_ids, attention_mask=inputs.attention_mask, streamer=streamer, + min_new_tokens=min_tokens, max_new_tokens=max_tokens, - do_sample=True, + use_cache=True, + num_beams=num_beams, + do_sample=(num_beams == 1), temperature=temperature, top_p=top_p, repetition_penalty=1.2, @@ -347,24 +442,59 @@ async def generate_async( try: _ = self._text_generator.submit(self.model.generate, **generation_kwargs) - output = "" - for content in streamer: - prev_output = output - output += content - if stop_sequences: - for stop_seq in stop_sequences: - if stop_seq in output: - remaining = output[len(prev_output):output.find(stop_seq)] - if remaining: - yield remaining - return - yield content - await asyncio.sleep(0.01) + buffer = "" + full_output = "" + + if not ensure_full_sentences: + for content in streamer: + prev_output = full_output + full_output += content + if stop_sequences: + for stop_seq in stop_sequences: + if stop_seq in full_output: + remaining = full_output[len(prev_output):full_output.find(stop_seq)] + if remaining: + yield remaining + return + yield content + await asyncio.sleep(0.01) + else: + for content in streamer: + buffer += content + + if stop_sequences: + stop_triggered = False + for stop_sequence in stop_sequences: + if stop_sequence in buffer: + remaining = buffer[:buffer.find(stop_sequence)] + if remaining: + yield remaining + full_output += remaining + stop_triggered = True + break + + if stop_triggered: + break + + last_sentence_ending = -1 + for ending in self._sentence_endings: + pos = buffer.rfind(ending) + if pos > last_sentence_ending: + last_sentence_ending = pos + + if last_sentence_ending != -1: + new_sentences = buffer[:last_sentence_ending + 1] + buffer = buffer[last_sentence_ending + 1:] + yield new_sentences + full_output += new_sentences + + await asyncio.sleep(0.01) + if report_tokens: report_tokens( prompt_token_num=inputs.input_ids.shape[-1], # type: ignore completion_token_num=self.tokenizer( # type: ignore - output, + full_output, add_special_tokens=False, return_tensors="pt" ).input_ids.shape[-1], @@ -437,7 +567,7 @@ def create_embeddings( final_embedding = torch.mean(torch.cat(chunk_embeddings, dim=0), dim=0, keepdim=True) l2_normalised = torch.nn.functional.normalize(final_embedding, p=2, dim=1) - all_embeddings.append(l2_normalised.cpu().numpy().tolist()[0]) + all_embeddings.append(l2_normalised.float().cpu().numpy().tolist()[0]) return all_embeddings[0] if isinstance(text, str) else all_embeddings diff --git a/app/processors/prompt_factory.py b/app/processors/prompt_factory.py index ee1a45c..12d77c6 100644 --- a/app/processors/prompt_factory.py +++ b/app/processors/prompt_factory.py @@ -234,29 +234,486 @@ class PromptFactory: "{% endif %}" ) + _GPT_OSS = ( + "{#-" + " In addition to the normal inputs of `messages` and `tools`, this template also accepts the" + " following kwargs:" + " - \"builtin_tools\": A list, can contain \"browser\" and/or \"python\"." + " - \"model_identity\": A string that optionally describes the model identity." + " - \"reasoning_effort\": A string that describes the reasoning effort, defaults to \"medium\"." + " #}" + "" + "{#- Tool Definition Rendering ============================================== #}" + "{%- macro render_typescript_type(param_spec, required_params, is_nullable=false) -%}" + " {%- if param_spec.type == \"array\" -%}" + " {%- if param_spec['items'] -%}" + " {%- if param_spec['items']['type'] == \"string\" -%}" + " {{- \"string[]\" }}" + " {%- elif param_spec['items']['type'] == \"number\" -%}" + " {{- \"number[]\" }}" + " {%- elif param_spec['items']['type'] == \"integer\" -%}" + " {{- \"number[]\" }}" + " {%- elif param_spec['items']['type'] == \"boolean\" -%}" + " {{- \"boolean[]\" }}" + " {%- else -%}" + " {%- set inner_type = render_typescript_type(param_spec['items'], required_params) -%}" + " {%- if inner_type == \"object | object\" or inner_type|length > 50 -%}" + " {{- \"any[]\" }}" + " {%- else -%}" + " {{- inner_type + \"[]\" }}" + " {%- endif -%}" + " {%- endif -%}" + " {%- if param_spec.nullable -%}" + " {{- \" | null\" }}" + " {%- endif -%}" + " {%- else -%}" + " {{- \"any[]\" }}" + " {%- if param_spec.nullable -%}" + " {{- \" | null\" }}" + " {%- endif -%}" + " {%- endif -%}" + " {%- elif param_spec.type is defined and param_spec.type is iterable and param_spec.type is not string and param_spec.type is not mapping and param_spec.type[0] is defined -%}" + " {#- Handle array of types like [\"object\", \"object\"] from Union[dict, list] #}" + " {%- if param_spec.type | length > 1 -%}" + " {{- param_spec.type | join(\" | \") }}" + " {%- else -%}" + " {{- param_spec.type[0] }}" + " {%- endif -%}" + " {%- elif param_spec.oneOf -%}" + " {#- Handle oneOf schemas - check for complex unions and fallback to any #}" + " {%- set has_object_variants = false -%}" + " {%- for variant in param_spec.oneOf -%}" + " {%- if variant.type == \"object\" -%}" + " {%- set has_object_variants = true -%}" + " {%- endif -%}" + " {%- endfor -%}" + " {%- if has_object_variants and param_spec.oneOf|length > 1 -%}" + " {{- \"any\" }}" + " {%- else -%}" + " {%- for variant in param_spec.oneOf -%}" + " {{- render_typescript_type(variant, required_params) -}}" + " {%- if variant.description %}" + " {{- \"// \" + variant.description }}" + " {%- endif -%}" + " {%- if variant.default is defined %}" + " {{ \"// default: \" + variant.default|tojson }}" + " {%- endif -%}" + " {%- if not loop.last %}" + " {{- \" | \" }}" + " {% endif -%}" + " {%- endfor -%}" + " {%- endif -%}" + " {%- elif param_spec.type == \"string\" -%}" + " {%- if param_spec.enum -%}" + " {{- '\"' + param_spec.enum|join('\" | \"') + '\"' -}}" + " {%- else -%}" + " {{- \"string\" }}" + " {%- if param_spec.nullable %}" + " {{- \" | null\" }}" + " {%- endif -%}" + " {%- endif -%}" + " {%- elif param_spec.type == \"number\" -%}" + " {{- \"number\" }}" + " {%- elif param_spec.type == \"integer\" -%}" + " {{- \"number\" }}" + " {%- elif param_spec.type == \"boolean\" -%}" + " {{- \"boolean\" }}" + "" + " {%- elif param_spec.type == \"object\" -%}" + " {%- if param_spec.properties -%}" + " {{- \"{\\n\" }}" + " {%- for prop_name, prop_spec in param_spec.properties.items() -%}" + " {{- prop_name -}}" + " {%- if prop_name not in (param_spec.required or []) -%}" + " {{- \"?\" }}" + " {%- endif -%}" + " {{- \": \" }}" + " {{ render_typescript_type(prop_spec, param_spec.required or []) }}" + " {%- if not loop.last -%}" + " {{- \", \" }}" + " {%- endif -%}" + " {%- endfor -%}" + " {{- \"}\" }}" + " {%- else -%}" + " {{- \"object\" }}" + " {%- endif -%}" + " {%- else -%}" + " {{- \"any\" }}" + " {%- endif -%}" + "{%- endmacro -%}" + "" + "{%- macro render_tool_namespace(namespace_name, tools) -%}" + " {{- \"## \" + namespace_name + \"\\n\\n\" }}" + " {{- \"namespace \" + namespace_name + \" {\\n\\n\" }}" + " {%- for tool in tools %}" + " {%- set tool = tool.function %}" + " {{- \"// \" + tool.description + \"\\n\" }}" + " {{- \"type \"+ tool.name + \" = \" }}" + " {%- if tool.parameters and tool.parameters.properties %}" + " {{- \"(_: {\\n\" }}" + " {%- for param_name, param_spec in tool.parameters.properties.items() %}" + " {%- if param_spec.description %}" + " {{- \"// \" + param_spec.description + \"\\n\" }}" + " {%- endif %}" + " {{- param_name }}" + " {%- if param_name not in (tool.parameters.required or []) -%}" + " {{- \"?\" }}" + " {%- endif -%}" + " {{- \": \" }}" + " {{- render_typescript_type(param_spec, tool.parameters.required or []) }}" + " {%- if param_spec.default is defined -%}" + " {%- if param_spec.enum %}" + " {{- \", // default: \" + param_spec.default }}" + " {%- elif param_spec.oneOf %}" + " {{- \"// default: \" + param_spec.default }}" + " {%- else %}" + " {{- \", // default: \" + param_spec.default|tojson }}" + " {%- endif -%}" + " {%- endif -%}" + " {%- if not loop.last %}" + " {{- \",\\n\" }}" + " {%- else %}" + " {{- \",\\n\" }}" + " {%- endif -%}" + " {%- endfor %}" + " {{- \"}) => any;\\n\\n\" }}" + " {%- else -%}" + " {{- \"() => any;\\n\\n\" }}" + " {%- endif -%}" + " {%- endfor %}" + " {{- \"} // namespace \" + namespace_name }}" + "{%- endmacro -%}" + "" + "{%- macro render_builtin_tools(browser_tool, python_tool) -%}" + " {%- if browser_tool %}" + " {{- \"## browser\\n\\n\" }}" + " {{- \"// Tool for browsing.\\n\" }}" + " {{- \"// The `cursor` appears in brackets before each browsing display: `[{cursor}]`.\\n\" }}" + " {{- \"// Cite information from the tool using the following format:\\n\" }}" + " {{- \"// `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.\\n\" }}" + " {{- \"// Do not quote more than 10 words directly from the tool output.\\n\" }}" + " {{- \"// sources=web (default: web)\\n\" }}" + " {{- \"namespace browser {\\n\\n\" }}" + " {{- \"// Searches for information related to `query` and displays `topn` results.\\n\" }}" + " {{- \"type search = (_: {\\n\" }}" + " {{- \"query: string,\\n\" }}" + " {{- \"topn?: number, // default: 10\\n\" }}" + " {{- \"source?: string,\\n\" }}" + " {{- \"}) => any;\\n\\n\" }}" + " {{- \"// Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.\\n\" }}" + " {{- \"// Valid link ids are displayed with the formatting: `【{id}†.*】`.\\n\" }}" + " {{- \"// If `cursor` is not provided, the most recent page is implied.\\n\" }}" + " {{- \"// If `id` is a string, it is treated as a fully qualified URL associated with `source`.\\n\" }}" + " {{- \"// If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.\\n\" }}" + " {{- \"// Use this function without `id` to scroll to a new location of an opened page.\\n\" }}" + " {{- \"type open = (_: {\\n\" }}" + " {{- \"id?: number | string, // default: -1\\n\" }}" + " {{- \"cursor?: number, // default: -1\\n\" }}" + " {{- \"loc?: number, // default: -1\\n\" }}" + " {{- \"num_lines?: number, // default: -1\\n\" }}" + " {{- \"view_source?: boolean, // default: false\\n\" }}" + " {{- \"source?: string,\\n\" }}" + " {{- \"}) => any;\\n\\n\" }}" + " {{- \"// Finds exact matches of `pattern` in the current page, or the page given by `cursor`.\\n\" }}" + " {{- \"type find = (_: {\\n\" }}" + " {{- \"pattern: string,\\n\" }}" + " {{- \"cursor?: number, // default: -1\\n\" }}" + " {{- \"}) => any;\\n\\n\" }}" + " {{- \"} // namespace browser\\n\\n\" }}" + " {%- endif -%}" + "" + " {%- if python_tool %}" + " {{- \"## python\\n\\n\" }}" + " {{- \"Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).\\n\\n\" }}" + " {{- \"When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.\\n\\n\" }}" + " {%- endif -%}" + "{%- endmacro -%}" + "" + "{#- System Message Construction ============================================ #}" + "{%- macro build_system_message() -%}" + " {%- if model_identity is not defined %}" + " {%- set model_identity = \"You are ChatGPT, a large language model trained by OpenAI.\" %}" + " {%- endif %}" + " {{- model_identity + \"\\n\" }}" + " {{- \"Knowledge cutoff: 2024-06\\n\" }}" + " {{- \"Current date: \" + strftime_now(\"%Y-%m-%d\") + \"\\n\\n\" }}" + " {%- if reasoning_effort is not defined %}" + " {%- set reasoning_effort = \"medium\" %}" + " {%- endif %}" + " {{- \"Reasoning: \" + reasoning_effort + \"\\n\\n\" }}" + " {%- if builtin_tools %}" + " {{- \"# Tools\\n\\n\" }}" + " {%- set available_builtin_tools = namespace(browser=false, python=false) %}" + " {%- for tool in builtin_tools %}" + " {%- if tool == \"browser\" %}" + " {%- set available_builtin_tools.browser = true %}" + " {%- elif tool == \"python\" %}" + " {%- set available_builtin_tools.python = true %}" + " {%- endif %}" + " {%- endfor %}" + " {{- render_builtin_tools(available_builtin_tools.browser, available_builtin_tools.python) }}" + " {%- endif -%}" + " {{- \"# Valid channels: analysis, commentary, final. Channel must be included for every message.\" }}" + " {%- if tools -%}" + " {{- \"\\nCalls to these tools must go to the commentary channel: 'functions'.\" }}" + " {%- endif -%}" + "{%- endmacro -%}" + "" + "{#- Main Template Logic ================================================= #}" + "{#- Set defaults #}" + "" + "{#- Render system message #}" + "{{- \"<|start|>system<|message|>\" }}" + "{{- build_system_message() }}" + "{{- \"<|end|>\" }}" + "" + "{#- Extract developer message #}" + "{%- if messages[0].role == \"developer\" or messages[0].role == \"system\" %}" + " {%- set developer_message = messages[0].content %}" + " {%- set loop_messages = messages[1:] %}" + "{%- else %}" + " {%- set developer_message = \"\" %}" + " {%- set loop_messages = messages %}" + "{%- endif %}" + "" + "{#- Render developer message #}" + "{%- if developer_message or tools %}" + " {{- \"<|start|>developer<|message|>\" }}" + " {%- if developer_message %}" + " {{- \"# Instructions\\n\\n\" }}" + " {{- developer_message }}" + " {{- \"\\n\\n\" }}" + " {%- endif %}" + " {%- if tools -%}" + " {{- \"# Tools\\n\\n\" }}" + " {{- render_tool_namespace(\"functions\", tools) }}" + " {%- endif -%}" + " {{- \"<|end|>\" }}" + "{%- endif %}" + "" + "{#- Render messages #}" + "{%- set last_tool_call = namespace(name=none) %}" + "{%- for message in loop_messages -%}" + " {#- At this point only assistant/user/tool messages should remain #}" + " {%- if message.role == 'assistant' -%}" + " {#- Checks to ensure the messages are being passed in the format we expect #}" + " {%- if \"content\" in message %}" + " {%- if \"<|channel|>analysis<|message|>\" in message.content or \"<|channel|>final<|message|>\" in message.content %}" + " {{- raise_exception(\"You have passed a message containing <|channel|> tags in the content field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.\") }}" + " {%- endif %}" + " {%- endif %}" + " {%- if \"thinking\" in message %}" + " {%- if \"<|channel|>analysis<|message|>\" in message.thinking or \"<|channel|>final<|message|>\" in message.thinking %}" + " {{- raise_exception(\"You have passed a message containing <|channel|> tags in the thinking field. Instead of doing this, you should pass analysis messages (the string between '<|message|>' and '<|end|>') in the 'thinking' field, and final messages (the string between '<|message|>' and '<|end|>') in the 'content' field.\") }}" + " {%- endif %}" + " {%- endif %}" + " {%- if \"tool_calls\" in message %}" + " {#- We need very careful handling here - we want to drop the tool call analysis message if the model #}" + " {#- has output a later <|final|> message, but otherwise we want to retain it. This is the only case #}" + " {#- when we render CoT/analysis messages in inference. #}" + " {%- set future_final_message = namespace(found=false) %}" + " {%- for future_message in loop_messages[loop.index:] %}" + " {%- if future_message.role == 'assistant' and \"tool_calls\" not in future_message %}" + " {%- set future_final_message.found = true %}" + " {%- endif %}" + " {%- endfor %}" + " {#- We assume max 1 tool call per message, and so we infer the tool call name #}" + " {#- in \"tool\" messages from the most recent assistant tool call name #}" + " {%- set tool_call = message.tool_calls[0] %}" + " {%- if tool_call.function %}" + " {%- set tool_call = tool_call.function %}" + " {%- endif %}" + " {%- if message.content and message.thinking %}" + " {{- raise_exception(\"Cannot pass both content and thinking in an assistant message with tool calls! Put the analysis message in one or the other, but not both.\") }}" + " {%- elif message.content and not future_final_message.found %}" + " {{- \"<|start|>assistant<|channel|>analysis<|message|>\" + message.content + \"<|end|>\" }}" + " {%- elif message.thinking and not future_final_message.found %}" + " {{- \"<|start|>assistant<|channel|>analysis<|message|>\" + message.thinking + \"<|end|>\" }}" + " {%- endif %}" + " {{- \"<|start|>assistant to=\" }}" + " {{- \"functions.\" + tool_call.name + \"<|channel|>commentary \" }}" + " {{- (tool_call.content_type if tool_call.content_type is defined else \"json\") + \"<|message|>\" }}" + " {{- tool_call.arguments|tojson }}" + " {{- \"<|call|>\" }}" + " {%- set last_tool_call.name = tool_call.name %}" + " {%- elif loop.last and not add_generation_prompt %}" + " {#- Only render the CoT if the final turn is an assistant turn and add_generation_prompt is false #}" + " {#- This is a situation that should only occur in training, never in inference. #}" + " {%- if \"thinking\" in message %}" + " {{- \"<|start|>assistant<|channel|>analysis<|message|>\" + message.thinking + \"<|end|>\" }}" + " {%- endif %}" + " {#- <|return|> indicates the end of generation, but <|end|> does not #}" + " {#- <|return|> should never be an input to the model, but we include it as the final token #}" + " {#- when training, so the model learns to emit it. #}" + " {{- \"<|start|>assistant<|channel|>final<|message|>\" + message.content + \"<|return|>\" }}" + " {%- else %}" + " {#- CoT is dropped during all previous turns, so we never render it for inference #}" + " {{- \"<|start|>assistant<|channel|>final<|message|>\" + message.content + \"<|end|>\" }}" + " {%- set last_tool_call.name = none %}" + " {%- endif %}" + " {%- elif message.role == 'tool' -%}" + " {%- if last_tool_call.name is none %}" + " {{- raise_exception(\"Message has tool role, but there was no previous assistant message with a tool call!\") }}" + " {%- endif %}" + " {{- \"<|start|>functions.\" + last_tool_call.name }}" + " {{- \" to=assistant<|channel|>commentary<|message|>\" + message.content|tojson + \"<|end|>\" }}" + " {%- elif message.role == 'user' -%}" + " {{- \"<|start|>user<|message|>\" + message.content + \"<|end|>\" }}" + " {%- endif -%}" + "{%- endfor -%}" + "" + "{#- Generation prompt #}" + "{%- if add_generation_prompt -%}" + "<|start|>assistant<|channel|>final<|message|>" + "{%- endif -%}" + ) + + _QWEN2 = ( + "{% if not add_generation_prompt is defined %}" + " {% set add_generation_prompt = false %}" + "{% endif %}" + "" + "{% set ns = namespace(" + " is_first=false," + " is_tool=false," + " is_output_first=true," + " system_prompt=''," + " is_first_sp=true" + ") %}" + "" + "{# Collect all system messages #}" + "{%- for message in messages %}" + " {%- if message['role'] == 'system' %}" + " {%- if ns.is_first_sp %}" + " {% set ns.system_prompt = ns.system_prompt + message['content'] %}" + " {% set ns.is_first_sp = false %}" + " {%- else %}" + " {% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}" + " {%- endif %}" + " {%- endif %}" + "{%- endfor %}" + "" + "{# Output BOS token and system prompt #}" + "{{ bos_token }}{{ ns.system_prompt }}" + "" + "{# Process messages #}" + "{%- for message in messages %}" + " " + " {# User message #}" + " {%- if message['role'] == 'user' %}" + " {%- set ns.is_tool = false -%}" + " {{'<|User|>' + message['content']}}" + " {%- endif %}" + " " + " {# Assistant message with tool calls #}" + " {%- if message['role'] == 'assistant' and 'tool_calls' in message %}" + " {%- set ns.is_tool = false -%}" + " {%- for tool in message['tool_calls'] %}" + " {%- if not ns.is_first %}" + " {%- if message['content'] is none %}" + " {{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + " + " tool['type'] + '<|tool▁sep|>' + " + " tool['function']['name'] + '\\n' + " + " '```json' + '\\n' + " + " tool['function']['arguments'] + '\\n' + " + " '```' + '<|tool▁call▁end|>'}}" + " {%- else %}" + " {{'<|Assistant|>' + message['content'] + " + " '<|tool▁calls▁begin|><|tool▁call▁begin|>' + " + " tool['type'] + '<|tool▁sep|>' + " + " tool['function']['name'] + '\\n' + " + " '```json' + '\\n' + " + " tool['function']['arguments'] + '\\n' + " + " '```' + '<|tool▁call▁end|>'}}" + " {%- endif %}" + " {%- set ns.is_first = true -%}" + " {%- else %}" + " {{'\\n' + '<|tool▁call▁begin|>' + " + " tool['type'] + '<|tool▁sep|>' + " + " tool['function']['name'] + '\\n' + " + " '```json' + '\\n' + " + " tool['function']['arguments'] + '\\n' + " + " '```' + '<|tool▁call▁end|>'}}" + " {%- endif %}" + " {%- endfor %}" + " {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}" + " {%- endif %}" + " " + " {# Assistant message without tool calls #}" + " {%- if message['role'] == 'assistant' and 'tool_calls' not in message %}" + " {%- if ns.is_tool %}" + " {{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}" + " {%- set ns.is_tool = false -%}" + " {%- else %}" + " {% set content = message['content'] %}" + " {% if '' in content %}" + " {% set content = content.split('')[-1] %}" + " {% endif %}" + " {{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}" + " {%- endif %}" + " {%- endif %}" + " " + " {# Tool output message #}" + " {%- if message['role'] == 'tool' %}" + " {%- set ns.is_tool = true -%}" + " {%- if ns.is_output_first %}" + " {{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + " + " message['content'] + '<|tool▁output▁end|>'}}" + " {%- set ns.is_output_first = false %}" + " {%- else %}" + " {{'<|tool▁output▁begin|>' + " + " message['content'] + '<|tool▁output▁end|>'}}" + " {%- endif %}" + " {%- endif %}" + " " + "{%- endfor -%}" + "" + "{# Close tool outputs if needed #}" + "{% if ns.is_tool %}" + " {{'<|tool▁outputs▁end|>'}}" + "{% endif %}" + "" + "{# Add generation prompt if requested #}" + "{% if add_generation_prompt and not ns.is_tool %}" + " {{'<|Assistant|>'}}" + "{% endif %}" + ) + @classmethod - def create_chat_template(cls, name: str = "default") -> str: - if name.lower() == "default": + def create_chat_template(cls, tmpl_name: str = "default") -> str: + """Creates and returns a chat template based on the provided name. + + Args: + tmpl_name (str): The name of the chat template to create. + Returns: + str: The chat template string. + + """ + if tmpl_name.lower() == "default": return cls._DEFAULT - elif name.lower() == "alpaca": + elif tmpl_name.lower() == "alpaca": return cls._ALPACA - elif name.lower() == "chat_ml": + elif tmpl_name.lower() == "chat_ml": return cls._CHAT_ML - elif name.lower() == "falcon": + elif tmpl_name.lower() == "falcon": return cls._FALCON - elif name.lower() == "gemma": + elif tmpl_name.lower() == "gemma": return cls._GEMMA - elif name.lower() == "llama_2": + elif tmpl_name.lower() == "llama_2": return cls._LLAMA_2 - elif name.lower() == "llama_3": + elif tmpl_name.lower() == "llama_3": return cls._LLAMA_3 - elif name.lower() == "mistral": + elif tmpl_name.lower() == "mistral": return cls._MISTRAL - elif name.lower() == "phi_2": + elif tmpl_name.lower() == "phi_2": return cls._PHI_2 - elif name.lower() == "phi_3": + elif tmpl_name.lower() == "phi_3": return cls._PHI_3 - elif name.lower() == "qwen": + elif tmpl_name.lower() == "qwen": return cls._QWEN + elif tmpl_name.lower() == "openai": + return cls._GPT_OSS + elif tmpl_name.lower() == "qwen2": + return cls._QWEN2 else: raise ValueError("Invalid template name") diff --git a/app/trainers/huggingface_llm_trainer.py b/app/trainers/huggingface_llm_trainer.py index 89404fa..8596bde 100644 --- a/app/trainers/huggingface_llm_trainer.py +++ b/app/trainers/huggingface_llm_trainer.py @@ -338,7 +338,12 @@ def run( get_model_data_package_base_name(trained_model_pack_path), ) - if non_default_device_is_available(self._config.DEVICE): + if (non_default_device_is_available(self._config.DEVICE) and + not ( + getattr(self._model_service._model, "is_loaded_in_8bit", False) or + getattr(self._model_service._model, "is_loaded_in_4bit", False) + ) + ): model.to(self._config.DEVICE) train_dataset, test_dataset = self._load_dataset_from_config(data_file, training_params) @@ -346,7 +351,7 @@ def run( train_dataset = train_dataset.map(make_conversation) test_dataset = test_dataset.map(make_conversation) - if hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None: + if hasattr(tokenizer, "chat_template") and tokenizer.chat_template is None: logger.warning("The tokenizer does not have a chat template. Using the default one.") tokenizer.chat_template = get_default_chat_template() else: @@ -785,7 +790,7 @@ def run( if not eval_mode: try: copied_model_directory = None - if self._model_service.is_4bit_quantised: + if self._model_service.is_quantised: logger.info("Use the LoRA adaptor for the quantised model...") lora_config = LoraConfig( task_type="CAUSAL_LM", @@ -964,7 +969,7 @@ def run( logger.info("Evaluating the running model...") model, tokenizer = self._model_service.model, self._model_service.tokenizer - if self._model_service.is_4bit_quantised: + if self._model_service.is_quantised: logger.error("Cannot evaluate against a quantised model") raise ManagedModelException("Cannot evaluate against a quantised model") diff --git a/app/utils.py b/app/utils.py index 9da3e75..78469f8 100644 --- a/app/utils.py +++ b/app/utils.py @@ -786,6 +786,24 @@ def get_default_chat_template() -> str: ) +def utilise_local_chat_template(hf_model_type: str, tokenizer: PreTrainedTokenizer) -> bool: + """Sets the chat template for the tokenizer if a local template is available. + + Args: + hf_model_type (str): The model type in the model config. + tokenizer (PreTrainedTokenizer): The tokenizer to set the chat template for. + + Returns: + bool: True if the local chat template was detected and utilised, False otherwise. + + """ + try: + tokenizer.chat_template = PromptFactory.create_chat_template(hf_model_type) + return True + except ValueError: + return False + + def get_default_system_prompt() -> str: """ Gets the default system prompt. @@ -849,7 +867,7 @@ def get_prompt_from_messages( prompt = "\n".join(prompt_parts) prompt += "\n<|assistant|>\n" else: - tokenizer.chat_template = PromptFactory.create_chat_template(name=override_template) + tokenizer.chat_template = PromptFactory.create_chat_template(tmpl_name=override_template) prompt = tokenizer.apply_chat_template( [dump_pydantic_object_to_dict(message) for message in messages], tokenize=False, diff --git a/pyproject.toml b/pyproject.toml index dc6c874..8f3c98d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,7 @@ llm = [ mcp = [ "mcp[cli]==1.26.0", "cms-client==0.0.1", - "loguru~=0.7.3" + "loguru~=0.7.3", ] # For pip versions not supporting PEP 735 diff --git a/tests/app/api/test_serving_hf_llm.py b/tests/app/api/test_serving_hf_llm.py index c01fd89..c209924 100644 --- a/tests/app/api/test_serving_hf_llm.py +++ b/tests/app/api/test_serving_hf_llm.py @@ -72,6 +72,42 @@ async def test_stream_generate(llm_model_service, llm_app): @pytest.mark.asyncio async def test_generate_chat_completions(llm_model_service, llm_app): + llm_model_service.model_name = "HuggingFace LLM model" + llm_model_service.generate.return_value = "I'm a chat bot." + request_data = { + "messages": [ + { + "role": "system", + "content": "You are a chat bot." + }, + { + "role": "user", + "content": "Who are you?" + } + ], + "model": "HuggingFace LLM model", + "stream": False, + "max_tokens": 128, + "temperature": 0.7 + } + + async with httpx.AsyncClient(app=llm_app, base_url="http://test") as ac: + response = await ac.post( + "/v1/chat/completions?max_tokens=128&temperature=0.7", + data=json.dumps(request_data), + headers={"Content-Type": "application/json"}, + ) + + response_json = response.json() + assert response.status_code == 200 + assert response.headers["content-type"] == "application/json" + assert response_json["object"] == "chat.completion" + assert response_json["model"] == "HuggingFace LLM model" + assert response_json["choices"][0]["message"]["content"] == "I'm a chat bot." + + +@pytest.mark.asyncio +async def test_generate_chat_completions_stream(llm_model_service, llm_app): llm_model_service.generate.return_value = "I'm a chat bot." request_data = { "messages": [ @@ -89,6 +125,7 @@ async def test_generate_chat_completions(llm_model_service, llm_app): "max_tokens": 128, "temperature": 0.7 } + async with httpx.AsyncClient(app=llm_app, base_url="http://test") as ac: response = await ac.post( "/v1/chat/completions?max_tokens=128&temperature=0.7", @@ -103,6 +140,63 @@ async def test_generate_chat_completions(llm_model_service, llm_app): assert "chat.completion.chunk" in response.text +@pytest.mark.asyncio +async def test_generate_completions(llm_model_service, llm_app): + llm_model_service.model_name = "HuggingFace LLM model" + llm_model_service.generate.return_value = "I'm a chat bot." + request_data = { + "model": "HuggingFace LLM model", + "prompt": "Who are you?", + "max_tokens": 128, + "temperature": 0.7, + "stream": False, + } + + async with httpx.AsyncClient(app=llm_app, base_url="http://test") as ac: + response = await ac.post( + "/v1/completions", + data=json.dumps(request_data), + headers={"Content-Type": "application/json"}, + ) + + response_json = response.json() + assert response.status_code == 200 + assert response.headers["content-type"] == "application/json" + assert response_json["object"] == "text_completion" + assert response_json["model"] == "HuggingFace LLM model" + assert response_json["choices"][0]["text"] == "I'm a chat bot." + + +@pytest.mark.asyncio +async def test_generate_completions_stream(llm_model_service, llm_app): + llm_model_service.model_name = "HuggingFace LLM model" + + async def async_gen(): + yield "I'm a chat bot." + + llm_model_service.generate_async.return_value = async_gen() + request_data = { + "model": "HuggingFace LLM model", + "prompt": "Who are you?", + "max_tokens": 128, + "temperature": 0.7, + "stream": True, + } + async with httpx.AsyncClient(app=llm_app, base_url="http://test") as ac: + response = await ac.post( + "/v1/completions", + data=json.dumps(request_data), + headers={"Content-Type": "application/json"}, + ) + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + assert response.text.startswith("data:") + assert "id" in response.text + assert "text_completion" in response.text + assert "[DONE]" in response.text + + def test_create_embeddings(client): request_data = { "input": ["Alright"], diff --git a/tests/app/model_services/test_huggingface_llm_model.py b/tests/app/model_services/test_huggingface_llm_model.py index 0f134f9..26dfc40 100644 --- a/tests/app/model_services/test_huggingface_llm_model.py +++ b/tests/app/model_services/test_huggingface_llm_model.py @@ -1,7 +1,8 @@ import os +import pytest from unittest.mock import MagicMock, patch from tests.app.conftest import MODEL_PARENT_DIR -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import PreTrainedModel, PreTrainedTokenizerBase, TextIteratorStreamer from app import __version__ from app.domain import ModelType from app.model_services.huggingface_llm_model import HuggingFaceLlmModel @@ -42,7 +43,11 @@ def test_info(huggingface_llm_model): assert model_card.model_type == ModelType.HUGGINGFACE_LLM -def test_generate(huggingface_llm_model): +@pytest.mark.parametrize("ensure_full_sentences, expected_output", [ + (False, "Yeah. Hmm"), + (True, "Yeah."), +]) +def test_generate(huggingface_llm_model, ensure_full_sentences, expected_output): huggingface_llm_model.init_model() huggingface_llm_model.model = MagicMock() huggingface_llm_model.tokenizer = MagicMock() @@ -53,7 +58,11 @@ def test_generate(huggingface_llm_model): huggingface_llm_model.tokenizer.return_value = inputs outputs = [MagicMock(shape=[2])] huggingface_llm_model.model.generate.return_value = outputs - huggingface_llm_model.tokenizer.decode.return_value = "Yeah." + completion_ids = MagicMock() + completion_ids.shape = [2] + outputs[0].__getitem__.return_value = completion_ids + huggingface_llm_model.tokenizer.decode.return_value = "Yeah. Hmm" + huggingface_llm_model.tokenizer.apply_chat_template.return_value = "chat template text" result = huggingface_llm_model.generate( prompt="Alright?", @@ -63,11 +72,12 @@ def test_generate(huggingface_llm_model): temperature=0.5, top_p=0.8, stop_sequences=["end"], - report_tokens=mock_send_metrics + report_tokens=mock_send_metrics, + ensure_full_sentences=ensure_full_sentences, ) - huggingface_llm_model.tokenizer.assert_called_once_with( - "Alright?", + huggingface_llm_model.tokenizer.assert_any_call( + "chat template text", add_special_tokens=False, return_tensors="pt", ) @@ -76,26 +86,31 @@ def test_generate(huggingface_llm_model): attention_mask=inputs.attention_mask, min_new_tokens=50, max_new_tokens=128, + use_cache=True, num_beams=2, - do_sample=True, + do_sample=False, temperature=0.5, top_p=0.8, repetition_penalty=1.2, no_repeat_ngram_size=3, ) huggingface_llm_model.tokenizer.decode.assert_called_once_with( - outputs[0], - skip_prompt=True, + outputs[0][2:], skip_special_tokens=True, ) mock_send_metrics.assert_called_once_with( prompt_token_num=2, completion_token_num=2, ) - assert result == "Yeah." + assert result == expected_output -async def test_generate_async(huggingface_llm_model): +@pytest.mark.parametrize("ensure_full_sentences, expected_output", [ + (False, "Yeah. Hmm"), + (True, "Yeah."), +]) +@pytest.mark.asyncio +async def test_generate_async(huggingface_llm_model, ensure_full_sentences, expected_output): huggingface_llm_model.init_model() huggingface_llm_model.model = MagicMock() huggingface_llm_model.tokenizer = MagicMock() @@ -103,49 +118,46 @@ async def test_generate_async(huggingface_llm_model): inputs = MagicMock() inputs.input_ids = MagicMock(shape=[1, 2]) inputs.attention_mask = MagicMock() - huggingface_llm_model.tokenizer.return_value = inputs - outputs = [MagicMock(shape=[2])] - huggingface_llm_model.model.generate.return_value = outputs - huggingface_llm_model.tokenizer.decode.return_value = "Yeah." + + def mock_tokenizer_call(*args, **kwargs): + if args and args[0] == "Alright?": + return inputs + mock_result = MagicMock() + mock_result.input_ids = MagicMock(shape=[1, 2]) + return mock_result + + huggingface_llm_model.tokenizer.side_effect = mock_tokenizer_call + streamer = TextIteratorStreamer(huggingface_llm_model.tokenizer, skip_special_tokens=True) + + for char in "Yeah. Hmm": + streamer.text_queue.put(char) + streamer.text_queue.put(streamer.stop_signal) + + with patch("app.model_services.huggingface_llm_model.TextIteratorStreamer", return_value=streamer): + huggingface_llm_model.model.generate.return_value = MagicMock(shape=[2]) + mock_future = MagicMock() + huggingface_llm_model._text_generator.submit = MagicMock(return_value=mock_future) - result = await huggingface_llm_model.generate_async( - prompt="Alright?", - min_tokens=50, - max_tokens=128, - num_beams=2, - temperature=0.5, - top_p=0.8, - stop_sequences=["end"], - report_tokens=mock_send_metrics - ) + results = [] + async for chunk in huggingface_llm_model.generate_async( + prompt="Alright?", + min_tokens=50, + max_tokens=128, + num_beams=2, + temperature=0.5, + top_p=0.8, + stop_sequences=["end"], + report_tokens=mock_send_metrics, + ensure_full_sentences=ensure_full_sentences, + ): + results.append(chunk) + result = "".join(results) - huggingface_llm_model.tokenizer.assert_called_once_with( - "Alright?", - add_special_tokens=False, - return_tensors="pt", - ) - huggingface_llm_model.model.generate_async.assert_called_once_with( - inputs=inputs.input_ids, - attention_mask=inputs.attention_mask, - min_new_tokens=50, - max_new_tokens=128, - num_beams=2, - do_sample=True, - temperature=0.5, - top_p=0.8, - repetition_penalty=1.2, - no_repeat_ngram_size=3, - ) - huggingface_llm_model.tokenizer.decode.assert_called_once_with( - outputs[0], - skip_prompt=True, - skip_special_tokens=True, - ) mock_send_metrics.assert_called_once_with( prompt_token_num=2, completion_token_num=2, ) - assert result == "Yeah." + assert result == expected_output @patch("torch.nn.functional.normalize") @@ -179,7 +191,7 @@ def tensor_side_effect(*args, **kwargs): mock_cat.return_value = mock_concatenated mock_mean.return_value = mock_final_embedding mock_normalise.return_value = mock_normalised - mock_normalised.cpu.return_value.numpy.return_value.tolist.return_value = [[0.1, 0.2, 0.3]] + mock_normalised.float.return_value.cpu.return_value.numpy.return_value.tolist.return_value = [[0.1, 0.2, 0.3]] mock_tensor.side_effect = tensor_side_effect mock_masked = MagicMock() mock_summed = MagicMock() @@ -235,7 +247,7 @@ def tensor_side_effect(*args, **kwargs): mock_cat.return_value = mock_concatenated mock_mean.return_value = mock_final_embedding mock_normalise.return_value = mock_normalised - mock_normalised.cpu.return_value.numpy.return_value.tolist.return_value = [[0.1, 0.2, 0.3]] + mock_normalised.float.return_value.cpu.return_value.numpy.return_value.tolist.return_value = [[0.1, 0.2, 0.3]] mock_tensor.side_effect = tensor_side_effect mock_masked = MagicMock() mock_summed = MagicMock() @@ -252,3 +264,88 @@ def tensor_side_effect(*args, **kwargs): assert mock_cat.call_count == 2 assert mock_mean.call_count == 2 assert result == [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]] + + +def test_load_model_quantization_check(): + mock_config = MagicMock() + mock_config.to_dict.return_value = {"quantization_config": {}} + mock_model = MagicMock(spec=PreTrainedModel) + mock_model.config = MagicMock() + mock_model.config.max_position_embeddings = 512 + mock_tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + + with patch("app.model_services.huggingface_llm_model.unpack_model_data_package", return_value=True), \ + patch("app.model_services.huggingface_llm_model.AutoConfig.from_pretrained", return_value=mock_config), \ + patch("app.model_services.huggingface_llm_model.AutoModelForCausalLM.from_pretrained", return_value=mock_model), \ + patch("app.model_services.huggingface_llm_model.AutoTokenizer.from_pretrained", return_value=mock_tokenizer), \ + patch("app.model_services.huggingface_llm_model.BitsAndBytesConfig", return_value=MagicMock()), \ + patch("app.model_services.huggingface_llm_model.get_settings") as mock_get_settings, \ + patch("app.model_services.huggingface_llm_model.logger") as mock_logger: + + mock_settings = MagicMock() + mock_settings.DEVICE = "cpu" + mock_get_settings.return_value = mock_settings + + model, tokenizer = HuggingFaceLlmModel.load_model("dummy_path", load_in_4bit=True) + + mock_logger.info.assert_any_call("Model already quantised, loading by ignoring 'load_in_4bit' or 'load_in_8bit' flag") + assert model == mock_model + assert tokenizer == mock_tokenizer + + with patch("app.model_services.huggingface_llm_model.unpack_model_data_package", return_value=True), \ + patch("app.model_services.huggingface_llm_model.AutoConfig.from_pretrained", return_value=mock_config), \ + patch("app.model_services.huggingface_llm_model.AutoModelForCausalLM.from_pretrained", return_value=mock_model), \ + patch("app.model_services.huggingface_llm_model.AutoTokenizer.from_pretrained", return_value=mock_tokenizer), \ + patch("app.model_services.huggingface_llm_model.BitsAndBytesConfig", return_value=MagicMock()), \ + patch("app.model_services.huggingface_llm_model.get_settings") as mock_get_settings, \ + patch("app.model_services.huggingface_llm_model.logger") as mock_logger: + + mock_settings = MagicMock() + mock_settings.DEVICE = "cpu" + mock_get_settings.return_value = mock_settings + mock_config.to_dict.return_value = {"quantization_config": {}} + + model, tokenizer = HuggingFaceLlmModel.load_model("dummy_path", load_in_8bit=True) + + mock_logger.info.assert_any_call("Model already quantised, loading by ignoring 'load_in_4bit' or 'load_in_8bit' flag") + assert model == mock_model + assert tokenizer == mock_tokenizer + + with patch("app.model_services.huggingface_llm_model.unpack_model_data_package", return_value=True), \ + patch("app.model_services.huggingface_llm_model.AutoConfig.from_pretrained", return_value=mock_config), \ + patch("app.model_services.huggingface_llm_model.AutoModelForCausalLM.from_pretrained", return_value=mock_model), \ + patch("app.model_services.huggingface_llm_model.AutoTokenizer.from_pretrained", return_value=mock_tokenizer), \ + patch("app.model_services.huggingface_llm_model.BitsAndBytesConfig", return_value=MagicMock()), \ + patch("app.model_services.huggingface_llm_model.get_settings") as mock_get_settings, \ + patch("app.model_services.huggingface_llm_model.logger") as mock_logger: + + mock_settings = MagicMock() + mock_settings.DEVICE = "cpu" + mock_get_settings.return_value = mock_settings + mock_config.to_dict.return_value = {} + + model, tokenizer = HuggingFaceLlmModel.load_model("dummy_path", load_in_4bit=True) + + mock_logger.info.assert_called_once_with("Model package loaded from %s", "dummy_path") + assert model == mock_model + assert tokenizer == mock_tokenizer + + with patch("app.model_services.huggingface_llm_model.unpack_model_data_package", return_value=True), \ + patch("app.model_services.huggingface_llm_model.AutoConfig.from_pretrained", return_value=mock_config), \ + patch("app.model_services.huggingface_llm_model.AutoModelForCausalLM.from_pretrained", return_value=mock_model), \ + patch("app.model_services.huggingface_llm_model.AutoTokenizer.from_pretrained", return_value=mock_tokenizer), \ + patch("app.model_services.huggingface_llm_model.BitsAndBytesConfig", return_value=MagicMock()), \ + patch("app.model_services.huggingface_llm_model.get_settings") as mock_get_settings, \ + patch("app.model_services.huggingface_llm_model.logger") as mock_logger: + + mock_settings = MagicMock() + mock_settings.DEVICE = "cpu" + mock_get_settings.return_value = mock_settings + mock_config.to_dict.return_value = {} + + model, tokenizer = HuggingFaceLlmModel.load_model("dummy_path", load_in_8bit=True) + + mock_logger.info.assert_called_once_with("Model package loaded from %s", "dummy_path") + assert model == mock_model + assert tokenizer == mock_tokenizer + diff --git a/tests/app/processors/test_prompt_factory.py b/tests/app/processors/test_prompt_factory.py index ab03388..993ba31 100644 --- a/tests/app/processors/test_prompt_factory.py +++ b/tests/app/processors/test_prompt_factory.py @@ -1,4 +1,5 @@ from jinja2.sandbox import ImmutableSandboxedEnvironment +from unittest.mock import Mock from app.processors.prompt_factory import PromptFactory diff --git a/tests/app/test_utils.py b/tests/app/test_utils.py index 3420fe7..1845d97 100644 --- a/tests/app/test_utils.py +++ b/tests/app/test_utils.py @@ -34,8 +34,8 @@ pyproject_dependencies_to_pip_requirements, get_model_data_package_base_name, load_pydantic_object_from_dict, - dump_pydantic_object_to_dict, get_prompt_from_messages, + utilise_local_chat_template, ) from app.domain import Annotation, Entity, PromptMessage, PromptRole @@ -406,7 +406,7 @@ def forward(self, x): def test_get_prompt_with_chat_template(): - with patch('transformers.PreTrainedTokenizer') as tok: + with patch("transformers.PreTrainedTokenizer") as tok: mock_tokenizer = tok.return_value mock_tokenizer.chat_template = "Mock chat template" mock_tokenizer.apply_chat_template.return_value = "Mock chat template applied" @@ -421,7 +421,7 @@ def test_get_prompt_with_chat_template(): def test_get_prompt_with_default_chat_template(): - with patch('transformers.PreTrainedTokenizer') as tok: + with patch("transformers.PreTrainedTokenizer") as tok: mock_tokenizer = tok.return_value mock_tokenizer.chat_template = None mock_tokenizer.default_chat_template = "Mock default chat template" @@ -436,8 +436,24 @@ def test_get_prompt_with_default_chat_template(): assert prompt == "Mock default chat template applied" +def test_utilise_local_chat_template_if_exists(): + tokenizer = MagicMock() + tokenizer.chat_template = "chat template" + result = utilise_local_chat_template("default", tokenizer) + assert result is True + assert tokenizer.chat_template != "chat template" + + +def test_utilise_local_chat_template_if_missing(): + tokenizer = MagicMock() + tokenizer.chat_template = "chat template" + result = utilise_local_chat_template("invalid", tokenizer) + assert result is False + assert tokenizer.chat_template == "chat template" + + def test_get_prompt_without_chat_template(): - with patch('transformers.PreTrainedTokenizer') as tok: + with patch("transformers.PreTrainedTokenizer") as tok: mock_tokenizer = tok.return_value mock_tokenizer.chat_template = None mock_tokenizer.default_chat_template = None @@ -454,7 +470,7 @@ def test_get_prompt_without_chat_template(): def test_get_prompt_with_no_messages(): - with patch('transformers.PreTrainedTokenizer') as tok: + with patch("transformers.PreTrainedTokenizer") as tok: mock_tokenizer = tok.return_value mock_tokenizer.chat_template = None mock_tokenizer.default_chat_template = None diff --git a/tests/app/trainers/test_hf_llm_trainer.py b/tests/app/trainers/test_hf_llm_trainer.py index 67ea601..0652593 100644 --- a/tests/app/trainers/test_hf_llm_trainer.py +++ b/tests/app/trainers/test_hf_llm_trainer.py @@ -3,7 +3,7 @@ from unittest.mock import create_autospec, patch, Mock from transformers import PreTrainedTokenizerFast from app.model_services.huggingface_llm_model import HuggingFaceLlmModel -from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer +from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer, HuggingFaceLlmUnsupervisedTrainer from app.config import Settings @@ -34,9 +34,12 @@ def _triton_installed(): model_service.model = Mock() model_service.model.config.max_position_embeddings = 512 model_service.model_name = "llm_test_model" +model_service.is_4bit_quantised = False supervised_trainer = HuggingFaceLlmSupervisedTrainer(model_service) supervised_trainer.model_name = "supervised_trainer" +unsupervised_trainer = HuggingFaceLlmUnsupervisedTrainer(model_service) +unsupervised_trainer.model_name = "unsupervised_trainer" data_dir = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "fixture") @@ -59,7 +62,22 @@ def test_huggingface_llm_supervised_trainer(mlflow_fixture): run.assert_called_once() -@skipIf(config.DEVICE != "cuda", "This requires a CUDA device to run") +def test_huggingface_llm_unsupervised_trainer(mlflow_fixture): + with patch.object(unsupervised_trainer, "run", wraps=unsupervised_trainer.run) as run: + unsupervised_trainer._tracker_client = Mock() + unsupervised_trainer._tracker_client.start_tracking = Mock(return_value=("experiment_id", "run_id")) + with open(os.path.join(data_dir, "sample_texts.json"), "r") as f: + unsupervised_trainer.train(f, 1, 1, "training_id", "input_file_name") + unsupervised_trainer._tracker_client.start_tracking.assert_called_once() + run.assert_called_once() + + +@skipIf((not _triton_installed()) or (config.DEVICE != "cuda"), "This requires triton to be installed and a CUDA device to run") def test_huggingface_llm_supervised_run(mlflow_fixture): with open(os.path.join(data_dir, "sample_qa.json"), "r") as data_file: HuggingFaceLlmSupervisedTrainer.run(supervised_trainer, {"nepochs": 1, "print_stats": 1}, data_file, 1, "run_id") + + +def test_huggingface_llm_unsupervised_run(mlflow_fixture): + with open(os.path.join(data_dir, "sample_texts.json"), "r") as data_file: + HuggingFaceLlmUnsupervisedTrainer.run(unsupervised_trainer, {"nepochs": 1}, data_file, 1, "run_id")