Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graphgen/bases/base_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class BaseSearcher(ABC):
"""

@abstractmethod
def search(self, query: str, **kwargs) -> Optional[Dict[str, Any]]:
async def search(self, query: str, **kwargs) -> Optional[Dict[str, Any]]:
"""
Search for data based on the given query.

Expand Down
89 changes: 37 additions & 52 deletions graphgen/models/searcher/db/interpro_searcher.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
import re
from typing import Dict, Optional

import requests
from requests.exceptions import RequestException
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)

from graphgen.bases import BaseSearcher
from graphgen.models.searcher.http_client import HTTPClient
from graphgen.utils import logger


Expand All @@ -26,75 +18,77 @@ class InterProSearch(BaseSearcher):
def __init__(
self,
api_timeout: int = 30,
qps: float = 10,
max_concurrent: int = 5,
):
"""
Initialize the InterPro Search client.

Args:
api_timeout (int): Request timeout in seconds.
"""
self.api_timeout = api_timeout
self.BASE_URL = "https://www.ebi.ac.uk/interpro/api"
self.http_client = HTTPClient(
base_url="https://www.ebi.ac.uk/interpro/api",
timeout=api_timeout,
qps=qps,
max_concurrent=max_concurrent,
headers={
"Accept": "application/json",
},
)

@staticmethod
def _is_uniprot_accession(text: str) -> bool:
"""Check if text looks like a UniProt accession number."""
# UniProt: 6-10 chars starting with letter, e.g., P01308, Q96KN2
return bool(re.fullmatch(r"[A-Z][A-Z0-9]{5,9}", text.strip(), re.I))

def search_by_uniprot_id(self, accession: str) -> Optional[Dict]:
async def search_by_uniprot_id(self, accession: str) -> Optional[Dict]:
"""
Search InterPro database by UniProt accession number.

This method queries the EBI API to get pre-computed domain information
for a known UniProt entry.

Args:
accession (str): UniProt accession number.

Returns:
Dictionary with domain information or None if not found.
"""
if not accession or not isinstance(accession, str) or not self._is_uniprot_accession(accession):
if (
not accession
or not isinstance(accession, str)
or not self._is_uniprot_accession(accession)
):
logger.error("Invalid accession provided")
return None

accession = accession.strip().upper()
endpoint = f"entry/interpro/protein/uniprot/{accession}/"

# Query InterPro REST API for UniProt entry
url = f"{self.BASE_URL}/entry/interpro/protein/uniprot/{accession}/"

response = requests.get(url, timeout=self.api_timeout)

if response.status_code != 200:
try:
data = await self.http_client.aget(endpoint)
except Exception as e:
logger.warning(
"Failed to search InterPro for accession %s: %d",
"Failed to search InterPro for accession %s: %s",
accession,
response.status_code,
str(e),
)
return None

data = response.json()

# Get entry details for each InterPro entry found
for result in data.get("results", []):
interpro_acc = result.get("metadata", {}).get("accession")
if interpro_acc:
entry_details = self.get_entry_details(interpro_acc)
entry_details = await self.get_entry_details(interpro_acc)
if entry_details:
result["entry_details"] = entry_details
Comment on lines 76 to 81
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The await inside this for-loop causes the API calls to get_entry_details to be executed sequentially. This can be a significant performance bottleneck as each API call will wait for the previous one to complete. To take full advantage of asyncio for I/O-bound operations, these calls should be made concurrently.

You can refactor this part to use asyncio.gather to perform the requests in parallel. You'll also need to add import asyncio at the top of the file.

Example refactoring to replace lines 76-81:

        results = data.get("results", [])

        async def fetch_and_update(result):
            interpro_acc = result.get("metadata", {}).get("accession")
            if interpro_acc:
                entry_details = await self.get_entry_details(interpro_acc)
                if entry_details:
                    result["entry_details"] = entry_details
        
        tasks = [fetch_and_update(r) for r in results]
        if tasks:
            await asyncio.gather(*tasks)


