Lucas ARRIESSE
Fix off-by-one issue with requirement IDs + add /search_solutions_gemini endpoint
72683de
import asyncio | |
import logging | |
import os | |
import sys | |
import uvicorn | |
from fastapi import FastAPI | |
from schemas import _SearchedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput, SolutionCriticism, SolutionModel, SolutionSearchResponse | |
from jinja2 import Environment, FileSystemLoader, StrictUndefined | |
from litellm.router import Router | |
from dotenv import load_dotenv | |
logging.basicConfig( | |
level=logging.INFO, | |
format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
# Load .env files | |
load_dotenv() | |
if "LLM_MODEL" not in os.environ or "LLM_API_KEY" not in os.environ: | |
logging.error( | |
"No LLM token (`LLM_API_KEY`) and/or LLM model (`LLM_MODEL`) were provided in the env vars. Exiting") | |
sys.exit(-1) | |
# LiteLLM router | |
llm_router = Router(model_list=[ | |
{ | |
"model_name": "chat", | |
"litellm_params": { | |
"model": os.environ.get("LLM_MODEL"), | |
"api_key": os.environ.get("LLM_API_KEY"), | |
"max_retries": 5 | |
} | |
} | |
], cooldown_time=30) | |
# Jinja2 environment to load prompt templates | |
prompt_env = Environment(loader=FileSystemLoader( | |
'prompts'), enable_async=True, undefined=StrictUndefined) | |
api = FastAPI(docs_url="/") | |
async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse: | |
"""Categorize the given service requirements into categories""" | |
MAX_ATTEMPTS = 5 | |
categories: list[_ReqGroupingCategory] = [] | |
messages = [] | |
# categorize the requirements using their indices | |
req_prompt = await prompt_env.get_template("classify.txt").render_async(**{ | |
"requirements": [rq.model_dump() for rq in params.requirements], | |
"max_n_categories": params.max_n_categories, | |
"response_schema": _ReqGroupingOutput.model_json_schema()}) | |
# add system prompt with requirements | |
messages.append({"role": "user", "content": req_prompt}) | |
# ensure all requirements items are processed | |
for attempt in range(MAX_ATTEMPTS): | |
req_completion = await llm_router.acompletion(model="chat", messages=messages, response_format=_ReqGroupingOutput) | |
output = _ReqGroupingOutput.model_validate_json( | |
req_completion.choices[0].message.content) | |
# quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category | |
valid_ids_universe = set(range(0, len(params.requirements))) | |
assigned_ids = { | |
req_id for cat in output.categories for req_id in cat.items} | |
# keep only non-hallucinated, valid assigned ids | |
valid_assigned_ids = assigned_ids.intersection(valid_ids_universe) | |
# check for remaining requirements assigned to none of the categories | |
unassigned_ids = valid_ids_universe - valid_assigned_ids | |
if len(unassigned_ids) == 0: | |
categories.extend(output.categories) | |
break | |
else: | |
messages.append(req_completion.choices[0].message) | |
messages.append( | |
{"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."}) | |
if attempt == MAX_ATTEMPTS - 1: | |
raise Exception("Failed to classify all requirements") | |
# build the final category objects | |
# remove the invalid (likely hallucinated) requirement IDs | |
final_categories = [] | |
for idx, cat in enumerate(output.categories): | |
final_categories.append(ReqGroupingCategory( | |
id=idx, | |
title=cat.title, | |
requirements=[params.requirements[i] | |
for i in cat.items if i < len(params.requirements)] | |
)) | |
return ReqGroupingResponse(categories=final_categories) | |
async def criticize_solution(params: CriticizeSolutionsRequest) -> CritiqueResponse: | |
"""Criticize the challenges, weaknesses and limitations of the provided solutions.""" | |
async def __criticize_single(solution: SolutionModel): | |
req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{ | |
"solutions": [solution.model_dump()], | |
"response_schema": _SolutionCriticismOutput.model_json_schema() | |
}) | |
req_completion = await llm_router.acompletion( | |
model="chat", | |
messages=[{"role": "user", "content": req_prompt}], | |
response_format=_SolutionCriticismOutput | |
) | |
criticism_out = _SolutionCriticismOutput.model_validate_json( | |
req_completion.choices[0].message.content | |
) | |
return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0]) | |
critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False) | |
return CritiqueResponse(critiques=critiques) | |
async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchResponse: | |
"""Searches solutions using Gemini and grounded on google search""" | |
async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel: | |
# ================== generate the solution with web grounding | |
req_prompt = await prompt_env.get_template("search_solution.txt").render_async(**{ | |
"category": cat.model_dump(), | |
}) | |
# generate the completion in non-structured mode. | |
# the googleSearch tool enables grounding gemini with google search | |
req_completion = await llm_router.acompletion(model="chat", messages=[ | |
{"role": "user", "content": req_prompt} | |
], tools=[{"googleSearch": {}}]) | |
# ==================== structure the solution as a json =================================== | |
structured_prompt = await prompt_env.get_template("structure_solution.txt").render_async(**{ | |
"solution": req_completion.choices[0].message.content, | |
"response_schema": _SearchedSolutionModel.model_json_schema() | |
}) | |
structured_completion = await llm_router.acompletion(model="chat", messages=[ | |
{"role": "user", "content": structured_prompt} | |
], response_format=_SearchedSolutionModel) | |
solution_model = _SearchedSolutionModel.model_validate_json( | |
structured_completion.choices[0].message.content) | |
# ======================== build the final solution object ================================ | |
# extract the source metadata from the search items | |
sources_metadata = [ | |
f'{a["web"]["title"]} - {a["web"]["uri"]}' for a in req_completion["vertex_ai_grounding_metadata"][0]['groundingChunks']] | |
final_sol = SolutionModel( | |
Context="", | |
Requirements=[ | |
cat.requirements[i].requirement for i in solution_model.requirement_ids | |
], | |
Problem_Description=solution_model.problem_description, | |
Solution_Description=solution_model.solution_description, | |
References=sources_metadata, | |
Category_Id=cat.id, | |
) | |
return final_sol | |
solutions = await asyncio.gather(*[_search_inner(cat) for cat in params.categories], return_exceptions=True) | |
logging.info(solutions) | |
final_solutions = [ | |
sol for sol in solutions if not isinstance(sol, Exception)] | |
return SolutionSearchResponse(solutions=final_solutions) | |
uvicorn.run(api, host="0.0.0.0", port=8000) | |