-
Notifications
You must be signed in to change notification settings - Fork 76
feat: async search #181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: async search #181
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,99 @@ | ||||||
| import asyncio | ||||||
| import threading | ||||||
| import time | ||||||
| from concurrent.futures import ThreadPoolExecutor | ||||||
| from typing import Dict, Optional | ||||||
|
|
||||||
| import requests | ||||||
| from requests.adapters import HTTPAdapter | ||||||
| from tenacity import ( | ||||||
| retry, | ||||||
| retry_if_exception_type, | ||||||
| stop_after_attempt, | ||||||
| wait_exponential, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| class RateLimiter: | ||||||
| def __init__(self, qps: float = 10, burst: int = 5): | ||||||
| self.qps = qps | ||||||
| self.burst = burst | ||||||
| self.tokens = burst | ||||||
| self.last = time.time() | ||||||
| self._lock = threading.Lock() | ||||||
|
|
||||||
| def acquire(self): | ||||||
| with self._lock: | ||||||
| now = time.time() | ||||||
| self.tokens = min( | ||||||
| float(self.burst), self.tokens + (now - self.last) * self.qps | ||||||
| ) | ||||||
| self.last = now | ||||||
|
|
||||||
| if self.tokens < 1: | ||||||
| sleep_time = (1 - self.tokens) / self.qps | ||||||
| time.sleep(sleep_time) | ||||||
| self.tokens = 0 | ||||||
| else: | ||||||
| self.tokens -= 1 | ||||||
|
|
||||||
|
|
||||||
| class HTTPClient: | ||||||
| def __init__( | ||||||
| self, | ||||||
| base_url: str, | ||||||
| timeout: float = 30, | ||||||
| qps: float = 10, | ||||||
| max_concurrent: int = 5, | ||||||
| headers: Optional[Dict] = None, | ||||||
| ): | ||||||
| self.base_url = base_url.rstrip("/") | ||||||
| self.timeout = timeout | ||||||
| self.limiter = RateLimiter(qps=qps, burst=max_concurrent) | ||||||
|
|
||||||
| self.session = requests.Session() | ||||||
| self.session.headers.update( | ||||||
| headers or {"User-Agent": "GraphGen/1.0", "Accept": "application/json"} | ||||||
| ) | ||||||
|
|
||||||
| adapter = HTTPAdapter( | ||||||
| pool_connections=max_concurrent, | ||||||
| pool_maxsize=max_concurrent * 2, | ||||||
| max_retries=3, | ||||||
| ) | ||||||
| self.session.mount("https://", adapter) | ||||||
| self.session.mount("http://", adapter) | ||||||
|
|
||||||
| self._executor = None | ||||||
| self._max_workers = max_concurrent | ||||||
|
|
||||||
| @retry( | ||||||
| stop=stop_after_attempt(3), | ||||||
| wait=wait_exponential(multiplier=1, min=2, max=5), | ||||||
| retry=retry_if_exception_type(requests.RequestException), | ||||||
| reraise=True, | ||||||
| ) | ||||||
| def get(self, endpoint: str) -> dict: | ||||||
| self.limiter.acquire() | ||||||
| url = f"{self.base_url}/{endpoint.lstrip('/')}" | ||||||
| resp = self.session.get(url, timeout=self.timeout) | ||||||
| resp.raise_for_status() | ||||||
| return resp.json() | ||||||
|
|
||||||
| async def aget(self, endpoint: str) -> dict: | ||||||
| if self._executor is None: | ||||||
| self._executor = ThreadPoolExecutor(max_workers=self._max_workers) | ||||||
|
|
||||||
| loop = asyncio.get_event_loop() | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| return await loop.run_in_executor(self._executor, self.get, endpoint) | ||||||
|
|
||||||
| def close(self): | ||||||
| self.session.close() | ||||||
| if self._executor: | ||||||
| self._executor.shutdown(wait=True) | ||||||
|
|
||||||
| def __enter__(self): | ||||||
| return self | ||||||
|
|
||||||
| def __exit__(self, *args): | ||||||
| self.close() | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,3 @@ | ||
| from functools import partial | ||
| from typing import TYPE_CHECKING, Optional, Tuple | ||
|
|
||
| from graphgen.bases import BaseOperator | ||
|
|
@@ -111,26 +110,31 @@ def process(self, batch: list) -> Tuple[list, dict]: | |
| item for item in batch if item and "content" in item and "_trace_id" in item | ||
| ] | ||
|
|
||
| if not seed_data: | ||
| logger.warning("No valid seeds in batch") | ||
| valid_seeds = [] | ||
| queries = [] | ||
| for seed in seed_data: | ||
| query = seed.get("content", "") | ||
| if not query: | ||
| logger.warning("Empty query for seed: %s", seed) | ||
| continue | ||
| valid_seeds.append(seed) | ||
| queries.append(query) | ||
|
|
||
| if not queries: | ||
| return [], {} | ||
|
|
||
| # Perform concurrent searches | ||
| results = run_concurrent( | ||
| partial( | ||
| self._perform_search, | ||
| searcher_obj=self.searcher, | ||
| data_source=self.data_source, | ||
| ), | ||
| seed_data, | ||
|
Comment on lines
+120
to
-125
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This refactoring is a good improvement. As a result of this change, the |
||
| self.searcher.search, | ||
| queries, | ||
| desc=f"Searching {self.data_source} database", | ||
| unit="keyword", | ||
| ) | ||
|
|
||
| # Filter out None results and add _trace_id from original seeds | ||
| final_results = [] | ||
| meta_updates = {} | ||
| for result, seed in zip(results, seed_data): | ||
| for result, seed in zip(results, valid_seeds): | ||
| if result is None: | ||
| continue | ||
| result["_trace_id"] = self.get_trace_id(result) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
awaitinside this for-loop causes the API calls toget_entry_detailsto be executed sequentially. This can be a significant performance bottleneck as each API call will wait for the previous one to complete. To take full advantage ofasynciofor I/O-bound operations, these calls should be made concurrently.You can refactor this part to use
asyncio.gatherto perform the requests in parallel. You'll also need to addimport asyncioat the top of the file.Example refactoring to replace lines 76-81: