om4r932 commited on
Commit
4673c0a
·
1 Parent(s): 43b72e8

Change data provider

Browse files
Files changed (3) hide show
  1. app.py +35 -54
  2. requirements.txt +3 -3
  3. schemas.py +16 -0
app.py CHANGED
@@ -1,34 +1,29 @@
1
- from typing import List, Dict, Any
2
- import zipfile
3
- import os
4
- import warnings
5
- from openai import OpenAI
6
  from dotenv import load_dotenv
 
 
 
 
 
 
 
7
  import bm25s
8
- from fastapi.staticfiles import StaticFiles
9
- from nltk.stem import WordNetLemmatizer
10
- import nltk
11
  from fastapi import FastAPI
12
- from fastapi.responses import FileResponse
13
  from fastapi.middleware.cors import CORSMiddleware
14
- import numpy as np
15
- from pydantic import BaseModel
16
- from sklearn.preprocessing import MinMaxScaler
17
 
18
- load_dotenv()
 
19
 
20
- nltk.download('wordnet')
21
- if os.path.exists("bm25s.zip"):
22
- with zipfile.ZipFile("bm25s.zip", 'r') as zip_ref:
23
- zip_ref.extractall(".")
24
- bm25_engine = bm25s.BM25.load("3gpp_bm25_docs", load_corpus=True)
25
- lemmatizer = WordNetLemmatizer()
26
- llm = OpenAI(api_key=os.environ.get("GEMINI"), base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
27
 
28
- warnings.filterwarnings("ignore")
29
 
30
  app = FastAPI(title="RAGnarok",
31
- description="API to search specifications for RAG")
32
 
33
  app.mount("/static", StaticFiles(directory="static"), name="static")
34
 
@@ -44,46 +39,22 @@ app.add_middleware(
44
  allow_headers=["*"],
45
  )
46
 
47
- class SearchRequest(BaseModel):
48
- keyword: str
49
- threshold: int
50
-
51
- class SearchResponse(BaseModel):
52
- results: List[Dict[str, Any]]
53
-
54
- class ChatRequest(BaseModel):
55
- messages: List[Dict[str, str]]
56
- model: str
57
-
58
- class ChatResponse(BaseModel):
59
- response: str
60
-
61
  @app.get("/")
62
- async def main_menu():
63
  return FileResponse(os.path.join("templates", "index.html"))
64
 
65
- @app.post("/chat", response_model=ChatResponse)
66
- def question_the_sources(req: ChatRequest):
67
- model = req.model
68
- resp = llm.chat.completions.create(
69
- messages=req.messages,
70
- model=model
71
- )
72
- return ChatResponse(response=resp.choices[0].message.content)
73
-
74
  @app.post("/search", response_model=SearchResponse)
75
  def search_specifications(req: SearchRequest):
76
  keywords = req.keyword
77
  threshold = req.threshold
78
- query = lemmatizer.lemmatize(keywords)
79
  results_out = []
80
- query_tokens = bm25s.tokenize(query)
81
- results, scores = bm25_engine.retrieve(query_tokens, k=len(bm25_engine.corpus))
82
 
83
  def calculate_boosted_score(metadata, score, query):
84
- title = {lemmatizer.lemmatize(metadata['title']).lower()}
85
- q = {query.lower()}
86
- spec_id_presence = 0.5 if len(q & {metadata['id']}) > 0 else 0
87
  booster = len(q & title) * 0.5
88
  return score + spec_id_presence + booster
89
 
@@ -96,7 +67,7 @@ def search_specifications(req: SearchRequest):
96
  score = scores[0, i]
97
  spec = doc["metadata"]["id"]
98
 
99
- boosted_score = calculate_boosted_score(doc['metadata'], score, query)
100
 
101
  if spec not in spec_scores or boosted_score > spec_scores[spec]:
102
  spec_scores[spec] = boosted_score
@@ -135,4 +106,14 @@ def search_specifications(req: SearchRequest):
135
  break
136
  results_out.append({'id': metadata['id'], 'title': metadata['title'], 'section': metadata['section_title'], 'content': details['doc']['text'], 'similarity': int(details['normalized_score']*100)})
137
 
138
- return SearchResponse(results=results_out)
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, warnings
 
 
 
 
2
  from dotenv import load_dotenv
3
+ from schemas import *
4
+
5
+ os.environ["CURL_CA_BUNDLE"] = ""
6
+ warnings.filterwarnings("ignore")
7
+ load_dotenv()
8
+
9
+ from datasets import load_dataset
10
  import bm25s
11
+ from bm25s.hf import BM25HF
12
+
 
13
  from fastapi import FastAPI
 
14
  from fastapi.middleware.cors import CORSMiddleware
15
+ from fastapi.responses import FileResponse
16
+ from fastapi.staticfiles import StaticFiles
 
17
 
18
+ from sklearn.preprocessing import MinMaxScaler
19
+ import numpy as np
20
 
21
+ import litellm
 
 
 
 
 
 
22
 
23
+ bm25_index = BM25HF.load_from_hub("OrganizedProgrammers/3GPPBM25IndexSections", load_corpus=True, token=os.environ["HF_TOKEN"])
24
 
25
  app = FastAPI(title="RAGnarok",
26
+ description="Speak with the specifications")
27
 
28
  app.mount("/static", StaticFiles(directory="static"), name="static")
29
 
 
39
  allow_headers=["*"],
40
  )
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  @app.get("/")
43
+ def main_menu():
44
  return FileResponse(os.path.join("templates", "index.html"))
