0504ankitsharma commited on
Commit
fac79dc
·
verified ·
1 Parent(s): 7faf51e

Upload main.py

Browse files
Files changed (1) hide show
  1. app/main.py +144 -0
app/main.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from openai import OpenAI
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.chains.combine_documents import create_stuff_documents_chain
7
+ from langchain_core.prompts import ChatPromptTemplate
8
+ from langchain.chains import create_retrieval_chain
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain_community.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from fastapi import FastAPI
13
+ from pydantic import BaseModel
14
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
15
+ import time
16
+
17
+ def clean_response(response):
18
+ # Remove any leading/trailing whitespace, including newlines
19
+ cleaned = response.strip()
20
+
21
+ # Remove any enclosing quotation marks
22
+ cleaned = re.sub(r'^["\']+|["\']+$', '', cleaned)
23
+
24
+ # Replace multiple newlines with a single newline
25
+ cleaned = re.sub(r'\n+', '\n', cleaned)
26
+
27
+ # Remove any remaining '\n' characters
28
+ cleaned = cleaned.replace('\\n', '')
29
+
30
+ return cleaned
31
+
32
+ app = FastAPI()
33
+
34
+ app.add_middleware(
35
+ CORSMiddleware,
36
+ allow_origins=["*"],
37
+ allow_credentials=True,
38
+ allow_methods=["*"],
39
+ allow_headers=["*"],
40
+ )
41
+
42
+ openai_api_key = os.environ.get('OPENAI_API_KEY')
43
+ llm = ChatOpenAI(
44
+ api_key=openai_api_key,
45
+ model_name="gpt-4-turbo-preview", # or "gpt-3.5-turbo" for a more economical option
46
+ temperature=0.7
47
+ )
48
+
49
+ @app.get("/")
50
+ def read_root():
51
+ return {"Hello": "World"}
52
+
53
+ class Query(BaseModel):
54
+ query_text: str
55
+
56
+ prompt = ChatPromptTemplate.from_template(
57
+ """
58
+ You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET), a renowned technical college. Your task is to answer all queries related to TIET. Every response you provide should be relevant to the context of TIET. If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.' If you do not know the answer to a question, do not attempt to fabricate a response; instead, politely decline.
59
+ You may elaborate on your answers slightly to provide more information, but avoid sounding boastful or exaggerating. Stay focused on the context provided.
60
+ If the query is not related to TIET or falls outside the context of education, respond with:
61
+ "Sorry, I cannot help with that. I'm specifically designed to answer questions about the Thapar Institute of Engineering and Technology.
62
+ For more information, please contact at our toll-free number: 18002024100 or E-mail us at [email protected]
63
+ <context>
64
+ {context}
65
+ </context>
66
+ Question: {input}
67
+ """
68
+ )
69
+
70
+ def vector_embedding():
71
+ try:
72
+ file_path = "./data/Data.docx"
73
+ if not os.path.exists(file_path):
74
+ print(f"The file {file_path} does not exist.")
75
+ return {"response": "Error: Data file not found"}
76
+
77
+ loader = DocxLoader(file_path)
78
+ documents = loader.load()
79
+
80
+ print(f"Loaded document: {file_path}")
81
+
82
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
83
+ chunks = text_splitter.split_documents(documents)
84
+
85
+ print(f"Created {len(chunks)} chunks.")
86
+
87
+ model_name = "BAAI/bge-base-en"
88
+ encode_kwargs = {'normalize_embeddings': True}
89
+ model_norm = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
90
+
91
+ db = FAISS.from_documents(chunks, model_norm)
92
+ db.save_local("./vectors_db")
93
+
94
+ print("Vector store created and saved successfully.")
95
+ return {"response": "Vector Store DB Is Ready"}
96
+
97
+ except Exception as e:
98
+ print(f"An error occurred: {str(e)}")
99
+ return {"response": f"Error: {str(e)}"}
100
+
101
+ def get_embeddings():
102
+ model_name = "BAAI/bge-base-en"
103
+ encode_kwargs = {'normalize_embeddings': True}
104
+ model_norm = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
105
+ return model_norm
106
+
107
+ @app.post("/chat") # Changed from /anthropic to /chat
108
+ def read_item(query: Query):
109
+ try:
110
+ embeddings = get_embeddings()
111
+ vectors = FAISS.load_local("./vectors_db", embeddings, allow_dangerous_deserialization=True)
112
+ except Exception as e:
113
+ print(f"Error loading vector store: {str(e)}")
114
+ return {"response": "Vector Store Not Found or Error Loading. Please run /setup first."}
115
+
116
+ prompt1 = query.query_text
117
+ if prompt1:
118
+ start = time.process_time()
119
+ document_chain = create_stuff_documents_chain(llm, prompt)
120
+ retriever = vectors.as_retriever()
121
+ retrieval_chain = create_retrieval_chain(retriever, document_chain)
122
+ response = retrieval_chain.invoke({'input': prompt1})
123
+ print("Response time:", time.process_time() - start)
124
+
125
+ # Apply the cleaning function to the response
126
+ cleaned_response = clean_response(response['answer'])
127
+
128
+ # For debugging, print the cleaned response
129
+ print("Cleaned response:", repr(cleaned_response))
130
+
131
+ return cleaned_response
132
+ else:
133
+ return "No Query Found"
134
+
135
+ @app.get("/setup")
136
+ def setup():
137
+ return vector_embedding()
138
+
139
+ # Uncomment this to check if the API key is set
140
+ # print(f"API key set: {'Yes' if os.environ.get('OPENAI_API_KEY') else 'No'}")
141
+
142
+ if __name__ == "__main__":
143
+ import uvicorn
144
+ uvicorn.run(app, host="0.0.0.0", port=8000)