Spaces:
Build error
Build error
File size: 2,200 Bytes
d660b02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import opik
from fastapi import FastAPI, HTTPException
from loguru import logger
from opik import opik_context
from pydantic import BaseModel
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from llm_engineering import settings
from llm_engineering.application.rag.retriever import ContextRetriever
from llm_engineering.application.utils import misc
from llm_engineering.domain.embedded_chunks import EmbeddedChunk
from llm_engineering.infrastructure.opik_utils import configure_opik
from llm_engineering.model.inference import InferenceExecutor, LLMInferenceOLLAMA
configure_opik()
app = FastAPI()
class QueryRequest(BaseModel):
query: str
class QueryResponse(BaseModel):
answer: str
@opik.track
def call_llm_service(query: HumanMessage, history: list, context: str | None = None) -> str:
llm = LLMInferenceOLLAMA(model_name=settings.LLAMA_MODEL_ID)
answer = InferenceExecutor(llm, query, context).execute()
return answer
@opik.track
def rag(query, history: list) -> str:
retriever = ContextRetriever(mock=False)
if len(history) == 0:
content = query.content
else:
content = query.content + history[-1].content
documents = retriever.search(content, k=3)
context = EmbeddedChunk.to_context(documents)
answer = call_llm_service(query, history , context)
#opik_context.update_current_trace(
# tags=["rag"],
# metadata={
# "model_id": settings.HF_MODEL_ID,
# "embedding_model_id": settings.TEXT_EMBEDDING_MODEL_ID,
# "temperature": settings.TEMPERATURE_INFERENCE,
# "query_tokens": misc.compute_num_tokens(query),
# "context_tokens": misc.compute_num_tokens(context),
# "answer_tokens": misc.compute_num_tokens(answer),
# },
#)
#
return answer
@app.post("/rag", response_model=QueryResponse)
async def rag_endpoint(request: QueryRequest):
try:
answer = rag(query=request.query)
return {"answer": answer}
except Exception as e:
print(e)
raise HTTPException(status_code=500, detail=str(e)) from e
|