John Graham Reynolds commited on
Commit
8bb66b9
·
1 Parent(s): 7f097f8

update app to use Langchain for retreival

Browse files
Files changed (1) hide show
  1. app.py +31 -15
app.py CHANGED
@@ -1,7 +1,7 @@
1
- import streamlit as st
2
  import os
3
- from mlflow import deployments
4
- from databricks.vector_search.client import VectorSearchClient
 
5
 
6
  DATABRICKS_HOST = os.environ.get("DATABRICKS_HOST")
7
  DATABRICKS_API_TOKEN = os.environ.get("DATABRICKS_API_TOKEN")
@@ -39,24 +39,40 @@ st.markdown("\n")
39
  # with open("style.css") as css:
40
  # st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True)
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # TODO *** configure to run only on prompt for verification?
44
- vsc = VectorSearchClient()
45
 
46
- question = "What is the data lake?"
47
- # question_2 = "What does EDW stand for?"
48
- # question_3 = "What does AIDET stand for?"
49
 
50
- deploy_client = deployments.get_deploy_client("databricks")
51
- response = deploy_client.predict(endpoint="databricks-bge-large-en", inputs={"input": [question]})
52
- embeddings = [e['embedding'] for e in response.data]
53
 
54
- results = vsc.get_index(VS_ENDPOINT_NAME, VS_INDEX_NAME).similarity_search(
55
- query_vector=embeddings[0],
56
- columns=["name", "description"],
57
- num_results=5)
58
 
59
- st.write(results)
60
 
61
 
62
  # print(results)
 
 
1
  import os
2
+ import streamlit as st
3
+ from langchain_huggingface import HuggingFaceEmbeddings
4
+ from langchain_databricks.vectorstores import DatabricksVectorSearch
5
 
6
  DATABRICKS_HOST = os.environ.get("DATABRICKS_HOST")
7
  DATABRICKS_API_TOKEN = os.environ.get("DATABRICKS_API_TOKEN")
 
39
  # with open("style.css") as css:
40
  # st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True)
41
 
42
+ # Same embedding model we used to create embeddings of terms
43
+ # make sure we cache this so that it doesnt redownload each time, hindering Space start time if sleeping
44
+ embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en", cache_folder="./langchain_cache/")
45
+
46
+ vector_store = DatabricksVectorSearch(
47
+ endpoint=VS_ENDPOINT_NAME,
48
+ index_name=VS_INDEX_NAME,
49
+ embedding=embeddings,
50
+ text_column="name",
51
+ columns=["name", "description"],
52
+ )
53
+
54
+ results = vector_store.similarity_search(query="Tell me about what a data lake is.", k=5)
55
+ st.write(results)
56
+
57
+
58
 
59
  # TODO *** configure to run only on prompt for verification?
60
+ # vsc = VectorSearchClient()
61
 
62
+ # question = "What is the data lake?"
63
+ # # question_2 = "What does EDW stand for?"
64
+ # # question_3 = "What does AIDET stand for?"
65
 
66
+ # deploy_client = deployments.get_deploy_client("databricks")
67
+ # response = deploy_client.predict(endpoint="databricks-bge-large-en", inputs={"input": [question]})
68
+ # embeddings = [e['embedding'] for e in response.data]
69
 
70
+ # results = vsc.get_index(VS_ENDPOINT_NAME, VS_INDEX_NAME).similarity_search(
71
+ # query_vector=embeddings[0],
72
+ # columns=["name", "description"],
73
+ # num_results=5)
74
 
75
+ # st.write(results)
76
 
77
 
78
  # print(results)