diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 9c2d6e1..b924aeb 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -25,7 +25,6 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements-dev.txt - pip install hidet pip install . - name: Test with pytest run: | diff --git a/README.md b/README.md index 98ad196..d14fa8f 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,6 @@ ![](https://github.com/CentML/centml-python-client/actions/workflows/unit_tests.yml/badge.svg) ### Installation -First, ensure you meet the requirements for [Hidet](https://github.com/hidet-org/hidet), namely: -- CUDA Toolkit 11.6+ -- Python 3.8+ To install without cloning, run the following command: ```bash @@ -36,46 +33,6 @@ source scripts/completions/completion. Shell language can be: bash, zsh, fish (Hint: add `source /path/to/completions/completion.` to your `~/.bashrc`, `~/.zshrc` or `~/.config/fish/completions/centml.fish`) -### Compilation - -centml-python-client's compiler feature allows you to compile your ML model remotely using the [hidet](https://hidet.org/docs/stable/index.html) backend. \ -Thus, use the compilation feature, make sure to run: -```bash -pip install hidet -``` - -To run the server locally, you can use the following CLI command: -```bash -centml server -``` -By default, the server will run at the URL `http://0.0.0.0:8090`. \ -You can change this by setting the environment variable `CENTML_SERVER_URL` - - -Then, within your python script include the following: -```python -import torch -# This will import the "centml" torch.compile backend -import centml.compiler - -# Define these yourself -model = ... -inputs = ... - -# Pass the "centml" backend -compiled_model = torch.compile(model, backend="centml") -# Since torch.compile is JIT, compilation is only triggered when you first call the model -output = compiled_model(inputs) -``` -Note that the centml backend compiler is non-blocking. This means it that until the server returns the compiled model, your python script will use the uncompiled model to generate the output. - -Again, make sure your script's environment sets `CENTML_SERVER_URL` to communicate with the desired server. - -To see logs, add this to your script before triggering compilation: -```python -logging.basicConfig(level=logging.INFO) -``` - ### Tests To run tests, first install required packages: ```bash diff --git a/centml/__init__.py b/centml/__init__.py index 551291c..8b13789 100644 --- a/centml/__init__.py +++ b/centml/__init__.py @@ -1 +1 @@ -from centml.compiler.main import compile + diff --git a/centml/cli/main.py b/centml/cli/main.py index b1ecc73..2a64d45 100644 --- a/centml/cli/main.py +++ b/centml/cli/main.py @@ -30,13 +30,6 @@ def cli(): cli.add_command(logout) -@cli.command(help="Start remote compilation server") -def server(): - from centml.compiler.server import run - - run() - - @click.group(help="CentML cluster CLI tool") def ccluster(): pass diff --git a/centml/compiler/__init__.py b/centml/compiler/__init__.py deleted file mode 100644 index 03d6a07..0000000 --- a/centml/compiler/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from centml.compiler.main import compile - -all = ["compile"] diff --git a/centml/compiler/backend.py b/centml/compiler/backend.py deleted file mode 100644 index 509e8a0..0000000 --- a/centml/compiler/backend.py +++ /dev/null @@ -1,194 +0,0 @@ -import os -import gc -import time -import hashlib -import logging -import threading as th -from http import HTTPStatus -from weakref import ReferenceType, ref -from tempfile import TemporaryDirectory -from typing import List, Callable, Optional -import requests -import torch -from torch.fx import GraphModule -from centml.compiler.config import settings, CompilationStatus -from centml.compiler.utils import get_backend_compiled_forward_path - - -class Runner: - def __init__(self, module: GraphModule, inputs: List[torch.Tensor]): - if not module: - raise Exception("No module provided") - - self._module: ReferenceType[GraphModule] = ref(module) - self._inputs: List[torch.Tensor] = inputs - self.compiled_forward_function: Optional[Callable[[torch.Tensor], tuple]] = None - self.lock = th.Lock() - self.child_thread = th.Thread(target=self.remote_compilation_starter) - - self.serialized_model_dir: Optional[TemporaryDirectory] = None - self.serialized_model_path: Optional[str] = None - self.serialized_input_path: Optional[str] = None - - try: - self.child_thread.start() - except Exception as e: - logging.getLogger(__name__).exception(f"Failed to start compilation thread\n{e}") - - @property - def module(self) -> Optional[GraphModule]: - return self._module() - - @module.deleter - def module(self): - self._module().graph.owning_module = None - self._module = None - - @property - def inputs(self) -> List[torch.Tensor]: - return self._inputs - - @inputs.deleter - def inputs(self): - self._inputs = None - - def _serialize_model_and_inputs(self): - self.serialized_model_dir = TemporaryDirectory() # pylint: disable=consider-using-with - self.serialized_model_path = os.path.join(self.serialized_model_dir.name, settings.CENTML_SERIALIZED_MODEL_FILE) - self.serialized_input_path = os.path.join(self.serialized_model_dir.name, settings.CENTML_SERIALIZED_INPUT_FILE) - - # torch.save saves a zip file full of pickled files with the model's states. - try: - torch.save(self.module, self.serialized_model_path, pickle_protocol=settings.CENTML_PICKLE_PROTOCOL) - torch.save(self.inputs, self.serialized_input_path, pickle_protocol=settings.CENTML_PICKLE_PROTOCOL) - except Exception as e: - raise Exception(f"Failed to save module or inputs with torch.save: {e}") from e - - def _get_model_id(self) -> str: - if not self.serialized_model_path or not os.path.isfile(self.serialized_model_path): - raise Exception(f"Model not saved at path {self.serialized_model_path}") - - sha_hash = hashlib.sha256() - with open(self.serialized_model_path, "rb") as serialized_model_file: - # Read in chunks to not load too much into memory - for block in iter(lambda: serialized_model_file.read(settings.CENTML_HASH_CHUNK_SIZE), b""): - sha_hash.update(block) - - model_id = sha_hash.hexdigest() - logging.info(f"Model has id {model_id}") - return model_id - - def _download_model(self, model_id: str): - download_response = requests.get( - url=f"{settings.CENTML_SERVER_URL}/download/{model_id}", timeout=settings.CENTML_COMPILER_TIMEOUT - ) - if download_response.status_code != HTTPStatus.OK: - raise Exception( - f"Download: request failed, exception from server:\n{download_response.json().get('detail')}" - ) - if download_response.content == b"": - raise Exception("Download: empty response from server") - download_path = get_backend_compiled_forward_path(model_id) - with open(download_path, "wb") as f: - f.write(download_response.content) - return torch.load(download_path) - - def _compile_model(self, model_id: str): - # The model should have been saved using torch.save when we found the model_id - if not self.serialized_model_path or not self.serialized_input_path: - raise Exception("Model or inputs not serialized") - if not os.path.isfile(self.serialized_model_path): - raise Exception(f"Model not saved at path {self.serialized_model_path}") - if not os.path.isfile(self.serialized_input_path): - raise Exception(f"Inputs not saved at path {self.serialized_input_path}") - - with open(self.serialized_model_path, 'rb') as model_file, open(self.serialized_input_path, 'rb') as input_file: - compile_response = requests.post( - url=f"{settings.CENTML_SERVER_URL}/submit/{model_id}", - files={"model": model_file, "inputs": input_file}, - timeout=settings.CENTML_COMPILER_TIMEOUT, - ) - if compile_response.status_code != HTTPStatus.OK: - raise Exception( - f"Compile model: request failed, exception from server:\n{compile_response.json().get('detail')}\n" - ) - - def _wait_for_status(self, model_id: str) -> bool: - tries = 0 - while True: - # get server compilation status - status = None - try: - status_response = requests.get( - f"{settings.CENTML_SERVER_URL}/status/{model_id}", timeout=settings.CENTML_COMPILER_TIMEOUT - ) - if status_response.status_code != HTTPStatus.OK: - raise Exception( - f"Status check: request failed, exception from server:\n{status_response.json().get('detail')}" - ) - status = status_response.json().get("status") - except Exception as e: - logging.getLogger(__name__).exception(f"Status check failed:\n{e}") - - if status == CompilationStatus.DONE.value: - return True - elif status == CompilationStatus.COMPILING.value: - pass - elif status == CompilationStatus.NOT_FOUND.value: - logging.info("Submitting model to server for compilation.") - try: - self._compile_model(model_id) - except Exception as e: - logging.getLogger(__name__).exception(f"Submitting compilation failed:\n{e}") - tries += 1 - else: - tries += 1 - - if tries > settings.CENTML_COMPILER_MAX_RETRIES: - raise Exception("Waiting for status: compilation failed too many times.\n") - - time.sleep(settings.CENTML_COMPILER_SLEEP_TIME) - - def remote_compilation_starter(self): - try: - self.remote_compilation() - except Exception as e: - logging.getLogger(__name__).exception(f"Compilation thread failed:\n{e}") - - def remote_compilation(self): - self._serialize_model_and_inputs() - - model_id = self._get_model_id() - - # check if compiled forward is saved locally - compiled_forward_path = get_backend_compiled_forward_path(model_id) - if os.path.isfile(compiled_forward_path): - logging.info("Compiled model found in local cache. Not submitting to server.") - compiled_forward = torch.load(compiled_forward_path) - else: - self._wait_for_status(model_id) - compiled_forward = self._download_model(model_id) - - self.compiled_forward_function = compiled_forward - - logging.info("Compilation successful.") - - # Let garbage collector free the memory used by the uncompiled model - with self.lock: - del self.inputs - if self.module: - del self.module - gc.collect() - torch.cuda.empty_cache() - - def __call__(self, *args, **kwargs): - # If model is currently compiling, return the uncompiled forward function - with self.lock: - if not self.compiled_forward_function: - return self.module.forward(*args, **kwargs) - - return self.compiled_forward_function(*args) - - -def centml_dynamo_backend(gm: GraphModule, example_inputs: List[torch.Tensor]): - return Runner(gm, example_inputs) diff --git a/centml/compiler/config.py b/centml/compiler/config.py deleted file mode 100644 index 9c11c53..0000000 --- a/centml/compiler/config.py +++ /dev/null @@ -1,45 +0,0 @@ -import os -from enum import Enum -from pydantic_settings import BaseSettings - - -class CompilationStatus(Enum): - NOT_FOUND = "NOT_FOUND" - COMPILING = "COMPILING" - DONE = "DONE" - - -class OperationMode(Enum): - PREDICTION = "PREDICTION" - REMOTE_COMPILATION = "REMOTE_COMPILATION" - - -class Config(BaseSettings): - CENTML_COMPILER_TIMEOUT: int = 10 - CENTML_COMPILER_MAX_RETRIES: int = 3 - CENTML_COMPILER_SLEEP_TIME: int = 15 - - CENTML_BASE_CACHE_DIR: str = os.path.expanduser("~/.cache/centml") - CENTML_BACKEND_BASE_PATH: str = os.path.join(CENTML_BASE_CACHE_DIR, "backend") - CENTML_SERVER_BASE_PATH: str = os.path.join(CENTML_BASE_CACHE_DIR, "server") - - CENTML_SERVER_URL: str = "http://0.0.0.0:8090" - - # Use a constant path since torch.save uses the given file name in it's zipfile. - # Using a different filename would result in a different hash. - CENTML_SERIALIZED_MODEL_FILE: str = "serialized_model.zip" - CENTML_SERIALIZED_INPUT_FILE: str = "serialized_input.zip" - CENTML_PICKLE_PROTOCOL: int = 4 - - CENTML_HASH_CHUNK_SIZE: int = 4096 - - # If the server response is smaller than this, don't gzip it - CENTML_MINIMUM_GZIP_SIZE: int = 1000 - - CENTML_MODE: OperationMode = OperationMode.REMOTE_COMPILATION - CENTML_PREDICTION_DATA_FILE: str = 'tests/sample_data.csv' - CENTML_PREDICTION_GPUS: str = "A10G,A100SXM440GB,L4,H10080GBHBM3" - CENTML_PROMETHEUS_PORT: int = 8000 - - -settings = Config() diff --git a/centml/compiler/main.py b/centml/compiler/main.py deleted file mode 100644 index d3efee1..0000000 --- a/centml/compiler/main.py +++ /dev/null @@ -1,57 +0,0 @@ -import builtins -from typing import Callable, Dict, Optional, Union - -from centml.compiler.config import OperationMode, settings - - -def compile( - model: Optional[Callable] = None, - *, - fullgraph: builtins.bool = False, - dynamic: Optional[builtins.bool] = None, - mode: Union[str, None] = None, - options: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None, - disable: builtins.bool = False, -) -> Callable: - import torch - - if settings.CENTML_MODE == OperationMode.REMOTE_COMPILATION: - from centml.compiler.backend import centml_dynamo_backend - - # Return the remote-compiled model - compiled_model = torch.compile( - model, - backend=centml_dynamo_backend, # Compilation backend - fullgraph=fullgraph, - dynamic=dynamic, - mode=mode, - options=options, - disable=disable, - ) - return compiled_model - elif settings.CENTML_MODE == OperationMode.PREDICTION: - from centml.compiler.prediction.backend import centml_prediction_backend, get_gauge - - # Proceed with prediction workflow - compiled_model = torch.compile( - model, - backend=centml_prediction_backend, # Prediction backend - fullgraph=fullgraph, - dynamic=dynamic, - mode=mode, - options=options, - disable=disable, - ) - - def centml_wrapper(*args, **kwargs): - out = compiled_model(*args, **kwargs) - # Update the prometheus metrics with final values - gauge = get_gauge() - for gpu in settings.CENTML_PREDICTION_GPUS.split(','): - gauge.set_metric_value(gpu) - - return out - - return centml_wrapper - else: - raise Exception("Invalid operation mode") diff --git a/centml/compiler/prediction/backend.py b/centml/compiler/prediction/backend.py deleted file mode 100644 index 6610812..0000000 --- a/centml/compiler/prediction/backend.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import List - -import torch -from torch._subclasses.fake_tensor import FakeTensorMode - -from centml.compiler.config import settings -from centml.compiler.prediction.kdtree import get_tree_db -from centml.compiler.prediction.metric import get_gauge -from centml.compiler.prediction.profiler import Profiler - - -def centml_prediction_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - profilers = [] - tree_db = get_tree_db() - for gpu in settings.CENTML_PREDICTION_GPUS.split(','): - profilers.append(Profiler(gm, gpu, tree_db)) - - def forward(*args): - fake_mode = FakeTensorMode(allow_non_fake_inputs=True) - fake_args = [fake_mode.from_tensor(arg) if isinstance(arg, torch.Tensor) else arg for arg in args] - with fake_mode: - for prof in profilers: - out, t = prof.propagate(*fake_args) - gauge = get_gauge() - gauge.increment(prof.gpu, t) - return out - - return forward diff --git a/centml/compiler/prediction/kdtree.py b/centml/compiler/prediction/kdtree.py deleted file mode 100644 index f42a63e..0000000 --- a/centml/compiler/prediction/kdtree.py +++ /dev/null @@ -1,69 +0,0 @@ -import ast -import csv -import logging - -from sklearn.neighbors import KDTree # type: ignore - -from centml.compiler.config import settings - -_tree_db = None - - -class KDTreeWithValues: - def __init__(self, points=None, values=None): - self.points = points if points else [] - self.values = values if values else [] - if self.points: - self.tree = KDTree(self.points) - else: - self.tree = None - - def add(self, point, value): - self.points.append(point) - self.values.append(value) - self.tree = KDTree(self.points) - - def query(self, point): - if self.tree is None: - return None, None - - dist, idx = self.tree.query([point], k=1) - return dist[0][0], self.values[idx[0][0]] - - -class TreeDB: - def __init__(self, data_csv): - self.db = {} - self._populate_db(data_csv) - - def get(self, key, inp): - if key not in self.db: - logging.getLogger(__name__).warning(f"Key {key} not found in database") - return float('-inf') - # TODO: Handle the case of unfound keys better. For now, return -inf to indicate something went wrong. - # Ideally, we shouldn't throw away a whole prediction because of one possibly insignificant node. - - _, val = self.db[key].query(inp) - return val - - def _add_from_db(self, key, points, values): - self.db[key] = KDTreeWithValues(points, values) - - def _populate_db(self, data_csv): - with open(data_csv, newline='') as f: - reader = csv.DictReader(f) - for row in reader: - try: - key = (row['op'], int(row['dim']), row['inp_dtypes'], row['out_dtypes'], row['gpu']) - points = ast.literal_eval(row['points']) - values = ast.literal_eval(row['values']) - self._add_from_db(key, points, values) - except ValueError as e: - logging.getLogger(__name__).exception(f"Error parsing row: {row}\n{e}") - - -def get_tree_db(): - global _tree_db - if _tree_db is None: - _tree_db = TreeDB(settings.CENTML_PREDICTION_DATA_FILE) - return _tree_db diff --git a/centml/compiler/prediction/metric.py b/centml/compiler/prediction/metric.py deleted file mode 100644 index 8736136..0000000 --- a/centml/compiler/prediction/metric.py +++ /dev/null @@ -1,30 +0,0 @@ -import time - -from prometheus_client import Gauge, start_http_server - -from centml.compiler.config import settings - -_gauge = None - - -def get_gauge(): - global _gauge - if _gauge is None: - _gauge = GaugeMetric() - return _gauge - - -class GaugeMetric: - def __init__(self): - start_http_server(settings.CENTML_PROMETHEUS_PORT) - self._gauge = Gauge('execution_time_microseconds', 'Kernel execution times by GPU', ['gpu', 'timestamp']) - self._values = {} - - def increment(self, gpu_name, value): - if gpu_name not in self._values: - self._values[gpu_name] = 0 - self._values[gpu_name] += value - - def set_metric_value(self, gpu_name): - self._gauge.labels(gpu=gpu_name, timestamp=time.time()).set(self._values[gpu_name]) - self._values[gpu_name] = 0 diff --git a/centml/compiler/prediction/profiler.py b/centml/compiler/prediction/profiler.py deleted file mode 100644 index 9325e7a..0000000 --- a/centml/compiler/prediction/profiler.py +++ /dev/null @@ -1,189 +0,0 @@ -from typing import Dict - -import torch -import torch.fx -from torch.fx.node import Node - -from scripts.timer import timed - - -class Profiler: - def __init__(self, mod, gpu, treeDB, data_collection_mode=False): - self.mod = mod - self.graph = mod.graph - self.modules = dict(self.mod.named_modules()) - self.tree_db = treeDB - self.gpu = gpu - self.data_collection_mode = data_collection_mode - self.trace_event_idx = 0 - - def propagate(self, *args): - args_iter = iter(args) - env: Dict[str, Node] = {} - total_gpu_time = 0 - actual_time = 0 - trace_events = [] - if self.data_collection_mode: - # Warmup before profiling - for _ in range(10): - _, t = timed(lambda: self.mod(*args)) - - # actual_time is to compare prediction to execution time of GraphModule - actual_time = t - - with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: - self.mod(*args) - for event in prof.events(): - # Ignore CPU events for now - if event.trace_name is None or event.device_type == torch.autograd.DeviceType.CPU: - continue - # Create a mapping of kernel execution times to the corresponding trace events - trace_events.append(event.time_range.elapsed_us()) - - def load_arg(a): - return torch.fx.graph.map_arg(a, lambda n: env[n.name]) - - def fetch_attr(target: str): - target_atoms = target.split('.') - attr_itr = self.mod - for i, atom in enumerate(target_atoms): - if not hasattr(attr_itr, atom): - raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}") - attr_itr = getattr(attr_itr, atom) - return attr_itr - - def get_flattened_shapes(args): - flattened_shapes = [] - dtypes = [] - - for arg in args: - if isinstance(arg, (tuple, list)): - if len(arg) > 0 and isinstance(arg[0], (tuple, list, torch.Tensor)): - nested_shapes, nested_dtypes = get_flattened_shapes(arg[0]) - shape = [len(arg)] + nested_shapes - dtypes.extend(nested_dtypes.split(',')) - else: - shape = [len(arg)] - elif isinstance(arg, torch.Tensor): - shape = list(arg.shape) - dtypes.append(str(arg.dtype)) - elif isinstance(arg, bool): - shape = [1 if arg is True else 0] - elif isinstance(arg, (int, float)): - shape = [arg] - else: - shape = [1] - flattened_shapes.extend(shape) - - if len(flattened_shapes) < 2: - flattened_shapes.extend([1]) - - input_dtypes = ','.join(dtypes) if dtypes else 'N/A' - - return flattened_shapes, input_dtypes - - def get_output_dtypes(results): - def find_dtypes(results): - if isinstance(results, torch.Tensor): - return [str(results.dtype)] - if isinstance(results, (list, tuple)): - dtypes = [] - for item in results: - dtypes.extend(find_dtypes(item)) - return dtypes - return [] - - types = find_dtypes(results) - - if types: - return ','.join(types) - return 'N/A' - - def get_time_or_profile(key, inp_shapes, operation, *args, **kwargs): - t = self.tree_db.get(key, inp_shapes) - - if self.data_collection_mode: - with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: - operation(*args, **kwargs) - - if t is None: - # New key - event_time_total = 0 - for event in prof.events(): - if event.trace_name is None or event.device_type == torch.autograd.DeviceType.CPU: - continue - event_time_total += trace_events[self.trace_event_idx] - self.trace_event_idx += 1 - t = event_time_total - self.tree_db.add(key, inp_shapes, t) - else: - # Existing key, increment trace_event_idx by # of events to maintain mapping to trace_events list - for event in prof.events(): - if event.trace_name is None or event.device_type == torch.autograd.DeviceType.CPU: - continue - self.trace_event_idx += 1 - - return t - - for node in self.graph.nodes: - result = None - if node.op == 'placeholder': - result = next(args_iter) - elif node.op == 'get_attr': - result = fetch_attr(node.target) - elif node.op == 'call_function': - args = load_arg(node.args) - kwargs = load_arg(node.kwargs) - result = node.target(*args, **kwargs) - - inp_shapes, input_dtypes = get_flattened_shapes(args) - output_dtypes = get_output_dtypes(result) - - key = (node.target.__name__, len(inp_shapes), input_dtypes, output_dtypes, self.gpu) - - t = get_time_or_profile(key, inp_shapes, node.target, *args, **kwargs) - - total_gpu_time += t - elif node.op == 'call_method': - self_obj, *args = load_arg(node.args) - kwargs = load_arg(node.kwargs) - result = getattr(self_obj, node.target)(*args, **kwargs) - - inp_shapes, input_dtypes = get_flattened_shapes(args) - output_dtypes = get_output_dtypes(result) - - key = (node.target, len(inp_shapes), input_dtypes, output_dtypes, self.gpu) - - t = get_time_or_profile(key, inp_shapes, getattr(self_obj, node.target), *args, **kwargs) - - total_gpu_time += t - elif node.op == 'call_module': - mod = self.modules[node.target] - args = load_arg(node.args) - kwargs = load_arg(node.kwargs) - result = mod(*args, **kwargs) - - inp_shapes, input_dtypes = get_flattened_shapes(args) - - param_shapes = [param.shape for name, param in mod.named_parameters()] - param_dtypes = [str(param.dtype) for name, param in mod.named_parameters()] - flattened_params = [dim for shape in param_shapes for dim in shape] - - inp_shapes = inp_shapes + flattened_params - input_dtypes = input_dtypes + ',' + ','.join(param_dtypes) - - output_dtypes = get_output_dtypes(result) - - key = (mod._get_name(), len(inp_shapes), input_dtypes, output_dtypes, self.gpu) - - t = get_time_or_profile(key, inp_shapes, mod, *args, **kwargs) - - total_gpu_time += t - elif node.op == 'output': - args = load_arg(node.args) - if self.data_collection_mode: - # Return full graph execution time as well for accuracy comparison - return args[0], total_gpu_time, actual_time - return args[0], total_gpu_time - - env[node.name] = result diff --git a/centml/compiler/server.py b/centml/compiler/server.py deleted file mode 100644 index bed7435..0000000 --- a/centml/compiler/server.py +++ /dev/null @@ -1,118 +0,0 @@ -import io -import os -from http import HTTPStatus -from urllib.parse import urlparse -import logging -import uvicorn -import torch -from fastapi import FastAPI, UploadFile, HTTPException, BackgroundTasks, Response -from fastapi.responses import FileResponse -from fastapi.middleware.gzip import GZipMiddleware -from centml.compiler.server_compilation import hidet_backend_server -from centml.compiler.utils import dir_cleanup -from centml.compiler.config import settings, CompilationStatus -from centml.compiler.utils import get_server_compiled_forward_path - -app = FastAPI() -app.add_middleware(GZipMiddleware, minimum_size=settings.CENTML_MINIMUM_GZIP_SIZE) # type: ignore - - -def get_status(model_id: str): - if not os.path.isdir(os.path.join(settings.CENTML_SERVER_BASE_PATH, model_id)): - return CompilationStatus.NOT_FOUND - - if not os.path.isfile(get_server_compiled_forward_path(model_id)): - return CompilationStatus.COMPILING - - return CompilationStatus.DONE - - -@app.get("/status/{model_id}") -async def status_handler(model_id: str): - status = get_status(model_id) - if status: - return {"status": status} - else: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail="Status check: invalid status state.") - - -def background_compile(model_id: str, tfx_graph, example_inputs): - try: - compiled_graph_module = hidet_backend_server(tfx_graph, example_inputs) - except Exception as e: - logging.getLogger(__name__).exception(f"Compilation: error compiling model. {e}") - dir_cleanup(model_id) - return - - try: - # torch.save's writing is not atomic; it creates an empty zip file then saves the data in multiple calls. - # We don't want this incomplete zipfile to be mistaken for the serialized forward function by /status/. - # To avoid this, we write to a tmp file and rename it to the correct path after saving. - save_path = get_server_compiled_forward_path(model_id) - tmp_path = save_path + ".tmp" - torch.save(compiled_graph_module, tmp_path, pickle_protocol=settings.CENTML_PICKLE_PROTOCOL) - os.rename(tmp_path, save_path) - except Exception as e: - logging.getLogger(__name__).exception(f"Saving graph module failed: {e}") - dir_cleanup(model_id) - - -def read_upload_files(model_id: str, model: UploadFile, inputs: UploadFile): - try: - tfx_contents = io.BytesIO(model.file.read()) - ei_contents = io.BytesIO(inputs.file.read()) - except Exception as e: - dir_cleanup(model_id) - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, detail=f"Compilation: error reading serialized content: {e}" - ) from e - finally: - model.file.close() - inputs.file.close() - - try: - tfx_graph = torch.load(tfx_contents) - example_inputs = torch.load(ei_contents) - except Exception as e: - dir_cleanup(model_id) - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, detail=f"Compilation: error loading content with torch.load: {e}" - ) from e - - return tfx_graph, example_inputs - - -@app.post("/submit/{model_id}") -async def compile_model_handler(model_id: str, model: UploadFile, inputs: UploadFile, background_task: BackgroundTasks): - status = get_status(model_id) - if status is None: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail="Compilation: error checking status.") - - # Only compile if the model is not compiled or compiling - if status != CompilationStatus.NOT_FOUND: - return Response(status_code=200) - - # This effectively sets the model's status to COMPILING - os.makedirs(os.path.join(settings.CENTML_SERVER_BASE_PATH, model_id)) - - tfx_graph, example_inputs = read_upload_files(model_id, model, inputs) - - # perform the compilation in the background and return HTTP.OK to client - background_task.add_task(background_compile, model_id, tfx_graph, example_inputs) - - -@app.get("/download/{model_id}") -async def download_handler(model_id: str): - compiled_forward_path = get_server_compiled_forward_path(model_id) - if not os.path.isfile(compiled_forward_path): - raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Download: compiled file not found") - return FileResponse(compiled_forward_path) - - -def run(): - parsed = urlparse(settings.CENTML_SERVER_URL) - uvicorn.run(app, host=parsed.hostname, port=parsed.port) - - -if __name__ == "__main__": - run() diff --git a/centml/compiler/server_compilation.py b/centml/compiler/server_compilation.py deleted file mode 100644 index 0ab9dee..0000000 --- a/centml/compiler/server_compilation.py +++ /dev/null @@ -1,51 +0,0 @@ -from enum import Enum -from typing import List -import torch -from torch.fx import GraphModule -from hidet.graph.frontend import from_torch -from hidet.graph.frontend.torch.interpreter import Interpreter -from hidet.graph.frontend.torch.dynamo_backends import ( - get_flow_graph, - get_compiled_graph, - preprocess_inputs, - HidetCompiledModel, -) - - -class CompilerType(Enum): - HIDET = "hidet" - - -class BaseRCReturn: - def __init__(self, compiler_type: CompilerType): - self.compiler_type = compiler_type - - # Implement in child class - def __call__(self, *args, **kwargs): - raise NotImplementedError - - -class HidetRCReturn(BaseRCReturn): - def __init__(self, hidet_compiled_model): - super().__init__(CompilerType.HIDET) - self.compiled_model_forward = hidet_compiled_model - - def __call__(self, *args, **kwargs): - return self.compiled_model_forward(*args) - - -def hidet_backend_server(input_graph_module: GraphModule, example_inputs: List[torch.Tensor]): - assert isinstance(input_graph_module, GraphModule) - - # Create hidet compiled graph - interpreter: Interpreter = from_torch(input_graph_module) - flow_graph, _, output_format = get_flow_graph(interpreter, example_inputs) - cgraph = get_compiled_graph(flow_graph) - - # Perform inference using example inputs to get dispatch table - hidet_inputs = preprocess_inputs(example_inputs) - cgraph.run_async(hidet_inputs) - - # Get compiled forward function - compiled_forward_function = HidetCompiledModel(cgraph, hidet_inputs, output_format) - return HidetRCReturn(compiled_forward_function) diff --git a/centml/compiler/utils.py b/centml/compiler/utils.py deleted file mode 100644 index ec755b2..0000000 --- a/centml/compiler/utils.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -import shutil -from centml.compiler.config import settings - - -def get_backend_compiled_forward_path(model_id: str): - os.makedirs(os.path.join(settings.CENTML_BACKEND_BASE_PATH, model_id), exist_ok=True) - return os.path.join(settings.CENTML_BACKEND_BASE_PATH, model_id, "compilation_return.pkl") - - -def get_server_compiled_forward_path(model_id: str): - os.makedirs(os.path.join(settings.CENTML_SERVER_BASE_PATH, model_id), exist_ok=True) - return os.path.join(settings.CENTML_SERVER_BASE_PATH, model_id, "compilation_return.pkl") - - -# This function will delete the storage_path/{model_id} directory -def dir_cleanup(model_id: str): - dir_path = os.path.join(settings.CENTML_SERVER_BASE_PATH, model_id) - if not os.path.exists(dir_path): - return # Directory does not exist, return - - if not os.path.isdir(dir_path): - raise Exception(f"'{dir_path}' is not a directory") - - try: - shutil.rmtree(dir_path) - except Exception as e: - raise Exception("Failed to delete the directory") from e diff --git a/mypy.ini b/mypy.ini index 0250c92..209ecab 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,9 +2,6 @@ warn_return_any = True warn_unused_configs = True -[mypy-hidet.*] -ignore_missing_imports = True - [mypy-transformers.*] ignore_missing_imports = True diff --git a/requirements.txt b/requirements.txt index 6b5e5cd..a1b91bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,6 @@ -fastapi>=0.103.0 -uvicorn>=0.23.0 -python-multipart>=0.0.6 pydantic-settings>=2.0,<3 Requests==2.32.4 tabulate>=0.9.0 pyjwt>=2.8.0 cryptography==44.0.1 -prometheus-client>=0.20.0 -scipy>=1.6.0 -scikit-learn>=1.5.1 platform-api-python-client==4.8.4 diff --git a/scripts/data_collection.py b/scripts/data_collection.py deleted file mode 100644 index 4acdf87..0000000 --- a/scripts/data_collection.py +++ /dev/null @@ -1,265 +0,0 @@ -import csv -import gc -import json - -import torch -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - AutoModelForImageClassification, - AutoModelForObjectDetection, -) - - -from centml.compiler.prediction.kdtree import KDTreeWithValues -from centml.compiler.prediction.profiler import Profiler -from scripts.timer import timed - -torch.set_float32_matmul_precision('high') -torch.set_default_device('cuda') -torch.set_default_dtype(torch.float16) - -CURR_GPU = "A10G" -OUTPUT_FILE = 'data.csv' - -# Different HuggingFace Models + Different Input Sizes -llm_tests = [ - ("google/gemma-7b", (1, 128)), - ("microsoft/phi-2", (1, 512)), - ("microsoft/phi-2", (2, 512)), - ("facebook/bart-large", (1, 1024)), - ("facebook/bart-large", (2, 512)), - ("gpt2-xl", (1, 1024)), - ("gpt2-xl", (1, 720)), - ("gpt2-xl", (1, 512)), - ("gpt2-xl", (2, 512)), - ("gpt2-xl", (4, 256)), - ("EleutherAI/gpt-neo-2.7B", (1, 512)), - ("EleutherAI/gpt-neo-2.7B", (1, 256)), - ("gpt2-large", (1, 1024)), - ("gpt2-large", (1, 720)), - ("gpt2-large", (1, 512)), - ("google-bert/bert-large-uncased", (8, 512)), - ("google-bert/bert-large-uncased", (16, 512)), - ("meta-llama/Meta-Llama-3.1-8B", (1, 256)), - ("gpt2-medium", (1, 1024)), - ("gpt2-medium", (1, 512)), - ("gpt2-medium", (2, 512)), - ("google/pegasus-cnn_dailymail", (1, 1024)), - ("google/pegasus-cnn_dailymail", (1, 512)), - ("google/pegasus-cnn_dailymail", (2, 512)), -] - -# Tests for larger GPUs (A100, H100, etc.) -# large_llm_tests = [ -# ("google/gemma-7b", (1, 256)), -# ("google/gemma-7b", (1, 512)), -# ("google/gemma-7b", (1, 1024)), -# ("microsoft/phi-2", (1,1024)), -# ("microsoft/phi-2", (1,2048)), -# ("microsoft/phi-2", (2,1024)), -# ("EleutherAI/gpt-neo-2.7B", (1, 1024)), -# ("gpt2-xl", (2, 1024)), -# ("gpt2-xl", (4, 512)), -# ("meta-llama/Meta-Llama-3.1-8B", (1, 1024)), -# ("meta-llama/Meta-Llama-3.1-8B", (1, 512)), -# ("google/pegasus-cnn_dailymail", (4, 1024)), -# ("facebook/bart-large", (4, 1024)), -# ("facebook/bart-large", (2, 1024)), -# ("google-bert/bert-large-uncased", (16, 512)), -# ("gpt2-medium", (2, 1024)), -# ("gpt2-medium", (4, 512)), -# ("gpt2-large", (2, 1024)), -# ("gpt2-large", (4, 512)), -# ] - -# Different Batch Sizes for each image classification model -image_classification_tests = [ - ("google/efficientnet-b0", 512), - ("google/efficientnet-b0", 256), - ("google/efficientnet-b0", 128), - ("google/vit-base-patch16-224", 128), - ("microsoft/resnet-50", 256), - ("microsoft/resnet-50", 512), -] - -# Different Batch Sizes for each object detection model -object_detection_tests = [ - ("hustvl/yolos-tiny", 128), - ("hustvl/yolos-tiny", 256), - ("hustvl/yolos-tiny", 512), - ("facebook/detr-resnet-50", 128), - ("facebook/detr-resnet-50", 256), -] - - -def percent_error(observed, true): - return abs((observed - true) / true) * 100 - - -class DataCollectionTreeDB: - def __init__(self): - self.db = {} - - def add(self, key, point, time): - if key not in self.db: - self.db[key] = KDTreeWithValues() - - self.db[key].add(point, time) - - def get(self, key, inp): - if key not in self.db: - # print("New Key") - return None - - dist, val = self.db[key].query(inp) - - if dist > 0: - # print("Distance too large ", dist) - return None - return val - - -db = DataCollectionTreeDB() -cuda_kernel_time = 0 -actual_time = 0 - - -def custom_backend(gm: torch.fx.GraphModule, inps): - print("Compiling") - profiler = Profiler(mod=gm, gpu=CURR_GPU, treeDB=db, data_collection_mode=True) - - def forward(*args): - global cuda_kernel_time - global actual_time - out, t, actual_t = profiler.propagate(*args) - cuda_kernel_time += t - actual_time += actual_t - return out - - return forward - - -def llm_test(model_name, input_size, custom_backend): - global cuda_kernel_time - global actual_time - models_without_tokenizer = {"google/pegasus-cnn_dailymail"} - - model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda:0") - if model_name not in models_without_tokenizer: - tokenizer = AutoTokenizer.from_pretrained(model_name) - model.eval() - - inp = torch.randint( - low=0, - high=tokenizer.vocab_size if model_name not in models_without_tokenizer else 50265, - size=input_size, - dtype=torch.int64, - device='cuda:0', - ) - - with torch.inference_mode(): - for _ in range(10): - _, t = timed(lambda: model(inp)) - print(t) - - compiled_model = torch.compile(model, backend=custom_backend) - compiled_model(inp) - - cuda_kernel_time /= 1000000 - - print(f"{model_name}, {input_size}") - print("Real time: ", actual_time) - print("Kernel execution time: ", cuda_kernel_time) - print("Error: ", percent_error(cuda_kernel_time, actual_time)) - - cuda_kernel_time = 0 - actual_time = 0 - del model, inp, compiled_model - gc.collect() - torch.cuda.empty_cache() - - -def image_classification_test(model_name, batch_size, custom_backend): - global cuda_kernel_time - global actual_time - model = AutoModelForImageClassification.from_pretrained(model_name).to("cuda:0") - model.eval() - if model_name == "google/vit-base-patch16-224": - inp = torch.randn(batch_size, 3, 224, 224).cuda(0) - else: - inp = torch.randn(batch_size, 3, 128, 128).cuda(0) - - with torch.inference_mode(): - for _ in range(10): - _, t = timed(lambda: model(inp)) - print(t) - - compiled_model = torch.compile(model, backend=custom_backend) - compiled_model(inp) - - cuda_kernel_time /= 1000000 - - print(f"{model_name}, {batch_size}") - print("Real time: ", actual_time) - print("TOTAL TIME: ", cuda_kernel_time) - print("Error: ", percent_error(cuda_kernel_time, actual_time)) - - cuda_kernel_time = 0 - actual_time = 0 - del model, inp, compiled_model - gc.collect() - torch.cuda.empty_cache() - - -def object_detection_test(model_name, batch_size, custom_backend): - global cuda_kernel_time - global actual_time - model = AutoModelForObjectDetection.from_pretrained(model_name).to("cuda:0") - model.eval() - inp = torch.randn(batch_size, 3, 128, 128).cuda(0) - - with torch.inference_mode(): - for _ in range(10): - _, t = timed(lambda: model(inp)) - print(t) - - compiled_model = torch.compile(model, backend=custom_backend) - compiled_model(inp) - - cuda_kernel_time /= 1000000 - - print(f"{model_name}, {batch_size}") - print("Real time: ", actual_time) - print("TOTAL TIME: ", cuda_kernel_time) - print("Error: ", percent_error(cuda_kernel_time, actual_time)) - - cuda_kernel_time = 0 - actual_time = 0 - del model, inp, compiled_model - gc.collect() - torch.cuda.empty_cache() - - -# for model_name, input_size in large_llm_tests: -# llm_test(model_name, input_size, custom_backend) - -for model_name, input_size in llm_tests: - llm_test(model_name, input_size, custom_backend) - -for model_name, batch_size in object_detection_tests: - object_detection_test(model_name, batch_size, custom_backend) - -for model_name, batch_size in image_classification_tests: - image_classification_test(model_name, batch_size, custom_backend) - -# Write to CSV -with open(OUTPUT_FILE, 'w', newline='') as csvfile: - csvwriter = csv.writer(csvfile, quoting=csv.QUOTE_ALL) - csvwriter.writerow(['op', 'dim', 'inp_dtypes', 'out_dtypes', 'gpu', 'points', 'values']) - - for key, tree in db.db.items(): - op, dim, inp_dtypes, out_dtype, gpu = key - points_str = json.dumps(tree.points) - values_str = json.dumps(tree.values) - csvwriter.writerow([op, dim, inp_dtypes, out_dtype, gpu, points_str, values_str]) diff --git a/scripts/prediction_sample.py b/scripts/prediction_sample.py deleted file mode 100644 index 6333351..0000000 --- a/scripts/prediction_sample.py +++ /dev/null @@ -1,27 +0,0 @@ -import time - -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - -import centml.compile - -torch.set_default_device('cpu') -torch.set_float32_matmul_precision('high') -torch.set_default_dtype(torch.float16) - -model_name = "gpt2-xl" -model = AutoModelForCausalLM.from_pretrained(model_name) -tokenizer = AutoTokenizer.from_pretrained(model_name) -model.eval() - -inputs = torch.randint(low=0, high=tokenizer.vocab_size, size=(1, 1024), dtype=torch.int64, device='cpu') - -compiled_model = centml.compile(model) -output = compiled_model(inputs) - -inputs_2 = torch.randint(low=0, high=tokenizer.vocab_size, size=(1, 512), dtype=torch.int64, device='cpu') - -output = compiled_model(inputs_2) - -while True: - time.sleep(1) diff --git a/scripts/pylintrc b/scripts/pylintrc index 4927aef..4f041c3 100644 --- a/scripts/pylintrc +++ b/scripts/pylintrc @@ -430,10 +430,6 @@ disable=raw-checker-failed, too-many-nested-blocks, too-many-ancestors, # our ir functor may have many ancestors redefined-builtin, - unsupported-assignment-operation, # might happen in hidet script - unsubscriptable-object, # might happen in hidet script - cell-var-from-loop, # might happen in hidet script - invalid-unary-operand-type, # might happen in hidet script arguments-differ, # will happen when we over forward function consider-using-enumerate, no-self-use, diff --git a/tests/test_backend.py b/tests/test_backend.py deleted file mode 100644 index 8bd08ad..0000000 --- a/tests/test_backend.py +++ /dev/null @@ -1,260 +0,0 @@ -from http import HTTPStatus -from copy import deepcopy -from unittest import TestCase -from unittest.mock import patch, MagicMock -import torch -from parameterized import parameterized_class -from torch.fx import GraphModule -import centml -from centml.compiler.backend import Runner -from centml.compiler.config import CompilationStatus, settings -from .test_helpers import MODEL_SUITE - - -# Ensure remote_compilation is called in the same thread -def start_func(thread_self): - thread_self._target() - - -class SetUpGraphModule(TestCase): - @patch('threading.Thread.start', new=lambda x: None) - def setUp(self) -> None: - model = MagicMock(spec=GraphModule) - inputs = [torch.tensor([1.0])] - self.runner = Runner(model, inputs) - - -@parameterized_class(list(MODEL_SUITE.values())) -class TestGetModelId(SetUpGraphModule): - # Reset the dynamo cache to force recompilation - def tearDown(self) -> None: - torch._dynamo.reset() - - @patch("centml.compiler.backend.os.path.isfile", new=lambda x: False) - def test_no_serialized_model(self): - with self.assertRaises(Exception) as context: - self.runner._get_model_id() - - self.assertIn("Model not saved at path", str(context.exception)) - - # Given the same model graph, the model id should be the same - # Grab the model_id's passed to get_backend_compiled_forward_path - @patch("threading.Thread.start", new=start_func) - @patch("centml.compiler.backend.get_backend_compiled_forward_path", side_effect=Exception("Exiting early")) - def test_model_id_consistency(self, mock_get_path): - # self.model and self.inputs come from @parameterized_class - model_compiled_1 = centml.compile(self.model) - model_compiled_1(self.inputs) - hash_1 = mock_get_path.call_args[0][0] - torch._dynamo.reset() # Reset the dynamo cache to force recompilation - - model_compiled_2 = centml.compile(self.model) - model_compiled_2(self.inputs) - hash_2 = mock_get_path.call_args[0][0] - torch._dynamo.reset() - - self.assertEqual(hash_1, hash_2) - - # Given two different models, the model ids should be different - # We made the models different by adding 1 to the first value in some layer's - # Grab the model_id's passed to get_backend_compiled_forward_path - @patch("threading.Thread.start", new=start_func) - @patch("centml.compiler.backend.get_backend_compiled_forward_path", side_effect=Exception("Exiting early")) - def test_model_id_uniqueness(self, mock_get_path): - def get_modified_model(model): - modified = deepcopy(model) - state_dict = modified.state_dict() - some_layer = list(state_dict.values())[0] - some_layer.view(-1)[0] += 1 - return modified - - # self.model and self.inputs come from @parameterized_class - model_compiled_1 = centml.compile(self.model) - model_compiled_1(self.inputs) - hash_1 = mock_get_path.call_args[0][0] - - model_2 = get_modified_model(self.model) - model_compiled_2 = centml.compile(model_2) - model_compiled_2(self.inputs) - hash_2 = mock_get_path.call_args[0][0] - - self.assertNotEqual(hash_1, hash_2) - - -class TestDownloadModel(SetUpGraphModule): - @patch("os.makedirs") - @patch("centml.compiler.backend.requests") - def test_failed_download(self, mock_requests, mock_makedirs): - # Mock the response from the requests library - mock_response = MagicMock() - mock_response.status_code = HTTPStatus.NOT_FOUND - mock_requests.get.return_value = mock_response - - model_id = "download_fail" - with self.assertRaises(Exception) as context: - self.runner._download_model(model_id) - - mock_requests.get.assert_called_once() - self.assertIn("Download: request failed, exception from server", str(context.exception)) - mock_makedirs.assert_not_called() - - @patch("os.makedirs") - @patch("builtins.open") - @patch("centml.compiler.backend.torch.load") - @patch("centml.compiler.backend.requests") - def test_successful_download(self, mock_requests, mock_load, mock_open, mock_makedirs): - # Mock the response from the requests library - mock_response = MagicMock(spec=bytes) - mock_response.status_code = HTTPStatus.OK - mock_response.content = b"model_data" - mock_requests.get.return_value = mock_response - - # Call the _download_model function - model_id = "download_success" - self.runner._download_model(model_id) - - mock_requests.get.assert_called_once() - mock_load.assert_called_once() - mock_open.assert_called_once() - mock_makedirs.assert_called_once() - - -class TestWaitForStatus(SetUpGraphModule): - @patch("centml.compiler.config.settings.CENTML_COMPILER_SLEEP_TIME", new=0) - @patch("centml.compiler.backend.requests") - @patch("logging.Logger.exception") - def test_invalid_status(self, mock_logger, mock_requests): - mock_response = MagicMock() - mock_response.status_code = HTTPStatus.BAD_REQUEST - mock_requests.get.return_value = mock_response - - model_id = "invalid_status" - with self.assertRaises(Exception) as context: - self.runner._wait_for_status(model_id) - - mock_requests.get.assert_called() - assert mock_requests.get.call_count == settings.CENTML_COMPILER_MAX_RETRIES + 1 - assert len(mock_logger.call_args_list) == settings.CENTML_COMPILER_MAX_RETRIES + 1 - print(mock_logger.call_args_list) - assert mock_logger.call_args_list[0].startswith("Status check failed:") - assert "Waiting for status: compilation failed too many times.\n" == str(context.exception) - - @patch("centml.compiler.config.settings.CENTML_COMPILER_SLEEP_TIME", new=0) - @patch("centml.compiler.backend.requests") - @patch("logging.Logger.exception") - def test_exception_in_status(self, mock_logger, mock_requests): - exception_message = "Exiting early" - mock_requests.get.side_effect = Exception(exception_message) - - model_id = "exception_in_status" - with self.assertRaises(Exception) as context: - self.runner._wait_for_status(model_id) - - mock_requests.get.assert_called() - assert mock_requests.get.call_count == settings.CENTML_COMPILER_MAX_RETRIES + 1 - mock_logger.assert_called_with(f"Status check failed:\n{exception_message}") - assert str(context.exception) == "Waiting for status: compilation failed too many times.\n" - - @patch("centml.compiler.config.settings.CENTML_COMPILER_SLEEP_TIME", new=0) - @patch("centml.compiler.backend.Runner._compile_model") - @patch("centml.compiler.backend.requests") - def test_max_tries(self, mock_requests, mock_compile): - mock_response = MagicMock() - mock_response.status_code = HTTPStatus.OK - mock_response.json.return_value = {"status": CompilationStatus.NOT_FOUND.value} - mock_requests.get.return_value = mock_response - - model_id = "max_tries" - with self.assertRaises(Exception) as context: - self.runner._wait_for_status(model_id) - - self.assertEqual(mock_compile.call_count, settings.CENTML_COMPILER_MAX_RETRIES + 1) - self.assertIn("Waiting for status: compilation failed too many times", str(context.exception)) - - @patch("centml.compiler.config.settings.CENTML_COMPILER_SLEEP_TIME", new=0) - @patch("centml.compiler.backend.requests") - def test_wait_on_compilation(self, mock_requests): - # Mock the status check - COMPILATION_STEPS = 10 - mock_response = MagicMock() - mock_response.status_code = HTTPStatus.OK - mock_response.json.side_effect = [{"status": CompilationStatus.COMPILING.value}] * COMPILATION_STEPS + [ - {"status": CompilationStatus.DONE.value} - ] - mock_requests.get.return_value = mock_response - - model_id = "compilation_done" - # _wait_for_status should return True when compilation DONE - assert self.runner._wait_for_status(model_id) - - @patch("centml.compiler.config.settings.CENTML_COMPILER_SLEEP_TIME", new=0) - @patch("centml.compiler.backend.requests") - @patch("centml.compiler.backend.Runner._compile_model") - @patch("logging.Logger.exception") - def test_exception_in_compilation(self, mock_logger, mock_compile, mock_requests): - # Mock the status check - mock_response = MagicMock() - mock_response.status_code = HTTPStatus.OK - mock_response.json.return_value = {"status": CompilationStatus.NOT_FOUND.value} - mock_requests.get.return_value = mock_response - - # Mock the compile model function - exception_message = "Exiting early" - mock_compile.side_effect = Exception(exception_message) - - model_id = "exception_in_compilation" - with self.assertRaises(Exception) as context: - self.runner._wait_for_status(model_id) - - mock_requests.get.assert_called() - assert mock_requests.get.call_count == settings.CENTML_COMPILER_MAX_RETRIES + 1 - - mock_compile.assert_called() - assert mock_compile.call_count == settings.CENTML_COMPILER_MAX_RETRIES + 1 - - mock_logger.assert_called_with(f"Submitting compilation failed:\n{exception_message}") - assert str(context.exception) == "Waiting for status: compilation failed too many times.\n" - - @patch("centml.compiler.backend.requests") - def test_compilation_done(self, mock_requests): - mock_response = MagicMock() - mock_response.status_code = HTTPStatus.OK - mock_response.json.return_value = {"status": CompilationStatus.DONE.value} - mock_requests.get.return_value = mock_response - - model_id = "compilation_done" - # _wait_for_status should return True when compilation DONE - assert self.runner._wait_for_status(model_id) - - -@parameterized_class(list(MODEL_SUITE.values())) -class TestRemoteCompilation(TestCase): - def call_remote_compilation(self): - with patch("threading.Thread.start", new=start_func), patch( - "centml.compiler.backend.Runner.__call__", new=self.model.forward - ): - compiled_model = centml.compile(self.model) - compiled_model(self.inputs) - - torch._dynamo.reset() - - @patch("os.path.isfile", new=lambda x: True) - @patch("centml.compiler.backend.Runner._get_model_id", new=lambda x: "1234") - @patch("centml.compiler.backend.torch.load") - def test_compiled_cached(self, mock_load): - mock_load.return_value = MagicMock() - self.call_remote_compilation() - mock_load.assert_called_once() - - @patch('os.path.isfile', new=lambda x: False) - @patch("centml.compiler.backend.Runner._get_model_id", new=lambda x: "1234") - @patch('centml.compiler.backend.Runner._download_model') - @patch('centml.compiler.backend.Runner._wait_for_status') - def test_compiled_return_not_cached(self, mock_status, mock_download): - mock_status.return_value = True - mock_download.return_value = MagicMock() - - self.call_remote_compilation() - - mock_status.assert_called_once() - mock_download.assert_called_once() diff --git a/tests/test_helpers.py b/tests/test_helpers.py deleted file mode 100644 index 2bee524..0000000 --- a/tests/test_helpers.py +++ /dev/null @@ -1,25 +0,0 @@ -from io import BytesIO -import torch -from transformers import BertForPreTraining, AutoTokenizer - -MODEL_SUITE = { - "resnet18": { - "model": torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True, verbose=False).eval(), - "inputs": torch.zeros(1, 3, 224, 224), - }, - "bert-base-uncased": { - "model": BertForPreTraining.from_pretrained("bert-base-uncased", ignore_mismatched_sizes=True).eval(), - "inputs": AutoTokenizer.from_pretrained("bert-base-uncased")( - "Hello, my dog is cute", padding='max_length', return_tensors="pt" - )['input_ids'], - }, -} - - -def get_dummy_model_and_inputs(model_content, input_content): - model, inputs = BytesIO(), BytesIO() - torch.save(model_content, model) - torch.save(input_content, inputs) - model.seek(0) - inputs.seek(0) - return model, inputs diff --git a/tests/test_server.py b/tests/test_server.py deleted file mode 100644 index a484dd6..0000000 --- a/tests/test_server.py +++ /dev/null @@ -1,168 +0,0 @@ -from unittest import TestCase -from unittest.mock import MagicMock, patch -from http import HTTPStatus -import pytest -import torch -import hidet -from fastapi import UploadFile, HTTPException -from fastapi.testclient import TestClient -from parameterized import parameterized_class -from centml.compiler.server import app, background_compile, read_upload_files -from centml.compiler.config import CompilationStatus -from tests.test_helpers import MODEL_SUITE, get_dummy_model_and_inputs - -client = TestClient(app) - - -class TestStatusHandler(TestCase): - def test_empty_request(self): - response = client.get("/status/") - self.assertEqual(response.status_code, HTTPStatus.NOT_FOUND) - - @patch("os.path.isdir", new=lambda x: False) - def test_model_not_found(self): - model_id = "nonexistent_model" - response = client.get(f"/status/{model_id}") - self.assertEqual(response.status_code, HTTPStatus.OK) - self.assertEqual(response.json(), {"status": CompilationStatus.NOT_FOUND.value}) - - @patch("os.path.isfile", new=lambda x: False) - @patch("os.path.isdir", new=lambda x: True) - def test_model_compiling(self): - model_id = "compiling_model" - response = client.get(f"/status/{model_id}") - self.assertEqual(response.status_code, HTTPStatus.OK) - self.assertEqual(response.json(), {"status": CompilationStatus.COMPILING.value}) - - @patch("os.path.isfile", new=lambda x: True) - @patch("os.path.isdir", new=lambda x: True) - def test_model_done(self): - model_id = "completed_model" - response = client.get(f"/status/{model_id}") - self.assertEqual(response.status_code, HTTPStatus.OK) - self.assertEqual(response.json(), {"status": CompilationStatus.DONE.value}) - - -@parameterized_class(list(MODEL_SUITE.values())) -class TestBackgroundCompile(TestCase): - @pytest.mark.gpu - @patch("os.rename") - @patch("logging.Logger.exception") - @patch("centml.compiler.server.torch.save") - def test_successful_compilation(self, mock_save, mock_logger, mock_rename): - # For some reason there is a deadlock with parallel builds - hidet.option.parallel_build(False) - - # Get the graph_module and example inputs that would be passed to background compile - class MockRunner: - def __init__(self): - self.graph_module = None - self.example_inputs = None - - def __call__(self, module, inputs): - self.graph_module, self.example_inputs = module, inputs - return module.forward - - mock_init = MockRunner() - - # self.model and self.inputs come from @parameterized_class - model, inputs = self.model.cuda(), self.inputs.cuda() - model_compiled = torch.compile(model, backend=mock_init) - model_compiled(inputs) - - model_id = "successful_model" - background_compile(model_id, mock_init.graph_module, mock_init.example_inputs) - - mock_rename.assert_called_once() - mock_save.assert_called_once() - mock_logger.assert_not_called() - - -class TestReadUploadFiles(TestCase): - def test_mock_cant_read(self): - model_id = "file_cant_be_read" - - mock_file = MagicMock() - mock_file.file.read.side_effect = Exception("an exception occurred") - - with self.assertRaises(HTTPException) as excinfo: - read_upload_files(model_id, mock_file, mock_file) - - self.assertEqual(excinfo.exception.status_code, HTTPStatus.BAD_REQUEST) - self.assertIn("Compilation: error reading serialized content", str(excinfo.exception)) - - @patch("torch.load", side_effect=Exception("an exception occurred")) - def test_cant_load(self, mock_load): - model_id = "file_cant_be_unpickled" - - # Create file-like objects with test data - model_file, input_file = get_dummy_model_and_inputs("model", "inputs") - model = UploadFile(filename="model", file=model_file) - inputs = UploadFile(filename="inputs", file=input_file) - - with self.assertRaises(HTTPException) as excinfo: - read_upload_files(model_id, model, inputs) - - mock_load.assert_called_once() - self.assertEqual(excinfo.exception.status_code, HTTPStatus.BAD_REQUEST) - self.assertIn("Compilation: error loading content with torch.load:", str(excinfo.exception)) - - def test_proper_read(self): - model_id = "test_model_id" - - model_data, inputs_data = "model", "inputs" - - # Create file-like objects with test data - model_file, input_file = get_dummy_model_and_inputs("model", "inputs") - model = UploadFile(filename="model", file=model_file) - inputs = UploadFile(filename="inputs", file=input_file) - - tfx_graph, example_inputs = read_upload_files(model_id, model, inputs) - - self.assertEqual(tfx_graph, model_data) - self.assertEqual(example_inputs, inputs_data) - - -class TestCompileHandler(TestCase): - @patch("centml.compiler.server.get_status") - def test_model_compiling(self, mock_status): - model_id = "compiling_model" - mock_status.return_value = CompilationStatus.COMPILING - - model_file, input_file = get_dummy_model_and_inputs("model", "inputs") - response = client.post(f"/submit/{model_id}", files={"model": model_file, "inputs": input_file}) - self.assertEqual(response.status_code, HTTPStatus.OK) - - @patch("os.makedirs") - @patch("centml.compiler.server.background_compile") - @patch("centml.compiler.server.get_status") - def test_model_not_compiled(self, mock_status, mock_compile, mock_mkdir): - model_id = "compiling_model" - mock_status.return_value = CompilationStatus.NOT_FOUND - # Stop compilation from happening - mock_compile.new = lambda x, y, z: None - - model_file, input_file = get_dummy_model_and_inputs("model", "inputs") - response = client.post(f"/submit/{model_id}", files={"model": model_file, "inputs": input_file}) - - self.assertEqual(response.status_code, HTTPStatus.OK) - mock_mkdir.assert_called_once() - mock_compile.assert_called_once() - - -class TestDownloadHandler(TestCase): - @patch("os.path.isfile", new=lambda x: False) - def test_download_handler_invalid_model_id(self): - model_id = "invalid_model_id" - - response = client.get(f"/download/{model_id}") - self.assertEqual(response.status_code, HTTPStatus.NOT_FOUND) - - @patch("os.path.isfile", new=lambda x: True) - @patch("centml.compiler.server.FileResponse") - def test_download_handler_success(self, mock_file_response): - model_id = "valid_model_id" - - response = client.get(f"/download/{model_id}") - - self.assertEqual(response.status_code, HTTPStatus.OK)