result = {
return {
"molecule_type": "protein",
"database": "InterPro",
"id": accession,
"content": data.get("results", []),
"url": f"https://www.ebi.ac.uk/interpro/protein/uniprot/{accession}/",
}

return result

def get_entry_details(self, interpro_accession: str) -> Optional[Dict]:
async def get_entry_details(self, interpro_accession: str) -> Optional[Dict]:
"""
Get detailed information for a specific InterPro entry.

Expand All @@ -106,26 +100,18 @@ def get_entry_details(self, interpro_accession: str) -> Optional[Dict]:
if not interpro_accession or not isinstance(interpro_accession, str):
return None

url = f"{self.BASE_URL}/entry/interpro/{interpro_accession}/"

response = requests.get(url, timeout=self.api_timeout)
if response.status_code != 200:
endpoint = f"entry/interpro/{interpro_accession}/"
try:
return await self.http_client.aget(endpoint)
except Exception as e:
logger.warning(
"Failed to get InterPro entry %s: %d",
"Failed to get InterPro entry %s: %s",
interpro_accession,
response.status_code,
str(e),
)
return None

return response.json()

@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=5),
retry=retry_if_exception_type(RequestException),
reraise=True,
)
def search(self, query: str, **kwargs) -> Optional[Dict]:
async def search(self, query: str, **kwargs) -> Optional[Dict]:
"""
Search InterPro for protein domain information by UniProt accession.

Expand All @@ -141,11 +127,10 @@ def search(self, query: str, **kwargs) -> Optional[Dict]:
return None

query = query.strip()
logger.debug("InterPro search query: %s", query[:100])

# Search by UniProt ID
logger.debug("Searching for UniProt accession: %s", query)
result = self.search_by_uniprot_id(query)
logger.debug("InterPro search query: %s", query[:100])
result = await self.search_by_uniprot_id(query)
logger.debug("InterPro search result: %s", str(result)[:100])

if result:
result["_search_query"] = query
Expand Down
99 changes: 99 additions & 0 deletions graphgen/models/searcher/http_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import asyncio
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Optional

import requests
from requests.adapters import HTTPAdapter
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)


class RateLimiter:
def __init__(self, qps: float = 10, burst: int = 5):
self.qps = qps
self.burst = burst
self.tokens = burst
self.last = time.time()
self._lock = threading.Lock()

def acquire(self):
with self._lock:
now = time.time()
self.tokens = min(
float(self.burst), self.tokens + (now - self.last) * self.qps
)
self.last = now

if self.tokens < 1:
sleep_time = (1 - self.tokens) / self.qps
time.sleep(sleep_time)
self.tokens = 0
else:
self.tokens -= 1


class HTTPClient:
def __init__(
self,
base_url: str,
timeout: float = 30,
qps: float = 10,
max_concurrent: int = 5,
headers: Optional[Dict] = None,
):
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.limiter = RateLimiter(qps=qps, burst=max_concurrent)

self.session = requests.Session()
self.session.headers.update(
headers or {"User-Agent": "GraphGen/1.0", "Accept": "application/json"}
)

adapter = HTTPAdapter(
pool_connections=max_concurrent,
pool_maxsize=max_concurrent * 2,
max_retries=3,
)
self.session.mount("https://", adapter)
self.session.mount("http://", adapter)

self._executor = None
self._max_workers = max_concurrent

@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=5),
retry=retry_if_exception_type(requests.RequestException),
reraise=True,
)
def get(self, endpoint: str) -> dict:
self.limiter.acquire()
url = f"{self.base_url}/{endpoint.lstrip('/')}"
resp = self.session.get(url, timeout=self.timeout)
resp.raise_for_status()
return resp.json()

async def aget(self, endpoint: str) -> dict:
if self._executor is None:
self._executor = ThreadPoolExecutor(max_workers=self._max_workers)

loop = asyncio.get_event_loop()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

asyncio.get_event_loop() is deprecated since Python 3.10 and its usage is discouraged. It's recommended to use asyncio.get_running_loop() instead, which is the modern and safer way to get the current event loop in an async context.

Suggested change
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()

return await loop.run_in_executor(self._executor, self.get, endpoint)

def close(self):
self.session.close()
if self._executor:
self._executor.shutdown(wait=True)

def __enter__(self):
return self

def __exit__(self, *args):
self.close()
24 changes: 14 additions & 10 deletions graphgen/operators/search/search_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
from typing import TYPE_CHECKING, Optional, Tuple

from graphgen.bases import BaseOperator
Expand Down Expand Up @@ -111,26 +110,31 @@ def process(self, batch: list) -> Tuple[list, dict]:
item for item in batch if item and "content" in item and "_trace_id" in item
]

if not seed_data:
logger.warning("No valid seeds in batch")
valid_seeds = []
queries = []
for seed in seed_data:
query = seed.get("content", "")
if not query:
logger.warning("Empty query for seed: %s", seed)
continue
valid_seeds.append(seed)
queries.append(query)

if not queries:
return [], {}

# Perform concurrent searches
results = run_concurrent(
partial(
self._perform_search,
searcher_obj=self.searcher,
data_source=self.data_source,
),
seed_data,
Comment on lines +120 to -125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This refactoring is a good improvement. As a result of this change, the _perform_search method is no longer used and can be removed to clean up the code. Additionally, the now-unused _perform_search method contains a bug where it calls the new async search method without await, so removing it would also prevent this latent bug from being reintroduced later.

self.searcher.search,
queries,
desc=f"Searching {self.data_source} database",
unit="keyword",
)

# Filter out None results and add _trace_id from original seeds
final_results = []
meta_updates = {}
for result, seed in zip(results, seed_data):
for result, seed in zip(results, valid_seeds):
if result is None:
continue
result["_trace_id"] = self.get_trace_id(result)
Expand Down
Loading