Trabis commited on
Commit
fe6115e
·
verified ·
1 Parent(s): 1ac7dd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -39
app.py CHANGED
@@ -443,77 +443,154 @@
443
 
444
 
445
 
446
- from fastapi import FastAPI, Request
447
  from fastapi.responses import StreamingResponse, HTMLResponse
448
  from fastapi.staticfiles import StaticFiles
449
  import uvicorn
450
- import asyncio # Pour le streaming
451
-
452
- # ... (imports OptimizedRAGLoader, LLM, etc.)
453
-
454
- # --- Initialisation (comme avant) ---
455
- # rag_loader = OptimizedRAGLoader()
456
- # llm = ChatGoogleGenerativeAI(...) # ou autre
457
- # retriever = rag_loader.get_retriever(...)
458
- # prompt_template = ChatPromptTemplate.from_messages(...)
459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  app = FastAPI()
461
 
462
- # --- Fonction backend modifiée (pour API) ---
463
- # Doit être adaptable pour streaming ou réponse unique
464
- def get_llm_response_stream(question: str):
465
- # Réutilise la logique de process_question mais génère des chunks de texte
466
- # Peut nécessiter des ajustements pour le format de streaming API (e.g., Server-Sent Events)
 
 
 
 
 
 
 
 
467
  print(f"API processing question: {question}")
468
  try:
 
469
  relevant_docs = retriever(question)
470
- # ... (logique pour créer context, sources) ...
471
- context_str = "..."
472
- sources_str = "..."
 
 
 
473
 
474
  if not relevant_docs:
475
- yield "data: لم أتمكن من العثور على معلومات ذات صلة.\n\n" # Format SSE
 
 
476
  return
477
 
 
478
  prompt = prompt_template.format_messages(context=context_str, question=question)
479
 
480
  full_response = ""
 
481
  stream = llm.stream(prompt)
482
  for chunk in stream:
483
  content = chunk.content if hasattr(chunk, 'content') else str(chunk)
484
  if content:
485
- # Format pour Server-Sent Events (SSE)
486
- # Chaque message doit être préfixé par "data: " et finir par "\n\n"
487
- formatted_chunk = content.replace('\n', '\ndata: ') # Gère les sauts de ligne dans le chunk
488
- yield f"data: {formatted_chunk}\n\n"
489
- full_response += content # Accumule pour référence interne si besoin
490
-
491
- # Envoyer les sources à la fin (aussi en format SSE)
492
- yield f"data: {sources_str}\n\n"
 
 
 
493
 
494
  except Exception as e:
495
  print(f"Error during API LLM generation: {e}")
496
- yield f"data: Erreur: {str(e)}\n\n"
 
 
 
 
497
 
498
  # --- Endpoint API ---
499
  @app.post("/ask")
500
  async def handle_ask(request: Request):
501
- data = await request.json()
502
- question = data.get("question")
503
- if not question:
504
- return {"error": "Question manquante"}, 400
505
 
506
- # Pour une réponse non-streamée (plus simple au début)
507
- # response_content = "".join(list(get_llm_response_stream(question))) # Collecter tout le stream
508
- # return {"answer": response_content}
 
 
 
 
 
 
 
 
 
509
 
510
- # Pour une réponse streamée (Server-Sent Events)
511
- return StreamingResponse(get_llm_response_stream(question), media_type="text/event-stream")
512
 
513
  # --- Servir les fichiers statiques (HTML/JS/CSS) ---
 
514
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
515
 
516
- # --- Démarrage du serveur (pour exécution locale/Spaces) ---
517
- # La commande de démarrage dans Spaces sera typiquement `uvicorn app:app --host 0.0.0.0 --port 7860`
518
  if __name__ == "__main__":
519
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
443
 
444
 
445
 
446
+ from fastapi import FastAPI, Request, HTTPException
447
  from fastapi.responses import StreamingResponse, HTMLResponse
