rtabrizi commited on
Commit
7e4f428
·
1 Parent(s): a7a8f80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -67,7 +67,7 @@ class Retriever:
67
  def load_chunks(self):
68
  self.text = self.extract_text_from_pdf(self.file_path)
69
  text_splitter = RecursiveCharacterTextSplitter(
70
- chunk_size=150,
71
  chunk_overlap=20,
72
  length_function=self.token_len,
73
  separators=["Section", "\n\n", "\n", ".", " ", ""]
@@ -76,7 +76,7 @@ class Retriever:
76
  self.chunks = text_splitter.split_text(self.text)
77
 
78
  def load_context_embeddings(self):
79
- encoded_input = self.context_tokenizer(self.chunks, return_tensors='pt', padding=True, truncation=True, max_length=150).to(device)
80
 
81
  with torch.no_grad():
82
  model_output = self.context_model(**encoded_input)
@@ -86,7 +86,7 @@ class Retriever:
86
  self.index.add(self.token_embeddings)
87
 
88
  def retrieve_top_k(self, query_prompt, k=10):
89
- encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", max_length=150, truncation=True, padding=True).to(device)
90
 
91
  with torch.no_grad():
92
  model_output = self.question_model(**encoded_query)
@@ -127,8 +127,8 @@ class RAG:
127
 
128
  input_text = "answer: " + " ".join(context) + " " + question
129
 
130
- inputs = self.generator_tokenizer.encode(input_text, return_tensors='pt', max_length=150, truncation=True).to(device)
131
- outputs = self.generator_model.generate(inputs, max_length=150, min_length=2, length_penalty=2.0, num_beams=4, early_stopping=True)
132
 
133
  answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
134
  return answer
 
67
  def load_chunks(self):
68
  self.text = self.extract_text_from_pdf(self.file_path)
69
  text_splitter = RecursiveCharacterTextSplitter(
70
+ chunk_size=300,
71
  chunk_overlap=20,
72
  length_function=self.token_len,
73
  separators=["Section", "\n\n", "\n", ".", " ", ""]
 
76
  self.chunks = text_splitter.split_text(self.text)
77
 
78
  def load_context_embeddings(self):
79
+ encoded_input = self.context_tokenizer(self.chunks, return_tensors='pt', padding=True, truncation=True, max_length=300).to(device)
80
 
81
  with torch.no_grad():
82
  model_output = self.context_model(**encoded_input)
 
86
  self.index.add(self.token_embeddings)
87
 
88
  def retrieve_top_k(self, query_prompt, k=10):
89
+ encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", max_length=300, truncation=True, padding=True).to(device)
90
 
91
  with torch.no_grad():
92
  model_output = self.question_model(**encoded_query)
 
127
 
128
  input_text = "answer: " + " ".join(context) + " " + question
129
 
130
+ inputs = self.generator_tokenizer.encode(input_text, return_tensors='pt', max_length=300, truncation=True).to(device)
131
+ outputs = self.generator_model.generate(inputs, max_length=300, min_length=2, length_penalty=2.0, num_beams=4, early_stopping=True)
132
 
133
  answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
134
  return answer