Lucas ARRIESSE
commited on
Commit
·
72683de
1
Parent(s):
f6a7399
Fix off-by-one issue with requirement IDs + add /search_solutions_gemini endpoint
Browse files- app.py +86 -19
- prompts/classify.txt +1 -1
- prompts/criticize.txt +2 -2
- prompts/search_solution.txt +21 -0
- prompts/structure_solution.txt +12 -0
- schemas.py +38 -7
app.py
CHANGED
@@ -4,8 +4,8 @@ import os
|
|
4 |
import sys
|
5 |
import uvicorn
|
6 |
from fastapi import FastAPI
|
7 |
-
from schemas import _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput, SolutionCriticism
|
8 |
-
from jinja2 import Environment, FileSystemLoader
|
9 |
from litellm.router import Router
|
10 |
from dotenv import load_dotenv
|
11 |
|
@@ -33,12 +33,13 @@ llm_router = Router(model_list=[
|
|
33 |
"max_retries": 5
|
34 |
}
|
35 |
}
|
36 |
-
])
|
37 |
|
38 |
# Jinja2 environment to load prompt templates
|
39 |
-
prompt_env = Environment(loader=FileSystemLoader(
|
|
|
40 |
|
41 |
-
api = FastAPI()
|
42 |
|
43 |
|
44 |
@api.post("/categorize_requirements")
|
@@ -65,8 +66,8 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
|
|
65 |
output = _ReqGroupingOutput.model_validate_json(
|
66 |
req_completion.choices[0].message.content)
|
67 |
|
68 |
-
#
|
69 |
-
valid_ids_universe = set(range(
|
70 |
assigned_ids = {
|
71 |
req_id for cat in output.categories for req_id in cat.items}
|
72 |
|
@@ -101,19 +102,85 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
|
|
101 |
return ReqGroupingResponse(categories=final_categories)
|
102 |
|
103 |
|
104 |
-
@api.post("/criticize_solution")
|
105 |
async def criticize_solution(params: CriticizeSolutionsRequest) -> CritiqueResponse:
|
106 |
"""Criticize the challenges, weaknesses and limitations of the provided solutions."""
|
107 |
-
|
108 |
-
|
109 |
-
"
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
uvicorn.run(api, host="0.0.0.0", port=8000)
|
|
|
4 |
import sys
|
5 |
import uvicorn
|
6 |
from fastapi import FastAPI
|
7 |
+
from schemas import _SearchedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput, SolutionCriticism, SolutionModel, SolutionSearchResponse
|
8 |
+
from jinja2 import Environment, FileSystemLoader, StrictUndefined
|
9 |
from litellm.router import Router
|
10 |
from dotenv import load_dotenv
|
11 |
|
|
|
33 |
"max_retries": 5
|
34 |
}
|
35 |
}
|
36 |
+
], cooldown_time=30)
|
37 |
|
38 |
# Jinja2 environment to load prompt templates
|
39 |
+
prompt_env = Environment(loader=FileSystemLoader(
|
40 |
+
'prompts'), enable_async=True, undefined=StrictUndefined)
|
41 |
|
42 |
+
api = FastAPI(docs_url="/")
|
43 |
|
44 |
|
45 |
@api.post("/categorize_requirements")
|
|
|
66 |
output = _ReqGroupingOutput.model_validate_json(
|
67 |
req_completion.choices[0].message.content)
|
68 |
|
69 |
+
# quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
|
70 |
+
valid_ids_universe = set(range(0, len(params.requirements)))
|
71 |
assigned_ids = {
|
72 |
req_id for cat in output.categories for req_id in cat.items}
|
73 |
|
|
|
102 |
return ReqGroupingResponse(categories=final_categories)
|
103 |
|
104 |
|
105 |
+
@api.post("/criticize_solution", response_model=CritiqueResponse)
|
106 |
async def criticize_solution(params: CriticizeSolutionsRequest) -> CritiqueResponse:
|
107 |
"""Criticize the challenges, weaknesses and limitations of the provided solutions."""
|
108 |
+
|
109 |
+
async def __criticize_single(solution: SolutionModel):
|
110 |
+
req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{
|
111 |
+
"solutions": [solution.model_dump()],
|
112 |
+
"response_schema": _SolutionCriticismOutput.model_json_schema()
|
113 |
+
})
|
114 |
+
|
115 |
+
req_completion = await llm_router.acompletion(
|
116 |
+
model="chat",
|
117 |
+
messages=[{"role": "user", "content": req_prompt}],
|
118 |
+
response_format=_SolutionCriticismOutput
|
119 |
+
)
|
120 |
+
|
121 |
+
criticism_out = _SolutionCriticismOutput.model_validate_json(
|
122 |
+
req_completion.choices[0].message.content
|
123 |
+
)
|
124 |
+
|
125 |
+
return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0])
|
126 |
+
|
127 |
+
critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False)
|
128 |
+
return CritiqueResponse(critiques=critiques)
|
129 |
+
|
130 |
+
|
131 |
+
@api.post("/search_solutions_gemini", response_model=SolutionSearchResponse)
|
132 |
+
async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchResponse:
|
133 |
+
"""Searches solutions using Gemini and grounded on google search"""
|
134 |
+
|
135 |
+
async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel:
|
136 |
+
# ================== generate the solution with web grounding
|
137 |
+
req_prompt = await prompt_env.get_template("search_solution.txt").render_async(**{
|
138 |
+
"category": cat.model_dump(),
|
139 |
+
})
|
140 |
+
|
141 |
+
# generate the completion in non-structured mode.
|
142 |
+
# the googleSearch tool enables grounding gemini with google search
|
143 |
+
req_completion = await llm_router.acompletion(model="chat", messages=[
|
144 |
+
{"role": "user", "content": req_prompt}
|
145 |
+
], tools=[{"googleSearch": {}}])
|
146 |
+
|
147 |
+
# ==================== structure the solution as a json ===================================
|
148 |
+
|
149 |
+
structured_prompt = await prompt_env.get_template("structure_solution.txt").render_async(**{
|
150 |
+
"solution": req_completion.choices[0].message.content,
|
151 |
+
"response_schema": _SearchedSolutionModel.model_json_schema()
|
152 |
+
})
|
153 |
+
|
154 |
+
structured_completion = await llm_router.acompletion(model="chat", messages=[
|
155 |
+
{"role": "user", "content": structured_prompt}
|
156 |
+
], response_format=_SearchedSolutionModel)
|
157 |
+
solution_model = _SearchedSolutionModel.model_validate_json(
|
158 |
+
structured_completion.choices[0].message.content)
|
159 |
+
|
160 |
+
# ======================== build the final solution object ================================
|
161 |
+
|
162 |
+
# extract the source metadata from the search items
|
163 |
+
sources_metadata = [
|
164 |
+
f'{a["web"]["title"]} - {a["web"]["uri"]}' for a in req_completion["vertex_ai_grounding_metadata"][0]['groundingChunks']]
|
165 |
+
|
166 |
+
final_sol = SolutionModel(
|
167 |
+
Context="",
|
168 |
+
Requirements=[
|
169 |
+
cat.requirements[i].requirement for i in solution_model.requirement_ids
|
170 |
+
],
|
171 |
+
Problem_Description=solution_model.problem_description,
|
172 |
+
Solution_Description=solution_model.solution_description,
|
173 |
+
References=sources_metadata,
|
174 |
+
Category_Id=cat.id,
|
175 |
+
)
|
176 |
+
return final_sol
|
177 |
+
|
178 |
+
solutions = await asyncio.gather(*[_search_inner(cat) for cat in params.categories], return_exceptions=True)
|
179 |
+
logging.info(solutions)
|
180 |
+
final_solutions = [
|
181 |
+
sol for sol in solutions if not isinstance(sol, Exception)]
|
182 |
+
|
183 |
+
return SolutionSearchResponse(solutions=final_solutions)
|
184 |
+
|
185 |
|
186 |
uvicorn.run(api, host="0.0.0.0", port=8000)
|
prompts/classify.txt
CHANGED
@@ -11,7 +11,7 @@ For each category indicate which requirements belong in that category using thei
|
|
11 |
Here are the requirements:
|
12 |
<requirements>
|
13 |
{% for req in requirements -%}
|
14 |
-
- {{ loop.
|
15 |
{% endfor -%}
|
16 |
</requirements>
|
17 |
|
|
|
11 |
Here are the requirements:
|
12 |
<requirements>
|
13 |
{% for req in requirements -%}
|
14 |
+
- {{ loop.index0 }}. {{ req["requirement"] }}
|
15 |
{% endfor -%}
|
16 |
</requirements>
|
17 |
|
prompts/criticize.txt
CHANGED
@@ -9,8 +9,8 @@ Here are the solutions:
|
|
9 |
{% for solution in solutions %}
|
10 |
## Solution
|
11 |
- Context: {{solution["Context"]}}
|
12 |
-
- Problem description: {{solution["
|
13 |
-
- Solution description: {{solution["
|
14 |
---
|
15 |
{% endfor -%}
|
16 |
</solutions>
|
|
|
9 |
{% for solution in solutions %}
|
10 |
## Solution
|
11 |
- Context: {{solution["Context"]}}
|
12 |
+
- Problem description: {{solution["Problem_Description"]}}
|
13 |
+
- Solution description: {{solution["Solution_Description"]}}
|
14 |
---
|
15 |
{% endfor -%}
|
16 |
</solutions>
|
prompts/search_solution.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<role>You are an expert system designer</role>
|
2 |
+
<task>
|
3 |
+
Your task is to create a solution which is a combination of mechanisms that addresses as many of the provided requirements of a category as possible and that by searching the web, while carefully considering the given context.
|
4 |
+
Please actually make searches and do not simulate them.
|
5 |
+
</task>
|
6 |
+
|
7 |
+
Here is the category item and the associated requirements:
|
8 |
+
<requirements>
|
9 |
+
Category Title: {{category["title"]}}
|
10 |
+
Context: {{category["requirements"][0]["context"]}}
|
11 |
+
Requirements:
|
12 |
+
{% for req in category["requirements"] -%}
|
13 |
+
- {{loop.index0}} {{req["requirement"]}}
|
14 |
+
{% endfor -%}
|
15 |
+
</requirements>
|
16 |
+
|
17 |
+
<additional_instructions>
|
18 |
+
- The solution must aim to maximize requirement satisfaction while respecting the context.
|
19 |
+
- Provide a list of requirements addressed by the solution (provide only the requirement IDs)
|
20 |
+
- Please also detail each mechanism used in final solution
|
21 |
+
</additional_instructions>
|
prompts/structure_solution.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<role>You are an expert system designer</role>
|
2 |
+
<task>Your task is to take a solution you've created previously and structure it into a JSON object.</task>
|
3 |
+
|
4 |
+
Here is the solution
|
5 |
+
<solution>
|
6 |
+
{{solution}}
|
7 |
+
</solution>
|
8 |
+
|
9 |
+
<response_format>
|
10 |
+
Reply in JSON using the following format:
|
11 |
+
{{response_schema}}
|
12 |
+
</response_format>
|
schemas.py
CHANGED
@@ -21,13 +21,23 @@ class ReqGroupingCategory(BaseModel):
|
|
21 |
|
22 |
|
23 |
class SolutionModel(BaseModel):
|
24 |
-
Context: str
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
|
|
|
|
|
|
31 |
|
32 |
|
33 |
class ReqGroupingRequest(BaseModel):
|
@@ -54,7 +64,7 @@ class _ReqGroupingOutput(BaseModel):
|
|
54 |
..., description="List of grouping categories")
|
55 |
|
56 |
|
57 |
-
# Criticize solution endpoint
|
58 |
|
59 |
class CriticizeSolutionsRequest(BaseModel):
|
60 |
solutions: list[SolutionModel]
|
@@ -82,3 +92,24 @@ class SolutionCriticism(BaseModel):
|
|
82 |
|
83 |
class CritiqueResponse(BaseModel):
|
84 |
critiques: List[SolutionCriticism]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
class SolutionModel(BaseModel):
|
24 |
+
Context: str = Field(...,
|
25 |
+
description="Full context provided for this category.")
|
26 |
+
Requirements: List[str] = Field(...,
|
27 |
+
description="List of each requirement as string.")
|
28 |
+
Problem_Description: str = Field(..., alias="Problem Description",
|
29 |
+
description="Description of the problem being solved.")
|
30 |
+
Solution_Description: str = Field(..., alias="Solution Description",
|
31 |
+
description="Detailed description of the solution.")
|
32 |
+
References: list[str] = Field(
|
33 |
+
..., description="References to documents used for the solution.")
|
34 |
+
Category_Id: int = Field(
|
35 |
+
..., description="ID of the requirements category the solution is based on")
|
36 |
|
37 |
+
class Config:
|
38 |
+
validate_by_name = True # Enables alias handling on input/output
|
39 |
+
|
40 |
+
# ============================================================= Categorize requirements endpoint
|
41 |
|
42 |
|
43 |
class ReqGroupingRequest(BaseModel):
|
|
|
64 |
..., description="List of grouping categories")
|
65 |
|
66 |
|
67 |
+
# =========================================================== Criticize solution endpoint
|
68 |
|
69 |
class CriticizeSolutionsRequest(BaseModel):
|
70 |
solutions: list[SolutionModel]
|
|
|
92 |
|
93 |
class CritiqueResponse(BaseModel):
|
94 |
critiques: List[SolutionCriticism]
|
95 |
+
|
96 |
+
|
97 |
+
# =================================================================== search solution response endpoint
|
98 |
+
|
99 |
+
class _SolutionSearchOutput(BaseModel):
|
100 |
+
solution: SolutionModel
|
101 |
+
|
102 |
+
|
103 |
+
class _SearchedSolutionModel(BaseModel):
|
104 |
+
""""Internal model used for solutions searched using gemini"""
|
105 |
+
requirement_ids: List[int] = Field(...,
|
106 |
+
description="List of each requirement ID addressed by the solution")
|
107 |
+
problem_description: str = Field(...,
|
108 |
+
description="Description of the problem being solved.")
|
109 |
+
solution_description: str = Field(...,
|
110 |
+
description="Detailed description of the solution.")
|
111 |
+
|
112 |
+
|
113 |
+
class SolutionSearchResponse(BaseModel):
|
114 |
+
"""Response model for solution search"""
|
115 |
+
solutions: list[SolutionModel]
|