import asyncio import logging import os import sys import uvicorn from fastapi import APIRouter, FastAPI from schemas import _RefinedSolutionModel, _SearchedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput, SolutionCriticism, SolutionModel, SolutionSearchResponse from jinja2 import Environment, FileSystemLoader, StrictUndefined from litellm.router import Router from dotenv import load_dotenv logging.basicConfig( level=logging.INFO, format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) # Load .env files load_dotenv() if "LLM_MODEL" not in os.environ or "LLM_API_KEY" not in os.environ: logging.error( "No LLM token (`LLM_API_KEY`) and/or LLM model (`LLM_MODEL`) were provided in the env vars. Exiting") sys.exit(-1) # LiteLLM router llm_router = Router(model_list=[ { "model_name": "chat", "litellm_params": { "model": os.environ.get("LLM_MODEL"), "api_key": os.environ.get("LLM_API_KEY"), "rpm": 15, "max_parallel_requests": 4, "allowed_fails": 1, "cooldown_time": 60, "max_retries": 10, } } ], num_retries=10, retry_after=30) # Jinja2 environment to load prompt templates prompt_env = Environment(loader=FileSystemLoader( 'prompts'), enable_async=True, undefined=StrictUndefined) api = FastAPI(docs_url="/") # requirements routes requirements_router = APIRouter(prefix="/reqs", tags=["requirements"]) # solution routes solution_router = APIRouter(prefix="/solution", tags=["solution"]) @requirements_router.post("/categorize_requirements") async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse: """Categorize the given service requirements into categories""" MAX_ATTEMPTS = 5 categories: list[_ReqGroupingCategory] = [] messages = [] # categorize the requirements using their indices req_prompt = await prompt_env.get_template("classify.txt").render_async(**{ "requirements": [rq.model_dump() for rq in params.requirements], "max_n_categories": params.max_n_categories, "response_schema": _ReqGroupingOutput.model_json_schema()}) # add system prompt with requirements messages.append({"role": "user", "content": req_prompt}) # ensure all requirements items are processed for attempt in range(MAX_ATTEMPTS): req_completion = await llm_router.acompletion(model="chat", messages=messages, response_format=_ReqGroupingOutput) output = _ReqGroupingOutput.model_validate_json( req_completion.choices[0].message.content) # quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category valid_ids_universe = set(range(0, len(params.requirements))) assigned_ids = { req_id for cat in output.categories for req_id in cat.items} # keep only non-hallucinated, valid assigned ids valid_assigned_ids = assigned_ids.intersection(valid_ids_universe) # check for remaining requirements assigned to none of the categories unassigned_ids = valid_ids_universe - valid_assigned_ids if len(unassigned_ids) == 0: categories.extend(output.categories) break else: messages.append(req_completion.choices[0].message) messages.append( {"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."}) if attempt == MAX_ATTEMPTS - 1: raise Exception("Failed to classify all requirements") # build the final category objects # remove the invalid (likely hallucinated) requirement IDs final_categories = [] for idx, cat in enumerate(output.categories): final_categories.append(ReqGroupingCategory( id=idx, title=cat.title, requirements=[params.requirements[i] for i in cat.items if i < len(params.requirements)] )) return ReqGroupingResponse(categories=final_categories) # ========================================================= Solution Endpoints =========================================================== @solution_router.post("/search_solutions_gemini", response_model=SolutionSearchResponse) async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchResponse: """Searches solutions using Gemini and grounded on google search""" async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel: # ================== generate the solution with web grounding req_prompt = await prompt_env.get_template("search_solution.txt").render_async(**{ "category": cat.model_dump(), }) # generate the completion in non-structured mode. # the googleSearch tool enables grounding gemini with google search # this also forces gemini to perform a tool call req_completion = await llm_router.acompletion(model="chat", messages=[ {"role": "user", "content": req_prompt} ], tools=[{"googleSearch": {}}], tool_choice="required") # ==================== structure the solution as a json =================================== structured_prompt = await prompt_env.get_template("structure_solution.txt").render_async(**{ "solution": req_completion.choices[0].message.content, "response_schema": _SearchedSolutionModel.model_json_schema() }) structured_completion = await llm_router.acompletion(model="chat", messages=[ {"role": "user", "content": structured_prompt} ], response_format=_SearchedSolutionModel) solution_model = _SearchedSolutionModel.model_validate_json( structured_completion.choices[0].message.content) # ======================== build the final solution object ================================ # extract the source metadata from the search items sources_metadata = [ f'{a["web"]["title"]} - {a["web"]["uri"]}' for a in req_completion["vertex_ai_grounding_metadata"][0]['groundingChunks']] final_sol = SolutionModel( Context="", Requirements=[ cat.requirements[i].requirement for i in solution_model.requirement_ids ], Problem_Description=solution_model.problem_description, Solution_Description=solution_model.solution_description, References=sources_metadata, Category_Id=cat.id, ) return final_sol solutions = await asyncio.gather(*[_search_inner(cat) for cat in params.categories], return_exceptions=True) logging.info(solutions) final_solutions = [ sol for sol in solutions if not isinstance(sol, Exception)] return SolutionSearchResponse(solutions=final_solutions) @solution_router.post("/criticize_solution", response_model=CritiqueResponse) async def criticize_solution(params: CriticizeSolutionsRequest) -> CritiqueResponse: """Criticize the challenges, weaknesses and limitations of the provided solutions.""" async def __criticize_single(solution: SolutionModel): req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{ "solutions": [solution.model_dump()], "response_schema": _SolutionCriticismOutput.model_json_schema() }) req_completion = await llm_router.acompletion( model="chat", messages=[{"role": "user", "content": req_prompt}], response_format=_SolutionCriticismOutput ) criticism_out = _SolutionCriticismOutput.model_validate_json( req_completion.choices[0].message.content ) return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0]) critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False) return CritiqueResponse(critiques=critiques) @solution_router.post("/refine_solutions", response_model=SolutionSearchResponse) async def refine_solutions(params: CritiqueResponse) -> SolutionSearchResponse: """Refines the previously critiqued solutions.""" async def __refine_solution(crit: SolutionCriticism): req_prompt = await prompt_env.get_template("refine_solution.txt").render_async(**{ "solution": crit.solution.model_dump(), "criticism": crit.criticism, "response_schema": _RefinedSolutionModel.model_json_schema(), }) req_completion = await llm_router.acompletion(model="chat", messages=[ {"role": "user", "content": req_prompt} ], response_format=_RefinedSolutionModel) req_model = _RefinedSolutionModel.model_validate_json( req_completion.choices[0].message.content) # copy previous solution model refined_solution = crit.solution.model_copy(deep=True) refined_solution.Problem_Description = req_model.problem_description refined_solution.Solution_Description = req_model.solution_description return refined_solution refined_solutions = await asyncio.gather(*[__refine_solution(crit) for crit in params.critiques], return_exceptions=False) return SolutionSearchResponse(solutions=refined_solutions) api.include_router(requirements_router) api.include_router(solution_router) uvicorn.run(api, host="0.0.0.0", port=8000)