0504ankitsharma commited on
Commit
e098a3b
·
verified ·
1 Parent(s): 40efbe7

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +59 -60
app/main.py CHANGED
@@ -1,27 +1,43 @@
1
  import os
2
  import re
3
- from openai import OpenAI
4
- from langchain_openai import ChatOpenAI
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain.chains.combine_documents import create_stuff_documents_chain
7
- from langchain_core.prompts import ChatPromptTemplate
8
- from langchain.chains import create_retrieval_chain
9
- from langchain_community.vectorstores import FAISS
10
- from langchain_community.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from fastapi import FastAPI
13
  from pydantic import BaseModel
14
- from langchain_community.embeddings import HuggingFaceBgeEmbeddings
15
  import nltk
16
  import time
17
 
18
- os.environ["TRANSFORMERS_CACHE"] = "/path/to/writable/cache"
 
 
 
19
 
20
- # Set up FastAPI app
21
  app = FastAPI()
22
 
23
- # Get OpenAI API key
24
- openai_api_key = os.environ.get('OPENAI_API_KEY')
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  llm = ChatOpenAI(
26
  api_key=openai_api_key,
27
  model_name="gpt-4-turbo-preview",
@@ -32,9 +48,11 @@ llm = ChatOpenAI(
32
  def read_root():
33
  return {"Hello": "World"}
34
 
 
35
  class Query(BaseModel):
36
  query_text: str
37
 
 
38
  def clean_response(response):
39
  cleaned = response.strip()
40
  cleaned = re.sub(r'^["\']+|["\']+$', '', cleaned)
@@ -42,59 +60,40 @@ def clean_response(response):
42
  cleaned = cleaned.replace('\\n', '')
43
  return cleaned
44
 
 
45
  prompt = ChatPromptTemplate.from_template(
46
- """
47
- You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET), a renowned technical college. Your task is to answer all queries related to TIET. If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.' For more information, please contact our toll-free number: 18002024100 or email us at [email protected].
48
- <context>
49
- {context}
50
- </context>
51
- Question: {input}
52
- """
 
 
 
 
53
  )
54
 
 
 
 
 
 
 
 
 
 
 
55
  @app.post("/chat")
56
- def read_item(query: Query):
 
57
  try:
58
- # Load vector store
59
- embeddings = HuggingFaceBgeEmbeddings(
60
- model_name="BAAI/bge-base-en",
61
- encode_kwargs={'normalize_embeddings': True}
62
  )
63
- vectors = FAISS.load_local("./vectors_db", embeddings, allow_dangerous_deserialization=True)
64
- except Exception as e:
65
- print(f"Error loading vector store: {str(e)}")
66
- return {"response": "Vector Store Not Found or Error Loading. Please run /setup first."}
67
-
68
- prompt1 = query.query_text
69
- if prompt1:
70
- start = time.process_time()
71
- document_chain = create_stuff_documents_chain(llm, prompt)
72
- retriever = vectors.as_retriever()
73
- retrieval_chain = create_retrieval_chain(retriever, document_chain)
74
- response = retrieval_chain.invoke({'input': prompt1})
75
-
76
- cleaned_response = clean_response(response['answer'])
77
- print("Response time:", time.process_time() - start)
78
  return {"response": cleaned_response}
79
- else:
80
- return {"response": "No Query Found"}
81
-
82
- @app.get("/setup")
83
- def setup():
84
- # Example setup function for vector embedding
85
- documents = [] # Load your documents here
86
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
87
- chunks = text_splitter.split_documents(documents)
88
-
89
- model_name = "BAAI/bge-base-en"
90
- encode_kwargs = {'normalize_embeddings': True}
91
- embeddings = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
92
-
93
- db = FAISS.from_documents(chunks, embeddings)
94
- db.save_local("./vectors_db")
95
- print("Vector store created and saved successfully.")
96
- return {"response": "Vector Store DB Is Ready"}
97
-
98
- if __name__ == "__main__":
99
- import uvicorn
100
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import os
2
  import re
3
+ from langchain.llms import OpenAI
4
+ from langchain.chat_models import ChatOpenAI
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain.chains.combine_documents import create_stuff_documents_chain
7
+ from langchain.prompts.chat import ChatPromptTemplate
8
+ from langchain.chains import RetrievalQA
9
+ from langchain.vectorstores import FAISS
10
+ from langchain.document_loaders import UnstructuredWordDocumentLoader as DocxLoader
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from fastapi import FastAPI
13
  from pydantic import BaseModel
14
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
15
  import nltk
16
  import time
17
 
18
+ # Ensure necessary directories are writable
19
+ cache_dir = "/app/cache" # Update this to a writable directory path
20
+ os.makedirs(cache_dir, exist_ok=True)
21
+ os.environ["TRANSFORMERS_CACHE"] = cache_dir
22
 
23
+ # Initialize FastAPI app
24
  app = FastAPI()
25
 
26
+ # Configure CORS (if required)
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"],
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ # Get OpenAI API key from environment
36
+ openai_api_key = os.environ.get("OPENAI_API_KEY")
37
+ if not openai_api_key:
38
+ raise ValueError("Please set the OPENAI_API_KEY environment variable.")
39
+
40
+ # Initialize LLM
41
  llm = ChatOpenAI(
42
  api_key=openai_api_key,
43
  model_name="gpt-4-turbo-preview",
 
48
  def read_root():
49
  return {"Hello": "World"}
50
 
51
+ # Define Pydantic model for query input
52
  class Query(BaseModel):
53
  query_text: str
54
 
55
+ # Utility function to clean responses
56
  def clean_response(response):
57
  cleaned = response.strip()
58
  cleaned = re.sub(r'^["\']+|["\']+$', '', cleaned)
 
60
  cleaned = cleaned.replace('\\n', '')
61
  return cleaned
62
 
63
+ # Define the prompt for the chatbot
64
  prompt = ChatPromptTemplate.from_template(
65
+ """
66
+ You are a helpful assistant designed specifically for the Thapar Institute of Engineering and Technology (TIET),
67
+ a renowned technical college. Your task is to answer all queries related to TIET.
68
+ If a question falls outside of this context, please decline by stating, 'Sorry, I cannot help with that.'
69
+ For more information, please contact our toll-free number: 18002024100 or email us at [email protected].
70
+ <context>
71
+ {context}
72
+ </context>
73
+ Question: {query}
74
+ Answer:
75
+ """
76
  )
77
 
78
+ # Load embeddings
79
+ try:
80
+ embeddings = HuggingFaceEmbeddings(
81
+ model_name="sentence-transformers/all-MiniLM-L6-v2", # Ensure this model is valid
82
+ cache_folder=cache_dir,
83
+ )
84
+ except Exception as e:
85
+ raise RuntimeError(f"Failed to initialize embeddings: {e}")
86
+
87
+ # Example endpoint for handling queries
88
  @app.post("/chat")
89
+ async def chat(query: Query):
90
+ context = "Thapar Institute of Engineering and Technology information."
91
  try:
92
+ # Use the LLM to generate a response
93
+ response = llm.generate(
94
+ prompt.format(context=context, query=query.query_text)
 
95
  )
96
+ cleaned_response = clean_response(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  return {"response": cleaned_response}
98
+ except Exception as e:
99
+ return {"error": str(e)}