Spaces:
Sleeping
Sleeping
import os | |
import re | |
import base64 | |
from io import BytesIO | |
from functools import lru_cache | |
import gradio as gr | |
import pdfplumber # For PDF document parsing | |
import pytesseract # OCR for extracting text from images | |
from PIL import Image | |
from huggingface_hub import InferenceClient | |
from mistralai import Mistral | |
# Initialize clients that don't require heavy model loading | |
client = InferenceClient(api_key=os.getenv('HF_TOKEN')) | |
client.headers["x-use-cache"] = "0" | |
api_key = os.getenv("MISTRAL_API_KEY") | |
Mistralclient = Mistral(api_key=api_key) | |
### Lazy Loading and Caching for Transformers Pipelines ### | |
def get_summarizer(): | |
from transformers import pipeline | |
# Use a smaller model for faster loading | |
return pipeline("summarization", model="sshleifer/distilbart-cnn-12-6") | |
def get_sentiment_analyzer(): | |
from transformers import pipeline | |
return pipeline("sentiment-analysis") | |
def get_ner_tagger(): | |
from transformers import pipeline | |
return pipeline("ner") | |
### Helper Functions ### | |
def encode_image(image_path): | |
"""Resizes and encodes an image to base64.""" | |
try: | |
image = Image.open(image_path).convert("RGB") | |
base_height = 512 | |
h_percent = (base_height / float(image.size[1])) | |
w_size = int((float(image.size[0]) * float(h_percent))) | |
image = image.resize((w_size, base_height), Image.LANCZOS) | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
except Exception as e: | |
print(f"Image encoding error: {e}") | |
return None | |
def extract_text_from_document(file_path): | |
"""Extracts text from a PDF or image document using pdfplumber and OCR.""" | |
text = "" | |
if file_path.lower().endswith(".pdf"): | |
try: | |
with pdfplumber.open(file_path) as pdf: | |
for page in pdf.pages: | |
page_text = page.extract_text() | |
if page_text: | |
text += page_text + "\n" | |
if text.strip(): | |
return text.strip() | |
except Exception as e: | |
print(f"PDF parsing error: {e}") | |
# Fallback to OCR for non-PDF or if PDF parsing yields no text | |
try: | |
image = Image.open(file_path) | |
text = pytesseract.image_to_string(image) | |
except Exception as e: | |
print(f"OCR error: {e}") | |
return text.strip() | |
def perform_semantic_analysis(text, analysis_type): | |
"""Applies semantic analysis tasks to the provided text using cached pipelines.""" | |
if analysis_type == "Summarization": | |
summarizer = get_summarizer() | |
return summarizer(text, max_length=150, min_length=40, do_sample=False)[0]['summary_text'] | |
elif analysis_type == "Sentiment Analysis": | |
sentiment_analyzer = get_sentiment_analyzer() | |
return sentiment_analyzer(text)[0] | |
elif analysis_type == "Named Entity Recognition": | |
ner_tagger = get_ner_tagger() | |
return ner_tagger(text) | |
return text | |
def process_text_input(message_text, history, model_choice, analysis_type): | |
"""Processes text-based inputs using selected model and optional semantic analysis.""" | |
if analysis_type and analysis_type != "None": | |
analysis_result = perform_semantic_analysis(message_text, analysis_type) | |
message_text += f"\n\n[Analysis Result]: {analysis_result}" | |
input_prompt = [{"role": "user", "content": message_text}] | |
if model_choice == "mistralai/Mistral-Nemo-Instruct-2411": | |
model = "mistral-large-2411" | |
stream_response = Mistralclient.chat.stream(model=model, messages=input_prompt) | |
for chunk in stream_response: | |
if chunk.data.choices[0].delta.content: | |
yield chunk.data.choices[0].delta.content | |
else: | |
stream = client.chat.completions.create( | |
model=model_choice, | |
messages=input_prompt, | |
temperature=0.5, | |
max_tokens=1024, | |
top_p=0.7, | |
stream=True | |
) | |
temp = "" | |
for chunk in stream: | |
if chunk.choices[0].delta.content: | |
temp += chunk.choices[0].delta.content | |
yield temp | |
def process_image_input(image_file, message_text, image_mod, model_choice, analysis_type): | |
"""Processes image-based inputs, applies OCR, and optional semantic analysis.""" | |
# Save the uploaded image temporarily | |
temp_image_path = "temp_upload.jpg" | |
image_file.save(temp_image_path) | |
extracted_text = extract_text_from_document(temp_image_path) | |
if extracted_text: | |
message_text += f"\n\n[Extracted Text]: {extracted_text}" | |
if analysis_type and analysis_type != "None": | |
analysis_result = perform_semantic_analysis(extracted_text, analysis_type) | |
message_text += f"\n\n[Analysis Result]: {analysis_result}" | |
base64_image = encode_image(temp_image_path) | |
if not base64_image: | |
yield "Failed to process image." | |
return | |
messages = [{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": message_text}, | |
{"type": "image_url", "image_url": f"data:image/jpeg;base64,{base64_image}"} | |
] | |
}] | |
if image_mod == "Vision": | |
stream = client.chat.completions.create( | |
model="meta-llama/Llama-3.2-11B-Vision-Instruct", | |
messages=messages, | |
max_tokens=500, | |
stream=True | |
) | |
temp = "" | |
for chunk in stream: | |
if chunk.choices[0].delta.content: | |
temp += chunk.choices[0].delta.content | |
yield temp | |
else: | |
model = "pixtral-large-2411" | |
partial_message = "" | |
for chunk in Mistralclient.chat.stream(model=model, messages=messages): | |
if chunk.data.choices[0].delta.content: | |
partial_message += chunk.data.choices[0].delta.content | |
yield partial_message | |
def multimodal_response(message, history, analyzer_mode, model_choice, image_mod, analysis_type): | |
"""Main response function handling both text and image inputs with analysis.""" | |
message_text = message.get("text", "") | |
message_files = message.get("files", []) | |
if message_files: | |
image_file = message_files[0] | |
yield from process_image_input(image_file, message_text, image_mod, model_choice, analysis_type) | |
else: | |
yield from process_text_input(message_text, history, model_choice, analysis_type) | |
# Set up the Gradio interface with user customization options | |
MultiModalAnalyzer = gr.ChatInterface( | |
fn=multimodal_response, | |
type="messages", | |
multimodal=True, | |
additional_inputs=[ | |
gr.Checkbox(label="Enable Analyzer Mode", value=True), | |
gr.Dropdown( | |
choices=[ | |
"meta-llama/Llama-3.3-70B-Instruct", | |
"CohereForAI/c4ai-command-r-plus-08-2024", | |
"Qwen/Qwen2.5-72B-Instruct", | |
"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", | |
"NousResearch/Hermes-3-Llama-3.1-8B", | |
"mistralai/Mistral-Nemo-Instruct-2411", | |
"microsoft/phi-4" | |
], | |
value="mistralai/Mistral-Nemo-Instruct-2411", | |
show_label=False, | |
container=False | |
), | |
gr.Radio( | |
choices=["pixtral", "Vision"], | |
value="pixtral", | |
show_label=False, | |
container=False | |
), | |
gr.Dropdown( | |
choices=["None", "Summarization", "Sentiment Analysis", "Named Entity Recognition"], | |
value="None", | |
label="Select Analysis Type", | |
container=False | |
) | |
], | |
title="MultiModal Analyzer", | |
description="Upload documents or images, select a model and analysis type to interact with your content." | |
) | |
MultiModalAnalyzer.launch() | |