45
 
 
 
 
 
 
 
 
 
 
46
  @app.post("/search", response_model=SearchResponse)
47
  def search_specifications(req: SearchRequest):
48
  keywords = req.keyword
49
  threshold = req.threshold
 
50
  results_out = []
51
+ query_tokens = bm25s.tokenize(keywords)
52
+ results, scores = bm25_index.retrieve(query_tokens, k=len(bm25_index.corpus))
53
 
54
  def calculate_boosted_score(metadata, score, query):
55
+ title = set(metadata['title'].lower().split())
56
+ q = set(query.lower().split())
57
+ spec_id_presence = 0.5 if metadata['id'].lower() in q else 0
58
  booster = len(q & title) * 0.5
59
  return score + spec_id_presence + booster
60
 
 
67
  score = scores[0, i]
68
  spec = doc["metadata"]["id"]
69
 
70
+ boosted_score = calculate_boosted_score(doc['metadata'], score, keywords)
71
 
72
  if spec not in spec_scores or boosted_score > spec_scores[spec]:
73
  spec_scores[spec] = boosted_score
 
106
  break
107
  results_out.append({'id': metadata['id'], 'title': metadata['title'], 'section': metadata['section_title'], 'content': details['doc']['text'], 'similarity': int(details['normalized_score']*100)})
108
 
109
+ return SearchResponse(results=results_out)
110
+
111
+ @app.post("/chat", response_model=ChatResponse)
112
+ def questions_the_sources(req: ChatRequest):
113
+ model = req.model
114
+ resp = litellm.completion(
115
+ model=f"gemini/{model}",
116
+ messages=req.messages,
117
+ api_key=os.environ["GEMINI"]
118
+ )
119
+ return ChatResponse(response=resp.choices[0].message.content)
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- openai
2
  fastapi
3
  uvicorn[standard]
4
  python-dotenv
5
  bm25s[full]
6
- nltk
 
7
  numpy
8
- scikit-learn
 
 
1
  fastapi
2
  uvicorn[standard]
3
  python-dotenv
4
  bm25s[full]
5
+ scikit-learn
6
+ litellm
7
  numpy
8
+ datasets
schemas.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from pydantic import BaseModel
3
+
4
+ class SearchRequest(BaseModel):
5
+ keyword: str
6
+ threshold: int
7
+
8
+ class SearchResponse(BaseModel):
9
+ results: List[Dict[str, Any]]
10
+
11
+ class ChatRequest(BaseModel):
12
+ messages: List[Dict[str, str]]
13
+ model: str
14
+
15
+ class ChatResponse(BaseModel):
16
+ response: str