|
import asyncio |
|
import logging |
|
import os |
|
import sys |
|
import uvicorn |
|
from fastapi import APIRouter, FastAPI |
|
from schemas import _RefinedSolutionModel, _SearchedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput, SolutionCriticism, SolutionModel, SolutionSearchResponse |
|
from jinja2 import Environment, FileSystemLoader, StrictUndefined |
|
from litellm.router import Router |
|
from dotenv import load_dotenv |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
) |
|
|
|
|
|
load_dotenv() |
|
|
|
if "LLM_MODEL" not in os.environ or "LLM_API_KEY" not in os.environ: |
|
logging.error( |
|
"No LLM token (`LLM_API_KEY`) and/or LLM model (`LLM_MODEL`) were provided in the env vars. Exiting") |
|
sys.exit(-1) |
|
|
|
|
|
llm_router = Router(model_list=[ |
|
{ |
|
"model_name": "chat", |
|
"litellm_params": { |
|
"model": os.environ.get("LLM_MODEL"), |
|
"api_key": os.environ.get("LLM_API_KEY"), |
|
"rpm": 15, |
|
"max_parallel_requests": 4, |
|
"allowed_fails": 1, |
|
"cooldown_time": 60, |
|
"max_retries": 10, |
|
} |
|
} |
|
], num_retries=10, retry_after=30) |
|
|
|
|
|
prompt_env = Environment(loader=FileSystemLoader( |
|
'prompts'), enable_async=True, undefined=StrictUndefined) |
|
|
|
api = FastAPI(docs_url="/") |
|
|
|
requirements_router = APIRouter(prefix="/reqs", tags=["requirements"]) |
|
|
|
solution_router = APIRouter(prefix="/solution", tags=["solution"]) |
|
|
|
|
|
@requirements_router.post("/categorize_requirements") |
|
async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse: |
|
"""Categorize the given service requirements into categories""" |
|
|
|
MAX_ATTEMPTS = 5 |
|
|
|
categories: list[_ReqGroupingCategory] = [] |
|
messages = [] |
|
|
|
|
|
req_prompt = await prompt_env.get_template("classify.txt").render_async(**{ |
|
"requirements": [rq.model_dump() for rq in params.requirements], |
|
"max_n_categories": params.max_n_categories, |
|
"response_schema": _ReqGroupingOutput.model_json_schema()}) |
|
|
|
|
|
messages.append({"role": "user", "content": req_prompt}) |
|
|
|
|
|
for attempt in range(MAX_ATTEMPTS): |
|
req_completion = await llm_router.acompletion(model="chat", messages=messages, response_format=_ReqGroupingOutput) |
|
output = _ReqGroupingOutput.model_validate_json( |
|
req_completion.choices[0].message.content) |
|
|
|
|
|
valid_ids_universe = set(range(0, len(params.requirements))) |
|
assigned_ids = { |
|
req_id for cat in output.categories for req_id in cat.items} |
|
|
|
|
|
valid_assigned_ids = assigned_ids.intersection(valid_ids_universe) |
|
|
|
|
|
unassigned_ids = valid_ids_universe - valid_assigned_ids |
|
|
|
if len(unassigned_ids) == 0: |
|
categories.extend(output.categories) |
|
break |
|
else: |
|
messages.append(req_completion.choices[0].message) |
|
messages.append( |
|
{"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."}) |
|
|
|
if attempt == MAX_ATTEMPTS - 1: |
|
raise Exception("Failed to classify all requirements") |
|
|
|
|
|
|
|
final_categories = [] |
|
for idx, cat in enumerate(output.categories): |
|
final_categories.append(ReqGroupingCategory( |
|
id=idx, |
|
title=cat.title, |
|
requirements=[params.requirements[i] |
|
for i in cat.items if i < len(params.requirements)] |
|
)) |
|
|
|
return ReqGroupingResponse(categories=final_categories) |
|
|
|
|
|
|
|
@solution_router.post("/search_solutions_gemini", response_model=SolutionSearchResponse) |
|
async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchResponse: |
|
"""Searches solutions using Gemini and grounded on google search""" |
|
|
|
async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel: |
|
|
|
req_prompt = await prompt_env.get_template("search_solution.txt").render_async(**{ |
|
"category": cat.model_dump(), |
|
}) |
|
|
|
|
|
|
|
|
|
req_completion = await llm_router.acompletion(model="chat", messages=[ |
|
{"role": "user", "content": req_prompt} |
|
], tools=[{"googleSearch": {}}], tool_choice="required") |
|
|
|
|
|
|
|
structured_prompt = await prompt_env.get_template("structure_solution.txt").render_async(**{ |
|
"solution": req_completion.choices[0].message.content, |
|
"response_schema": _SearchedSolutionModel.model_json_schema() |
|
}) |
|
|
|
structured_completion = await llm_router.acompletion(model="chat", messages=[ |
|
{"role": "user", "content": structured_prompt} |
|
], response_format=_SearchedSolutionModel) |
|
solution_model = _SearchedSolutionModel.model_validate_json( |
|
structured_completion.choices[0].message.content) |
|
|
|
|
|
|
|
|
|
sources_metadata = [ |
|
f'{a["web"]["title"]} - {a["web"]["uri"]}' for a in req_completion["vertex_ai_grounding_metadata"][0]['groundingChunks']] |
|
|
|
final_sol = SolutionModel( |
|
Context="", |
|
Requirements=[ |
|
cat.requirements[i].requirement for i in solution_model.requirement_ids |
|
], |
|
Problem_Description=solution_model.problem_description, |
|
Solution_Description=solution_model.solution_description, |
|
References=sources_metadata, |
|
Category_Id=cat.id, |
|
) |
|
return final_sol |
|
|
|
solutions = await asyncio.gather(*[_search_inner(cat) for cat in params.categories], return_exceptions=True) |
|
logging.info(solutions) |
|
final_solutions = [ |
|
sol for sol in solutions if not isinstance(sol, Exception)] |
|
|
|
return SolutionSearchResponse(solutions=final_solutions) |
|
|
|
@solution_router.post("/criticize_solution", response_model=CritiqueResponse) |
|
async def criticize_solution(params: CriticizeSolutionsRequest) -> CritiqueResponse: |
|
"""Criticize the challenges, weaknesses and limitations of the provided solutions.""" |
|
|
|
async def __criticize_single(solution: SolutionModel): |
|
req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{ |
|
"solutions": [solution.model_dump()], |
|
"response_schema": _SolutionCriticismOutput.model_json_schema() |
|
}) |
|
|
|
req_completion = await llm_router.acompletion( |
|
model="chat", |
|
messages=[{"role": "user", "content": req_prompt}], |
|
response_format=_SolutionCriticismOutput |
|
) |
|
|
|
criticism_out = _SolutionCriticismOutput.model_validate_json( |
|
req_completion.choices[0].message.content |
|
) |
|
|
|
return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0]) |
|
|
|
critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False) |
|
return CritiqueResponse(critiques=critiques) |
|
|
|
|
|
@solution_router.post("/refine_solutions", response_model=SolutionSearchResponse) |
|
async def refine_solutions(params: CritiqueResponse) -> SolutionSearchResponse: |
|
"""Refines the previously critiqued solutions.""" |
|
|
|
async def __refine_solution(crit: SolutionCriticism): |
|
req_prompt = await prompt_env.get_template("refine_solution.txt").render_async(**{ |
|
"solution": crit.solution.model_dump(), |
|
"criticism": crit.criticism, |
|
"response_schema": _RefinedSolutionModel.model_json_schema(), |
|
}) |
|
|
|
req_completion = await llm_router.acompletion(model="chat", messages=[ |
|
{"role": "user", "content": req_prompt} |
|
], response_format=_RefinedSolutionModel) |
|
|
|
req_model = _RefinedSolutionModel.model_validate_json( |
|
req_completion.choices[0].message.content) |
|
|
|
|
|
refined_solution = crit.solution.model_copy(deep=True) |
|
refined_solution.Problem_Description = req_model.problem_description |
|
refined_solution.Solution_Description = req_model.solution_description |
|
|
|
return refined_solution |
|
|
|
refined_solutions = await asyncio.gather(*[__refine_solution(crit) for crit in params.critiques], return_exceptions=False) |
|
|
|
return SolutionSearchResponse(solutions=refined_solutions) |
|
|
|
|
|
api.include_router(requirements_router) |
|
api.include_router(solution_router) |
|
|
|
uvicorn.run(api, host="0.0.0.0", port=8000) |
|
|