import logging import os import sys import uvicorn from fastapi import FastAPI from schemas import RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput from jinja2 import Environment, FileSystemLoader from litellm import acompletion from dotenv import load_dotenv logging.basicConfig( level=logging.INFO, format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) # Load .env files load_dotenv() if "LLM_API_MODEL" not in os.environ or "LLM_API_KEY" not in os.environ: logging.error( "No LLM token (`LLM_API_TOKEN`) and/or LLM model (`LLM_API_KEY`) were provided in the env vars. Exiting") sys.exit(-1) LLM_API_MODEL = os.environ.get("LLM_API_MODEL") LLM_API_KEY = os.environ.get("LLM_API_KEY") # Jinja2 environment to load prompt templates prompt_env = Environment(loader=FileSystemLoader('prompts'), enable_async=True) fastapi = FastAPI() @fastapi.post("/categorize_requirements") async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse: """Categorize the given service requirements into categories""" MAX_ATTEMPTS = 5 categories: list[_ReqGroupingCategory] = [] messages = [] # categorize the requirements using their indices req_prompt = await prompt_env.get_template("classify.txt").render_async(**{ "requirements": [rq.model_dump() for rq in params.requirements], "max_n_categories": params.max_n_categories, "response_schema": _ReqGroupingOutput.model_json_schema()}) # add system prompt with requirements messages.append({"role": "user", "content": req_prompt}) # ensure all requirements items are processed for attempt in range(MAX_ATTEMPTS): req_completion = await acompletion(model=LLM_API_MODEL, api_key=LLM_API_KEY, messages=messages, response_format=_ReqGroupingOutput) output = _ReqGroupingOutput.model_validate_json( req_completion.choices[0].message.content) # quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category assigned_ids = { req_id for cat in output.categories for req_id in cat.items} unassigned_ids = set(range(1, len(params.requirements))) - assigned_ids if len(unassigned_ids) == 0: categories.extend(output.categories) break else: messages.append(req_completion.choices[0].message) messages.append( {"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."}) if attempt == MAX_ATTEMPTS - 1: raise Exception("Failed to classify all requirements") # build the final category objects # remove the invalid (likely hallucinated) requirement IDs final_categories = [] for idx, cat in enumerate(output.categories): final_categories.append(ReqGroupingCategory( id=idx, title=cat.title, requirements=[params.requirements[i] for i in cat.items if i < len(params.requirements)] )) return ReqGroupingResponse(categories=final_categories) uvicorn.run(fastapi, host="0.0.0.0", port=8000)