Cheselle commited on
Commit
6cca36c
·
verified ·
1 Parent(s): 5ebf50b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -5
app.py CHANGED
@@ -13,12 +13,19 @@ from langchain_community.vectorstores import Qdrant
13
  from langchain_core.runnables import RunnablePassthrough, RunnableParallel
14
  import chainlit as cl
15
  from pathlib import Path
16
- from sentence_transformers import SentenceTransformer # Ensure this import is correct
17
 
18
  load_dotenv()
19
 
20
  os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
21
 
 
 
 
 
 
 
 
22
  @cl.on_chat_start
23
  async def on_chat_start():
24
  model = ChatOpenAI(streaming=True)
@@ -53,10 +60,10 @@ async def on_chat_start():
53
 
54
  sentence_combined_documents = sentence_framework + sentence_blueprint
55
 
56
- # Initialize the SentenceTransformer model properly
57
- embedding_model = SentenceTransformer('Cheselle/finetuned-arctic-sentence')
58
 
59
- # Create the Qdrant vector store using the initialized embedding model
60
  sentence_vectorstore = Qdrant.from_documents(
61
  documents=sentence_combined_documents,
62
  embedding=embedding_model, # Ensure this is an instance
@@ -65,7 +72,11 @@ async def on_chat_start():
65
  )
66
 
67
  sentence_retriever = sentence_vectorstore.as_retriever()
68
-
 
 
 
 
69
  # Set the retriever and prompt into session for reuse
70
  cl.user_session.set("runnable", model)
71
  cl.user_session.set("retriever", sentence_retriever)
@@ -83,6 +94,11 @@ async def on_message(message: cl.Message):
83
  print(f"Received message: {message.content}")
84
 
85
  # Retrieve relevant context from documents based on the user's message
 
 
 
 
 
86
  relevant_docs = retriever.get_relevant_documents(message.content)
87
  print(f"Retrieved {len(relevant_docs)} documents.")
88
 
 
13
  from langchain_core.runnables import RunnablePassthrough, RunnableParallel
14
  import chainlit as cl
15
  from pathlib import Path
16
+ from sentence_transformers import SentenceTransformer
17
 
18
  load_dotenv()
19
 
20
  os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
21
 
22
+ class SentenceTransformerEmbedding:
23
+ def __init__(self, model_name):
24
+ self.model = SentenceTransformer(model_name)
25
+
26
+ def embed_documents(self, texts):
27
+ return self.model.encode(texts, convert_to_tensor=True).tolist() # Convert to list for compatibility
28
+
29
  @cl.on_chat_start
30
  async def on_chat_start():
31
  model = ChatOpenAI(streaming=True)
 
60
 
61
  sentence_combined_documents = sentence_framework + sentence_blueprint
62
 
63
+ # Initialize the custom embedding class
64
+ embedding_model = SentenceTransformerEmbedding('Cheselle/finetuned-arctic-sentence')
65
 
66
+ # Create the Qdrant vector store using the custom embedding model
67
  sentence_vectorstore = Qdrant.from_documents(
68
  documents=sentence_combined_documents,
69
  embedding=embedding_model, # Ensure this is an instance
 
72
  )
73
 
74
  sentence_retriever = sentence_vectorstore.as_retriever()
75
+
76
+ # Check if retriever is initialized correctly
77
+ if sentence_retriever is None:
78
+ raise ValueError("Retriever is not initialized correctly.")
79
+
80
  # Set the retriever and prompt into session for reuse
81
  cl.user_session.set("runnable", model)
82
  cl.user_session.set("retriever", sentence_retriever)
 
94
  print(f"Received message: {message.content}")
95
 
96
  # Retrieve relevant context from documents based on the user's message
97
+ if retriever is None:
98
+ print("Retriever is not available.")
99
+ await cl.Message(content="Sorry, the retriever is not initialized.").send()
100
+ return
101
+
102
  relevant_docs = retriever.get_relevant_documents(message.content)
103
  print(f"Retrieved {len(relevant_docs)} documents.")
104