|
import asyncio |
|
import logging |
|
import os |
|
import sys |
|
import uvicorn |
|
from fastapi import FastAPI |
|
from schemas import CriticizeSolutionsRequest, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput |
|
from jinja2 import Environment, FileSystemLoader |
|
from litellm import acompletion |
|
from litellm.router import Router |
|
from dotenv import load_dotenv |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
) |
|
|
|
|
|
load_dotenv() |
|
|
|
if "LLM_MODEL" not in os.environ or "LLM_API_KEY" not in os.environ: |
|
logging.error( |
|
"No LLM token (`LLM_API_KEY`) and/or LLM model (`LLM_MODEL`) were provided in the env vars. Exiting") |
|
sys.exit(-1) |
|
|
|
|
|
llm_router = Router(model_list=[ |
|
{ |
|
"model_name": "chat", |
|
"litellm_params": { |
|
"model": os.environ.get("LLM_MODEL"), |
|
"api_key": os.environ.get("LLM_API_KEY"), |
|
"max_retries": 5 |
|
} |
|
} |
|
]) |
|
|
|
|
|
prompt_env = Environment(loader=FileSystemLoader('prompts'), enable_async=True) |
|
|
|
api = FastAPI() |
|
|
|
|
|
@api.post("/categorize_requirements") |
|
async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse: |
|
"""Categorize the given service requirements into categories""" |
|
|
|
MAX_ATTEMPTS = 5 |
|
|
|
categories: list[_ReqGroupingCategory] = [] |
|
messages = [] |
|
|
|
|
|
req_prompt = await prompt_env.get_template("classify.txt").render_async(**{ |
|
"requirements": [rq.model_dump() for rq in params.requirements], |
|
"max_n_categories": params.max_n_categories, |
|
"response_schema": _ReqGroupingOutput.model_json_schema()}) |
|
|
|
logging.info(req_prompt) |
|
|
|
|
|
messages.append({"role": "user", "content": req_prompt}) |
|
|
|
|
|
for attempt in range(MAX_ATTEMPTS): |
|
req_completion = await llm_router.acompletion(model="chat", messages=messages, response_format=_ReqGroupingOutput) |
|
output = _ReqGroupingOutput.model_validate_json( |
|
req_completion.choices[0].message.content) |
|
|
|
|
|
assigned_ids = { |
|
req_id for cat in output.categories for req_id in cat.items} |
|
unassigned_ids = set(range(1, len(params.requirements))) - assigned_ids |
|
|
|
if len(unassigned_ids) == 0: |
|
categories.extend(output.categories) |
|
break |
|
else: |
|
messages.append(req_completion.choices[0].message) |
|
messages.append( |
|
{"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."}) |
|
|
|
if attempt == MAX_ATTEMPTS - 1: |
|
raise Exception("Failed to classify all requirements") |
|
|
|
|
|
|
|
final_categories = [] |
|
for idx, cat in enumerate(output.categories): |
|
final_categories.append(ReqGroupingCategory( |
|
id=idx, |
|
title=cat.title, |
|
requirements=[params.requirements[i] |
|
for i in cat.items if i < len(params.requirements)] |
|
)) |
|
|
|
return ReqGroupingResponse(categories=final_categories) |
|
|
|
|
|
@api.post("/criticize_solution") |
|
async def criticize_solution(params: CriticizeSolutionsRequest) -> str: |
|
req_prompt = await prompt_env.get_template("criticize.txt").render_async(solutions=[sol.model_dump() for sol in params.solutions]) |
|
req_completion = await llm_router.acompletion(model="chat", messages=[{"role": "user", "content": req_prompt}]) |
|
|
|
return req_completion.choices[0].message.content |
|
|
|
|
|
uvicorn.run(api, host="0.0.0.0", port=8000) |
|
|