Update app.py
Browse files
app.py
CHANGED
@@ -64,6 +64,7 @@
|
|
64 |
|
65 |
|
66 |
|
|
|
67 |
import gradio as gr
|
68 |
import torch
|
69 |
from transformers import RagRetriever, RagSequenceForGeneration, AutoTokenizer
|
@@ -71,7 +72,7 @@ from transformers import RagRetriever, RagSequenceForGeneration, AutoTokenizer
|
|
71 |
"""
|
72 |
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
73 |
"""
|
74 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "
|
75 |
|
76 |
def strip_title(title):
|
77 |
if title.startswith('"'):
|
@@ -113,6 +114,7 @@ def retrieved_info(rag_model, query):
|
|
113 |
retrieved_context.append(f"{title}: {text}")
|
114 |
|
115 |
answer = retrieved_context
|
|
|
116 |
|
117 |
|
118 |
|
@@ -121,33 +123,33 @@ def respond(
|
|
121 |
message,
|
122 |
history: list[tuple[str, str]],
|
123 |
system_message,
|
124 |
-
max_tokens
|
125 |
-
temperature
|
126 |
-
top_p
|
127 |
):
|
128 |
# Load model
|
129 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
130 |
-
|
131 |
dataset_path = "./sample/my_knowledge_dataset"
|
132 |
index_path = "./sample/my_knowledge_dataset_hnsw_index.faiss"
|
133 |
-
|
134 |
tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
|
135 |
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom",
|
136 |
passages_path = dataset_path,
|
137 |
index_path = index_path,
|
138 |
-
n_docs =
|
139 |
rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
|
140 |
rag_model.retriever.init_retrieval()
|
141 |
rag_model.to(device)
|
142 |
|
143 |
if message: # If there's a user query
|
144 |
response = retrieved_info(rag_model, message) # Get the answer from your local FAISS and Q&A model
|
145 |
-
return response
|
146 |
|
147 |
# In case no message, return an empty string
|
148 |
return ""
|
149 |
-
|
150 |
-
|
151 |
|
152 |
"""
|
153 |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
@@ -170,4 +172,4 @@ demo = gr.ChatInterface(
|
|
170 |
|
171 |
|
172 |
if __name__ == "__main__":
|
173 |
-
demo.launch()
|
|
|
64 |
|
65 |
|
66 |
|
67 |
+
|
68 |
import gradio as gr
|
69 |
import torch
|
70 |
from transformers import RagRetriever, RagSequenceForGeneration, AutoTokenizer
|
|
|
72 |
"""
|
73 |
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
74 |
"""
|
75 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
76 |
|
77 |
def strip_title(title):
|
78 |
if title.startswith('"'):
|
|
|
114 |
retrieved_context.append(f"{title}: {text}")
|
115 |
|
116 |
answer = retrieved_context
|
117 |
+
return answer
|
118 |
|
119 |
|
120 |
|
|
|
123 |
message,
|
124 |
history: list[tuple[str, str]],
|
125 |
system_message,
|
126 |
+
max_tokens ,
|
127 |
+
temperature,
|
128 |
+
top_p,
|
129 |
):
|
130 |
# Load model
|
131 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
132 |
+
|
133 |
dataset_path = "./sample/my_knowledge_dataset"
|
134 |
index_path = "./sample/my_knowledge_dataset_hnsw_index.faiss"
|
135 |
+
|
136 |
tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
|
137 |
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom",
|
138 |
passages_path = dataset_path,
|
139 |
index_path = index_path,
|
140 |
+
n_docs = 5)
|
141 |
rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
|
142 |
rag_model.retriever.init_retrieval()
|
143 |
rag_model.to(device)
|
144 |
|
145 |
if message: # If there's a user query
|
146 |
response = retrieved_info(rag_model, message) # Get the answer from your local FAISS and Q&A model
|
147 |
+
return response[0]
|
148 |
|
149 |
# In case no message, return an empty string
|
150 |
return ""
|
151 |
+
|
152 |
+
|
153 |
|
154 |
"""
|
155 |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
|
|
172 |
|
173 |
|
174 |
if __name__ == "__main__":
|
175 |
+
demo.launch( )
|