timefullytrue commited on
Commit
0722f22
·
verified ·
1 Parent(s): f50f0bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -12
app.py CHANGED
@@ -12,22 +12,20 @@ import numpy as np
12
  from chromadb.utils import embedding_functions
13
  from huggingface_hub import InferenceClient
14
 
15
-
16
-
17
  dfs = pd.read_csv('Patents.csv')
18
  ids= [str(x) for x in dfs.index.tolist()]
19
  docs = dfs['text'].tolist()
20
  client = chromadb.Client()
21
  collection = client.get_or_create_collection("patents")
22
-
23
  collection.add(documents=docs,ids=ids)
24
 
 
 
 
 
25
  def gen_context(query):
26
- vector = text_embedding(query).tolist()
27
-
28
- results = collection.query(
29
- query_embeddings=vector,n_results=15,include=["documents"])
30
-
31
  res = "\n".join(str(item) for item in results['documents'][0])
32
  return res
33
 
@@ -38,12 +36,8 @@ def chat_completion(user_prompt):
38
 
39
  return client.text_generation(prompt=final_prompt,max_new_tokens = length).strip()
40
 
41
-
42
-
43
  client = InferenceClient(model = "mistralai/Mixtral-8x7B-Instruct-v0.1")
44
 
45
-
46
-
47
  demo = gr.Interface(fn=chat_completion,
48
  inputs=[gr.Textbox(label="Query", lines=2)],
49
  outputs=[gr.Textbox(label="Result", lines=16)],
 
12
  from chromadb.utils import embedding_functions
13
  from huggingface_hub import InferenceClient
14
 
 
 
15
  dfs = pd.read_csv('Patents.csv')
16
  ids= [str(x) for x in dfs.index.tolist()]
17
  docs = dfs['text'].tolist()
18
  client = chromadb.Client()
19
  collection = client.get_or_create_collection("patents")
 
20
  collection.add(documents=docs,ids=ids)
21
 
22
+ def text_embedding(input):
23
+ model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
24
+ return model.encode(input)
25
+
26
  def gen_context(query):
27
+ vector = text_embedding(query).tolist()
28
+ results = collection.query(query_embeddings=vector,n_results=15,include=["documents"])
 
 
 
29
  res = "\n".join(str(item) for item in results['documents'][0])
30
  return res
31
 
 
36
 
37
  return client.text_generation(prompt=final_prompt,max_new_tokens = length).strip()
38
 
 
 
39
  client = InferenceClient(model = "mistralai/Mixtral-8x7B-Instruct-v0.1")
40
 
 
 
41
  demo = gr.Interface(fn=chat_completion,
42
  inputs=[gr.Textbox(label="Query", lines=2)],
43
  outputs=[gr.Textbox(label="Result", lines=16)],