rtabrizi commited on
Commit
d6f6f10
·
1 Parent(s): fa7567e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -86,7 +86,7 @@ class Retriever:
86
  self.index.add(self.token_embeddings)
87
 
88
  def retrieve_top_k(self, query_prompt, k=10):
89
- encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", truncation=True, padding=True).to(device)
90
 
91
  with torch.no_grad():
92
  model_output = self.question_model(**encoded_query)
@@ -138,7 +138,7 @@ class RAG:
138
  context = self.retriever.retrieve_top_k(question, k=15)
139
 
140
 
141
- inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=300 , padding="max_length")
142
  with torch.no_grad():
143
  model_inputs = inputs.to(device)
144
  outputs = self.generator_model(**model_inputs)
 
86
  self.index.add(self.token_embeddings)
87
 
88
  def retrieve_top_k(self, query_prompt, k=10):
89
+ encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", max_length=300, truncation=True, padding=True).to(device)
90
 
91
  with torch.no_grad():
92
  model_output = self.question_model(**encoded_query)
 
138
  context = self.retriever.retrieve_top_k(question, k=15)
139
 
140
 
141
+ inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=300, padding="max_length")
142
  with torch.no_grad():
143
  model_inputs = inputs.to(device)
144
  outputs = self.generator_model(**model_inputs)