la04 commited on
Commit
b6d30d1
·
verified ·
1 Parent(s): f965a1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  from langchain.vectorstores import Chroma
3
  from langchain_community.document_loaders import PyPDFLoader
4
  from langchain_community.embeddings import HuggingFaceEmbeddings
5
- from transformers import LayoutLMv3Processor, AutoModelForTokenClassification
6
  from langchain.chains import RetrievalQA
7
  from langchain.prompts import PromptTemplate
8
  from pdf2image import convert_from_path
@@ -11,15 +11,15 @@ import os
11
  class LayoutLMv3OCR:
12
  def __init__(self):
13
  self.processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
14
- self.model = AutoModelForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
15
 
16
  def extract_text(self, pdf_path):
17
  images = convert_from_path(pdf_path)
18
  text_pages = []
19
  for image in images:
20
  inputs = self.processor(images=image, return_tensors="pt")
21
- outputs = self.model(**inputs)
22
- text = self.processor.batch_decode(outputs.logits, skip_special_tokens=True)[0]
23
  text_pages.append(text)
24
  return text_pages
25
 
@@ -41,14 +41,21 @@ def process_pdf_and_query(pdf_path, question):
41
  return response
42
 
43
  def chatbot_response(pdf, question):
 
44
  pdf_path = "uploaded_pdf.pdf"
45
- pdf.save(pdf_path)
 
 
 
 
46
  extracted_text = ocr_tool.extract_text(pdf_path)
47
  answer = process_pdf_and_query(pdf_path, question)
 
 
48
  os.remove(pdf_path)
 
49
  return answer
50
 
51
- # Ändere 'inputs' und 'outputs' zur neuen Gradio API
52
  pdf_input = gr.File(label="PDF-Datei hochladen")
53
  question_input = gr.Textbox(label="Frage eingeben")
54
  response_output = gr.Textbox(label="Antwort")
@@ -62,4 +69,4 @@ interface = gr.Interface(
62
  )
63
 
64
  if __name__ == "__main__":
65
- interface.launch()
 
2
  from langchain.vectorstores import Chroma
3
  from langchain_community.document_loaders import PyPDFLoader
4
  from langchain_community.embeddings import HuggingFaceEmbeddings
5
+ from transformers import LayoutLMv3Processor, AutoModelForSeq2SeqLM
6
  from langchain.chains import RetrievalQA
7
  from langchain.prompts import PromptTemplate
8
  from pdf2image import convert_from_path
 
11
  class LayoutLMv3OCR:
12
  def __init__(self):
13
  self.processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
14
+ self.model = AutoModelForSeq2SeqLM.from_pretrained("microsoft/layoutlmv3-base")
15
 
16
  def extract_text(self, pdf_path):
17
  images = convert_from_path(pdf_path)
18
  text_pages = []
19
  for image in images:
20
  inputs = self.processor(images=image, return_tensors="pt")
21
+ outputs = self.model.generate(**inputs)
22
+ text = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
23
  text_pages.append(text)
24
  return text_pages
25
 
 
41
  return response
42
 
43
  def chatbot_response(pdf, question):
44
+ # Speichern der hochgeladenen Datei auf dem lokalen Dateisystem
45
  pdf_path = "uploaded_pdf.pdf"
46
+
47
+ # Schreibe die PDF-Datei in eine lokale Datei
48
+ with open(pdf_path, "wb") as f:
49
+ f.write(pdf.read())
50
+
51
  extracted_text = ocr_tool.extract_text(pdf_path)
52
  answer = process_pdf_and_query(pdf_path, question)
53
+
54
+ # Lösche die gespeicherte PDF-Datei nach der Verarbeitung
55
  os.remove(pdf_path)
56
+
57
  return answer
58
 
 
59
  pdf_input = gr.File(label="PDF-Datei hochladen")
60
  question_input = gr.Textbox(label="Frage eingeben")
61
  response_output = gr.Textbox(label="Antwort")
 
69
  )
70
 
71
  if __name__ == "__main__":
72
+ interface.launch(share=True)