Reqxtract-API / app.py
Lucas ARRIESSE
Add endpoints to use Insight Finder API
36dc4ec
raw
history blame
16.7 kB
import asyncio
import logging
import os
import sys
import uvicorn
from fastapi import APIRouter, FastAPI
from schemas import _RefinedSolutionModel, _SearchedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, InsightFinderConstraintsList, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput, SolutionCriticism, SolutionModel, SolutionSearchResponse, SolutionSearchV2Request, TechnologyData
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from litellm.router import Router
from dotenv import load_dotenv
from util import retry_until
from httpx import AsyncClient
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Load .env files
load_dotenv()
if "LLM_MODEL" not in os.environ or "LLM_API_KEY" not in os.environ:
logging.error(
"No LLM token (`LLM_API_KEY`) and/or LLM model (`LLM_MODEL`) were provided in the env vars. Exiting")
sys.exit(-1)
# LiteLLM router
llm_router = Router(model_list=[
{
"model_name": "chat",
"litellm_params": {
"model": os.environ.get("LLM_MODEL"),
"api_key": os.environ.get("LLM_API_KEY"),
"rpm": 15,
"max_parallel_requests": 4,
"allowed_fails": 1,
"cooldown_time": 60,
"max_retries": 10,
}
}
], num_retries=10, retry_after=30)
# HTTP client
INSIGHT_FINDER_BASE_URL = "https://organizedprogrammers-insight-finder.hf.space/"
http_client = AsyncClient(verify=os.environ.get(
"NO_SSL", "0") == "1", timeout=None)
# Jinja2 environment to load prompt templates
prompt_env = Environment(loader=FileSystemLoader(
'prompts'), enable_async=True, undefined=StrictUndefined)
api = FastAPI(docs_url="/", title="Reqxtract-API",
description=open("docs/docs.md").read())
# requirements routes
requirements_router = APIRouter(prefix="/reqs", tags=["requirements"])
# solution routes
solution_router = APIRouter(prefix="/solution", tags=["solution"])
async def format_prompt(prompt_name: str, **args) -> str:
"""Helper to format a prompt"""
return await prompt_env.get_template(prompt_name).render_async(args)
@requirements_router.post("/categorize_requirements")
async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
"""Categorize the given service requirements into categories"""
MAX_ATTEMPTS = 5
categories: list[_ReqGroupingCategory] = []
messages = []
# categorize the requirements using their indices
req_prompt = await prompt_env.get_template("classify.txt").render_async(**{
"requirements": [rq.model_dump() for rq in params.requirements],
"max_n_categories": params.max_n_categories,
"response_schema": _ReqGroupingOutput.model_json_schema()})
# add system prompt with requirements
messages.append({"role": "user", "content": req_prompt})
# ensure all requirements items are processed
for attempt in range(MAX_ATTEMPTS):
req_completion = await llm_router.acompletion(model="chat", messages=messages, response_format=_ReqGroupingOutput)
output = _ReqGroupingOutput.model_validate_json(
req_completion.choices[0].message.content)
# quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
valid_ids_universe = set(range(0, len(params.requirements)))
assigned_ids = {
req_id for cat in output.categories for req_id in cat.items}
# keep only non-hallucinated, valid assigned ids
valid_assigned_ids = assigned_ids.intersection(valid_ids_universe)
# check for remaining requirements assigned to none of the categories
unassigned_ids = valid_ids_universe - valid_assigned_ids
if len(unassigned_ids) == 0:
categories.extend(output.categories)
break
else:
messages.append(req_completion.choices[0].message)
messages.append(
{"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."})
if attempt == MAX_ATTEMPTS - 1:
raise Exception("Failed to classify all requirements")
# build the final category objects
# remove the invalid (likely hallucinated) requirement IDs
final_categories = []
for idx, cat in enumerate(output.categories):
final_categories.append(ReqGroupingCategory(
id=idx,
title=cat.title,
requirements=[params.requirements[i]
for i in cat.items if i < len(params.requirements)]
))
return ReqGroupingResponse(categories=final_categories)
# ========================================================= Solution Endpoints ===========================================================
@solution_router.post("/search_solutions_gemini", response_model=SolutionSearchResponse)
async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchResponse:
"""Searches solutions solving the given grouping params using Gemini and grounded on google search"""
logging.info(f"Searching solutions for categories: {params.categories}")
async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel:
# ================== generate the solution with web grounding
req_prompt = await prompt_env.get_template("search_solution.txt").render_async(**{
"category": cat.model_dump(),
})
# generate the completion in non-structured mode.
# the googleSearch tool enables grounding gemini with google search
# this also forces gemini to perform a tool call
req_completion = await llm_router.acompletion(model="chat", messages=[
{"role": "user", "content": req_prompt}
], tools=[{"googleSearch": {}}], tool_choice="required")
# ==================== structure the solution as a json ===================================
structured_prompt = await prompt_env.get_template("structure_solution.txt").render_async(**{
"solution": req_completion.choices[0].message.content,
"response_schema": _SearchedSolutionModel.model_json_schema()
})
structured_completion = await llm_router.acompletion(model="chat", messages=[
{"role": "user", "content": structured_prompt}
], response_format=_SearchedSolutionModel)
solution_model = _SearchedSolutionModel.model_validate_json(
structured_completion.choices[0].message.content)
# ======================== build the final solution object ================================
sources_metadata = []
# extract the source metadata from the search items, if gemini actually called the tools to search .... and didn't hallucinated
try:
sources_metadata.extend([{"name": a["web"]["title"], "url": a["web"]["uri"]}
for a in req_completion["vertex_ai_grounding_metadata"][0]['groundingChunks']])
except KeyError as ke:
pass
final_sol = SolutionModel(
Context="",
Requirements=[
cat.requirements[i].requirement for i in solution_model.requirement_ids
],
Problem_Description=solution_model.problem_description,
Solution_Description=solution_model.solution_description,
References=sources_metadata,
Category_Id=cat.id,
)
return final_sol
solutions = await asyncio.gather(*[retry_until(_search_inner, cat, lambda v: len(v.References) > 0, 2) for cat in params.categories], return_exceptions=True)
logging.info(solutions)
final_solutions = [
sol for sol in solutions if not isinstance(sol, Exception)]
return SolutionSearchResponse(solutions=final_solutions)
@solution_router.post("/search_solutions_gemini/v2", response_model=SolutionSearchResponse)
async def search_solutions(params: SolutionSearchV2Request) -> SolutionSearchResponse:
"""Searches solutions solving the given grouping params and respecting the user constraints using Gemini and grounded on google search"""
logging.info(f"Searching solutions for categories: {params}")
async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel:
# ================== generate the solution with web grounding
req_prompt = await prompt_env.get_template("search_solution_v2.txt").render_async(**{
"category": cat.model_dump(),
"user_constraints": params.user_constraints,
})
# generate the completion in non-structured mode.
# the googleSearch tool enables grounding gemini with google search
# this also forces gemini to perform a tool call
req_completion = await llm_router.acompletion(model="chat", messages=[
{"role": "user", "content": req_prompt}
], tools=[{"googleSearch": {}}], tool_choice="required")
# ==================== structure the solution as a json ===================================
structured_prompt = await prompt_env.get_template("structure_solution.txt").render_async(**{
"solution": req_completion.choices[0].message.content,
"response_schema": _SearchedSolutionModel.model_json_schema()
})
structured_completion = await llm_router.acompletion(model="chat", messages=[
{"role": "user", "content": structured_prompt}
], response_format=_SearchedSolutionModel)
solution_model = _SearchedSolutionModel.model_validate_json(
structured_completion.choices[0].message.content)
# ======================== build the final solution object ================================
sources_metadata = []
# extract the source metadata from the search items, if gemini actually called the tools to search .... and didn't hallucinated
try:
sources_metadata.extend([{"name": a["web"]["title"], "url": a["web"]["uri"]}
for a in req_completion["vertex_ai_grounding_metadata"][0]['groundingChunks']])
except KeyError as ke:
pass
final_sol = SolutionModel(
Context="",
Requirements=[
cat.requirements[i].requirement for i in solution_model.requirement_ids
],
Problem_Description=solution_model.problem_description,
Solution_Description=solution_model.solution_description,
References=sources_metadata,
Category_Id=cat.id,
)
return final_sol
solutions = await asyncio.gather(*[retry_until(_search_inner, cat, lambda v: len(v.References) > 0, 2) for cat in params.categories], return_exceptions=True)
logging.info(solutions)
final_solutions = [
sol for sol in solutions if not isinstance(sol, Exception)]
return SolutionSearchResponse(solutions=final_solutions)
# =================================================================================================================
@solution_router.post("/criticize_solution", response_model=CritiqueResponse)
async def criticize_solution(params: CriticizeSolutionsRequest) -> CritiqueResponse:
"""Criticize the challenges, weaknesses and limitations of the provided solutions."""
async def __criticize_single(solution: SolutionModel):
req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{
"solutions": [solution.model_dump()],
"response_schema": _SolutionCriticismOutput.model_json_schema()
})
req_completion = await llm_router.acompletion(
model="chat",
messages=[{"role": "user", "content": req_prompt}],
response_format=_SolutionCriticismOutput
)
criticism_out = _SolutionCriticismOutput.model_validate_json(
req_completion.choices[0].message.content
)
return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0])
critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False)
return CritiqueResponse(critiques=critiques)
@solution_router.post("/refine_solutions", response_model=SolutionSearchResponse)
async def refine_solutions(params: CritiqueResponse) -> SolutionSearchResponse:
"""Refines the previously critiqued solutions."""
async def __refine_solution(crit: SolutionCriticism):
req_prompt = await prompt_env.get_template("refine_solution.txt").render_async(**{
"solution": crit.solution.model_dump(),
"criticism": crit.criticism,
"response_schema": _RefinedSolutionModel.model_json_schema(),
})
req_completion = await llm_router.acompletion(model="chat", messages=[
{"role": "user", "content": req_prompt}
], response_format=_RefinedSolutionModel)
req_model = _RefinedSolutionModel.model_validate_json(
req_completion.choices[0].message.content)
# copy previous solution model
refined_solution = crit.solution.model_copy(deep=True)
refined_solution.Problem_Description = req_model.problem_description
refined_solution.Solution_Description = req_model.solution_description
return refined_solution
refined_solutions = await asyncio.gather(*[__refine_solution(crit) for crit in params.critiques], return_exceptions=False)
return SolutionSearchResponse(solutions=refined_solutions)
# ======================================== Solution generation using Insights Finder ==================
@solution_router.post("/search_solutions_if")
async def search_solutions_if(req: SolutionSearchV2Request) -> SolutionSearchResponse:
async def _search_solution_inner(cat: ReqGroupingCategory):
# process requirements into insight finder format
fmt_completion = await llm_router.acompletion("chat", messages=[
{
"role": "user",
"content": await format_prompt("if/format_requirements.txt", **{
"category": cat.model_dump(),
"response_schema": InsightFinderConstraintsList.model_json_schema()
})
}], response_format=InsightFinderConstraintsList)
fmt_model = InsightFinderConstraintsList.model_validate_json(
fmt_completion.choices[0].message.content)
# fetch technologies from insight finder
technologies_req = await http_client.post(INSIGHT_FINDER_BASE_URL + "process-constraints", content=fmt_model.model_dump_json())
technologies = TechnologyData.model_validate(technologies_req.json())
# =============================================================== synthesize solution using LLM =========================================
format_solution = await llm_router.acompletion("chat", messages=[{
"role": "user",
"content": await format_prompt("if/synthesize_solution.txt", **{
"category": cat.model_dump(),
"technologies": technologies.model_dump()["technologies"],
"user_constraints": None,
"response_schema": _SearchedSolutionModel.model_json_schema()
})}
], response_format=_SearchedSolutionModel)
format_solution_model = _SearchedSolutionModel.model_validate_json(
format_solution.choices[0].message.content)
final_solution = SolutionModel(
Context="",
Requirements=[
cat.requirements[i].requirement for i in format_solution_model.requirement_ids
],
Problem_Description=format_solution_model.problem_description,
Solution_Description=format_solution_model.solution_description,
References=[],
Category_Id=cat.id,
)
# ========================================================================================================================================
return final_solution
tasks = await asyncio.gather(*[_search_solution_inner(cat) for cat in req.categories])
final_solutions = [sol for sol in tasks if not isinstance(sol, Exception)]
return SolutionSearchResponse(solutions=final_solutions)
api.include_router(requirements_router)
api.include_router(solution_router)
uvicorn.run(api, host="0.0.0.0", port=8000)