0504ankitsharma commited on
Commit
3144702
·
verified ·
1 Parent(s): e098a3b

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +132 -61
app/main.py CHANGED
@@ -1,29 +1,64 @@
1
  import os
2
  import re
3
- from langchain.llms import OpenAI
4
- from langchain.chat_models import ChatOpenAI
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain.chains.combine_documents import create_stuff_documents_chain
7
- from langchain.prompts.chat import ChatPromptTemplate
8
- from langchain.chains import RetrievalQA
9
- from langchain.vectorstores import FAISS
10
- from langchain.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from fastapi import FastAPI
13
  from pydantic import BaseModel
14
- from langchain.embeddings.huggingface import HuggingFaceEmbeddings
15
- import nltk
16
  import time
17
 
18
- # Ensure necessary directories are writable
19
- cache_dir = "/app/cache" # Update this to a writable directory path
20
- os.makedirs(cache_dir, exist_ok=True)
21
- os.environ["TRANSFORMERS_CACHE"] = cache_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Initialize FastAPI app
24
  app = FastAPI()
25
 
26
- # Configure CORS (if required)
27
  app.add_middleware(
28
  CORSMiddleware,
29
  allow_origins=["*"],
@@ -32,68 +67,104 @@ app.add_middleware(
32
  allow_headers=["*"],
33
  )
34
 
35
- # Get OpenAI API key from environment
36
- openai_api_key = os.environ.get("OPENAI_API_KEY")
37
- if not openai_api_key:
38
- raise ValueError("Please set the OPENAI_API_KEY environment variable.")
39
-
40
- # Initialize LLM
41
  llm = ChatOpenAI(
42
  api_key=openai_api_key,
43
- model_name="gpt-4-turbo-preview",
44
  temperature=0.7,
 
45
  )
46
 
47
  @app.get("/")
48
  def read_root():
49
  return {"Hello": "World"}
50
 
51
- # Define Pydantic model for query input
52
  class Query(BaseModel):
53
  query_text: str
54
 
55
- # Utility function to clean responses
56
- def clean_response(response):
57
- cleaned = response.strip()
58
- cleaned = re.sub(r'^["\']+|["\']+$', '', cleaned)
59
- cleaned = re.sub(r'\n+', '\n', cleaned)
60
- cleaned = cleaned.replace('\\n', '')
61
- return cleaned
62
-
63
- # Define the prompt for the chatbot
64
  prompt = ChatPromptTemplate.from_template(
65
- """
66
- You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET),
67
- a renowned technical college. Your task is to answer all queries related to TIET.
68
- If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.'
69
- For more information, please contact our toll-free number: 18002024100 or email us at admissions@thapar.edu.
70
- <context>
71
- {context}
72
- </context>
73
- Question: {query}
74
- Answer:
75
- """
76
  )
77
 
78
- # Load embeddings
79
- try:
80
- embeddings = HuggingFaceEmbeddings(
81
- model_name="sentence-transformers/all-MiniLM-L6-v2", # Ensure this model is valid
82
- cache_folder=cache_dir,
83
- )
84
- except Exception as e:
85
- raise RuntimeError(f"Failed to initialize embeddings: {e}")
 
86
 
87
- # Example endpoint for handling queries
88
- @app.post("/chat")
89
- async def chat(query: Query):
90
- context = "Thapar Institute of Engineering and Technology information."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  try:
92
- # Use the LLM to generate a response
93
- response = llm.generate(
94
- prompt.format(context=context, query=query.query_text)
95
- )
96
- cleaned_response = clean_response(response)
97
- return {"response": cleaned_response}
98
  except Exception as e:
99
- return {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
+ from openai import OpenAI
4
+ from langchain_openai import ChatOpenAI
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain.chains.combine_documents import create_stuff_documents_chain
7
+ from langchain_core.prompts import ChatPromptTemplate
8
+ from langchain.chains import create_retrieval_chain
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain_community.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from fastapi import FastAPI
13
  from pydantic import BaseModel
14
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
15
+ import nltk # Importing NLTK
16
  import time
17
 
18
+ # Set writable paths for cache and data
19
+ cache_dir = '/tmp'
20
+ nltk_data_path = os.path.join(cache_dir, 'nltk_data')
21
+
22
+ # Configure NLTK and other library paths
23
+ os.environ['TRANSFORMERS_CACHE'] = os.path.join(cache_dir, 'transformers_cache')
24
+ os.environ['HF_HOME'] = os.path.join(cache_dir, 'huggingface')
25
+ os.environ['XDG_CACHE_HOME'] = cache_dir
26
+
27
+ # Add NLTK data path
28
+ nltk.data.path.append(nltk_data_path)
29
+
30
+ # Ensure the directory exists
31
+ try:
32
+ os.makedirs(nltk_data_path, exist_ok=True)
33
+ except OSError as e:
34
+ print(f"Error creating directory {nltk_data_path}: {e}")
35
+ raise
36
+
37
+ # Download required NLTK resources
38
+ try:
39
+ nltk.download('punkt', download_dir=nltk_data_path)
40
+ print("NLTK 'punkt' resource downloaded successfully.")
41
+ except Exception as e:
42
+ print(f"Error downloading NLTK resources: {e}")
43
+ raise
44
+
45
+ def clean_response(response):
46
+ # Remove any leading/trailing whitespace, including newlines
47
+ cleaned = response.strip()
48
+
49
+ # Remove any enclosing quotation marks
50
+ cleaned = re.sub(r'^["\']+|["\']+$', '', cleaned)
51
+
52
+ # Replace multiple newlines with a single newline
53
+ cleaned = re.sub(r'\n+', '\n', cleaned)
54
+
55
+ # Remove any remaining '\n' characters
56
+ cleaned = cleaned.replace('\\n', '')
57
+
58
+ return cleaned
59
 
 
60
  app = FastAPI()
61
 
 
62
  app.add_middleware(
63
  CORSMiddleware,
64
  allow_origins=["*"],
 
67
  allow_headers=["*"],
68
  )
69
 
70
+ openai_api_key = os.environ.get('OPENAI_API_KEY')
 
 
 
 
 
71
  llm = ChatOpenAI(
72
  api_key=openai_api_key,
73
+ model_name="gpt-4-turbo-preview", # or "gpt-3.5-turbo" for a more economical option
74
  temperature=0.7,
75
+ max_tokens=200
76
  )
77
 
78
  @app.get("/")
79
  def read_root():
80
  return {"Hello": "World"}
81
 
 
82
  class Query(BaseModel):
83
  query_text: str
84
 
 
 
 
 
 
 
 
 
 
85
  prompt = ChatPromptTemplate.from_template(
86
+ """
87
+ You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET), a renowned technical college. Your task is to answer all queries related to TIET in concised manner. Every response you provide should be relevant to the context of TIET. If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.' If you do not know the answer to a question, do not attempt to fabricate a response; instead, politely decline.
88
+ but avoid sounding boastful or exaggerating. Stay focused on the context provided.
89
+ If the query is not related to TIET or falls outside the context of education, respond with:
90
+ "Sorry, I cannot help with that. I'm specifically designed to answer questions about the Thapar Institute of Engineering and Technology.
91
+ For more information, please contact at our toll-free number: 18002024100 or E-mail us at [email protected]
92
+ <context>
93
+ {context}
94
+ </context>
95
+ Question: {input}
96
+ """
97
  )
98
 
99
+ def vector_embedding():
100
+ try:
101
+ file_path = "./data/Data.docx"
102
+ if not os.path.exists(file_path):
103
+ print(f"The file {file_path} does not exist.")
104
+ return {"response": "Error: Data file not found"}
105
+
106
+ loader = DocxLoader(file_path)
107
+ documents = loader.load()
108
 
109
+ print(f"Loaded document: {file_path}")
110
+
111
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
112
+ chunks = text_splitter.split_documents(documents)
113
+
114
+ print(f"Created {len(chunks)} chunks.")
115
+
116
+ model_name = "BAAI/bge-base-en"
117
+ encode_kwargs = {'normalize_embeddings': True}
118
+ model_norm = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
119
+
120
+ db = FAISS.from_documents(chunks, model_norm)
121
+ db.save_local("./vectors_db")
122
+
123
+ print("Vector store created and saved successfully.")
124
+ return {"response": "Vector Store DB Is Ready"}
125
+
126
+ except Exception as e:
127
+ print(f"An error occurred: {str(e)}")
128
+ return {"response": f"Error: {str(e)}"}
129
+
130
+ def get_embeddings():
131
+ model_name = "BAAI/bge-base-en"
132
+ encode_kwargs = {'normalize_embeddings': True}
133
+ model_norm = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
134
+ return model_norm
135
+
136
+ @app.post("/chat") # Changed from /anthropic to /chat
137
+ def read_item(query: Query):
138
  try:
139
+ embeddings = get_embeddings()
140
+ vectors = FAISS.load_local("./vectors_db", embeddings, allow_dangerous_deserialization=True)
 
 
 
 
141
  except Exception as e:
142
+ print(f"Error loading vector store: {str(e)}")
143
+ return {"response": "Vector Store Not Found or Error Loading. Please run /setup first."}
144
+
145
+ prompt1 = query.query_text
146
+ if prompt1:
147
+ start = time.process_time()
148
+ document_chain = create_stuff_documents_chain(llm, prompt)
149
+ retriever = vectors.as_retriever()
150
+ retrieval_chain = create_retrieval_chain(retriever, document_chain)
151
+ response = retrieval_chain.invoke({'input': prompt1})
152
+ print("Response time:", time.process_time() - start)
153
+
154
+ # Apply the cleaning function to the response
155
+ cleaned_response = clean_response(response['answer'])
156
+
157
+ # For debugging, print the cleaned response
158
+ print("Cleaned response:", repr(cleaned_response))
159
+
160
+ return cleaned_response
161
+ else:
162
+ return "No Query Found"
163
+
164
+ @app.get("/setup")
165
+ def setup():
166
+ return vector_embedding()
167
+
168
+ if _name_ == "_main_":
169
+ import uvicorn
170
+ uvicorn.run(app, host="0.0.0.0", port=8000)