SY23 / app.py
Last commit not found
raw
history blame
5.68 kB
import gradio as gr
import torch
from transformers import (
BlipProcessor,
BlipForQuestionAnswering,
pipeline,
AutoTokenizer,
AutoModelForCausalLM
)
from PIL import Image
import os
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class MultimodalProcessor:
def __init__(self):
self.load_models()
def load_models(self):
"""Charge les modèles avec gestion d'erreurs"""
try:
logger.info("Chargement des modèles...")
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
self.blip_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
self.audio_transcriber = pipeline("automatic-speech-recognition",
model="openai/whisper-small")
self.text_generator = pipeline("text-generation",
model="gpt2")
logger.info("Modèles chargés avec succès")
except Exception as e:
logger.error(f"Erreur lors du chargement des modèles: {str(e)}")
raise
def analyze_image(self, image):
"""Analyse une image et retourne une description"""
try:
if image is None:
return ""
questions = [
"What is in the picture?",
"What are the main colors?",
"What is the setting or background?",
"What is happening in the image?",
]
responses = {}
for question in questions:
inputs = self.blip_processor(images=image, text=question, return_tensors="pt")
outputs = self.blip_model.generate(**inputs)
answer = self.blip_processor.decode(outputs[0], skip_special_tokens=True)
responses[question] = answer
description = (
f"This image shows {responses['What is in the picture?']}. "
f"The main colors are {responses['What are the main colors?']}. "
f"The setting is {responses['What is the setting or background?']}. "
f"In the scene, {responses['What is happening in the image?']}"
)
return description
except Exception as e:
logger.error(f"Erreur lors de l'analyse de l'image: {str(e)}")
return "Erreur lors de l'analyse de l'image."
def transcribe_audio(self, audio_path):
"""Transcrit un fichier audio"""
try:
if audio_path is None:
return ""
return self.audio_transcriber(audio_path)["text"]
except Exception as e:
logger.error(f"Erreur lors de la transcription audio: {str(e)}")
return "Erreur lors de la transcription audio."
def generate_text(self, prompt):
"""Génère du texte à partir d'un prompt"""
try:
if not prompt:
return ""
response = self.text_generator(prompt,
max_length=200,
num_return_sequences=1)[0]["generated_text"]
return response
except Exception as e:
logger.error(f"Erreur lors de la génération de texte: {str(e)}")
return "Erreur lors de la génération de texte."
def process_inputs(self, image, audio, text):
"""Traite les entrées multimodales"""
try:
# Analyse de l'image
image_description = self.analyze_image(image) if image is not None else ""
# Transcription audio
audio_text = self.transcribe_audio(audio) if audio is not None else ""
# Combinaison des entrées
combined_input = ""
if image_description:
combined_input += f"Visual description: {image_description}\n"
if audio_text:
combined_input += f"Audio content: {audio_text}\n"
if text:
combined_input += f"Additional context: {text}\n"
# Génération du prompt final
if combined_input:
final_prompt = self.generate_text(combined_input)
else:
final_prompt = "Aucune entrée fournie."
return final_prompt
except Exception as e:
logger.error(f"Erreur lors du traitement des entrées: {str(e)}")
return "Une erreur est survenue lors du traitement des entrées."
def create_interface():
"""Crée l'interface Gradio"""
processor = MultimodalProcessor()
interface = gr.Interface(
fn=processor.process_inputs,
inputs=[
gr.Image(type="pil", label="Télécharger une image"),
gr.Audio(type="filepath", label="Télécharger un fichier audio"),
gr.Textbox(label="Entrez du texte additionnel")
],
outputs=[
gr.Textbox(label="Description générée")
],
title="Analyseur de Contenu Multimodal",
description="""
Cette application analyse vos contenus multimodaux :
- Images : génère une description détaillée
- Audio : transcrit le contenu
- Texte : enrichit la description
La sortie combine toutes ces informations en une description cohérente.
"""
)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.launch()