0504ankitsharma commited on
Commit
8c44cf7
·
verified ·
1 Parent(s): c4e6640

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +97 -93
app/main.py CHANGED
@@ -1,10 +1,6 @@
1
  import os
2
  import re
3
- import time
4
- import nltk
5
- from fastapi import FastAPI, HTTPException
6
- from fastapi.middleware.cors import CORSMiddleware
7
- from pydantic import BaseModel
8
  from langchain_openai import ChatOpenAI
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  from langchain.chains.combine_documents import create_stuff_documents_chain
@@ -12,30 +8,44 @@ from langchain_core.prompts import ChatPromptTemplate
12
  from langchain.chains import create_retrieval_chain
13
  from langchain_community.vectorstores import FAISS
14
  from langchain_community.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
 
 
 
15
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
 
 
 
 
 
 
16
 
17
- # Configure NLTK custom download directory
18
- NLTK_DATA_PATH = os.getenv("NLTK_DATA_PATH", os.path.join(os.getcwd(), "nltk_data"))
19
- os.makedirs(NLTK_DATA_PATH, exist_ok=True)
20
- nltk.data.path.append(NLTK_DATA_PATH)
21
 
22
- # Download necessary NLTK resources
23
- nltk.download("punkt", download_dir=NLTK_DATA_PATH)
 
 
 
24
 
25
- # Utility function to clean the response
26
  def clean_response(response):
27
- if not response:
28
- return "Sorry, I couldn't generate a response."
29
  cleaned = response.strip()
 
 
30
  cleaned = re.sub(r'^["\']+|["\']+$', '', cleaned)
 
 
31
  cleaned = re.sub(r'\n+', '\n', cleaned)
 
 
32
  cleaned = cleaned.replace('\\n', '')
 
33
  return cleaned
34
 
35
- # Initialize FastAPI app
36
  app = FastAPI()
37
 
