pvanand commited on
Commit
11f96c1
1 Parent(s): de5a712

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +29 -18
main.py CHANGED
@@ -1,6 +1,6 @@
1
- from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
  from typing import List
5
  import json
6
  import os
@@ -11,7 +11,11 @@ from txtai.embeddings import Embeddings
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
- app = FastAPI()
 
 
 
 
15
 
16
  # Enable CORS
17
  app.add_middleware(
@@ -25,22 +29,20 @@ app.add_middleware(
25
  embeddings = Embeddings({"path": "avsolatorio/GIST-all-MiniLM-L6-v2"})
26
 
27
  class DocumentRequest(BaseModel):
28
- index_id: str
29
- documents: List[str]
30
 
31
  class QueryRequest(BaseModel):
32
- index_id: str
33
- query: str
34
- num_results: int
35
 
36
- def save_embeddings(index_id, document_list):
37
  try:
38
  folder_path = f"/app/indexes/{index_id}"
39
  os.makedirs(folder_path, exist_ok=True)
40
-
41
  # Save embeddings
42
  embeddings.save(f"{folder_path}/embeddings")
43
-
44
  # Save document_list
45
  with open(f"{folder_path}/document_list.json", "w") as f:
46
  json.dump(document_list, f)
@@ -49,29 +51,31 @@ def save_embeddings(index_id, document_list):
49
  logger.error(f"Error saving embeddings for index_id {index_id}: {str(e)}")
50
  raise HTTPException(status_code=500, detail=f"Error saving embeddings: {str(e)}")
51
 
52
- def load_embeddings(index_id):
53
  try:
54
  folder_path = f"/app/indexes/{index_id}"
55
-
56
  if not os.path.exists(folder_path):
57
  logger.error(f"Index not found for index_id: {index_id}")
58
  raise HTTPException(status_code=404, detail="Index not found")
59
-
60
  # Load embeddings
61
  embeddings.load(f"{folder_path}/embeddings")
62
-
63
  # Load document_list
64
  with open(f"{folder_path}/document_list.json", "r") as f:
65
  document_list = json.load(f)
66
-
67
  logger.info(f"Embeddings and document list loaded for index_id: {index_id}")
68
  return document_list
69
  except Exception as e:
70
  logger.error(f"Error loading embeddings for index_id {index_id}: {str(e)}")
71
  raise HTTPException(status_code=500, detail=f"Error loading embeddings: {str(e)}")
72
 
73
- @app.post("/create_index/")
74
  async def create_index(request: DocumentRequest):
 
 
 
 
 
 
75
  try:
76
  document_list = [(i, text, None) for i, text in enumerate(request.documents)]
77
  embeddings.index(document_list)
@@ -82,8 +86,15 @@ async def create_index(request: DocumentRequest):
82
  logger.error(f"Error creating index: {str(e)}")
83
  raise HTTPException(status_code=500, detail=f"Error creating index: {str(e)}")
84
 
85
- @app.post("/query_index/")
86
  async def query_index(request: QueryRequest):
 
 
 
 
 
 
 
87
  try:
88
  document_list = load_embeddings(request.index_id)
89
  results = embeddings.search(request.query, request.num_results)
 
1
+ from fastapi import FastAPI, HTTPException, Query, Path
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel, Field
4
  from typing import List
5
  import json
6
  import os
 
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
+ app = FastAPI(
15
+ title="Embeddings API",
16
+ description="An API for creating and querying text embeddings indexes.",
17
+ version="1.0.0"
18
+ )
19
 
20
  # Enable CORS
21
  app.add_middleware(
 
29
  embeddings = Embeddings({"path": "avsolatorio/GIST-all-MiniLM-L6-v2"})
30
 
31
  class DocumentRequest(BaseModel):
32
+ index_id: str = Field(..., description="Unique identifier for the index")
33
+ documents: List[str] = Field(..., description="List of documents to be indexed")
34
 
35
  class QueryRequest(BaseModel):
36
+ index_id: str = Field(..., description="Unique identifier for the index to query")
37
+ query: str = Field(..., description="The search query")
38
+ num_results: int = Field(..., description="Number of results to return", ge=1)
39
 
40
+ def save_embeddings(index_id: str, document_list: List[str]):
41
  try:
42
  folder_path = f"/app/indexes/{index_id}"
43
  os.makedirs(folder_path, exist_ok=True)
 
44
  # Save embeddings
45
  embeddings.save(f"{folder_path}/embeddings")
 
46
  # Save document_list
47
  with open(f"{folder_path}/document_list.json", "w") as f:
48
  json.dump(document_list, f)
 
51
  logger.error(f"Error saving embeddings for index_id {index_id}: {str(e)}")
52
  raise HTTPException(status_code=500, detail=f"Error saving embeddings: {str(e)}")
53
 
54
+ def load_embeddings(index_id: str) -> List[str]:
55
  try:
56
  folder_path = f"/app/indexes/{index_id}"
 
57
  if not os.path.exists(folder_path):
58
  logger.error(f"Index not found for index_id: {index_id}")
59
  raise HTTPException(status_code=404, detail="Index not found")
 
60
  # Load embeddings
61
  embeddings.load(f"{folder_path}/embeddings")
 
62
  # Load document_list
63
  with open(f"{folder_path}/document_list.json", "r") as f:
64
  document_list = json.load(f)
 
65
  logger.info(f"Embeddings and document list loaded for index_id: {index_id}")
66
  return document_list
67
  except Exception as e:
68
  logger.error(f"Error loading embeddings for index_id {index_id}: {str(e)}")
69
  raise HTTPException(status_code=500, detail=f"Error loading embeddings: {str(e)}")
70
 
71
+ @app.post("/create_index/", response_model=dict, tags=["Index Operations"])
72
  async def create_index(request: DocumentRequest):
73
+ """
74
+ Create a new index with the given documents.
75
+
76
+ - **index_id**: Unique identifier for the index
77
+ - **documents**: List of documents to be indexed
78
+ """
79
  try:
80
  document_list = [(i, text, None) for i, text in enumerate(request.documents)]
81
  embeddings.index(document_list)
 
86
  logger.error(f"Error creating index: {str(e)}")
87
  raise HTTPException(status_code=500, detail=f"Error creating index: {str(e)}")
88
 
89
+ @app.post("/query_index/", response_model=dict, tags=["Index Operations"])
90
  async def query_index(request: QueryRequest):
91
+ """
92
+ Query an existing index with the given search query.
93
+
94
+ - **index_id**: Unique identifier for the index to query
95
+ - **query**: The search query
96
+ - **num_results**: Number of results to return
97
+ """
98
  try:
99
  document_list = load_embeddings(request.index_id)
100
  results = embeddings.search(request.query, request.num_results)