diff --git a/.gitignore b/.gitignore index b953209..99ba63f 100644 --- a/.gitignore +++ b/.gitignore @@ -175,3 +175,6 @@ cython_debug/ # PyPI configuration file .pypirc + +# Others +.trash \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..b1c93cf --- /dev/null +++ b/Makefile @@ -0,0 +1,10 @@ +lint: + black . + isort . + flake8 . + mypy . + +test: + pytest + +check: lint test \ No newline at end of file diff --git a/config/agent_input.yaml b/config/agent_input.yaml deleted file mode 100644 index c122a8b..0000000 --- a/config/agent_input.yaml +++ /dev/null @@ -1,34 +0,0 @@ -id: agent-io -hostname: localhost - -huri: - hostname: localhost - router: - port: 3000 - event-proxy: - xsub: 5555 - xpub: 5556 - log-puller: - port: 8008 - -forwarder-proxy: - down-xsub: 6665 - up-xpub: 6666 - -logging: INFO - -modules: - inp: - name: INP - logging: INFO - out: - name: OUT - logging: INFO - mod: - name: MOD - logging: INFO - rag: - name: RAG - args: - model: deepseek-v2:16b - logging: INFO diff --git a/config/agent_io.yaml b/config/agent_io.yaml deleted file mode 100644 index c9a5646..0000000 --- a/config/agent_io.yaml +++ /dev/null @@ -1,36 +0,0 @@ -id: agent-io -hostname: localhost - -huri: - hostname: localhost - router: - port: 3000 - event-proxy: - xsub: 5555 - xpub: 5556 - log-puller: - port: 8008 - -forwarder-proxy: - down-xsub: 6665 - up-xpub: 6666 - -logging: INFO - -modules: - mic: - name: mic - args: - sample_rate: 18000 - logging: INFO - stt: - name: stt - args: - sample_rate: 18000 - logging: INFO - # tts: - # name: vibe - # args: - # model: vibe-voice - # voice: adrien - # logging: DEBUG diff --git a/config/huri.yaml b/config/huri.yaml deleted file mode 100644 index 13f06b1..0000000 --- a/config/huri.yaml +++ /dev/null @@ -1,11 +0,0 @@ -hostname: localhost - -router: - port: 3000 - -event-proxy: - xsub: 5555 - xpub: 5556 - -log-puller: - port: 8008 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..476f4a8 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,31 @@ +[tool.black] +line-length = 88 +target-version = ["py310"] + +[tool.isort] +profile = "black" +line_length = 88 +multi_line_output = 3 +include_trailing_comma = true +skip_gitignore = true + +[tool.flake8] +max-line-length = 88 +extend-ignore = [] +exclude = """ + __pycache__ + venv + .venv +""" + +[tool.mypy] +python_version = "3.10" +ignore_missing_imports = true +strict_optional = true +warn_unused_ignores = true +warn_return_any = true +warn_unused_configs = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] diff --git a/quick_launch.sh b/quick_launch.sh deleted file mode 100755 index a76da2a..0000000 --- a/quick_launch.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env bash - -set -e - -# Check args -if [ "$#" -lt 2 ]; then - echo "Usage: $0 [CLEAN]" - exit 1 -fi - -HURI_CONFIG="$1" -AGENT_CONFIG="$2" - -LOG_DIR="./tmp/log" - -if [[ " $* " == *" CLEAN "* ]]; then - echo "Cleaning previous logs in ${LOG_DIR}" - rm -rf "${LOG_DIR}" -fi - -mkdir -p "$LOG_DIR" - -TIMESTAMP=$(date +"%Y%m%d-%H%M%S") -HURI_LOG="${LOG_DIR}/huri-${TIMESTAMP}.log" - - -# Run huri with output redirected -python -m src.launch_huri --config "$HURI_CONFIG" > "$HURI_LOG" 2>&1 & -HURI_PID=$! -echo "HURI started in background (PID=${HURI_PID}), logging to ${HURI_LOG}" - -# Run agent -python -m src.launch_agent --config "$AGENT_CONFIG" - -# Ensure HURI is killed on script exit (normal or Ctrl+C) -cleanup() { - echo "Stopping HURI (PID=${HURI_PID})" - kill "${HURI_PID}" 2>/dev/null || true -} -trap cleanup EXIT INT TERM \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 863a44c..dbc30d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,10 @@ +black +isort +mypy +flake8 +flake8-toml-config +pytest + deepfilternet sounddevice soundfile diff --git a/src/client.py b/src/client.py new file mode 100644 index 0000000..6c439ee --- /dev/null +++ b/src/client.py @@ -0,0 +1,45 @@ +import asyncio + +import numpy as np +import sounddevice as sd +import websockets + +SERVER_URL = "ws://localhost:8000/session" +CHUNK_DURATION = 1 +SAMPLE_RATE = 16000 + + +async def stream_audio(): + + async with websockets.connect(SERVER_URL) as ws: + print("Connected to server") + + async def receive(ws: websockets.ClientConnection): + while True: + text = await ws.recv() + print("received:", text) + + async def send(ws: websockets.ClientConnection): + loop = asyncio.get_running_loop() + + queue: asyncio.Queue = asyncio.Queue() + + def callback(indata: np.ndarray, frames, time, status): + loop.call_soon_threadsafe(queue.put_nowait, indata.copy()) + + with sd.InputStream( + samplerate=SAMPLE_RATE, + channels=1, + dtype="int16", + callback=callback, + blocksize=int(CHUNK_DURATION * SAMPLE_RATE), + ): + while True: + chunk = await queue.get() + await ws.send(chunk.tobytes()) + + await asyncio.gather(receive(ws), send(ws)) + + +if __name__ == "__main__": + asyncio.run(stream_audio()) diff --git a/src/core/agent.py b/src/core/agent.py deleted file mode 100644 index df2c182..0000000 --- a/src/core/agent.py +++ /dev/null @@ -1,277 +0,0 @@ -import multiprocessing as mp -import signal -import threading -from dataclasses import dataclass -from multiprocessing.synchronize import Event -from typing import Any, Dict, Mapping - -from src.modules.factory import ModuleFactory -from src.tools.logger import logging, setup_logger - -from .huri import HuriConfig -from .zmq.control_channel import Command, Dealer -from .zmq.event_proxy import EventProxy -from .zmq.log_channel import LogPusher - - -@dataclass -class ForwarderProxyConfig: - down_xsub: int - up_xpub: int - - @classmethod - def from_dict(cls, raw: dict): - return cls( - down_xsub=raw["down-xsub"], - up_xpub=raw["up-xpub"], - ) - - -@dataclass -class ModuleConfig: - name: str - args: Mapping[str, Any] - logging: int - - @classmethod - def from_dict(cls, raw: dict): - level = logging._nameToLevel.get( - raw.get("logging", "INFO"), - logging.INFO, - ) - return cls( - name=raw["name"], - args=raw.get("args", {}), - logging=level, - ) - - -@dataclass -class AgentConfig: - id: str - hostname: str - huri: HuriConfig - logging: int - forwarder_proxy: ForwarderProxyConfig - modules: Dict[str, ModuleConfig] - - @classmethod - def from_dict(cls, raw: dict): - level = logging._nameToLevel.get( - raw.get("logging", "INFO").upper(), - logging.INFO, - ) - modules = { - module_id: ModuleConfig.from_dict(mod_raw) - for module_id, mod_raw in raw.get("modules", {}).items() - } - return cls( - id=raw["id"], - hostname=raw["hostname"], - huri=HuriConfig.from_dict(raw["huri"]), - forwarder_proxy=ForwarderProxyConfig.from_dict(raw["forwarder-proxy"]), - logging=level, - modules=modules, - ) - - -class Agent: - """Control Modules and communication with HuRI""" - - def __init__(self, config: AgentConfig) -> None: - self.modules: Dict[str, ModuleConfig] = config.modules - self.config = config - - self.processes: Dict[str, mp.Process] = {} - self.stop_events: Dict[str, Event] = {} - - self.threads: Dict[str, threading.Thread] = {} - - self.log_pusher = LogPusher( - hostname=config.huri.hostname, port=config.huri.log_puller.port - ) - - self.dealer = Dealer( - hostname=config.huri.hostname, - port=config.huri.router.port, - executor=self._command_handler, - logger=setup_logger("Dealer", log_queue=self.log_pusher.log_queue), - ) - - self.up_proxy = EventProxy( - hostname=config.hostname, - connect_hostname=config.huri.hostname, - xpub_port=config.huri.event_proxy.xsub, - xsub_port=config.forwarder_proxy.up_xpub, - logger=setup_logger("UpProxy", log_queue=self.log_pusher.log_queue), - ) - self.down_proxy = EventProxy( - hostname=config.hostname, - connect_hostname=config.huri.hostname, - xpub_port=config.forwarder_proxy.down_xsub, - xsub_port=config.huri.event_proxy.xpub, - logger=setup_logger("DownProxy", log_queue=self.log_pusher.log_queue), - ) - - self.logger = setup_logger( - f"Agent {self.dealer.identity}", log_queue=self.log_pusher.log_queue - ) - - def _command_handler(self, command: Command) -> bool: - match command.cmd: - case "START": - return self.start_module(*command.args) - case "STOP": - return self.stop_module(*command.args) - case "STATUS": - return self.status() - case _: - return False # todo log - - @staticmethod - def _start_module( - name: str, - module_config: ModuleConfig, - agent_config: AgentConfig, - log_queue: mp.Queue, - stop_event: Event, - ) -> None: - """Helper function to start module in child process.""" - logger = setup_logger( - module_config.name, level=module_config.logging, log_queue=log_queue - ) - - module = ModuleFactory.create(name, module_config.args) - module.set_custom_logger(logger) - - def handle_sigint(signum, frame): - logger.info("Ctrl+C ignored in child module") - - signal.signal(signal.SIGINT, handle_sigint) - - module.start_module( - agent_config.hostname, - agent_config.forwarder_proxy.up_xpub, - agent_config.forwarder_proxy.down_xsub, - stop_event=stop_event, - ) - - def start_module(self, name) -> None: - """Check if module is registered and not already running, and start a child process.""" - if name not in self.modules: - self.logger.warning( - f"{name} is not in the registered Modules: {self.modules.keys()}" - ) - return - if name in self.processes: - self.logger.warning( - f"{name} is already running (PID={self.processes[name].pid})" - ) - return - - module_config = self.modules[name] - stop_event = mp.Event() - p = mp.Process( - target=self._start_module, - args=( - name, - module_config, - self.config, - self.log_pusher.log_queue, - stop_event, - ), - daemon=True, - ) - self.processes[name] = p - self.stop_events[name] = stop_event - self.log_pusher.level_filter.add_level(name) - - p.start() - self.logger.info(f"{name} ({module_config.name}) started (PID={p.pid})") - - def stop_module(self, name) -> None: - if name in self.processes: - self.logger.info(f"Stopping {name}...") - self.stop_events[name].set() - self.processes[name].join(timeout=5) - if self.processes[name].is_alive(): - self.logger.warning(f"{name} did not stop in time, killing") - self.processes[name].kill() - self.logger.info(f"{name} stopped") - del self.processes[name] - del self.stop_events[name] - self.log_pusher.level_filter.del_level(name) - - def stop_all(self) -> None: - for name in list(self.processes.keys()): - self.stop_module(name) - - self.dealer.stop() - self.up_proxy.stop() - self.down_proxy.stop() - for name, thread in self.threads.items(): - self.logger.info(f"Stopping {name} thread...") - thread.join(timeout=5) - self.logger.info(f"{name} thread stopped") - self.log_pusher.level_filter.del_level(name) - - self.log_pusher.stop() - print("Fully stopped") - - def status(self) -> None: - """Print status of all modules and router.""" - print("=== Module Status ===") - for name in self.modules: - process = self.processes.get(name) - if process: - state = "alive" if process.is_alive() else "stopped" - print(f"- {name}: {state} (PID={process.pid})") - else: - print(f"- {name}: stopped") - print("=====================") - - def set_root_log_level(self, level: int) -> None: - self.log_pusher.level_filter.set_root_level(level) - - def set_log_level(self, name: str, level: int) -> None: - self.log_pusher.level_filter.set_level(name, level) - - def set_log_levels(self, level: int) -> None: - self.log_pusher.level_filter.set_levels(level) - - def _connect_to_huri(self) -> None: - self.log_pusher.level_filter.add_level("Dealer") - self.threads["Dealer"] = threading.Thread(target=self.dealer.start) - self.threads["Dealer"].start() - - def _start_event_proxies(self) -> None: - """Used to handle inter-module communication, though events""" - self.log_pusher.level_filter.add_level("UpProxy") - self.log_pusher.level_filter.add_level("DownProxy") - self.threads["UpProxy"] = threading.Thread( - target=self.up_proxy.start, args=[True, False] - ) - self.threads["DownProxy"] = threading.Thread( - target=self.down_proxy.start, args=[False, True] - ) - - self.threads["UpProxy"].start() - self.threads["DownProxy"].start() - - def run(self) -> None: - """Start event router and modules""" # TODO config (also logs levels) - - try: - self.log_pusher.start() - self._connect_to_huri() - self._start_event_proxies() - except Exception as e: - self.logger.error(e) - return - - for name in self.modules: - self.start_module(name) - - while True: - data = input() - self.down_proxy.publish("std.in", data) diff --git a/src/core/events.py b/src/core/events.py index e69de29..0d3653d 100644 --- a/src/core/events.py +++ b/src/core/events.py @@ -0,0 +1,33 @@ +import asyncio +from collections import defaultdict + +from .module import Module + + +class EventGraph: + + def __init__(self): + + self.subscribers = defaultdict(list) + + def register(self, module: Module): + self.subscribers[module.input_type].append(module) + + async def publish(self, event_topic, data): + for module in self.subscribers[event_topic]: + asyncio.create_task(self._run(module, data)) + + async def _run(self, module: Module, data): + + result = module.process(data) + + if hasattr(result, "__aiter__"): + async for item in result: + if item is None: + continue + await self.publish(module.output_type, item) + + else: + value = await result + if value is not None: + await self.publish(module.output_type, value) diff --git a/src/core/huri.py b/src/core/huri.py index 18f52f3..6883687 100644 --- a/src/core/huri.py +++ b/src/core/huri.py @@ -1,97 +1,49 @@ -import sys -import threading -from dataclasses import dataclass +import uuid from typing import Dict -from src.tools.logger import setup_logger +from fastapi import FastAPI, WebSocket +from ray import serve +from ray.serve import handle -from .zmq.control_channel import Router -from .zmq.event_proxy import EventProxy -from .zmq.log_channel import LogPuller +from src.modules.speech_to_text.record_speech import MIC +from src.modules.speech_to_text.speech_to_text import STT +from src.modules.utils.sender import Sender +from .session import Session -@dataclass -class RouterConfig: - port: int - - -@dataclass -class EventProxyConfig: - xsub: int - xpub: int - - -@dataclass -class LogPullerConfig: - port: int - - -@dataclass -class HuriConfig: - hostname: str - router: RouterConfig - event_proxy: EventProxyConfig - log_puller: LogPullerConfig - - @classmethod - def from_dict(cls, raw: dict): - return cls( - hostname=raw["hostname"], - router=RouterConfig(**raw["router"]), - event_proxy=EventProxyConfig(**raw["event-proxy"]), - log_puller=LogPullerConfig(**raw["log-puller"]), - ) +app = FastAPI() +@serve.deployment +@serve.ingress(app) class HuRI: - """Wait for Agent to connect, handle module communication and Logging""" - - def __init__(self, config: HuriConfig) -> None: - self.router = Router(config.hostname, config.router.port) - self.event_proxy = EventProxy( - config.hostname, "", config.event_proxy.xpub, config.event_proxy.xsub - ) - self.log_channel = LogPuller(config.hostname, config.log_puller.port) - - self.threads: Dict[str, threading.Thread] = {} - - self.logger = setup_logger("HuRI") - - def _start_router(self) -> None: - """Used to handle Agent registration and control""" - self.threads["Router"] = threading.Thread(target=self.router.start) - self.threads["Router"].start() - - def _start_event_proxy(self) -> None: - """Used to handle inter-module communication, though events""" - self.threads["EventProxy"] = threading.Thread( - target=self.event_proxy.start, args=[False, False] - ) - self.threads["EventProxy"].start() - - def _start_log_channel(self) -> None: - """Used to handle Agent registration and control""" - self.threads["LogChannel"] = threading.Thread(target=self.log_channel.start) - self.threads["LogChannel"].start() - - def run(self) -> None: - self._start_log_channel() - self._start_router() - self._start_event_proxy() - - if not sys.stdin.isatty(): - threading.Event().wait() - return - - from src.core.shell import RobotShell - - RobotShell(self).cmdloop() - - def stop(self) -> None: - self.router.stop() - self.event_proxy.stop() - self.log_channel.stop() - for name, thread in self.threads.items(): - self.logger.info(f"Stopping {name} thread...") - thread.join(timeout=5) - self.logger.info(f"{name} thread stopped") + def __init__(self, config, handles: Dict[str, handle.DeploymentHandle]) -> None: + self.config = config + self.handles = handles + + self.clients: Dict[str, Session] = {} + + @app.websocket("/session") + async def run_session(self, ws: WebSocket): + await ws.accept() + + modules = [ + STT(self.handles["stt"]), + MIC(5), + Sender(ws, "text"), + ] + session_id = str(uuid.uuid4()) + + self.clients[session_id] = Session(modules) + + async def receive_loop(session: Session, ws: WebSocket): + while True: + msg = await ws.receive() + if "bytes" in msg: + chunk = msg["bytes"] + await session.publish("chunk", chunk) + # else: + # data = msg + # await session.publish(data["type"], data["data"]) + + await receive_loop(self.clients[session_id], ws) diff --git a/src/core/module.py b/src/core/module.py index 4be1246..0ae9cad 100644 --- a/src/core/module.py +++ b/src/core/module.py @@ -1,159 +1,9 @@ -import json -import threading -from abc import ABC, abstractmethod -from multiprocessing.synchronize import Event -from typing import Callable, Dict, final +from typing import Any, Optional -import zmq -from src.tools.logger import logging +class Module: + input_type: Optional[str] + output_type: Optional[str] - -class Module(ABC): - def __init__(self): - """Child Modules must call super.__init__() in their __init__() function.""" - self.ctx = None - self.pub_socket = None - self.connect_hostname = None - self.xpub_port = None - self.xsub_port = None - self.subs: Dict[str, zmq.Socket[bytes]] = {} - self.callbacks = {} - self._poller_running = False - self.poller = None - self.logger = logging.getLogger(__name__) - - @final - def _initialize(self) -> None: - """ - Called inside start_module() or manually before usage. - This function exist because ctx cannot be set in __init__, because of multi-processing. maybe deprecated - """ - self.ctx = zmq.Context() - self.pub_socket = self.ctx.socket(zmq.PUB) - self.pub_socket.connect(f"tcp://{self.connect_hostname}:{self.xpub_port}") - self.poller = threading.Thread(target=self._poll_loop, daemon=True) - self.set_subscriptions() - - @abstractmethod - def set_subscriptions(self) -> None: - """Child module must define this funcction with subscriptions""" - ... - - @final - def subscribe(self, topic: str, callback: Callable) -> None: - sub_socket = self.ctx.socket(zmq.SUB) - sub_socket.connect(f"tcp://{self.connect_hostname}:{self.xsub_port}") - sub_socket.setsockopt_string(zmq.SUBSCRIBE, topic) - self.subs[topic] = sub_socket - self.callbacks[topic] = callback - self.logger.info(f"Subscribe: {topic}") - - @final - def publish( - self, topic: str, msg: object, content_type: str = "str" - ) -> None: # TODO content type enum - if content_type == "json": - payload = json.dumps(msg).encode() - elif content_type == "bytes": - payload = msg - elif content_type == "str": - payload = msg.encode() - else: - raise ValueError(f"Unsupported content_type: {content_type}") - - self.pub_socket.send_multipart([topic.encode(), content_type.encode(), payload]) - self.logger.info(f"Publish: {topic} {content_type}") - - @final - def _start_polling(self) -> None: - self._poller_running = True - self.poller.start() - - @final - def _poll_loop(self) -> None: - poller = zmq.Poller() - for sub in self.subs.values(): - poller.register(sub, zmq.POLLIN) - - while self._poller_running: - events = dict(poller.poll(100)) - for _, sub in self.subs.items(): - if sub in events: - topic, content_type, payload = sub.recv_multipart() - topic_str = topic.decode() - content_type_str = content_type.decode() - self.logger.info(f"Receive: {topic_str} {content_type_str}") - if content_type_str == "json": - kwargs = json.loads(payload.decode()) - self.callbacks[topic_str]( - **kwargs - ) # TODO better and cleaner way ? - elif content_type_str == "bytes": - data = payload - self.callbacks[topic_str](data) - elif content_type_str == "str": - data = payload.decode() - self.callbacks[topic_str](data) - - @final - def start_module( - self, - connect_hostname: str, - xpub_port: int, - xsub_port: int, - stop_event: Event = None, - ) -> None: - self.connect_hostname = connect_hostname - self.xpub_port = xpub_port - self.xsub_port = xsub_port - self._initialize() - if self.subs != {}: - self._start_polling() - try: - self.run_module(stop_event) - except KeyboardInterrupt: - self.logger.info("Ctrl+C pressed, exiting cleanly") - except Exception as e: - self.logger.error(e) - finally: - self.stop_module() - - @final - def stop_module(self) -> None: - """Stop the module gracefully.""" - - if self._poller_running: - self._poller_running = False - self.poller.join() - - for topic, sub in self.subs.items(): - try: - sub.close(0) - except Exception as e: - self.logger.error(f"Error closing SUB socket for '{topic}': {e}") - - self.subs.clear() - self.callbacks.clear() - - try: - self.pub_socket.close(0) - except Exception as e: - self.logger.error(f"Error closing SUB socket for '{topic}': {e}") - - try: - self.ctx.term() - except Exception as e: - self.logger.error(f"Error terminating ZMQ context: {e}") - - self.logger.info("Module stopped gracefully.") - - def run_module(self, stop_event: Event = None) -> None: - """Child modules override this instead of run(). Default: idle wait.""" - if stop_event: - stop_event.wait() - - @final - def set_custom_logger(self, logger) -> None: - """The default logger in set in __init__.""" - self.logger = logger + async def process(self, _) -> Optional[Any]: + raise NotImplementedError diff --git a/src/core/session.py b/src/core/session.py new file mode 100644 index 0000000..85d6712 --- /dev/null +++ b/src/core/session.py @@ -0,0 +1,12 @@ +from .events import EventGraph + + +class Session: + def __init__(self, modules): + self.event_graph = EventGraph() + + for module in modules: + self.event_graph.register(module) + + async def publish(self, topic, data): + await self.event_graph.publish(topic, data) diff --git a/src/core/shell.py b/src/core/shell.py deleted file mode 100644 index 6473357..0000000 --- a/src/core/shell.py +++ /dev/null @@ -1,31 +0,0 @@ -import cmd - -from src.core.huri import HuRI -from src.core.zmq.control_channel import Command - - -class RobotShell(cmd.Cmd): - intro = "HuRI's shell. Type 'help' to see command's list." - prompt = "(HuRI) " - - def __init__(self, huri: HuRI) -> None: - super().__init__() - self.huri = huri - - def do_status(self, arg) -> None: - "Display modules and router status." - self.huri.router.send_commands(Command("STATUS", [])) - - def do_start(self, arg) -> None: - "Start a module." - self.huri.router.send_commands(Command("START", [arg.strip()])) - - def do_stop(self, arg) -> None: - "Stop a module." - self.huri.router.send_commands(Command("STOP", [arg.strip()])) - - def do_exit(self, arg) -> None: - "Exit HuRi." - self.huri.router.send_commands(Command("EXIT", [])) - print("Bye !") - return True diff --git a/src/core/zmq/control_channel.py b/src/core/zmq/control_channel.py deleted file mode 100644 index 89ef839..0000000 --- a/src/core/zmq/control_channel.py +++ /dev/null @@ -1,151 +0,0 @@ -import json -import uuid -from dataclasses import asdict, dataclass -from typing import Any, Callable, Dict, List, Optional - -import zmq - -from src.tools.logger import logging, setup_logger - - -@dataclass -class Command: - cmd: str # "STOP", "START", "STATUS", ... - args: List[Any] # JSON-serializable arguments - - def to_bytes(self) -> bytes: - return json.dumps(asdict(self)).encode("utf-8") - - @staticmethod - def from_bytes(data: bytes) -> "Command": - obj = json.loads(data.decode("utf-8")) - return Command(**obj) - - -@dataclass -class Result: - success: bool - result: List[Any] - - def to_bytes(self) -> bytes: - return json.dumps(asdict(self)).encode("utf-8") - - @staticmethod - def from_bytes(data: bytes) -> "Command": - obj = json.loads(data.decode("utf-8")) - return Result(**obj) - - -class Router: - def __init__( - self, - hostname: str, - port: int, - logger: Optional[logging.Logger] = setup_logger("Router"), - ): - - self.ctx = zmq.Context.instance() - self.router = self.ctx.socket(zmq.ROUTER) - self.hostname = hostname - self.port = port - - self.logger = logger or logging.getLogger(__name__) - - self.dealers: Dict[bytes, bool] = {} - - def start(self): - self.router.bind(f"tcp://{self.hostname}:{self.port}") - self.logger.info("Router started") - - try: - while True: - identity, *frames = self.router.recv_multipart() - - if not frames: - continue - - command = frames[0] - - if command == b"REGISTER": - self.dealers[identity] = True - self.logger.info(f"Dealer registered: {identity}") - - elif command == b"RESULT": - payload = frames[1] if len(frames) > 1 else b"" - self.logger.info(f"Result from {identity}: {payload.decode()}") - except Exception as e: - self.logger.exception(e) - pass - finally: - self.router.close() - - def stop(self) -> None: - self.router.close() - - def send_command(self, dealer_id: bytes, command: Command) -> None: - if dealer_id not in self.dealers: - raise ValueError("Dealer not registered") - - self.router.send_multipart([dealer_id, b"COMMAND", command.to_bytes()]) - - def send_commands(self, command: Command) -> None: - for dealer_id, _ in self.dealers.items(): - self.send_command(dealer_id, command) - - -class Dealer: - def __init__( - self, - hostname: str, - port: int, - executor: Callable[[Command], bool], - logger: Optional[logging.Logger] = None, - identity: Optional[str] = None, - ): - self.ctx = zmq.Context.instance() - self.dealer = self.ctx.socket(zmq.DEALER) - - self.hostname = hostname - self.port = port - - self.executor = executor - self.identity = (identity or str(uuid.uuid4())).encode() # TODO agent name - - self.logger = logger or logging.getLogger(f"Dealer {self.identity}") - - def start(self): - self.dealer.connect(f"tcp://{self.hostname}:{self.port}") - self.dealer.setsockopt(zmq.IDENTITY, self.identity) - self.logger.info(f"Dealer started: {self.identity}") - - try: - self.dealer.send(b"REGISTER") - - while True: - frames = self.dealer.recv_multipart() - - command = frames[0] - - if command == b"COMMAND": - self.logger.info("received command") - payload = frames[1] if len(frames) > 1 else b"" - result = self.execute(payload) - - self.dealer.send_multipart([b"RESULT", result]) - except Exception as e: - self.logger.exception(e) - finally: - self.dealer.close() - - def execute(self, command: Command) -> bytes: - """ - Execute command sent by Router - """ - self.executor(command) - - # Example execution - result = f"Executed: {command.cmd}" - return result.encode() - - def stop(self) -> None: - self.dealer.close(linger=0) diff --git a/src/core/zmq/event_proxy.py b/src/core/zmq/event_proxy.py deleted file mode 100644 index 9447de8..0000000 --- a/src/core/zmq/event_proxy.py +++ /dev/null @@ -1,58 +0,0 @@ -from dataclasses import dataclass -from typing import Optional - -import zmq - -from src.tools.logger import logging, setup_logger - - -@dataclass -class ZMQEventPorts: - xpub: str - xsub: str - - -class EventProxy: - def __init__( - self, - hostname: str, - connect_hostname: str, - xpub_port: int, - xsub_port: int, - logger: Optional[logging.Logger] = setup_logger("EventProxy"), - ): - - self.ctx = zmq.Context.instance() - self.xpub = self.ctx.socket(zmq.XPUB) - self.xsub = self.ctx.socket(zmq.XSUB) - - self.hostname = hostname - self.connect_hostname = connect_hostname - self.xpub_port = xpub_port - self.xsub_port = xsub_port - - self.logger = logger or logging.getLogger(__name__) - - def start(self, xpub_connect: bool, xsub_connect: bool): - if xpub_connect: - self.xpub.connect(f"tcp://{self.connect_hostname}:{self.xpub_port}") - else: - self.xpub.bind(f"tcp://{self.hostname}:{self.xpub_port}") - if xsub_connect: - self.xsub.connect(f"tcp://{self.connect_hostname}:{self.xsub_port}") - else: - self.xsub.bind(f"tcp://{self.hostname}:{self.xsub_port}") - - try: - self.logger.info("Correctly initialized, starting proxy") - zmq.proxy(self.xsub, self.xpub) - except Exception as e: - self.logger.error(e) - - def stop(self) -> None: - self.xsub.close(linger=0) - self.xpub.close(linger=0) - - def publish(self, topic: str, msg: str) -> None: - self.xpub.send_multipart([topic.encode(), "str".encode(), msg.encode()]) - self.logger.info(f"Publish: {topic} str") diff --git a/src/core/zmq/log_channel.py b/src/core/zmq/log_channel.py deleted file mode 100644 index 6eb2a8e..0000000 --- a/src/core/zmq/log_channel.py +++ /dev/null @@ -1,142 +0,0 @@ -import json -import time -from typing import Any, Dict, Optional - -import zmq - -from src.tools.logger import ( - LevelFilter, - QueueListener, - logging, - mp, - setup_log_listener, - setup_logger, -) - - -def record_to_dict(record: logging.LogRecord) -> Dict[str, Any]: - return { - "name": record.name, - "levelno": record.levelno, - "levelname": record.levelname, - "message": record.getMessage(), - "created": record.created, - "asctime": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.created)), - "process": record.process, - "processName": record.processName, - "thread": record.thread, - "threadName": record.threadName, - "module": record.module, - "filename": record.filename, - "pathname": record.pathname, - "lineno": record.lineno, - "funcName": record.funcName, - } - - -def dict_to_record(data: Dict[str, Any]) -> logging.LogRecord: - record = logging.LogRecord( - name=data["name"], - level=data["levelno"], - pathname=data["pathname"], - lineno=data["lineno"], - msg=data["message"], - args=(), - exc_info=None, - func=data["funcName"], - ) - - # Restore metadata - record.created = data["created"] - record.process = data["process"] - record.processName = data["processName"] - record.thread = data["thread"] - record.threadName = data["threadName"] - record.module = data["module"] - record.filename = data["filename"] - - return record - - -class LogPuller: - def __init__( - self, - hostname: str, - port: int, - logger: Optional[logging.Logger] = setup_logger("LogPuller"), - ) -> None: - self.ctx = zmq.Context.instance() - self.pull = self.ctx.socket(zmq.PULL) - - self.hostname = hostname - self.port = port - - self.logger = logger or logging.getLogger(__name__) - - def start(self) -> None: - self.pull.bind(f"tcp://{self.hostname}:{self.port}") - - self.logger.info("started") - while True: - payload = self.pull.recv() - - self.logger.handle(dict_to_record(json.loads(payload.decode()))) - - def stop(self) -> None: - self.pull.close() - - -class LogPusher: - class LogPusherHandler(logging.Handler): - def __init__( - self, - hostname: str, - port: int, - ): - super().__init__() - self.ctx = zmq.Context.instance() - self.socket = self.ctx.socket(zmq.PUSH) - - self.hostname = hostname - self.port = port - - def emit(self, record: logging.LogRecord) -> None: - try: - payload = json.dumps(record_to_dict(record)).encode() - self.socket.send(payload) - except Exception: - self.handleError(record) - except Exception: - self.handleError(record) - - def start(self) -> None: - self.socket.connect(f"tcp://{self.hostname}:{self.port}") - - def stop(self) -> None: - self.socket.close() - - def __init__( - self, - hostname: str, - port: int, - ): - - self.log_queue = mp.Queue() - - self.log_handler = self.LogPusherHandler(hostname, port) - self.level_filter = LevelFilter(logging.DEBUG) - self.log_listener: QueueListener = setup_log_listener( - self.log_queue, self.level_filter, self.log_handler - ) - - self.logger = setup_logger("LogPusher", log_queue=self.log_queue) - - def start(self) -> None: - self.log_handler.start() - self.log_listener.start() - - def stop(self): - self.logger.info("stopping") - time.sleep(0.2) - self.log_listener.stop() - self.log_handler.stop() diff --git a/src/emotional_hub/input_analysis.py b/src/emotional_hub/input_analysis.py deleted file mode 100644 index 7391578..0000000 --- a/src/emotional_hub/input_analysis.py +++ /dev/null @@ -1,24 +0,0 @@ -import numpy as np -import torch -from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor - -MODEL_NAME = "superb/hubert-large-superb-er" -model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME) -feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME) - - -def predict_emotion(audio_np: np.ndarray, sr=16000): - if audio_np.dtype != np.float32: - audio_np = audio_np.astype(np.float32) - - inputs = feature_extractor( - audio_np, sampling_rate=sr, return_tensors="pt", padding=True - ) - - with torch.no_grad(): - logits = model(**inputs).logits - - predicted_id = torch.argmax(logits, dim=-1).item() - predicted_label = model.config.id2label[predicted_id] - - return predicted_label diff --git a/src/launch_agent.py b/src/launch_agent.py deleted file mode 100644 index efc63c4..0000000 --- a/src/launch_agent.py +++ /dev/null @@ -1,43 +0,0 @@ -import argparse -import logging -import time - -import yaml - -from src.core.agent import Agent, AgentConfig -from src.modules.factory import build_module_factory - - -def load_config(path: str) -> AgentConfig: - with open(path) as f: - raw = yaml.safe_load(f) - - return AgentConfig.from_dict(raw) - - -def main() -> None: - parser = argparse.ArgumentParser(description="HuRI core") - parser.add_argument( - "--config", - required=True, - help="Path to HuRI config file (YAML)", - ) - - args = parser.parse_args() - - config = load_config(args.config) - - build_module_factory() - - agent = Agent(config) - time.sleep(0.1) - try: - agent.run() - except KeyboardInterrupt: - agent.stop_all() - except Exception as e: - logging.getLogger(__name__).error(e) - - -if __name__ == "__main__": - main() diff --git a/src/launch_huri.py b/src/launch_huri.py index b4f9fd0..729ee79 100644 --- a/src/launch_huri.py +++ b/src/launch_huri.py @@ -1,42 +1,25 @@ -import argparse -import logging import time -import yaml +import ray -from src.core.huri import HuRI, HuriConfig -from src.modules.factory import build_module_factory - - -def load_config(path: str) -> HuriConfig: - with open(path) as f: - raw = yaml.safe_load(f) - - return HuriConfig.from_dict(raw) +from src.core.huri import Dict, HuRI, handle, serve +from src.modules.speech_to_text.speech_to_text import STTHandle def main() -> None: - parser = argparse.ArgumentParser(description="HuRI core") - parser.add_argument( - "--config", - required=True, - help="Path to HuRI config file (YAML)", - ) - - args = parser.parse_args() - - config = load_config(args.config) - - build_module_factory() + ray.init() - huri = HuRI(config) + services: Dict[str, handle.DeploymentHandle] = { + "stt": STTHandle.bind(), # type: ignore[attr-defined] + } + app = HuRI.bind("", services) # type: ignore[attr-defined] time.sleep(0.1) try: - huri.run() + serve.run(app, name="HuRI", blocking=True) except KeyboardInterrupt: - huri.stop() + return except Exception as e: - logging.getLogger(__name__).error(e) + ray.logger.error(e) if __name__ == "__main__": diff --git a/src/modules/factory.py b/src/modules/factory.py deleted file mode 100644 index 74bba03..0000000 --- a/src/modules/factory.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Any, Mapping - -from src.core.module import Module - -from .rag.mode_controller import ModeController -from .rag.rag import Rag -from .speech_to_text.record_speech import RecordSpeech -from .speech_to_text.speech_to_text import SpeechToText -from .textIO.input import TextInput -from .textIO.output import TextOutput - - -class ModuleFactory: - _registry = {} - - @classmethod - def register(cls, name: str, module_cls): - cls._registry[name] = module_cls - - @classmethod - def create(cls, name: str, args: Mapping[str, Any] | None = None) -> Module: - if name not in cls._registry: - raise ValueError(f"Unknown module '{name}'") - return cls._registry[name](**args) - - -def build_module_factory() -> None: - ModuleFactory.register("mic", RecordSpeech) - ModuleFactory.register("stt", SpeechToText) - ModuleFactory.register("inp", TextInput) - ModuleFactory.register("out", TextOutput) - ModuleFactory.register("rag", Rag) - ModuleFactory.register("mod", ModeController) diff --git a/src/modules/rag/__init__.py b/src/modules/rag/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/modules/rag/mode_controller.py b/src/modules/rag/mode_controller.py deleted file mode 100644 index 9ed0606..0000000 --- a/src/modules/rag/mode_controller.py +++ /dev/null @@ -1,37 +0,0 @@ -from enum import Enum - -from src.core.module import Module - - -class Modes(Enum): - LLM = 0 - CONTEXT = 1 - RAG = 2 - - -class ModeController(Module): - def __init__(self, default_mode: Modes = Modes.LLM): - super().__init__() - self.mode = default_mode - - def switchMode(self, mode: str) -> None: - self.mode = mode - - def processTextInput(self, text: str): - if "switch llm" in text.lower(): - self.switchMode(Modes.LLM) - elif "switch context" in text.lower(): - self.switchMode(Modes.CONTEXT) - elif "switch rag" in text.lower(): - self.switchMode(Modes.RAG) - elif "bye bye" in text.lower(): - self.publish("exit", "") # TODO handle (manager being a module) usefull ? - elif text.strip() == "": - return - else: - topic = f"{str(self.mode.name).lower()}.in" - self.publish(topic, text) - - def set_subscriptions(self): - self.subscribe("text.in", self.processTextInput) - self.subscribe("mode.switch", self.switchMode) diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py deleted file mode 100644 index f2fc736..0000000 --- a/src/modules/rag/rag.py +++ /dev/null @@ -1,96 +0,0 @@ -import json -import pathlib - -from langchain.chains import create_retrieval_chain -from langchain.chains.combine_documents import create_stuff_documents_chain -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain_chroma import Chroma -from langchain_community.document_loaders import TextLoader -from langchain_core.documents import Document -from langchain_core.prompts import ChatPromptTemplate -from langchain_ollama.embeddings import OllamaEmbeddings -from langchain_ollama.llms import OllamaLLM -from langgraph.checkpoint.memory import MemorySaver - -from src.core.module import Module - - -class Rag(Module): - def __init__( - self, - model: str = "deepseek-v2:16b", - collectionName: str = "vectorStore", - vectorstorePath: str = "src/rag/vectorStore", - ): - super().__init__() - self.memory = MemorySaver() - self.embeddings = OllamaEmbeddings(model=model) - self.llm = OllamaLLM(model=model) - self.vectorstore = Chroma( - collection_name=collectionName, - embedding_function=self.embeddings, - persist_directory=vectorstorePath, - ) - self.textSplitter = RecursiveCharacterTextSplitter( - chunk_size=1000, chunk_overlap=200 - ) - self.retriever = self.vectorstore.as_retriever() - self.systemPrompt = "Conversation history:\n{history}\n\nContext:\n{context}" - self.prompt = ChatPromptTemplate.from_messages( - [ - ("system", self.systemPrompt), - ("human", "{input}"), - ] - ) - self.questionChain = create_stuff_documents_chain(self.llm, self.prompt) - self.qaChain = create_retrieval_chain(self.retriever, self.questionChain) - self.documents = [] - self.docs = [] - self.conversation = [] - self.conversation_log = {"conversation": []} - - def ragFill(self, text: str) -> None: - self.documents += self.textSplitter.split_documents( - [Document(page_content=text)] - ) - self.vectorstore.add_documents(self.documents) - - def ragLoad(self, folderPath: str, fileType: str) -> None: - if fileType == "txt": - for file in pathlib.Path(folderPath).rglob("*.txt"): - fileLoader = TextLoader(file_path=folderPath + "/" + file.name) - self.documents += self.textSplitter.split_documents(fileLoader.load()) - self.vectorstore.add_documents(self.documents) - - def ragQuestion(self, question: str) -> None: - self.logger.debug("question:", question) - history = "\n".join( - [ - f"Human: {qa['question']}\nAI: {qa['answer']}" - for qa in self.conversation_log["conversation"] - ] - ) - helpingContext = "Answer with just your message like in a conversation. " - question = helpingContext + question - self.logger.debug("full question:", question) - response = self.qaChain.invoke({"history": history, "input": question}) - answer = response["answer"] - self.logger.debug("answer:", answer) - self.conversation_log["conversation"].append( - {"question": question.split(helpingContext)[1:], "answer": answer} - ) - self.publish("llm.response", answer) - - def saveConversation(self, filename: str = "conversation_log.json"): - with open(filename, "w") as f: - json.dump(self.conversation_log, f, indent=4) - - def set_subscriptions(self) -> None: - self.subscribe("rag.load", self.ragLoad) - self.subscribe("llm.in", self.ragQuestion) - self.subscribe("rag.in", self.ragFill) - self.subscribe("rag.save", self.saveConversation) - - def run_module(self, stop_event=None) -> None: - self.ragLoad("tests/rag/docsRag", "txt") - super().run_module(stop_event) diff --git a/src/modules/speech_to_text/record_speech.py b/src/modules/speech_to_text/record_speech.py index a361ac3..0f28755 100644 --- a/src/modules/speech_to_text/record_speech.py +++ b/src/modules/speech_to_text/record_speech.py @@ -1,107 +1,25 @@ -import queue -import threading -import time -from typing import List, Optional +from typing import Optional import numpy as np -import sounddevice as sd -from src.core.module import Event, Module +from src.core.module import Module -class RecordSpeech(Module): +class MIC(Module): + input_type = "chunk" + output_type = "voice" + def __init__( self, threshold: int = 0, - silence_duration: float = 1.0, - chunk_duration: float = 0.5, - sample_rate: int = 16000, ): super().__init__() self.THRESHOLD: int = threshold - self.SILENCE_DURATION: float = silence_duration - self.CHUNK_DURATION: float = chunk_duration - self.SAMPLE_RATE: int = sample_rate - self.running: bool = False - self.audio_queue: queue.Queue = queue.Queue() - self.transcriptions: queue.Queue = queue.Queue() - self.pause_record = threading.Semaphore(1) - self.audio_to_process = threading.Semaphore(0) - self.prompt_available = threading.Semaphore(0) - self.noise_profile: np.ndarray - - def reduce_noise(self, chunk: np.ndarray) -> np.ndarray: - if np.abs(chunk).mean() <= self.THRESHOLD: - return chunk - - return np.clip(chunk - self.noise_profile, -32768, 32767).astype(np.int16) - - def record_chunk(self) -> np.ndarray: - self.pause_record.acquire() - chunk: np.ndarray = sd.rec( - int(self.CHUNK_DURATION * self.SAMPLE_RATE), - samplerate=self.SAMPLE_RATE, - channels=1, - dtype="int16", - ).ravel() - sd.wait() - self.pause_record.release() - return self.reduce_noise(chunk) - - def calculate_noise_level(self) -> None: - self.logger.info("Listening for 10 seconds to calculate noise level...") - noise_chunk: np.ndarray = sd.rec( - int(10 * self.SAMPLE_RATE), - samplerate=self.SAMPLE_RATE, - channels=1, - dtype="int16", - ).ravel() - sd.wait() - self.noise_profile = noise_chunk.mean(axis=0) - self.THRESHOLD = np.abs(self.reduce_noise(noise_chunk)).mean() - self.logger.info(f"Threshold: {self.THRESHOLD}") - - def record_audio(self, starting_chunk, stop_event: Event = None) -> None: - buffer: List[np.ndarray] = [starting_chunk] - silence_start: Optional[float] = None - - while stop_event is None or not stop_event.is_set(): - chunk = self.record_chunk() - buffer.append(chunk) - - if np.abs(chunk).mean() <= self.THRESHOLD: - if silence_start is None: - silence_start = time.time() - elif time.time() - silence_start >= self.SILENCE_DURATION: - if buffer == []: - break - speech = np.concatenate(buffer, axis=0) - self.publish("speech.in", speech.tobytes(), "bytes") - break - else: - silence_start = None - - def set_subscriptions(self) -> None: - self.subscribe("speech.in.pause", self.pause()) - self.subscribe("speech.in.resume", self.pause(False)) - - def run_module(self, stop_event: Event = None) -> None: - if not self.THRESHOLD: - self.calculate_noise_level() - else: - self.noise_profile = np.zeros( - int(self.CHUNK_DURATION * self.SAMPLE_RATE), dtype=np.int16 - ) - - while stop_event is None or not stop_event.is_set(): - chunk: np.ndarray = self.record_chunk() - - if np.abs(chunk).mean() > self.THRESHOLD: - self.record_audio(chunk, stop_event) - def pause(self, true: bool = True) -> None: - if true: - self.pause_record.acquire() - else: - self.pause_record.release() + async def process(self, data: bytes) -> Optional[np.ndarray]: + audio_array = np.frombuffer(data, dtype=np.int16) + if np.abs(audio_array).mean() > self.THRESHOLD: + audio_array_float = audio_array.astype(np.float32) / 32768.0 + return audio_array_float + return None diff --git a/src/modules/speech_to_text/speech_to_text.py b/src/modules/speech_to_text/speech_to_text.py index a086a2c..7925f3c 100644 --- a/src/modules/speech_to_text/speech_to_text.py +++ b/src/modules/speech_to_text/speech_to_text.py @@ -1,50 +1,50 @@ -import queue -import threading +from typing import Any, List, Optional import numpy as np import whisper +from ray import serve +from ray.serve import handle from src.core.module import Module -class SpeechToText(Module): +@serve.deployment(num_replicas=5) +class STTHandle: def __init__( self, - model_name: str = "base.en", - device: str = "cpu", - sample_rate: int = 16000, + model_name: str = "base", ): super().__init__() - print(model_name) - if device == "cpu": - import warnings - - warnings.filterwarnings( - "ignore", message="FP16 is not supported on CPU; using FP32 instead" - ) - self.model: whisper.Whisper = whisper.load_model(model_name, device=device) - self.SAMPLE_RATE: int = sample_rate - self.running: bool = False - self.audio_queue: queue.Queue = queue.Queue() - self.transcriptions: queue.Queue = queue.Queue() - self.pause_record = threading.Semaphore(1) - self.audio_to_process = threading.Semaphore(0) - self.prompt_available = threading.Semaphore(0) - self.noise_profile: np.ndarray - - def process_audio(self, buffer: bytes) -> None: - if not buffer: - return - - audio_array = np.frombuffer(buffer, dtype=np.int16) - audio_array = audio_array.astype(np.float32) / 32768.0 - - result: dict = self.model.transcribe(audio_array, language="en") + + self.model: whisper.Whisper = whisper.load_model(model_name) + + async def process(self, audio_array: np.ndarray) -> Optional[Any]: + result: dict = self.model.transcribe( + audio_array.copy(), condition_on_previous_text=False, fp16=False + ) result["text"] = result["text"].strip() if not result["text"] or result["text"] == "": - return + return None + + return result["text"] + + +class STT(Module): + input_type = "voice" + output_type = "text" + + def __init__(self, stt_handle: handle.DeploymentHandle[STTHandle]): + self.stt = stt_handle - self.publish("text.in", result["text"]) + self.chunks: List[np.ndarray] = [] + self.running = False - def set_subscriptions(self) -> None: - self.subscribe("speech.in", self.process_audio) + async def process(self, audio: np.ndarray) -> Optional[Any]: + self.chunks.append(audio) + if self.running is True: + return None + self.running = True + text = await self.stt.process.remote(np.concatenate(self.chunks, axis=0)) + self.chunks.clear() + self.running = False + return text diff --git a/src/modules/textIO/input.py b/src/modules/textIO/input.py deleted file mode 100644 index 6440238..0000000 --- a/src/modules/textIO/input.py +++ /dev/null @@ -1,17 +0,0 @@ -from src.core.module import Module - - -class TextInput(Module): - def set_subscriptions(self): - self.subscribe("std.in", self.stdin_to_text) - self.subscribe("std.out", lambda _: print(">> ", end="", flush=True)) - - def stdin_to_text(self, data): - print(">> ", end="", flush=True) - if data == "": - return - self.publish("text.in", data) - - def run_module(self, stop_event=None): - print(">> ", end="", flush=True) - stop_event.wait() diff --git a/src/modules/textIO/output.py b/src/modules/textIO/output.py deleted file mode 100644 index c68ca76..0000000 --- a/src/modules/textIO/output.py +++ /dev/null @@ -1,10 +0,0 @@ -from src.core.module import Module - - -class TextOutput(Module): - def set_subscriptions(self) -> None: - self.subscribe("llm.response", self.print_response) - - def print_response(self, text: str) -> None: - print(f"\r<< {text}") - self.publish("std.out", "") diff --git a/src/core/zmq/__init__.py b/src/modules/utils/__init__.py similarity index 100% rename from src/core/zmq/__init__.py rename to src/modules/utils/__init__.py diff --git a/src/modules/utils/sender.py b/src/modules/utils/sender.py new file mode 100644 index 0000000..e4c45a2 --- /dev/null +++ b/src/modules/utils/sender.py @@ -0,0 +1,19 @@ +from typing import Any + +from src.core.huri import WebSocket +from src.core.module import Module + + +class Sender(Module): + """Module to send output data to the client""" + + input_type = None + output_type = None + + def __init__(self, ws: WebSocket, type: str): + super().__init__() + self.ws: WebSocket = ws + self.input_type = type + + async def process(self, data: Any): + await self.ws.send_text(data) diff --git a/src/tools/logger.py b/src/tools/logger.py deleted file mode 100644 index 1b7d5c5..0000000 --- a/src/tools/logger.py +++ /dev/null @@ -1,104 +0,0 @@ -import logging -import multiprocessing as mp -from logging.handlers import QueueHandler, QueueListener -from typing import IO, Dict, Optional - - -def setup_handler( - stream: Optional[IO] = None, - filename: Optional[str] = None, - log_queue: Optional[mp.Queue] = None, - formatter: logging.Formatter = logging.Formatter( - "[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s", datefmt="%H:%M:%S" - ), -) -> logging.Handler: - if stream is not None: - handler = logging.StreamHandler(stream) - elif filename is not None: - handler = logging.FileHandler(filename) - elif log_queue is not None: - return QueueHandler(log_queue) - else: - # Default: stdout - handler = logging.StreamHandler() - - handler.setFormatter(formatter) - - return handler - - -def setup_logger( - name: str, - level: int = logging.DEBUG, - stream: Optional[IO] = None, - filename: Optional[str] = None, - log_queue: Optional[mp.Queue] = None, -) -> logging.Logger: - """ - Creates and returns a logger with optional output: - - log_queue (multiprocessing-safe queue, preferred for child processes) - - stream (e.g., sys.stdout) - - filename (log file) - - defaults to stdout if none is given - """ - logger = logging.getLogger(name) - logger.setLevel(level) - if log_queue: - logger.propagate = False - - logger.handlers.clear() - handler = setup_handler(stream, filename, log_queue) - logger.addHandler(handler) - - return logger - - -class LevelFilter(logging.Filter): - def __init__(self, root_level: int = logging.WARNING): - self.root_level = root_level - self.log_levels: Dict[str, int] = {} - - def filter(self, record: logging.LogRecord) -> bool: - """the root level has priority over custom levels""" - level = self.log_levels.get(record.name, self.root_level) - - return self.root_level <= record.levelno and level <= record.levelno - - def set_root_level(self, level: int) -> None: - self.root_level = level - - def add_level(self, name: str) -> None: - self.log_levels[name] = self.root_level - - def set_level(self, name: str, level: int) -> None: - if name not in self.log_levels: - raise ValueError(f"{name} has no linked log level") - self.log_levels[name] = level - - def set_levels(self, level: int) -> None: - self.set_root_level(level) - for name in self.log_levels: - self.set_level(name, level) - - def del_level(self, name: str) -> None: - del self.log_levels[name] - - -def setup_log_listener( - log_queue: mp.Queue, - filter: logging.Filter, - custom_handler: Optional[logging.Handler] = None, -) -> QueueListener: - """ - Starts a central logging listener that reads LogRecords from a queue - and emits them using normal loggers/handlers. - """ - formatter = logging.Formatter( - "[%(asctime)s] [%(processName)s] [%(name)s] [%(levelname)s] %(message)s", - datefmt="%H:%M:%S", - ) - handler = custom_handler or setup_handler(formatter=formatter) - handler.addFilter(filter) - - listener = QueueListener(log_queue, handler) - return listener