OmniMed_SIA / app.py
analist's picture
Update app.py
7e39dd0 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Définir le modèle et le tokenizer
# Utilisation d'un modèle français pour le domaine médical
MODEL_NAME = "analist/llama3.1-8B-omnimed-rl" # Vous pouvez utiliser un modèle plus adapté au français comme "camembert" ou un modèle médical spécifique
# Fonction pour charger le modèle et le tokenizer
def load_model():
print("Chargement du modèle et du tokenizer...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = MODEL_NAME,
max_seq_length = 8192,
load_in_4bit = True,
token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
return model, tokenizer
# Charger le modèle et le tokenizer
model, tokenizer = load_model()
# Créer un pipeline de génération de texte
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
# Message initial pour orienter le modèle vers le domaine médical
SYSTEM_PROMPT = """Tu es un assistant médical IA nommé MediBot. Tu fournis des informations médicales générales et des conseils de santé basés sur des données scientifiques.
Important: Tu n'es pas un médecin et tu dois toujours recommander à l'utilisateur de consulter un professionnel de santé pour un diagnostic ou un traitement spécifique.
Réponds aux questions médicales de manière précise et claire, en te basant sur des informations médicales vérifiées."""
# Fonction pour générer une réponse
def generate_response(message, chat_history):
# Construire le prompt avec l'historique de conversation et le nouveau message
prompt = SYSTEM_PROMPT + "\n\n"
# Ajouter l'historique du chat
for user_msg, bot_msg in chat_history:
prompt += f"Utilisateur: {user_msg}\nMediBot: {bot_msg}\n\n"
# Ajouter le nouveau message
prompt += f"Utilisateur: {message}\nMediBot:"
# Générer la réponse
response = generator(prompt, max_new_tokens=256)[0]["generated_text"]
# Extraire seulement la partie réponse du modèle
response_only = response.split("MediBot:")[-1].strip()
# Ajouter un rappel de consulter un professionnel
if len(response_only) > 0 and not "consulter un professionnel" in response_only.lower():
response_only += "\n\n(N'oubliez pas que ces informations sont générales. Pour des conseils personnalisés, consultez un professionnel de santé.)"
return response_only
# Fonction pour gérer l'historique du chat Gradio
def chatbot_response(message, history):
bot_message = generate_response(message, history)
history.append((message, bot_message))
return "", history
# Interface Gradio
with gr.Blocks(title="MediBot - Assistant Médical IA") as demo:
gr.Markdown("# MediBot - Votre Assistant Médical basé sur l'IA")
gr.Markdown("""
### ⚠️ IMPORTANT - AVERTISSEMENT MÉDICAL ⚠️
Ce chatbot utilise l'intelligence artificielle pour fournir des informations médicales générales.
Il ne remplace en aucun cas l'avis d'un professionnel de santé qualifié.
Pour toute question médicale sérieuse, veuillez consulter un médecin.
""")
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(placeholder="Posez votre question médicale ici...", label="Votre question")
clear = gr.Button("Effacer la conversation")
# Configurer les événements
msg.submit(chatbot_response, [msg, chatbot], [msg, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
gr.Markdown("""
### Exemples de questions que vous pouvez poser:
- Quels sont les symptômes courants de la grippe?
- Comment prévenir l'hypertension artérielle?
- Quels aliments sont recommandés pour les diabétiques?
- Quelles sont les causes fréquentes des migraines?
""")
# Lancer l'application
if __name__ == "__main__":
demo.launch(share=True) # share=True permet de générer un lien public temporaire