Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ from typing import List, Dict, Tuple
|
|
5 |
import re
|
6 |
import os
|
7 |
import torch
|
|
|
8 |
|
9 |
# Load the RAG model and tokenizer
|
10 |
rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
|
@@ -22,6 +23,12 @@ doc_chunks = re.split(split_pattern, pdf_text)
|
|
22 |
# Preprocess the corpus
|
23 |
corpus = rag_tokenizer(doc_chunks, return_tensors="pt", padding=True, truncation=True).input_ids
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
"""
|
26 |
For more information on huggingface_hub Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
27 |
"""
|
|
|
5 |
import re
|
6 |
import os
|
7 |
import torch
|
8 |
+
from math import ceil
|
9 |
|
10 |
# Load the RAG model and tokenizer
|
11 |
rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
|
|
|
23 |
# Preprocess the corpus
|
24 |
corpus = rag_tokenizer(doc_chunks, return_tensors="pt", padding=True, truncation=True).input_ids
|
25 |
|
26 |
+
# Pad the corpus to be a multiple of `n_docs`
|
27 |
+
n_docs = rag_model.config.n_docs
|
28 |
+
corpus_length = corpus.size(-1)
|
29 |
+
pad_length = ceil(corpus_length / n_docs) * n_docs - corpus_length
|
30 |
+
corpus = torch.nn.functional.pad(corpus, (0, pad_length), mode='constant', value=rag_tokenizer.pad_token_id)
|
31 |
+
|
32 |
"""
|
33 |
For more information on huggingface_hub Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
34 |
"""
|