ikraamkb commited on
Commit
8e24199
·
verified ·
1 Parent(s): 39b3aed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -42
app.py CHANGED
@@ -11,71 +11,104 @@ from transformers import pipeline
11
  import gradio as gr
12
  from fastapi.responses import RedirectResponse
13
  import numpy as np
 
14
 
15
  # Initialize FastAPI
16
  app = FastAPI()
17
 
18
- # Load AI Model for Question Answering (Summarization-based approach)
19
- qa_pipeline = pipeline("summarization", model="facebook/bart-large-cnn", tokenizer="facebook/bart-large-cnn")
20
 
21
- # Load Pretrained Object Detection Model (Torchvision)
 
 
 
22
  model = fasterrcnn_resnet50_fpn(pretrained=True)
23
  model.eval()
24
 
 
 
 
25
  # Image Transformations
26
  transform = transforms.Compose([
27
  transforms.ToTensor()
28
  ])
29
 
 
 
 
 
 
 
 
 
 
30
  # Function to truncate text to 450 tokens
31
  def truncate_text(text, max_tokens=450):
32
  words = text.split()
33
  return " ".join(words[:max_tokens])
34
 
35
- # Functions to extract text from different file formats
36
  def extract_text_from_pdf(pdf_file):
37
  text = ""
38
- with pdfplumber.open(pdf_file) as pdf:
39
- for page in pdf.pages:
40
- text += page.extract_text() + "\n"
41
- return text.strip()
 
 
 
 
 
42
 
43
  def extract_text_from_docx(docx_file):
44
- doc = docx.Document(docx_file)
45
- return "\n".join([para.text for para in doc.paragraphs])
 
 
 
46
 
47
  def extract_text_from_pptx(pptx_file):
48
- ppt = Presentation(pptx_file)
49
- text = []
50
- for slide in ppt.slides:
51
- for shape in slide.shapes:
52
- if hasattr(shape, "text"):
53
- text.append(shape.text)
54
- return "\n".join(text)
 
 
 
55
 
56
  def extract_text_from_excel(excel_file):
57
- wb = openpyxl.load_workbook(excel_file)
58
- text = []
59
- for sheet in wb.worksheets:
60
- for row in sheet.iter_rows(values_only=True):
61
- text.append(" ".join(map(str, row)))
62
- return "\n".join(text)
63
-
64
- # Function to extract text from image
 
 
65
  def extract_text_from_image(image_file):
66
- if isinstance(image_file, np.ndarray): # Check if input is a NumPy array
67
- image = Image.fromarray(image_file) # Convert NumPy array to PIL image
68
- else:
69
- image = Image.open(image_file).convert("RGB") # Handle file input
 
 
70
 
71
- reader = easyocr.Reader(["en"])
72
- result = reader.readtext(np.array(image)) # Convert PIL image back to NumPy array
73
- return " ".join([res[1] for res in result])
74
 
75
- # Function to answer questions based on document content using BART summarization
76
  def answer_question_from_document(file, question):
77
- file_ext = file.name.split(".")[-1].lower()
 
 
78
 
 
79
  if file_ext == "pdf":
80
  text = extract_text_from_pdf(file)
81
  elif file_ext == "docx":
@@ -86,27 +119,26 @@ def answer_question_from_document(file, question):
86
  text = extract_text_from_excel(file)
87
  else:
88
  return "Unsupported file format!"
89
-
90
  if not text:
91
  return "No text extracted from the document."
92
 
 
93
  truncated_text = truncate_text(text)
94
- input_text = f"Context: {truncated_text} Question: {question}"
95
- response = qa_pipeline(input_text, max_length=100, min_length=30, do_sample=False)
96
 
97
- return response[0]["summary_text"]
98
 
99
- # Function to answer questions based on image content
100
  def answer_question_from_image(image, question):
101
  image_text = extract_text_from_image(image)
102
  if not image_text:
103
  return "No meaningful content detected in the image."
104
 
 
105
  truncated_text = truncate_text(image_text)
106
- input_text = f"Context: {truncated_text} Question: {question}"
107
- response = qa_pipeline(input_text, max_length=100, min_length=30, do_sample=False)
108
 
109
- return response[0]["summary_text"]
110
 
111
  # Gradio UI for Document & Image QA