448
  from fastapi.staticfiles import StaticFiles
449
  import uvicorn
450
+ import asyncio
451
+ import os # Assurez-vous que 'os' est importé si vous l'utilisez pour les clés API, etc.
 
 
 
 
 
 
 
452
 
453
+ # --- Vos imports (Document, LLM, PromptTemplate, etc.) ---
454
+ # from langchain_google_genai import ChatGoogleGenerativeAI
455
+ # from langchain.prompts import ChatPromptTemplate
456
+ # ... autres imports nécessaires ...
457
+ # from your_rag_module import OptimizedRAGLoader # Assurez-vous que la classe est importable
458
+
459
+ # --- Variables globales (initialisées à None) ---
460
+ rag_loader = None
461
+ llm = None
462
+ retriever = None
463
+ prompt_template = None
464
+ initialization_error = None # Pour stocker une erreur d'initialisation
465
+
466
+ # --- Bloc d'initialisation robuste ---
467
+ print("--- Starting Application Initialization ---")
468
+ try:
469
+ # Initialisation du LLM
470
+ print("Initializing LLM...")
471
+ gemini_api_key = os.getenv("GEMINI_KEY")
472
+ if not gemini_api_key:
473
+ raise ValueError("GEMINI_KEY environment variable not set.")
474
+ llm = ChatGoogleGenerativeAI(
475
+ model="gemini-1.5-flash",
476
+ temperature=0.1,
477
+ google_api_key=gemini_api_key,
478
+ )
479
+ print("LLM Initialized.")
480
+
481
+ # Initialisation RAG Loader et Retriever
482
+ print("Initializing RAG Loader...")
483
+ # Assurez-vous que OptimizedRAGLoader est défini ou importé correctement
484
+ rag_loader = OptimizedRAGLoader() # Cette ligne peut échouer (chargement modèles/index)
485
+ print("RAG Loader Initialized. Getting Retriever...")
486
+ retriever = rag_loader.get_retriever(k=15, rerank_k=5) # Cette ligne dépend de rag_loader
487
+ print("Retriever Initialized.")
488
+
489
+ # Initialisation du Prompt Template
490
+ print("Initializing Prompt Template...")
491
+ prompt_template = ChatPromptTemplate.from_messages([
492
+ ("system", """أنت مساعد قانوني خبير... (votre prompt système complet ici) ...السؤال المطلوب الإجابة عليه: {question}"""),
493
+ ("human", "{question}")
494
+ ])
495
+ print("Prompt Template Initialized.")
496
+
497
+ print("--- Application Initialization Successful ---")
498
+
499
+ except Exception as e:
500
+ print(f"!!!!!!!!!! FATAL INITIALIZATION ERROR !!!!!!!!!!")
501
+ print(f"Error during startup: {e}")
502
+ import traceback
503
+ traceback.print_exc() # Affiche la trace complète de l'erreur dans les logs
504
+ initialization_error = str(e) # Stocke l'erreur pour l'API
505
+ # On laisse les variables globales à None si l'initialisation échoue
506
+
507
+ # --- FastAPI App ---
508
  app = FastAPI()
509
 
510
+ # --- Fonction backend modifiée ---
511
+ # (get_llm_response_stream - Gardez la version précédente qui gère le streaming SSE)
512
+ # Assurez-vous qu'elle utilise les variables globales llm, retriever, prompt_template
513
+ async def get_llm_response_stream(question: str):
514
+ # *** Vérification cruciale au début de la fonction ***
515
+ if initialization_error:
516
+ yield f"data: Erreur critique lors de l'initialisation du serveur: {initialization_error}\n\n"
517
+ return
518
+ if not retriever or not llm or not prompt_template:
519
+ yield f"data: Erreur: Un ou plusieurs composants serveur (LLM, Retriever, Prompt) ne sont pas initialisés.\n\n"
520
+ return
521
+ # *** Fin de la vérification ***
522
+
523
  print(f"API processing question: {question}")
