Spaces:
Sleeping
Sleeping
import gradio as gr | |
from langchain.chains import RetrievalQA | |
from langchain.vectorstores import Chroma | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import Document | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
# OCR-Ersatz: LayoutLMv3 für Textextraktion aus PDFs | |
from transformers import LayoutLMv3Processor | |
from pdf2image import convert_from_path | |
from PIL import Image | |
import torch | |
class LayoutLMv3OCR: | |
def __init__(self): | |
self.processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base") | |
self.model = AutoModelForSeq2SeqLM.from_pretrained("microsoft/layoutlmv3-base-finetuned", num_labels=2) | |
def extract_text(self, pdf_path): | |
pages = convert_from_path(pdf_path) | |
extracted_texts = [] | |
for page in pages: | |
encoding = self.processor(images=page, return_tensors="pt") | |
outputs = self.model(**encoding) | |
logits = outputs.logits | |
predictions = torch.argmax(logits, dim=-1).squeeze() | |
tokens = self.processor.tokenizer.convert_ids_to_tokens(encoding.input_ids[0]) | |
page_text = " ".join([token for token, pred in zip(tokens, predictions) if pred == 1]) | |
extracted_texts.append(page_text) | |
return extracted_texts | |
# Initialisiere OCR | |
ocr_tool = LayoutLMv3OCR() | |
# Embeddings und LLM konfigurieren | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") | |
model_name = "google/flan-t5-base" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
def flan_generate(input_text): | |
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True) | |
outputs = model.generate(**inputs, max_length=512) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def process_pdf_and_create_rag(pdf_path): | |
extracted_text = ocr_tool.extract_text(pdf_path) | |
documents = [] | |
for page_num, text in enumerate(extracted_text, start=1): | |
doc = Document(page_content=text, metadata={"page": page_num}) | |
documents.append(doc) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
split_docs = text_splitter.split_documents(documents) | |
vector_store = Chroma.from_documents(split_docs, embedding=embeddings) | |
retriever = vector_store.as_retriever() | |
qa_chain = RetrievalQA(retriever=retriever, combine_documents_chain=flan_generate) | |
return qa_chain | |
def chatbot_response(pdf_file, question): | |
qa_chain = process_pdf_and_create_rag(pdf_file.name) | |
response = qa_chain.run(question) | |
relevant_pages = set() | |
for doc in qa_chain.retriever.get_relevant_documents(question): | |
relevant_pages.add(doc.metadata.get("page", "Unbekannt")) | |
page_info = f" (Referenz: Seite(n) {', '.join(map(str, relevant_pages))})" | |
return response + page_info | |
def gradio_interface(): | |
pdf_input = gr.File(label="PDF hochladen") | |
question_input = gr.Textbox(label="Ihre Frage", placeholder="Geben Sie Ihre Frage hier ein...") | |
response_output = gr.Textbox(label="Antwort") | |
interface = gr.Interface( | |
fn=chatbot_response, | |
inputs=[pdf_input, question_input], | |
outputs=response_output, | |
title="RAG Chatbot (Deutsch)" | |
) | |
return interface | |
if __name__ == "__main__": | |
interface = gradio_interface() | |
interface.launch() | |