RAG_test_1 / app.py
la04's picture
Update app.py
e84a97b verified
raw
history blame
2.12 kB
from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
import gradio as gr
from PyPDF2 import PdfReader
# T5-Modell und Tokenizer laden
model = T5ForConditionalGeneration.from_pretrained('t5-small')
tokenizer = T5Tokenizer.from_pretrained('t5-small')
qa_model = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
# Funktion zum Extrahieren von Text aus der PDF
def extract_text_from_pdf(pdf_path):
reader = PdfReader(pdf_path)
text = ""
for page in reader.pages:
text += page.extract_text()
return text
# Funktion für das Bearbeiten der Frage und des Kontextes
def chatbot_response(pdf_path, question):
# PDF-Text extrahieren
context = extract_text_from_pdf(pdf_path)
# Text aufteilen, falls er zu lang ist
max_input_length = 512
context_parts = [context[i:i + max_input_length] for i in range(0, len(context), max_input_length)]
# Die Frage als Prompt an T5 schicken
answers = []
for part in context_parts:
input_text = f"question: {question} context: {part}"
input_ids = tokenizer.encode(input_text, return_tensors="pt", truncation=True, max_length=512)
output = model.generate(input_ids, max_length=150, num_beams=4, early_stopping=True)
answer = tokenizer.decode(output[0], skip_special_tokens=True)
answers.append(answer.strip())
# Gib die letzte Antwort zurück
return answers[-1] if answers else "Keine Antwort gefunden"
# Gradio-Interface erstellen
pdf_input = gr.File(label="PDF-Datei hochladen", type="filepath")
question_input = gr.Textbox(label="Frage eingeben", placeholder="Stelle eine Frage zu dem PDF-Dokument")
response_output = gr.Textbox(label="Antwort")
# Gradio-Interface
interface = gr.Interface(
fn=chatbot_response,
inputs=[pdf_input, question_input],
outputs=response_output,
title="PDF-Fragebeantwortung mit T5 und Gradio",
description="Lade eine PDF-Datei hoch und stelle Fragen zu ihrem Inhalt. Das System verwendet T5, um passende Antworten zu finden."
)
if __name__ == "__main__":
interface.launch()