524
  try:
525
+ # Utilisation de la variable globale 'retriever'
526
  relevant_docs = retriever(question)
527
+ # ... (le reste de votre logique pour context, sources, llm.stream) ...
528
+
529
+ context_str = "\n\n".join([f"المصدر: {doc.metadata.get('source', 'غير معروف')}\nالمحتوى: {doc.page_content}" for doc in relevant_docs]) if relevant_docs else "لا يوجد سياق"
530
+ sources = sorted(list(set([os.path.splitext(doc.metadata.get("source", "غير معروف"))[0] for doc in relevant_docs]))) if relevant_docs else []
531
+ sources_str = "\n\n\nالمصادر المحتملة التي تم الرجوع إليها:\n- " + "\n- ".join(sources) if sources else ""
532
+
533
 
534
  if not relevant_docs:
535
+ # Gérer le cas il n'y a pas de documents
536
+ yield f"data: لم أتمكن من العثور على معلومات ذات صلة في المستندات.\n\n"
537
+ # Optionnel: appeler le LLM sans contexte ou s'arrêter ici
538
  return
539
 
540
+ # Utilisation de la variable globale 'prompt_template'
541
  prompt = prompt_template.format_messages(context=context_str, question=question)
542
 
543
  full_response = ""
544
+ # Utilisation de la variable globale 'llm'
545
  stream = llm.stream(prompt)
546
  for chunk in stream:
547
  content = chunk.content if hasattr(chunk, 'content') else str(chunk)
548
  if content:
549
+ formatted_chunk = content.replace('\n', '\ndata: ')
550
+ yield f"data: {formatted_chunk}\n\n" # Format SSE
551
+ full_response += content
552
+
553
+ # Envoyer les sources à la fin
554
+ if sources_str:
555
+ # Assurez-vous que sources_str est bien formaté pour SSE s'il contient des sauts de ligne
556
+ sources_sse = sources_str.replace('\n', '\ndata: ')
557
+ yield f"data: {sources_sse}\n\n"
558
+ # Signal de fin (optionnel mais utile pour le client JS)
559
+ yield "event: end\ndata: Stream finished\n\n"
560
 
561
  except Exception as e:
562
  print(f"Error during API LLM generation: {e}")
563
+ import traceback
564
+ traceback.print_exc() # Affiche l'erreur dans les logs serveur
565
+ yield f"data: حدث خطأ أثناء معالجة طلبك: {str(e)}\n\n"
566
+ yield "event: error\ndata: Stream error\n\n" # Signale une erreur au client
567
+
568
 
569
  # --- Endpoint API ---
570
  @app.post("/ask")
571
  async def handle_ask(request: Request):
572
+ # Vérifie si l'initialisation globale a échoué dès le début
573
+ if initialization_error:
574
+ raise HTTPException(status_code=500, detail=f"Erreur d'initialisation serveur: {initialization_error}")
 
575
 
576
+ try:
577
+ data = await request.json()
578
+ question = data.get("question")
579
+ if not question:
580
+ raise HTTPException(status_code=400, detail="Question manquante dans la requête JSON")
581
+
582
+ # Retourne la réponse streamée
583
+ return StreamingResponse(get_llm_response_stream(question), media_type="text/event-stream")
584
+
585
+ except Exception as e:
586
+ print(f"Error in /ask endpoint: {e}")
587
+ raise HTTPException(status_code=500, detail=f"Erreur interne du serveur: {str(e)}")
588
 
 
 
589
 
590
  # --- Servir les fichiers statiques (HTML/JS/CSS) ---
591
+ # Assurez-vous que le dossier 'static' existe et contient index.html, script.js, style.css
592
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
593
 
594
+ # --- Démarrage du serveur ---
 
595
  if __name__ == "__main__":
596
  uvicorn.run(app, host="0.0.0.0", port=7860)