ikraamkb commited on
Commit
4be6e3a
·
verified ·
1 Parent(s): f7e80da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -111
app.py CHANGED
@@ -1,141 +1,113 @@
 
 
 
 
 
1
  import os
2
- from fastapi import FastAPI, UploadFile, File
3
- import gradio as gr
4
- from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
5
  import pdfplumber
6
  import docx
7
  import openpyxl
8
  import pytesseract
9
- from PIL import Image
10
- from easyocr import Reader
 
 
 
11
 
12
- # Initialize FastAPI app
13
  app = FastAPI()
14
 
15
- # Load the question-answering models
16
- doc_qa_model_name = "microsoft/phi-2"
17
- image_qa_model_name = "Salesforce/blip-vqa-base"
18
 
19
- # Load the models and tokenizers for each use case
20
- doc_qa_model = AutoModelForQuestionAnswering.from_pretrained(doc_qa_model_name)
21
- doc_qa_tokenizer = AutoTokenizer.from_pretrained(doc_qa_model_name)
22
 
23
- image_qa_model = AutoModelForQuestionAnswering.from_pretrained(image_qa_model_name)
24
- image_qa_tokenizer = AutoTokenizer.from_pretrained(image_qa_model_name)
25
 
26
- # Load VQA pipeline for images
27
- image_qa_pipeline = pipeline("image-captioning", model=image_qa_model, tokenizer=image_qa_tokenizer)
28
 
29
- # Load OCR Reader
30
- ocr_reader = Reader(lang_list=['en'])
31
-
32
- # Helper function to process DOCX files
33
- def extract_text_from_docx(file_path):
34
- doc = docx.Document(file_path)
35
- text = "\n".join([para.text for para in doc.paragraphs])
36
- return text
37
-
38
- # Helper function to process Excel files
39
- def extract_text_from_excel(file_path):
40
- wb = openpyxl.load_workbook(file_path)
41
- sheet = wb.active
42
- text = ""
43
- for row in sheet.iter_rows():
44
- text += " ".join([str(cell.value) for cell in row]) + "\n"
45
- return text
46
-
47
- # Helper function to process PDF files
48
- def extract_text_from_pdf(file_path):
49
  with pdfplumber.open(file_path) as pdf:
50
  text = ""
51
  for page in pdf.pages:
52
  text += page.extract_text()
53
  return text
54
 
55
- # Helper function to process images (OCR)
56
- def extract_text_from_image(image_path):
57
- image = Image.open(image_path)
58
- text = pytesseract.image_to_string(image)
 
 
59
  return text
60
 
61
- # AI-powered Question Answering for documents
62
- def document_question_answering(document_text, question):
63
- inputs = doc_qa_tokenizer.encode(question, document_text, return_tensors='pt')
64
- outputs = doc_qa_model(inputs)
65
- answer_start = outputs.start_logits.argmax()
66
- answer_end = outputs.end_logits.argmax()
67
- answer = doc_qa_tokenizer.decode(inputs[0][answer_start:answer_end + 1])
68
- return answer
 
 
69
 
70
- # AI-powered Question Answering for images
71
- def image_question_answering(image, question):
72
- # Generate caption for the image first
73
- caption = image_qa_pipeline(image)[0]['caption']
74
- # Combine caption with the question to form input to the QA model
75
- return caption # For simplicity, you may further process this if needed
76
 
77
- # FastAPI route for document-based question answering
78
- @app.post("/qa/document")
79
- async def qa_document(file: UploadFile = File(...), question: str = None):
80
- file_location = f"temp_files/{file.filename}"
81
- with open(file_location, "wb") as f:
 
 
 
 
 
 
 
82
  f.write(await file.read())
83
 
84
- # Extract text from the document based on its format
85
- if file.filename.endswith('.pdf'):
86
- document_text = extract_text_from_pdf(file_location)
87
- elif file.filename.endswith('.docx'):
88
- document_text = extract_text_from_docx(file_location)
89
- elif file.filename.endswith('.xlsx'):
90
- document_text = extract_text_from_excel(file_location)
91
  else:
92
- return {"error": "Unsupported file format."}
93
 
94
- # Get the answer using the document QA model
95
- answer = document_question_answering(document_text, question)
96
- return {"answer": answer}
97
 
98
- # FastAPI route for image-based question answering
99
- @app.post("/qa/image")
100
- async def qa_image(file: UploadFile = File(...), question: str = None):
101
- file_location = f"temp_files/{file.filename}"
102
- with open(file_location, "wb") as f:
103
- f.write(await file.read())
104
 
105
- # Extract text from the image using OCR
106
- image_text = extract_text_from_image(file_location)
107
-
108
- # Get the answer using the image QA model (BLIP VQA)
109
- image_answer = image_question_answering(file_location, question)
110
- return {"answer": image_answer}
111
-
112
- # Gradio Interface for Document QA
113
- def document_qa_interface(file, question):
114
- file_path = file.name
115
- if file_path.endswith(".pdf"):
116
- document_text = extract_text_from_pdf(file_path)
117
- elif file_path.endswith(".docx"):
118
- document_text = extract_text_from_docx(file_path)
119
- elif file_path.endswith(".xlsx"):
120
- document_text = extract_text_from_excel(file_path)
121
- else:
122
- return "Unsupported document format."
123
-
124
- return document_question_answering(document_text, question)
125
-
126
- # Gradio Interface for Image QA
127
- def image_qa_interface(image, question):
128
- return image_question_answering(image, question)
129
-
130
- # Gradio Web Interface
131
- iface_document = gr.Interface(fn=document_qa_interface, inputs=["file", "text"], outputs="text")
132
- iface_image = gr.Interface(fn=image_qa_interface, inputs=["image", "text"], outputs="text")
133
 