112
  doc_interface = gr.Interface(
 
11
  import gradio as gr
12
  from fastapi.responses import RedirectResponse
13
  import numpy as np
14
+ import easyocr
15
 
16
  # Initialize FastAPI
17
  app = FastAPI()
18
 
19
+ # Load AI Model for Question Answering (Proper Extractive QA Model)
20
+ qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
21
 
22
+ # Initialize Translator for Multilingual Support
23
+ translator = pipeline("translation", model="facebook/m2m100_418M")
24
+
25
+ # Load Pretrained Object Detection Model (if needed)
26
  model = fasterrcnn_resnet50_fpn(pretrained=True)
27
  model.eval()
28
 
29
+ # Initialize OCR Model (Lazy Load)
30
+ reader = easyocr.Reader(["en"], gpu=True)
31
+
32
  # Image Transformations
33
  transform = transforms.Compose([
34
  transforms.ToTensor()
35
  ])
36
 
37
+ # Allowed File Extensions
38
+ ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
39
+
40
+ def validate_file_type(file):
41
+ ext = file.name.split(".")[-1].lower()
42
+ if ext not in ALLOWED_EXTENSIONS:
43
+ return f"Unsupported file format: {ext}"
44
+ return None
45
+
46
  # Function to truncate text to 450 tokens
47
  def truncate_text(text, max_tokens=450):
48
  words = text.split()
49
  return " ".join(words[:max_tokens])
50
 
51
+ # Text Extraction Functions
52
  def extract_text_from_pdf(pdf_file):
53
  text = ""
54
+ try:
55
+ with pdfplumber.open(pdf_file) as pdf:
56
+ for page in pdf.pages:
57
+ page_text = page.extract_text()
58
+ if page_text:
59
+ text += page_text + "\n"
60
+ except Exception as e:
61
+ return f"Error reading PDF: {str(e)}"
62
+ return text.strip() if text else "No text found."
63
 
64
  def extract_text_from_docx(docx_file):
65
+ try:
66
+ doc = docx.Document(docx_file)
67
+ return "\n".join([para.text for para in doc.paragraphs])
68
+ except Exception as e:
69
+ return f"Error reading DOCX: {str(e)}"
70
 
71
  def extract_text_from_pptx(pptx_file):
72
+ try:
73
+ ppt = Presentation(pptx_file)
74
+ text = []
75
+ for slide in ppt.slides:
76
+ for shape in slide.shapes:
77
+ if hasattr(shape, "text"):
78
+ text.append(shape.text)
79
+ return "\n".join(text) if text else "No text found."
80
+ except Exception as e:
81
+ return f"Error reading PPTX: {str(e)}"
82
 
83
  def extract_text_from_excel(excel_file):
84
+ try:
85
+ wb = openpyxl.load_workbook(excel_file, read_only=True)
86
+ text = []
87
+ for sheet in wb.worksheets:
88
+ for row in sheet.iter_rows(values_only=True):
89
+ text.append(" ".join(map(str, row)))
90
+ return "\n".join(text) if text else "No text found."
91
+ except Exception as e:
92
+ return f"Error reading Excel: {str(e)}"
93
+
94
  def extract_text_from_image(image_file):
95
+ image = Image.open(image_file).convert("RGB")
96
+ if np.array(image).std() < 10: # Low contrast = likely empty
97
+ return "No meaningful content detected in the image."
98
+
99
+ result = reader.readtext(np.array(image))
100
+ return " ".join([res[1] for res in result]) if result else "No text found."
101
 
102
+ def translate_text(text, target_lang="en"):
103
+ return translator(text, src_lang="auto", tgt_lang=target_lang)[0]["translation_text"]
 
104
 
105
+ # Function to answer questions based on document content
106
  def answer_question_from_document(file, question):
107
+ validation_error = validate_file_type(file)
108
+ if validation_error:
109
+ return validation_error
110
 
111
+ file_ext = file.name.split(".")[-1].lower()
112
  if file_ext == "pdf":
113
  text = extract_text_from_pdf(file)
114
  elif file_ext == "docx":
 
119
  text = extract_text_from_excel(file)
120
  else:
121
  return "Unsupported file format!"
122
+
123
  if not text:
124
  return "No text extracted from the document."
125
 
126
+ text = translate_text(text) # Translate non-English text to English
127
  truncated_text = truncate_text(text)
128
+ response = qa_pipeline({"question": question, "context": truncated_text})
 
129
 
130
+ return response["answer"]
131
 
 
132
  def answer_question_from_image(image, question):
133
  image_text = extract_text_from_image(image)
134
  if not image_text:
135
  return "No meaningful content detected in the image."
136
 
137
+ image_text = translate_text(image_text) # Translate non-English text to English
138
  truncated_text = truncate_text(image_text)
139
+ response = qa_pipeline({"question": question, "context": truncated_text})
 
140
 
141
+ return response["answer"]
142
 
143
  # Gradio UI for Document & Image QA
144
  doc_interface = gr.Interface(