Baweja commited on
Commit
6c30bcb
·
verified ·
1 Parent(s): d584b19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -64,6 +64,7 @@
64
 
65
 
66
 
 
67
  import gradio as gr
68
  import torch
69
  from transformers import RagRetriever, RagSequenceForGeneration, AutoTokenizer
@@ -71,7 +72,7 @@ from transformers import RagRetriever, RagSequenceForGeneration, AutoTokenizer
71
  """
72
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
73
  """
74
- device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
75
 
76
  def strip_title(title):
77
  if title.startswith('"'):
@@ -113,6 +114,7 @@ def retrieved_info(rag_model, query):
113
  retrieved_context.append(f"{title}: {text}")
114
 
115
  answer = retrieved_context
 
116
 
117
 
118
 
@@ -121,33 +123,33 @@ def respond(
121
  message,
122
  history: list[tuple[str, str]],
123
  system_message,
124
- max_tokens = None,
125
- temperature = None,
126
- top_p = None,
127
  ):
128
  # Load model
129
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130
-
131
  dataset_path = "./sample/my_knowledge_dataset"
132
  index_path = "./sample/my_knowledge_dataset_hnsw_index.faiss"
133
-
134
  tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
135
  retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom",
136
  passages_path = dataset_path,
137
  index_path = index_path,
138
- n_docs = 1)
139
  rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
140
  rag_model.retriever.init_retrieval()
141
  rag_model.to(device)
142
 
143
  if message: # If there's a user query
144
  response = retrieved_info(rag_model, message) # Get the answer from your local FAISS and Q&A model
145
- return response
146
 
147
  # In case no message, return an empty string
148
  return ""
149
-
150
-
151
 
152
  """
153
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
@@ -170,4 +172,4 @@ demo = gr.ChatInterface(
170
 
171
 
172
  if __name__ == "__main__":
173
- demo.launch()
 
64
 
65
 
66
 
67
+
68
  import gradio as gr
69
  import torch
70
  from transformers import RagRetriever, RagSequenceForGeneration, AutoTokenizer
 
72
  """
73
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
74
  """
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
 
77
  def strip_title(title):
78
  if title.startswith('"'):
 
114
  retrieved_context.append(f"{title}: {text}")
115
 
116
  answer = retrieved_context
117
+ return answer
118
 
119
 
120
 
 
123
  message,
124
  history: list[tuple[str, str]],
125
  system_message,
126
+ max_tokens ,
127
+ temperature,
128
+ top_p,
129
  ):
130
  # Load model
131
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
132
+
133
  dataset_path = "./sample/my_knowledge_dataset"
134
  index_path = "./sample/my_knowledge_dataset_hnsw_index.faiss"
135
+
136
  tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
137
  retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom",
138
  passages_path = dataset_path,
139
  index_path = index_path,
140
+ n_docs = 5)
141
  rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
142
  rag_model.retriever.init_retrieval()
143
  rag_model.to(device)
144
 
145
  if message: # If there's a user query
146
  response = retrieved_info(rag_model, message) # Get the answer from your local FAISS and Q&A model
147
+ return response[0]
148
 
149
  # In case no message, return an empty string
150
  return ""
151
+
152
+
153
 
154
  """
155
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
172
 
173
 
174
  if __name__ == "__main__":
175
+ demo.launch( )