diff --git a/graphgen/bases/base_searcher.py b/graphgen/bases/base_searcher.py index 61845e32..d8ccd556 100644 --- a/graphgen/bases/base_searcher.py +++ b/graphgen/bases/base_searcher.py @@ -8,7 +8,7 @@ class BaseSearcher(ABC): """ @abstractmethod - def search(self, query: str, **kwargs) -> Optional[Dict[str, Any]]: + async def search(self, query: str, **kwargs) -> Optional[Dict[str, Any]]: """ Search for data based on the given query. diff --git a/graphgen/models/searcher/db/interpro_searcher.py b/graphgen/models/searcher/db/interpro_searcher.py index 9d3e7c06..cb48883d 100644 --- a/graphgen/models/searcher/db/interpro_searcher.py +++ b/graphgen/models/searcher/db/interpro_searcher.py @@ -1,16 +1,8 @@ import re from typing import Dict, Optional -import requests -from requests.exceptions import RequestException -from tenacity import ( - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) - from graphgen.bases import BaseSearcher +from graphgen.models.searcher.http_client import HTTPClient from graphgen.utils import logger @@ -26,6 +18,8 @@ class InterProSearch(BaseSearcher): def __init__( self, api_timeout: int = 30, + qps: float = 10, + max_concurrent: int = 5, ): """ Initialize the InterPro Search client. @@ -33,58 +27,60 @@ def __init__( Args: api_timeout (int): Request timeout in seconds. """ - self.api_timeout = api_timeout - self.BASE_URL = "https://www.ebi.ac.uk/interpro/api" + self.http_client = HTTPClient( + base_url="https://www.ebi.ac.uk/interpro/api", + timeout=api_timeout, + qps=qps, + max_concurrent=max_concurrent, + headers={ + "Accept": "application/json", + }, + ) @staticmethod def _is_uniprot_accession(text: str) -> bool: """Check if text looks like a UniProt accession number.""" - # UniProt: 6-10 chars starting with letter, e.g., P01308, Q96KN2 return bool(re.fullmatch(r"[A-Z][A-Z0-9]{5,9}", text.strip(), re.I)) - def search_by_uniprot_id(self, accession: str) -> Optional[Dict]: + async def search_by_uniprot_id(self, accession: str) -> Optional[Dict]: """ Search InterPro database by UniProt accession number. - This method queries the EBI API to get pre-computed domain information - for a known UniProt entry. - Args: accession (str): UniProt accession number. Returns: Dictionary with domain information or None if not found. """ - if not accession or not isinstance(accession, str) or not self._is_uniprot_accession(accession): + if ( + not accession + or not isinstance(accession, str) + or not self._is_uniprot_accession(accession) + ): logger.error("Invalid accession provided") return None accession = accession.strip().upper() + endpoint = f"entry/interpro/protein/uniprot/{accession}/" - # Query InterPro REST API for UniProt entry - url = f"{self.BASE_URL}/entry/interpro/protein/uniprot/{accession}/" - - response = requests.get(url, timeout=self.api_timeout) - - if response.status_code != 200: + try: + data = await self.http_client.aget(endpoint) + except Exception as e: logger.warning( - "Failed to search InterPro for accession %s: %d", + "Failed to search InterPro for accession %s: %s", accession, - response.status_code, + str(e), ) return None - data = response.json() - - # Get entry details for each InterPro entry found for result in data.get("results", []): interpro_acc = result.get("metadata", {}).get("accession") if interpro_acc: - entry_details = self.get_entry_details(interpro_acc) + entry_details = await self.get_entry_details(interpro_acc) if entry_details: result["entry_details"] = entry_details - result = { + return { "molecule_type": "protein", "database": "InterPro", "id": accession, @@ -92,9 +88,7 @@ def search_by_uniprot_id(self, accession: str) -> Optional[Dict]: "url": f"https://www.ebi.ac.uk/interpro/protein/uniprot/{accession}/", } - return result - - def get_entry_details(self, interpro_accession: str) -> Optional[Dict]: + async def get_entry_details(self, interpro_accession: str) -> Optional[Dict]: """ Get detailed information for a specific InterPro entry. @@ -106,26 +100,18 @@ def get_entry_details(self, interpro_accession: str) -> Optional[Dict]: if not interpro_accession or not isinstance(interpro_accession, str): return None - url = f"{self.BASE_URL}/entry/interpro/{interpro_accession}/" - - response = requests.get(url, timeout=self.api_timeout) - if response.status_code != 200: + endpoint = f"entry/interpro/{interpro_accession}/" + try: + return await self.http_client.aget(endpoint) + except Exception as e: logger.warning( - "Failed to get InterPro entry %s: %d", + "Failed to get InterPro entry %s: %s", interpro_accession, - response.status_code, + str(e), ) return None - return response.json() - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=2, max=5), - retry=retry_if_exception_type(RequestException), - reraise=True, - ) - def search(self, query: str, **kwargs) -> Optional[Dict]: + async def search(self, query: str, **kwargs) -> Optional[Dict]: """ Search InterPro for protein domain information by UniProt accession. @@ -141,11 +127,10 @@ def search(self, query: str, **kwargs) -> Optional[Dict]: return None query = query.strip() - logger.debug("InterPro search query: %s", query[:100]) - # Search by UniProt ID - logger.debug("Searching for UniProt accession: %s", query) - result = self.search_by_uniprot_id(query) + logger.debug("InterPro search query: %s", query[:100]) + result = await self.search_by_uniprot_id(query) + logger.debug("InterPro search result: %s", str(result)[:100]) if result: result["_search_query"] = query diff --git a/graphgen/models/searcher/http_client.py b/graphgen/models/searcher/http_client.py new file mode 100644 index 00000000..61533e9d --- /dev/null +++ b/graphgen/models/searcher/http_client.py @@ -0,0 +1,99 @@ +import asyncio +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, Optional + +import requests +from requests.adapters import HTTPAdapter +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + + +class RateLimiter: + def __init__(self, qps: float = 10, burst: int = 5): + self.qps = qps + self.burst = burst + self.tokens = burst + self.last = time.time() + self._lock = threading.Lock() + + def acquire(self): + with self._lock: + now = time.time() + self.tokens = min( + float(self.burst), self.tokens + (now - self.last) * self.qps + ) + self.last = now + + if self.tokens < 1: + sleep_time = (1 - self.tokens) / self.qps + time.sleep(sleep_time) + self.tokens = 0 + else: + self.tokens -= 1 + + +class HTTPClient: + def __init__( + self, + base_url: str, + timeout: float = 30, + qps: float = 10, + max_concurrent: int = 5, + headers: Optional[Dict] = None, + ): + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self.limiter = RateLimiter(qps=qps, burst=max_concurrent) + + self.session = requests.Session() + self.session.headers.update( + headers or {"User-Agent": "GraphGen/1.0", "Accept": "application/json"} + ) + + adapter = HTTPAdapter( + pool_connections=max_concurrent, + pool_maxsize=max_concurrent * 2, + max_retries=3, + ) + self.session.mount("https://", adapter) + self.session.mount("http://", adapter) + + self._executor = None + self._max_workers = max_concurrent + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=5), + retry=retry_if_exception_type(requests.RequestException), + reraise=True, + ) + def get(self, endpoint: str) -> dict: + self.limiter.acquire() + url = f"{self.base_url}/{endpoint.lstrip('/')}" + resp = self.session.get(url, timeout=self.timeout) + resp.raise_for_status() + return resp.json() + + async def aget(self, endpoint: str) -> dict: + if self._executor is None: + self._executor = ThreadPoolExecutor(max_workers=self._max_workers) + + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self._executor, self.get, endpoint) + + def close(self): + self.session.close() + if self._executor: + self._executor.shutdown(wait=True) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() diff --git a/graphgen/operators/search/search_service.py b/graphgen/operators/search/search_service.py index 220db049..b5089427 100644 --- a/graphgen/operators/search/search_service.py +++ b/graphgen/operators/search/search_service.py @@ -1,4 +1,3 @@ -from functools import partial from typing import TYPE_CHECKING, Optional, Tuple from graphgen.bases import BaseOperator @@ -111,18 +110,23 @@ def process(self, batch: list) -> Tuple[list, dict]: item for item in batch if item and "content" in item and "_trace_id" in item ] - if not seed_data: - logger.warning("No valid seeds in batch") + valid_seeds = [] + queries = [] + for seed in seed_data: + query = seed.get("content", "") + if not query: + logger.warning("Empty query for seed: %s", seed) + continue + valid_seeds.append(seed) + queries.append(query) + + if not queries: return [], {} # Perform concurrent searches results = run_concurrent( - partial( - self._perform_search, - searcher_obj=self.searcher, - data_source=self.data_source, - ), - seed_data, + self.searcher.search, + queries, desc=f"Searching {self.data_source} database", unit="keyword", ) @@ -130,7 +134,7 @@ def process(self, batch: list) -> Tuple[list, dict]: # Filter out None results and add _trace_id from original seeds final_results = [] meta_updates = {} - for result, seed in zip(results, seed_data): + for result, seed in zip(results, valid_seeds): if result is None: continue result["_trace_id"] = self.get_trace_id(result)