-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
64 lines (47 loc) · 1.57 KB
/
main.py
File metadata and controls
64 lines (47 loc) · 1.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from enum import Enum
from typing import List
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from models.align_retriver import ALIGNRetriever
from models.blip_retriever import BLIPRetriever
from models.clip_retriever import CLIPRetriever
from models.flava_retriever import FLAVARetriever
class ModelType(str, Enum):
CLIP = "clip"
BLIP = "blip"
FLAVA = "flava"
ALIGN = "align"
class RetrievalRequest(BaseModel):
query: str
n: int = 60
model: ModelType
app = FastAPI(title="Image Retrieval API")
# Initialize retrievers with image directory
IMAGE_DIR = "/home/mahdi/Projects/Needle/backend/resources/nocaps"
retrievers = {
ModelType.CLIP: CLIPRetriever(IMAGE_DIR),
ModelType.BLIP: BLIPRetriever(IMAGE_DIR),
ModelType.FLAVA: FLAVARetriever(IMAGE_DIR),
ModelType.ALIGN: ALIGNRetriever(IMAGE_DIR),
}
@app.post("/retrieve/", response_model=List[str])
async def retrieve_images(request: RetrievalRequest):
"""
Retrieve the top n most relevant images for the given query using the specified model.
"""
try:
retriever = retrievers[request.model]
results = retriever.retrieve(request.query, request.n)
return results
except Exception as e:
print(e)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/models/", response_model=List[str])
async def list_models():
"""
List all available retrieval models.
"""
return [model.value for model in ModelType]
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8020)