Moha782 commited on
Commit
986054a
·
verified ·
1 Parent(s): daeb152

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -0
app.py CHANGED
@@ -5,6 +5,7 @@ from typing import List, Dict, Tuple
5
  import re
6
  import os
7
  import torch
 
8
 
9
  # Load the RAG model and tokenizer
10
  rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
@@ -22,6 +23,12 @@ doc_chunks = re.split(split_pattern, pdf_text)
22
  # Preprocess the corpus
23
  corpus = rag_tokenizer(doc_chunks, return_tensors="pt", padding=True, truncation=True).input_ids
24
 
 
 
 
 
 
 
25
  """
26
  For more information on huggingface_hub Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
27
  """
 
5
  import re
6
  import os
7
  import torch
8
+ from math import ceil
9
 
10
  # Load the RAG model and tokenizer
11
  rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
 
23
  # Preprocess the corpus
24
  corpus = rag_tokenizer(doc_chunks, return_tensors="pt", padding=True, truncation=True).input_ids
25
 
26
+ # Pad the corpus to be a multiple of `n_docs`
27
+ n_docs = rag_model.config.n_docs
28
+ corpus_length = corpus.size(-1)
29
+ pad_length = ceil(corpus_length / n_docs) * n_docs - corpus_length
30
+ corpus = torch.nn.functional.pad(corpus, (0, pad_length), mode='constant', value=rag_tokenizer.pad_token_id)
31
+
32
  """
33
  For more information on huggingface_hub Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
34
  """