38
- # CORS Middleware setup
39
  app.add_middleware(
40
  CORSMiddleware,
41
  allow_origins=["*"],
@@ -44,109 +54,103 @@ app.add_middleware(
44
  allow_headers=["*"],
45
  )
46
 
47
- # Global Variables
48
- openai_api_key = os.getenv('OPENAI_API_KEY') # Ensure this is set in your environment
49
- VECTOR_DB_PATH = "./vectors_db"
50
- DATA_FILE_PATH = "./data/Data.docx"
51
- MODEL_NAME = "BAAI/bge-base-en"
52
-
53
- # Initialize OpenAI LLM
54
  llm = ChatOpenAI(
55
  api_key=openai_api_key,
56
- model_name="gpt-4-turbo-preview", # Use "gpt-3.5-turbo" for cost efficiency if required
57
- temperature=0.7,
58
- )
59
-
60
- # Prompt template
61
- prompt = ChatPromptTemplate.from_template(
62
- """
63
- You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET), a renowned technical college. Your task is to answer all queries related to TIET. Every response you provide should be relevant to the context of TIET. If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.' If you do not know the answer to a question, do not attempt to fabricate a response; instead, politely decline.
64
- If the query is not related to TIET or falls outside the context of education, respond with:
65
- "Sorry, I cannot help with that. I'm specifically designed to answer questions about the Thapar Institute of Engineering and Technology.
66
- For more information, please contact at our toll-free number: 18002024100 or E-mail us at [email protected]
67
- <context>
68
- {context}
69
- </context>
70
- Question: {input}
71
- """
72
  )
73
 
74
- # Route: Home
75
  @app.get("/")
76
  def read_root():
77
- return {"message": "Welcome to the ThaparGPT API!"}
78
 
79
- # Route: Chat Endpoint
80
  class Query(BaseModel):
81
  query_text: str
82
 
83
- @app.post("/chat")
84
- def chat(query: Query):
85
- try:
86
- # Load the vector store
87
- embeddings = get_embeddings()
88
- vectors = FAISS.load_local(VECTOR_DB_PATH, embeddings, allow_dangerous_deserialization=True)
89
- except Exception as e:
90
- print(f"Error loading vector store: {str(e)}")
91
- raise HTTPException(status_code=500, detail="Vector Store not found or loading failed. Please run /setup first.")
92
-
93
- # Retrieve and process the query
94
- query_text = query.query_text
95
- if query_text:
96
- start_time = time.process_time()
97
- document_chain = create_stuff_documents_chain(llm, prompt)
98
- retriever = vectors.as_retriever()
99
- retrieval_chain = create_retrieval_chain(retriever, document_chain)
100
-
101
- try:
102
- response = retrieval_chain.invoke({'input': query_text})
103
- except Exception as e:
104
- print(f"Error during query processing: {str(e)}")
105
- raise HTTPException(status_code=500, detail="Error processing the query.")
106
-
107
- print("Response time:", time.process_time() - start_time)
108
- cleaned_response = clean_response(response.get('answer', ''))
109
- return {"response": cleaned_response}
110
- else:
111
- raise HTTPException(status_code=400, detail="No query found in the request.")
112
-
113
- # Route: Setup Endpoint
114
- @app.get("/setup")
115
- def setup():
116
- return vector_embedding()
117
 
118
- # Utility: Create Vector Embeddings
119
  def vector_embedding():
120
  try:
121
- if not os.path.exists(DATA_FILE_PATH):
122
- print(f"The file {DATA_FILE_PATH} does not exist.")
123
- raise HTTPException(status_code=404, detail="Data file not found.")
 
124
 
125
- # Load and split document
126
- loader = DocxLoader(DATA_FILE_PATH)
127
  documents = loader.load()
128
- print(f"Loaded document: {DATA_FILE_PATH}")
 
129
 
130
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
131
  chunks = text_splitter.split_documents(documents)
 
132
  print(f"Created {len(chunks)} chunks.")
133
 
134
- # Create vector store
135
- embeddings = get_embeddings()
136
- db = FAISS.from_documents(chunks, embeddings)
137
- db.save_local(VECTOR_DB_PATH)
 
 
 
138
  print("Vector store created and saved successfully.")
139
- return {"response": "Vector Store DB is ready."}
 
140
  except Exception as e:
141
- print(f"Error during setup: {str(e)}")
142
- raise HTTPException(status_code=500, detail=f"Error during setup: {str(e)}")
143
 
144
- # Utility: Load Embedding Model
145
  def get_embeddings():
 
146
  encode_kwargs = {'normalize_embeddings': True}
147
- return HuggingFaceBgeEmbeddings(model_name=MODEL_NAME, encode_kwargs=encode_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- # Main entry point
150
  if __name__ == "__main__":
151
  import uvicorn
152
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import os
2
  import re
3
+ from openai import OpenAI
 
 
 
 
4
  from langchain_openai import ChatOpenAI
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain.chains.combine_documents import create_stuff_documents_chain
 
8
  from langchain.chains import create_retrieval_chain
9
  from langchain_community.vectorstores import FAISS
10
  from langchain_community.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from fastapi import FastAPI
13
+ from pydantic import BaseModel
14
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
15
+ import nltk # Importing NLTK
16
+ import time
17
+
18
+ # Configure NLTK data directory
19
+ nltk_data_path = os.path.join(os.getcwd(), 'nltk_data') # Use a writable directory
20
+ nltk.data.path.append(nltk_data_path)
21
 
22
+ # Ensure the directory exists
23
+ if not os.path.exists(nltk_data_path):
24
+ os.makedirs(nltk_data_path)
 
25
 
26
+ # Download required NLTK resources
27
+ try:
28
+ nltk.download('punkt', download_dir=nltk_data_path)
29
+ except Exception as e:
30
+ print(f"Error downloading NLTK resources: {e}")
31
 
 
32
  def clean_response(response):
33
+ # Remove any leading/trailing whitespace, including newlines
 
34
  cleaned = response.strip()
35
+
36
+ # Remove any enclosing quotation marks
37
  cleaned = re.sub(r'^["\']+|["\']+$', '', cleaned)
38
+
39
+ # Replace multiple newlines with a single newline
40
  cleaned = re.sub(r'\n+', '\n', cleaned)
41
+
42
+ # Remove any remaining '\n' characters
43
  cleaned = cleaned.replace('\\n', '')
44
+
45
  return cleaned
46
 
 
47
  app = FastAPI()
48
 
 
49
  app.add_middleware(
50
  CORSMiddleware,
51
  allow_origins=["*"],
 
54
  allow_headers=["*"],
55
  )
56
 
57
+ openai_api_key = os.environ.get('OPENAI_API_KEY')
 
 
 
 
 
 
58
  llm = ChatOpenAI(
59
  api_key=openai_api_key,
60
+ model_name="gpt-4-turbo-preview", # or "gpt-3.5-turbo" for a more economical option
61
+ temperature=0.7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
63
 
 
64
  @app.get("/")
65
  def read_root():
66
+ return {"Hello": "World"}
67
 
 
68
  class Query(BaseModel):
69
  query_text: str
70
 
71
+ prompt = ChatPromptTemplate.from_template(
72
+ """
73
+ You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET), a renowned technical college. Your task is to answer all queries related to TIET. Every response you provide should be relevant to the context of TIET. If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.' If you do not know the answer to a question, do not attempt to fabricate a response; instead, politely decline.
74
+ You may elaborate on your answers slightly to provide more information, but avoid sounding boastful or exaggerating. Stay focused on the context provided.
75
+ If the query is not related to TIET or falls outside the context of education, respond with:
76
+ "Sorry, I cannot help with that. I'm specifically designed to answer questions about the Thapar Institute of Engineering and Technology.
77
+ For more information, please contact at our toll-free number: 18002024100 or E-mail us at [email protected]
78
+ <context>
79
+ {context}
80
+ </context>
81
+ Question: {input}
82
+ """
83
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
 
85
  def vector_embedding():
86
  try:
87
+ file_path = "./data/Data.docx"
88
+ if not os.path.exists(file_path):
89
+ print(f"The file {file_path} does not exist.")
90
+ return {"response": "Error: Data file not found"}
91
 
92
+ loader = DocxLoader(file_path)
 
93
  documents = loader.load()
94
+
95
+ print(f"Loaded document: {file_path}")
96
 
97
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
98
  chunks = text_splitter.split_documents(documents)
99
+
100
  print(f"Created {len(chunks)} chunks.")
101
 
102
+ model_name = "BAAI/bge-base-en"
103
+ encode_kwargs = {'normalize_embeddings': True}
104
+ model_norm = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
105
+
106
+ db = FAISS.from_documents(chunks, model_norm)
107
+ db.save_local("./vectors_db")
108
+
109
  print("Vector store created and saved successfully.")
110
+ return {"response": "Vector Store DB Is Ready"}
111
+
112
  except Exception as e:
113
+ print(f"An error occurred: {str(e)}")
114
+ return {"response": f"Error: {str(e)}"}
115
 
 
116
  def get_embeddings():
117
+ model_name = "BAAI/bge-base-en"
118
  encode_kwargs = {'normalize_embeddings': True}
119
+ model_norm = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
120
+ return model_norm
121
+
122
+ @app.post("/chat") # Changed from /anthropic to /chat
123
+ def read_item(query: Query):
124
+ try:
125
+ embeddings = get_embeddings()
126
+ vectors = FAISS.load_local("./vectors_db", embeddings, allow_dangerous_deserialization=True)
127
+ except Exception as e:
128
+ print(f"Error loading vector store: {str(e)}")
129
+ return {"response": "Vector Store Not Found or Error Loading. Please run /setup first."}
130
+
131
+ prompt1 = query.query_text
132
+ if prompt1:
133
+ start = time.process_time()
134
+ document_chain = create_stuff_documents_chain(llm, prompt)
135
+ retriever = vectors.as_retriever()
136
+ retrieval_chain = create_retrieval_chain(retriever, document_chain)
137
+ response = retrieval_chain.invoke({'input': prompt1})
138
+ print("Response time:", time.process_time() - start)
139
+
140
+ # Apply the cleaning function to the response
141
+ cleaned_response = clean_response(response['answer'])
142
+
143
+ # For debugging, print the cleaned response
144
+ print("Cleaned response:", repr(cleaned_response))
145
+
146
+ return cleaned_response
147
+ else:
148
+ return "No Query Found"
149
+
150
+ @app.get("/setup")
151
+ def setup():
152
+ return vector_embedding()
153
 
 
154
  if __name__ == "__main__":
155
  import uvicorn
156
  uvicorn.run(app, host="0.0.0.0", port=8000)