la04 commited on
Commit
803ac17
·
verified ·
1 Parent(s): 1d58bcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -1,14 +1,14 @@
1
  import gradio as gr
2
- import fitz # PyMuPDF
3
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
 
4
 
5
- # Initialisiere das deutsche Modell und den Tokenizer für RAG
6
- model_name = "deepset/gbert-base" # Beispiel für ein deutsches Modell
7
  tokenizer = RagTokenizer.from_pretrained(model_name)
8
  model = RagSequenceForGeneration.from_pretrained(model_name)
9
  retriever = RagRetriever.from_pretrained(model_name, index_name="exact", use_dummy_dataset=True)
10
 
11
- # Funktion zur Textextraktion aus PDF (ohne OCR)
12
  def extract_text_from_pdf(file):
13
  # Öffne die PDF-Datei mit PyMuPDF
14
  doc = fitz.open(file.name)
@@ -21,16 +21,18 @@ def extract_text_from_pdf(file):
21
 
22
  return text
23
 
24
- # Funktion zur Generierung einer Antwort basierend auf dem hochgeladenen Dokument
25
  def get_rag_answer(input_message, uploaded_file):
26
  # Extrahiere den Text aus dem hochgeladenen PDF-Dokument
27
  document_text = extract_text_from_pdf(uploaded_file)
28
 
29
- # Simuliere den Retrieval-Mechanismus, indem wir den extrahierten Text verwenden
30
  inputs = tokenizer(input_message, return_tensors="pt")
 
 
31
  retrieved_docs = retriever.retrieve(input_ids=inputs["input_ids"])
32
 
33
- # Kombiniere die extrahierten Dokumente und frage das Modell zur Generierung einer Antwort
34
  input_ids = tokenizer(input_message, return_tensors="pt").input_ids
35
  generated_ids = model.generate(input_ids=input_ids,
36
  decoder_start_token_id=model.config.pad_token_id,
 
1
  import gradio as gr
 
2
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
+ import fitz # PyMuPDF
4
 
5
+ # Lade das RAG-Modell, Tokenizer und Retriever
6
+ model_name = "facebook/rag-token-nq" # Funktionierendes RAG-Modell mit Encoder und Generator
7
  tokenizer = RagTokenizer.from_pretrained(model_name)
8
  model = RagSequenceForGeneration.from_pretrained(model_name)
9
  retriever = RagRetriever.from_pretrained(model_name, index_name="exact", use_dummy_dataset=True)
10
 
11
+ # Funktion zur Textextraktion aus PDF
12
  def extract_text_from_pdf(file):
13
  # Öffne die PDF-Datei mit PyMuPDF
14
  doc = fitz.open(file.name)
 
21
 
22
  return text
23
 
24
+ # Funktion zur Beantwortung der Frage durch das Modell
25
  def get_rag_answer(input_message, uploaded_file):
26
  # Extrahiere den Text aus dem hochgeladenen PDF-Dokument
27
  document_text = extract_text_from_pdf(uploaded_file)
28
 
29
+ # Hier verwenden wir den extrahierten Text für das Abrufen von Informationen
30
  inputs = tokenizer(input_message, return_tensors="pt")
31
+
32
+ # Abrufen von relevanten Dokumenten mit dem RagRetriever
33
  retrieved_docs = retriever.retrieve(input_ids=inputs["input_ids"])
34
 
35
+ # Kombiniere die abgerufenen Dokumente und frage das Modell zur Generierung einer Antwort
36
  input_ids = tokenizer(input_message, return_tensors="pt").input_ids
37
  generated_ids = model.generate(input_ids=input_ids,
38
  decoder_start_token_id=model.config.pad_token_id,