IA2_model / app.py
AshenClock's picture
Update app.py
34e60a5 verified
raw
history blame
3.9 kB
import os
from fastapi import FastAPI, HTTPException
from huggingface_hub import InferenceClient
from rdflib import Graph
from pydantic import BaseModel
# Configurazione API Hugging Face
API_KEY = os.getenv("HF_API_KEY")
client = InferenceClient(api_key=API_KEY)
# File RDF
RDF_FILE = "Progetto.rdf"
# Carica il file RDF
def load_rdf():
if os.path.exists(RDF_FILE):
with open(RDF_FILE, "r") as f:
return f.read()
return ""
rdf_context = load_rdf()
# Valida le query SPARQL
def validate_sparql_query(query, rdf_data):
try:
g = Graph()
g.parse(data=rdf_data, format="xml")
g.query(query)
return True
except Exception:
return False
# FastAPI app
app = FastAPI()
# Modello di input per richieste POST
class QueryRequest(BaseModel):
message: str
max_tokens: int = 512
temperature: float = 0.7
# Messaggio di sistema con RDF incluso
def create_system_message(rdf_context):
return f"""
Sei un assistente specializzato nella generazione e spiegazione di query SPARQL basate su dati RDF.
La base di conoscenza RDF è la seguente:
{rdf_context}
Il tuo compito principale è:
1. Analizzare lo schema RDF o i dati RDF forniti e la domanda in linguaggio naturale posta dall'utente.
2. Generare una query SPARQL valida che recuperi le informazioni richieste dai dati RDF.
3. Spiegare in modo prolisso e naturale il significato dei risultati, come farebbe una guida in un museo.
Regole:
- Genera solo query SPARQL in una singola riga, senza formattazioni aggiuntive.
- Se la domanda non può essere soddisfatta con una query SPARQL, rispondi con: \"Non posso generare una query per questa domanda.\"
"""
# Funzione per inviare la richiesta al modello Hugging Face
async def generate_response(message, max_tokens, temperature):
system_message = create_system_message(rdf_context)
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": message}
]
try:
# Usa il metodo chat.completions.create per lo streaming dei risultati
stream = client.chat.completions.create(
model="Qwen/Qwen2.5-72B-Instruct",
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=0.7,
stream=True
)
response = ""
for chunk in stream:
if "choices" in chunk and len(chunk["choices"]) > 0:
response += chunk["choices"][0]["delta"]["content"]
return response.strip()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}")
# Endpoint per generare query SPARQL
@app.post("/generate-query/")
async def generate_query(request: QueryRequest):
# Genera risposta
response = await generate_response(request.message, request.max_tokens, request.temperature)
# Valida la query se possibile
if response.startswith("SELECT") or response.startswith("ASK"):
is_valid = validate_sparql_query(response, rdf_context)
if not is_valid:
raise HTTPException(status_code=400, detail="La query generata non è valida rispetto al file RDF fornito.")
# Correzione f-string: usiamo `.replace()` in modo sicuro
explanation = f"La query generata è: {response.replace('\\n', ' ').strip()}. "
explanation += "Questa query è stata progettata per estrarre informazioni specifiche dai dati RDF, consentendo di rispondere alla tua domanda. Risultati ottenuti da questa query possono includere dettagli come entità, relazioni o attributi collegati al contesto fornito."
return {"query": response.replace("\\n", " ").strip(), "explanation": explanation}
# Endpoint per verificare se il server è attivo
@app.get("/")
async def root():
return {"message": "Il server è attivo e pronto a generare query SPARQL!"}