mgbam's picture
Update app.py
4d8b824 verified
import os
import re
import base64
from io import BytesIO
from functools import lru_cache
import gradio as gr
import pdfplumber # For PDF document parsing
import pytesseract # OCR for extracting text from images
from PIL import Image
from huggingface_hub import InferenceClient
from mistralai import Mistral
# Initialize clients that don't require heavy model loading
client = InferenceClient(api_key=os.getenv('HF_TOKEN'))
client.headers["x-use-cache"] = "0"
api_key = os.getenv("MISTRAL_API_KEY")
Mistralclient = Mistral(api_key=api_key)
### Lazy Loading and Caching for Transformers Pipelines ###
@lru_cache(maxsize=1)
def get_summarizer():
from transformers import pipeline
# Use a smaller model for faster loading
return pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
@lru_cache(maxsize=1)
def get_sentiment_analyzer():
from transformers import pipeline
return pipeline("sentiment-analysis")
@lru_cache(maxsize=1)
def get_ner_tagger():
from transformers import pipeline
return pipeline("ner")
### Helper Functions ###
def encode_image(image_path):
"""Resizes and encodes an image to base64."""
try:
image = Image.open(image_path).convert("RGB")
base_height = 512
h_percent = (base_height / float(image.size[1]))
w_size = int((float(image.size[0]) * float(h_percent)))
image = image.resize((w_size, base_height), Image.LANCZOS)
buffered = BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
except Exception as e:
print(f"Image encoding error: {e}")
return None
def extract_text_from_document(file_path):
"""Extracts text from a PDF or image document using pdfplumber and OCR."""
text = ""
if file_path.lower().endswith(".pdf"):
try:
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
page_text = page.extract_text()
if page_text:
text += page_text + "\n"
if text.strip():
return text.strip()
except Exception as e:
print(f"PDF parsing error: {e}")
# Fallback to OCR for non-PDF or if PDF parsing yields no text
try:
image = Image.open(file_path)
text = pytesseract.image_to_string(image)
except Exception as e:
print(f"OCR error: {e}")
return text.strip()
def perform_semantic_analysis(text, analysis_type):
"""Applies semantic analysis tasks to the provided text using cached pipelines."""
if analysis_type == "Summarization":
summarizer = get_summarizer()
return summarizer(text, max_length=150, min_length=40, do_sample=False)[0]['summary_text']
elif analysis_type == "Sentiment Analysis":
sentiment_analyzer = get_sentiment_analyzer()
return sentiment_analyzer(text)[0]
elif analysis_type == "Named Entity Recognition":
ner_tagger = get_ner_tagger()
return ner_tagger(text)
return text
def process_text_input(message_text, history, model_choice, analysis_type):
"""Processes text-based inputs using selected model and optional semantic analysis."""
if analysis_type and analysis_type != "None":
analysis_result = perform_semantic_analysis(message_text, analysis_type)
message_text += f"\n\n[Analysis Result]: {analysis_result}"
input_prompt = [{"role": "user", "content": message_text}]
if model_choice == "mistralai/Mistral-Nemo-Instruct-2411":
model = "mistral-large-2411"
stream_response = Mistralclient.chat.stream(model=model, messages=input_prompt)
for chunk in stream_response:
if chunk.data.choices[0].delta.content:
yield chunk.data.choices[0].delta.content
else:
stream = client.chat.completions.create(
model=model_choice,
messages=input_prompt,
temperature=0.5,
max_tokens=1024,
top_p=0.7,
stream=True
)
temp = ""
for chunk in stream:
if chunk.choices[0].delta.content:
temp += chunk.choices[0].delta.content
yield temp
def process_image_input(image_file, message_text, image_mod, model_choice, analysis_type):
"""Processes image-based inputs, applies OCR, and optional semantic analysis."""
# Save the uploaded image temporarily
temp_image_path = "temp_upload.jpg"
image_file.save(temp_image_path)
extracted_text = extract_text_from_document(temp_image_path)
if extracted_text:
message_text += f"\n\n[Extracted Text]: {extracted_text}"
if analysis_type and analysis_type != "None":
analysis_result = perform_semantic_analysis(extracted_text, analysis_type)
message_text += f"\n\n[Analysis Result]: {analysis_result}"
base64_image = encode_image(temp_image_path)
if not base64_image:
yield "Failed to process image."
return
messages = [{
"role": "user",
"content": [
{"type": "text", "text": message_text},
{"type": "image_url", "image_url": f"data:image/jpeg;base64,{base64_image}"}
]
}]
if image_mod == "Vision":
stream = client.chat.completions.create(
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
messages=messages,
max_tokens=500,
stream=True
)
temp = ""
for chunk in stream:
if chunk.choices[0].delta.content:
temp += chunk.choices[0].delta.content
yield temp
else:
model = "pixtral-large-2411"
partial_message = ""
for chunk in Mistralclient.chat.stream(model=model, messages=messages):
if chunk.data.choices[0].delta.content:
partial_message += chunk.data.choices[0].delta.content
yield partial_message
def multimodal_response(message, history, analyzer_mode, model_choice, image_mod, analysis_type):
"""Main response function handling both text and image inputs with analysis."""
message_text = message.get("text", "")
message_files = message.get("files", [])
if message_files:
image_file = message_files[0]
yield from process_image_input(image_file, message_text, image_mod, model_choice, analysis_type)
else:
yield from process_text_input(message_text, history, model_choice, analysis_type)
# Set up the Gradio interface with user customization options
MultiModalAnalyzer = gr.ChatInterface(
fn=multimodal_response,
type="messages",
multimodal=True,
additional_inputs=[
gr.Checkbox(label="Enable Analyzer Mode", value=True),
gr.Dropdown(
choices=[
"meta-llama/Llama-3.3-70B-Instruct",
"CohereForAI/c4ai-command-r-plus-08-2024",
"Qwen/Qwen2.5-72B-Instruct",
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
"NousResearch/Hermes-3-Llama-3.1-8B",
"mistralai/Mistral-Nemo-Instruct-2411",
"microsoft/phi-4"
],
value="mistralai/Mistral-Nemo-Instruct-2411",
show_label=False,
container=False
),
gr.Radio(
choices=["pixtral", "Vision"],
value="pixtral",
show_label=False,
container=False
),
gr.Dropdown(
choices=["None", "Summarization", "Sentiment Analysis", "Named Entity Recognition"],
value="None",
label="Select Analysis Type",
container=False
)
],
title="MultiModal Analyzer",
description="Upload documents or images, select a model and analysis type to interact with your content."
)
MultiModalAnalyzer.launch()