Spaces:
Sleeping
Sleeping
Update app/main.py
Browse files- app/main.py +59 -60
app/main.py
CHANGED
@@ -1,27 +1,43 @@
|
|
1 |
import os
|
2 |
import re
|
3 |
-
from
|
4 |
-
from
|
5 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
6 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
7 |
-
from
|
8 |
-
from langchain.chains import
|
9 |
-
from
|
10 |
-
from
|
11 |
from fastapi.middleware.cors import CORSMiddleware
|
12 |
from fastapi import FastAPI
|
13 |
from pydantic import BaseModel
|
14 |
-
from
|
15 |
import nltk
|
16 |
import time
|
17 |
|
18 |
-
|
|
|
|
|
|
|
19 |
|
20 |
-
#
|
21 |
app = FastAPI()
|
22 |
|
23 |
-
#
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
llm = ChatOpenAI(
|
26 |
api_key=openai_api_key,
|
27 |
model_name="gpt-4-turbo-preview",
|
@@ -32,9 +48,11 @@ llm = ChatOpenAI(
|
|
32 |
def read_root():
|
33 |
return {"Hello": "World"}
|
34 |
|
|
|
35 |
class Query(BaseModel):
|
36 |
query_text: str
|
37 |
|
|
|
38 |
def clean_response(response):
|
39 |
cleaned = response.strip()
|
40 |
cleaned = re.sub(r'^["\']+|["\']+$', '', cleaned)
|
@@ -42,59 +60,40 @@ def clean_response(response):
|
|
42 |
cleaned = cleaned.replace('\\n', '')
|
43 |
return cleaned
|
44 |
|
|
|
45 |
prompt = ChatPromptTemplate.from_template(
|
46 |
-
"""
|
47 |
-
You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET),
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
53 |
)
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
@app.post("/chat")
|
56 |
-
def
|
|
|
57 |
try:
|
58 |
-
#
|
59 |
-
|
60 |
-
|
61 |
-
encode_kwargs={'normalize_embeddings': True}
|
62 |
)
|
63 |
-
|
64 |
-
except Exception as e:
|
65 |
-
print(f"Error loading vector store: {str(e)}")
|
66 |
-
return {"response": "Vector Store Not Found or Error Loading. Please run /setup first."}
|
67 |
-
|
68 |
-
prompt1 = query.query_text
|
69 |
-
if prompt1:
|
70 |
-
start = time.process_time()
|
71 |
-
document_chain = create_stuff_documents_chain(llm, prompt)
|
72 |
-
retriever = vectors.as_retriever()
|
73 |
-
retrieval_chain = create_retrieval_chain(retriever, document_chain)
|
74 |
-
response = retrieval_chain.invoke({'input': prompt1})
|
75 |
-
|
76 |
-
cleaned_response = clean_response(response['answer'])
|
77 |
-
print("Response time:", time.process_time() - start)
|
78 |
return {"response": cleaned_response}
|
79 |
-
|
80 |
-
return {"
|
81 |
-
|
82 |
-
@app.get("/setup")
|
83 |
-
def setup():
|
84 |
-
# Example setup function for vector embedding
|
85 |
-
documents = [] # Load your documents here
|
86 |
-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
|
87 |
-
chunks = text_splitter.split_documents(documents)
|
88 |
-
|
89 |
-
model_name = "BAAI/bge-base-en"
|
90 |
-
encode_kwargs = {'normalize_embeddings': True}
|
91 |
-
embeddings = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
|
92 |
-
|
93 |
-
db = FAISS.from_documents(chunks, embeddings)
|
94 |
-
db.save_local("./vectors_db")
|
95 |
-
print("Vector store created and saved successfully.")
|
96 |
-
return {"response": "Vector Store DB Is Ready"}
|
97 |
-
|
98 |
-
if __name__ == "__main__":
|
99 |
-
import uvicorn
|
100 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
1 |
import os
|
2 |
import re
|
3 |
+
from langchain.llms import OpenAI
|
4 |
+
from langchain.chat_models import ChatOpenAI
|
5 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
6 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
7 |
+
from langchain.prompts.chat import ChatPromptTemplate
|
8 |
+
from langchain.chains import RetrievalQA
|
9 |
+
from langchain.vectorstores import FAISS
|
10 |
+
from langchain.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
|
11 |
from fastapi.middleware.cors import CORSMiddleware
|
12 |
from fastapi import FastAPI
|
13 |
from pydantic import BaseModel
|
14 |
+
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
15 |
import nltk
|
16 |
import time
|
17 |
|
18 |
+
# Ensure necessary directories are writable
|
19 |
+
cache_dir = "/app/cache" # Update this to a writable directory path
|
20 |
+
os.makedirs(cache_dir, exist_ok=True)
|
21 |
+
os.environ["TRANSFORMERS_CACHE"] = cache_dir
|
22 |
|
23 |
+
# Initialize FastAPI app
|
24 |
app = FastAPI()
|
25 |
|
26 |
+
# Configure CORS (if required)
|
27 |
+
app.add_middleware(
|
28 |
+
CORSMiddleware,
|
29 |
+
allow_origins=["*"],
|
30 |
+
allow_credentials=True,
|
31 |
+
allow_methods=["*"],
|
32 |
+
allow_headers=["*"],
|
33 |
+
)
|
34 |
+
|
35 |
+
# Get OpenAI API key from environment
|
36 |
+
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
37 |
+
if not openai_api_key:
|
38 |
+
raise ValueError("Please set the OPENAI_API_KEY environment variable.")
|
39 |
+
|
40 |
+
# Initialize LLM
|
41 |
llm = ChatOpenAI(
|
42 |
api_key=openai_api_key,
|
43 |
model_name="gpt-4-turbo-preview",
|
|
|
48 |
def read_root():
|
49 |
return {"Hello": "World"}
|
50 |
|
51 |
+
# Define Pydantic model for query input
|
52 |
class Query(BaseModel):
|
53 |
query_text: str
|
54 |
|
55 |
+
# Utility function to clean responses
|
56 |
def clean_response(response):
|
57 |
cleaned = response.strip()
|
58 |
cleaned = re.sub(r'^["\']+|["\']+$', '', cleaned)
|
|
|
60 |
cleaned = cleaned.replace('\\n', '')
|
61 |
return cleaned
|
62 |
|
63 |
+
# Define the prompt for the chatbot
|
64 |
prompt = ChatPromptTemplate.from_template(
|
65 |
+
"""
|
66 |
+
You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET),
|
67 |
+
a renowned technical college. Your task is to answer all queries related to TIET.
|
68 |
+
If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.'
|
69 |
+
For more information, please contact our toll-free number: 18002024100 or email us at [email protected].
|
70 |
+
<context>
|
71 |
+
{context}
|
72 |
+
</context>
|
73 |
+
Question: {query}
|
74 |
+
Answer:
|
75 |
+
"""
|
76 |
)
|
77 |
|
78 |
+
# Load embeddings
|
79 |
+
try:
|
80 |
+
embeddings = HuggingFaceEmbeddings(
|
81 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2", # Ensure this model is valid
|
82 |
+
cache_folder=cache_dir,
|
83 |
+
)
|
84 |
+
except Exception as e:
|
85 |
+
raise RuntimeError(f"Failed to initialize embeddings: {e}")
|
86 |
+
|
87 |
+
# Example endpoint for handling queries
|
88 |
@app.post("/chat")
|
89 |
+
async def chat(query: Query):
|
90 |
+
context = "Thapar Institute of Engineering and Technology information."
|
91 |
try:
|
92 |
+
# Use the LLM to generate a response
|
93 |
+
response = llm.generate(
|
94 |
+
prompt.format(context=context, query=query.query_text)
|
|
|
95 |
)
|
96 |
+
cleaned_response = clean_response(response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
return {"response": cleaned_response}
|
98 |
+
except Exception as e:
|
99 |
+
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|