reqroup / app.py
Lucas ARRIESSE
Reorganize endpoints + add refine_solution endpoint
03acc5b
import asyncio
import logging
import os
import sys
import uvicorn
from fastapi import APIRouter, FastAPI
from schemas import _RefinedSolutionModel, _SearchedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput, SolutionCriticism, SolutionModel, SolutionSearchResponse
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from litellm.router import Router
from dotenv import load_dotenv
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Load .env files
load_dotenv()
if "LLM_MODEL" not in os.environ or "LLM_API_KEY" not in os.environ:
logging.error(
"No LLM token (`LLM_API_KEY`) and/or LLM model (`LLM_MODEL`) were provided in the env vars. Exiting")
sys.exit(-1)
# LiteLLM router
llm_router = Router(model_list=[
{
"model_name": "chat",
"litellm_params": {
"model": os.environ.get("LLM_MODEL"),
"api_key": os.environ.get("LLM_API_KEY"),
"rpm": 15,
"max_parallel_requests": 4,
"allowed_fails": 1,
"cooldown_time": 60,
"max_retries": 10,
}
}
], num_retries=10, retry_after=30)
# Jinja2 environment to load prompt templates
prompt_env = Environment(loader=FileSystemLoader(
'prompts'), enable_async=True, undefined=StrictUndefined)
api = FastAPI(docs_url="/")
# requirements routes
requirements_router = APIRouter(prefix="/reqs", tags=["requirements"])
# solution routes
solution_router = APIRouter(prefix="/solution", tags=["solution"])
@requirements_router.post("/categorize_requirements")
async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
"""Categorize the given service requirements into categories"""
MAX_ATTEMPTS = 5
categories: list[_ReqGroupingCategory] = []
messages = []
# categorize the requirements using their indices
req_prompt = await prompt_env.get_template("classify.txt").render_async(**{
"requirements": [rq.model_dump() for rq in params.requirements],
"max_n_categories": params.max_n_categories,
"response_schema": _ReqGroupingOutput.model_json_schema()})
# add system prompt with requirements
messages.append({"role": "user", "content": req_prompt})
# ensure all requirements items are processed
for attempt in range(MAX_ATTEMPTS):
req_completion = await llm_router.acompletion(model="chat", messages=messages, response_format=_ReqGroupingOutput)
output = _ReqGroupingOutput.model_validate_json(
req_completion.choices[0].message.content)
# quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
valid_ids_universe = set(range(0, len(params.requirements)))
assigned_ids = {
req_id for cat in output.categories for req_id in cat.items}
# keep only non-hallucinated, valid assigned ids
valid_assigned_ids = assigned_ids.intersection(valid_ids_universe)
# check for remaining requirements assigned to none of the categories
unassigned_ids = valid_ids_universe - valid_assigned_ids
if len(unassigned_ids) == 0:
categories.extend(output.categories)
break
else:
messages.append(req_completion.choices[0].message)
messages.append(
{"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."})
if attempt == MAX_ATTEMPTS - 1:
raise Exception("Failed to classify all requirements")
# build the final category objects
# remove the invalid (likely hallucinated) requirement IDs
final_categories = []
for idx, cat in enumerate(output.categories):
final_categories.append(ReqGroupingCategory(
id=idx,
title=cat.title,
requirements=[params.requirements[i]
for i in cat.items if i < len(params.requirements)]
))
return ReqGroupingResponse(categories=final_categories)
# ========================================================= Solution Endpoints ===========================================================
@solution_router.post("/search_solutions_gemini", response_model=SolutionSearchResponse)
async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchResponse:
"""Searches solutions using Gemini and grounded on google search"""
async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel:
# ================== generate the solution with web grounding
req_prompt = await prompt_env.get_template("search_solution.txt").render_async(**{
"category": cat.model_dump(),
})
# generate the completion in non-structured mode.
# the googleSearch tool enables grounding gemini with google search
# this also forces gemini to perform a tool call
req_completion = await llm_router.acompletion(model="chat", messages=[
{"role": "user", "content": req_prompt}
], tools=[{"googleSearch": {}}], tool_choice="required")
# ==================== structure the solution as a json ===================================
structured_prompt = await prompt_env.get_template("structure_solution.txt").render_async(**{
"solution": req_completion.choices[0].message.content,
"response_schema": _SearchedSolutionModel.model_json_schema()
})
structured_completion = await llm_router.acompletion(model="chat", messages=[
{"role": "user", "content": structured_prompt}
], response_format=_SearchedSolutionModel)
solution_model = _SearchedSolutionModel.model_validate_json(
structured_completion.choices[0].message.content)
# ======================== build the final solution object ================================
# extract the source metadata from the search items
sources_metadata = [
f'{a["web"]["title"]} - {a["web"]["uri"]}' for a in req_completion["vertex_ai_grounding_metadata"][0]['groundingChunks']]
final_sol = SolutionModel(
Context="",
Requirements=[
cat.requirements[i].requirement for i in solution_model.requirement_ids
],
Problem_Description=solution_model.problem_description,
Solution_Description=solution_model.solution_description,
References=sources_metadata,
Category_Id=cat.id,
)
return final_sol
solutions = await asyncio.gather(*[_search_inner(cat) for cat in params.categories], return_exceptions=True)
logging.info(solutions)
final_solutions = [
sol for sol in solutions if not isinstance(sol, Exception)]
return SolutionSearchResponse(solutions=final_solutions)
@solution_router.post("/criticize_solution", response_model=CritiqueResponse)
async def criticize_solution(params: CriticizeSolutionsRequest) -> CritiqueResponse:
"""Criticize the challenges, weaknesses and limitations of the provided solutions."""
async def __criticize_single(solution: SolutionModel):
req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{
"solutions": [solution.model_dump()],
"response_schema": _SolutionCriticismOutput.model_json_schema()
})
req_completion = await llm_router.acompletion(
model="chat",
messages=[{"role": "user", "content": req_prompt}],
response_format=_SolutionCriticismOutput
)
criticism_out = _SolutionCriticismOutput.model_validate_json(
req_completion.choices[0].message.content
)
return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0])
critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False)
return CritiqueResponse(critiques=critiques)
@solution_router.post("/refine_solutions", response_model=SolutionSearchResponse)
async def refine_solutions(params: CritiqueResponse) -> SolutionSearchResponse:
"""Refines the previously critiqued solutions."""
async def __refine_solution(crit: SolutionCriticism):
req_prompt = await prompt_env.get_template("refine_solution.txt").render_async(**{
"solution": crit.solution.model_dump(),
"criticism": crit.criticism,
"response_schema": _RefinedSolutionModel.model_json_schema(),
})
req_completion = await llm_router.acompletion(model="chat", messages=[
{"role": "user", "content": req_prompt}
], response_format=_RefinedSolutionModel)
req_model = _RefinedSolutionModel.model_validate_json(
req_completion.choices[0].message.content)
# copy previous solution model
refined_solution = crit.solution.model_copy(deep=True)
refined_solution.Problem_Description = req_model.problem_description
refined_solution.Solution_Description = req_model.solution_description
return refined_solution
refined_solutions = await asyncio.gather(*[__refine_solution(crit) for crit in params.critiques], return_exceptions=False)
return SolutionSearchResponse(solutions=refined_solutions)
api.include_router(requirements_router)
api.include_router(solution_router)
uvicorn.run(api, host="0.0.0.0", port=8000)