rtabrizi commited on
Commit
2e5203c
·
1 Parent(s): d6f6f10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -76,7 +76,7 @@ class Retriever:
76
  self.chunks = text_splitter.split_text(self.text)
77
 
78
  def load_context_embeddings(self):
79
- encoded_input = self.context_tokenizer(self.chunks, return_tensors='pt', padding=True, truncation=True, max_length=300).to(device)
80
 
81
  with torch.no_grad():
82
  model_output = self.context_model(**encoded_input)
@@ -86,7 +86,7 @@ class Retriever:
86
  self.index.add(self.token_embeddings)
87
 
88
  def retrieve_top_k(self, query_prompt, k=10):
89
- encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", max_length=300, truncation=True, padding=True).to(device)
90
 
91
  with torch.no_grad():
92
  model_output = self.question_model(**encoded_query)
@@ -128,7 +128,7 @@ class RAG:
128
 
129
  input_text = "answer: " + " ".join(context) + " " + question
130
 
131
- inputs = self.generator_tokenizer.encode(input_text, return_tensors='pt', max_length=1024, truncation=True).to(device)
132
  outputs = self.generator_model.generate(inputs, max_length=150, min_length=2, length_penalty=2.0, num_beams=4, early_stopping=True)
133
 
134
  answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -138,7 +138,7 @@ class RAG:
138
  context = self.retriever.retrieve_top_k(question, k=15)
139
 
140
 
141
- inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=300, padding="max_length")
142
  with torch.no_grad():
143
  model_inputs = inputs.to(device)
144
  outputs = self.generator_model(**model_inputs)
 
76
  self.chunks = text_splitter.split_text(self.text)
77
 
78
  def load_context_embeddings(self):
79
+ encoded_input = self.context_tokenizer(self.chunks, return_tensors='pt', padding=True, truncation=True, max_length=150).to(device)
80
 
81
  with torch.no_grad():
82
  model_output = self.context_model(**encoded_input)
 
86
  self.index.add(self.token_embeddings)
87
 
88
  def retrieve_top_k(self, query_prompt, k=10):
89
+ encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", max_length=150, truncation=True, padding=True).to(device)
90
 
91
  with torch.no_grad():
92
  model_output = self.question_model(**encoded_query)
 
128
 
129
  input_text = "answer: " + " ".join(context) + " " + question
130
 
131
+ inputs = self.generator_tokenizer.encode(input_text, return_tensors='pt', max_length=150, truncation=True).to(device)
132
  outputs = self.generator_model.generate(inputs, max_length=150, min_length=2, length_penalty=2.0, num_beams=4, early_stopping=True)
133
 
134
  answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
138
  context = self.retriever.retrieve_top_k(question, k=15)
139
 
140
 
141
+ inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=150, padding="max_length")
142
  with torch.no_grad():
143
  model_inputs = inputs.to(device)
144
  outputs = self.generator_model(**model_inputs)