text_prompt / app.py
bambadij's picture
fixe
f821485
raw
history blame
2.37 kB
from fastapi import FastAPI, HTTPException, status, UploadFile, File
from pydantic import BaseModel
import uvicorn
import logging
import os
import requests
from fastapi.middleware.cors import CORSMiddleware
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache'
os.environ['HF_HOME'] = '/app/.cache'
Informations = """
-text : Texte à résumer
output:
- Text summary : texte résumé
"""
app = FastAPI(
title='Text Summary',
description=Informations
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
DEFAULT_PROMPT = "Résumez la plainte suivante en 5 phrases concises, en vous concentrant sur les faits principaux et en évitant toute introduction générique : "
class TextSummary(BaseModel):
prompt: str
class RequestModel(BaseModel):
text: str
OLLAMA_URL = "http://localhost:11434" # URL d'Ollama dans le conteneur
@app.get("/")
async def home():
return 'STN BIG DATA'
@app.post("/generate/")
async def generate_text(request: RequestModel):
try:
full_prompt = DEFAULT_PROMPT + request.text
response = requests.post(f"{OLLAMA_URL}/api/generate", json={
"prompt": full_prompt,
"stream": False,
"model": "llama3"
})
if response.status_code != 200:
raise HTTPException(status_code=response.status_code, detail="Erreur de l'API Ollama")
generated_text = response.json().get('response', '')
intro_phrases = [
"Voici un résumé de la plainte en 5 phrases :",
"Résumé :",
"Voici ce qui s'est passé :",
"Cette plainte a été déposée par"
]
for phrase in intro_phrases:
if generated_text.startswith(phrase):
generated_text = generated_text[len(phrase):].strip()
break
return {"summary_text_2": generated_text}
except requests.RequestException as e:
raise HTTPException(status_code=500, detail=f"Erreur de requête : {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur inattendue : {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)