134
- # Run FastAPI app
135
- if __name__ == "__main__":
136
- import uvicorn
137
- uvicorn.run(app, host="0.0.0.0", port=8000)
138
 
139
- # Launch Gradio Interface
140
- iface_document.launch(share=True)
141
- iface_image.launch(share=True)
 
1
+ from fastapi import FastAPI, Form, File, UploadFile
2
+ from fastapi.responses import RedirectResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from pydantic import BaseModel
5
+ from transformers import pipeline
6
  import os
7
+ from PIL import Image
8
+ import io
 
9
  import pdfplumber
10
  import docx
11
  import openpyxl
12
  import pytesseract
13
+ from io import BytesIO
14
+ import fitz # PyMuPDF
15
+ import easyocr
16
+ from fastapi.templating import Jinja2Templates
17
+ from starlette.requests import Request
18
 
19
+ # Initialize the app
20
  app = FastAPI()
21
 
22
+ # Mount the static directory to serve HTML, CSS, JS files
23
+ app.mount("/static", StaticFiles(directory="static"), name="static")
 
24
 
25
+ # Initialize transformers pipelines
26
+ qa_pipeline = pipeline("question-answering", model="microsoft/phi-2", tokenizer="microsoft/phi-2")
27
+ image_qa_pipeline = pipeline("image-question-answering", model="Salesforce/blip-vqa-base", tokenizer="Salesforce/blip-vqa-base")
28
 
29
+ # Initialize EasyOCR for image-based text extraction
30
+ reader = easyocr.Reader(['en'])
31
 
32
+ # Define a template for rendering HTML
33
+ templates = Jinja2Templates(directory="templates")
34
 
35
+ # Function to process PDFs
36
+ def extract_pdf_text(file_path: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  with pdfplumber.open(file_path) as pdf:
38
  text = ""
39
  for page in pdf.pages:
40
  text += page.extract_text()
41
  return text
42
 
43
+ # Function to process DOCX files
44
+ def extract_docx_text(file_path: str):
45
+ doc = docx.Document(file_path)
46
+ text = ""
47
+ for para in doc.paragraphs:
48
+ text += para.text
49
  return text
50
 
51
+ # Function to process PPTX files
52
+ def extract_pptx_text(file_path: str):
53
+ from pptx import Presentation
54
+ prs = Presentation(file_path)
55
+ text = ""
56
+ for slide in prs.slides:
57
+ for shape in slide.shapes:
58
+ if hasattr(shape, "text"):
59
+ text += shape.text
60
+ return text
61
 
62
+ # Function to extract text from images using OCR
63
+ def extract_text_from_image(image: Image):
64
+ text = pytesseract.image_to_string(image)
65
+ return text
 
 
66
 
67
+ # Home route
68
+ @app.get("/")
69
+ def home():
70
+ return RedirectResponse(url="/docs")
71
+
72
+ # Function to answer questions based on document content
73
+ @app.post("/question-answering-doc")
74
+ async def question_answering_doc(question: str = Form(...), file: UploadFile = File(...)):
75
+ # Save the uploaded file temporarily
76
+ file_path = f"temp_files/{file.filename}"
77
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
78
+ with open(file_path, "wb") as f:
79
  f.write(await file.read())
80
 
81
+ # Extract text based on file type
82
+ if file.filename.endswith(".pdf"):
83
+ text = extract_pdf_text(file_path)
84
+ elif file.filename.endswith(".docx"):
85
+ text = extract_docx_text(file_path)
86
+ elif file.filename.endswith(".pptx"):
87
+ text = extract_pptx_text(file_path)
88
  else:
89
+ return {"error": "Unsupported file format"}
90
 
91
+ # Use the model for question answering
92
+ qa_result = qa_pipeline(question=question, context=text)
93
+ return {"answer": qa_result['answer']}
94
 
95
+ # Function to answer questions based on images
96
+ @app.post("/question-answering-image")
97
+ async def question_answering_image(question: str = Form(...), image_file: UploadFile = File(...)):
98
+ # Open the uploaded image
99
+ image = Image.open(BytesIO(await image_file.read()))
 
100
 
101
+ # Use EasyOCR to extract text if the image has textual content
102
+ image_text = extract_text_from_image(image)
103
+
104
+ # Use the BLIP VQA model for question answering on the image
105
+ image_qa_result = image_qa_pipeline(image=image, question=question)
106
+
107
+ return {"answer": image_qa_result['answer'], "image_text": image_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ # Serve the application in Hugging Face space
110
+ @app.get("/docs")
111
+ async def get_docs(request: Request):
112
+ return templates.TemplateResponse("index.html", {"request": request})
113