Lucas ARRIESSE commited on
Commit
72683de
·
1 Parent(s): f6a7399

Fix off-by-one issue with requirement IDs + add /search_solutions_gemini endpoint

Browse files
app.py CHANGED
@@ -4,8 +4,8 @@ import os
4
  import sys
5
  import uvicorn
6
  from fastapi import FastAPI
7
- from schemas import _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput, SolutionCriticism
8
- from jinja2 import Environment, FileSystemLoader
9
  from litellm.router import Router
10
  from dotenv import load_dotenv
11
 
@@ -33,12 +33,13 @@ llm_router = Router(model_list=[
33
  "max_retries": 5
34
  }
35
  }
36
- ])
37
 
38
  # Jinja2 environment to load prompt templates
39
- prompt_env = Environment(loader=FileSystemLoader('prompts'), enable_async=True)
 
40
 
41
- api = FastAPI()
42
 
43
 
44
  @api.post("/categorize_requirements")
@@ -65,8 +66,8 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
65
  output = _ReqGroupingOutput.model_validate_json(
66
  req_completion.choices[0].message.content)
67
 
68
- # # quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
69
- valid_ids_universe = set(range(1, len(params.requirements)))
70
  assigned_ids = {
71
  req_id for cat in output.categories for req_id in cat.items}
72
 
@@ -101,19 +102,85 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
101
  return ReqGroupingResponse(categories=final_categories)
102
 
103
 
104
- @api.post("/criticize_solution")
105
  async def criticize_solution(params: CriticizeSolutionsRequest) -> CritiqueResponse:
106
  """Criticize the challenges, weaknesses and limitations of the provided solutions."""
107
- req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{
108
- "solutions": [sol.model_dump() for sol in params.solutions],
109
- "response_schema": _SolutionCriticismOutput.model_json_schema()
110
- })
111
- req_completion = await llm_router.acompletion(model="chat", messages=[{"role": "user", "content": req_prompt}], response_format=_SolutionCriticismOutput)
112
- criticism_out = _SolutionCriticismOutput.model_validate_json(
113
- req_completion.choices[0].message.content)
114
-
115
- return CritiqueResponse(critiques=[
116
- SolutionCriticism(solution=sol, criticism=crit) for (sol, crit) in zip(params.solutions, criticism_out.criticisms)
117
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  uvicorn.run(api, host="0.0.0.0", port=8000)
 
4
  import sys
5
  import uvicorn
6
  from fastapi import FastAPI
7
+ from schemas import _SearchedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput, SolutionCriticism, SolutionModel, SolutionSearchResponse
8
+ from jinja2 import Environment, FileSystemLoader, StrictUndefined
9
  from litellm.router import Router
10
  from dotenv import load_dotenv
11
 
 
33
  "max_retries": 5
34
  }
35
  }
36
+ ], cooldown_time=30)
37
 
38
  # Jinja2 environment to load prompt templates
39
+ prompt_env = Environment(loader=FileSystemLoader(
40
+ 'prompts'), enable_async=True, undefined=StrictUndefined)
41
 
42
+ api = FastAPI(docs_url="/")
43
 
44
 
45
  @api.post("/categorize_requirements")
 
66
  output = _ReqGroupingOutput.model_validate_json(
67
  req_completion.choices[0].message.content)
68
 
69
+ # quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
70
+ valid_ids_universe = set(range(0, len(params.requirements)))
71
  assigned_ids = {
72
  req_id for cat in output.categories for req_id in cat.items}
73
 
 
102
  return ReqGroupingResponse(categories=final_categories)
103
 
104
 
105
+ @api.post("/criticize_solution", response_model=CritiqueResponse)
106
  async def criticize_solution(params: CriticizeSolutionsRequest) -> CritiqueResponse:
107
  """Criticize the challenges, weaknesses and limitations of the provided solutions."""
108
+
109
+ async def __criticize_single(solution: SolutionModel):
110
+ req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{
111
+ "solutions": [solution.model_dump()],
112
+ "response_schema": _SolutionCriticismOutput.model_json_schema()
113
+ })
114
+
115
+ req_completion = await llm_router.acompletion(
116
+ model="chat",
117
+ messages=[{"role": "user", "content": req_prompt}],
118
+ response_format=_SolutionCriticismOutput
119
+ )
120
+
121
+ criticism_out = _SolutionCriticismOutput.model_validate_json(
122
+ req_completion.choices[0].message.content
123
+ )
124
+
125
+ return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0])
126
+
127
+ critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False)
128
+ return CritiqueResponse(critiques=critiques)
129
+
130
+
131
+ @api.post("/search_solutions_gemini", response_model=SolutionSearchResponse)
132
+ async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchResponse:
133
+ """Searches solutions using Gemini and grounded on google search"""
134
+
135
+ async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel:
136
+ # ================== generate the solution with web grounding
137
+ req_prompt = await prompt_env.get_template("search_solution.txt").render_async(**{
138
+ "category": cat.model_dump(),
139
+ })
140
+
141
+ # generate the completion in non-structured mode.
142
+ # the googleSearch tool enables grounding gemini with google search
143
+ req_completion = await llm_router.acompletion(model="chat", messages=[
144
+ {"role": "user", "content": req_prompt}
145
+ ], tools=[{"googleSearch": {}}])
146
+
147
+ # ==================== structure the solution as a json ===================================
148
+
149
+ structured_prompt = await prompt_env.get_template("structure_solution.txt").render_async(**{
150
+ "solution": req_completion.choices[0].message.content,
151
+ "response_schema": _SearchedSolutionModel.model_json_schema()
152
+ })
153
+
154
+ structured_completion = await llm_router.acompletion(model="chat", messages=[
155
+ {"role": "user", "content": structured_prompt}
156
+ ], response_format=_SearchedSolutionModel)
157
+ solution_model = _SearchedSolutionModel.model_validate_json(
158
+ structured_completion.choices[0].message.content)
159
+
160
+ # ======================== build the final solution object ================================
161
+
162
+ # extract the source metadata from the search items
163
+ sources_metadata = [
164
+ f'{a["web"]["title"]} - {a["web"]["uri"]}' for a in req_completion["vertex_ai_grounding_metadata"][0]['groundingChunks']]
165
+
166
+ final_sol = SolutionModel(
167
+ Context="",
168
+ Requirements=[
169
+ cat.requirements[i].requirement for i in solution_model.requirement_ids
170
+ ],
171
+ Problem_Description=solution_model.problem_description,
172
+ Solution_Description=solution_model.solution_description,
173
+ References=sources_metadata,
174
+ Category_Id=cat.id,
175
+ )
176
+ return final_sol
177
+
178
+ solutions = await asyncio.gather(*[_search_inner(cat) for cat in params.categories], return_exceptions=True)
179
+ logging.info(solutions)
180
+ final_solutions = [
181
+ sol for sol in solutions if not isinstance(sol, Exception)]
182
+
183
+ return SolutionSearchResponse(solutions=final_solutions)
184
+
185
 
186
  uvicorn.run(api, host="0.0.0.0", port=8000)
prompts/classify.txt CHANGED
@@ -11,7 +11,7 @@ For each category indicate which requirements belong in that category using thei
11
  Here are the requirements:
12
  <requirements>
13
  {% for req in requirements -%}
14
- - {{ loop.index }}. {{ req["requirement"] }}
15
  {% endfor -%}
16
  </requirements>
17
 
 
11
  Here are the requirements:
12
  <requirements>
13
  {% for req in requirements -%}
14
+ - {{ loop.index0 }}. {{ req["requirement"] }}
15
  {% endfor -%}
16
  </requirements>
17
 
prompts/criticize.txt CHANGED
@@ -9,8 +9,8 @@ Here are the solutions:
9
  {% for solution in solutions %}
10
  ## Solution
11
  - Context: {{solution["Context"]}}
12
- - Problem description: {{solution["Problem Description"]}}
13
- - Solution description: {{solution["Solution Description"]}}
14
  ---
15
  {% endfor -%}
16
  </solutions>
 
9
  {% for solution in solutions %}
10
  ## Solution
11
  - Context: {{solution["Context"]}}
12
+ - Problem description: {{solution["Problem_Description"]}}
13
+ - Solution description: {{solution["Solution_Description"]}}
14
  ---
15
  {% endfor -%}
16
  </solutions>
prompts/search_solution.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <role>You are an expert system designer</role>
2
+ <task>
3
+ Your task is to create a solution which is a combination of mechanisms that addresses as many of the provided requirements of a category as possible and that by searching the web, while carefully considering the given context.
4
+ Please actually make searches and do not simulate them.
5
+ </task>
6
+
7
+ Here is the category item and the associated requirements:
8
+ <requirements>
9
+ Category Title: {{category["title"]}}
10
+ Context: {{category["requirements"][0]["context"]}}
11
+ Requirements:
12
+ {% for req in category["requirements"] -%}
13
+ - {{loop.index0}} {{req["requirement"]}}
14
+ {% endfor -%}
15
+ </requirements>
16
+
17
+ <additional_instructions>
18
+ - The solution must aim to maximize requirement satisfaction while respecting the context.
19
+ - Provide a list of requirements addressed by the solution (provide only the requirement IDs)
20
+ - Please also detail each mechanism used in final solution
21
+ </additional_instructions>
prompts/structure_solution.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <role>You are an expert system designer</role>
2
+ <task>Your task is to take a solution you've created previously and structure it into a JSON object.</task>
3
+
4
+ Here is the solution
5
+ <solution>
6
+ {{solution}}
7
+ </solution>
8
+
9
+ <response_format>
10
+ Reply in JSON using the following format:
11
+ {{response_schema}}
12
+ </response_format>
schemas.py CHANGED
@@ -21,13 +21,23 @@ class ReqGroupingCategory(BaseModel):
21
 
22
 
23
  class SolutionModel(BaseModel):
24
- Context: str
25
- Requirements: List[str]
26
- ProblemDescription: str
27
- SolutionDescription: str
28
- References: Optional[str] = ""
 
 
 
 
 
 
 
29
 
30
- # Categorize requirements endpoint
 
 
 
31
 
32
 
33
  class ReqGroupingRequest(BaseModel):
@@ -54,7 +64,7 @@ class _ReqGroupingOutput(BaseModel):
54
  ..., description="List of grouping categories")
55
 
56
 
57
- # Criticize solution endpoint
58
 
59
  class CriticizeSolutionsRequest(BaseModel):
60
  solutions: list[SolutionModel]
@@ -82,3 +92,24 @@ class SolutionCriticism(BaseModel):
82
 
83
  class CritiqueResponse(BaseModel):
84
  critiques: List[SolutionCriticism]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  class SolutionModel(BaseModel):
24
+ Context: str = Field(...,
25
+ description="Full context provided for this category.")
26
+ Requirements: List[str] = Field(...,
27
+ description="List of each requirement as string.")
28
+ Problem_Description: str = Field(..., alias="Problem Description",
29
+ description="Description of the problem being solved.")
30
+ Solution_Description: str = Field(..., alias="Solution Description",
31
+ description="Detailed description of the solution.")
32
+ References: list[str] = Field(
33
+ ..., description="References to documents used for the solution.")
34
+ Category_Id: int = Field(
35
+ ..., description="ID of the requirements category the solution is based on")
36
 
37
+ class Config:
38
+ validate_by_name = True # Enables alias handling on input/output
39
+
40
+ # ============================================================= Categorize requirements endpoint
41
 
42
 
43
  class ReqGroupingRequest(BaseModel):
 
64
  ..., description="List of grouping categories")
65
 
66
 
67
+ # =========================================================== Criticize solution endpoint
68
 
69
  class CriticizeSolutionsRequest(BaseModel):
70
  solutions: list[SolutionModel]
 
92
 
93
  class CritiqueResponse(BaseModel):
94
  critiques: List[SolutionCriticism]
95
+
96
+
97
+ # =================================================================== search solution response endpoint
98
+
99
+ class _SolutionSearchOutput(BaseModel):
100
+ solution: SolutionModel
101
+
102
+
103
+ class _SearchedSolutionModel(BaseModel):
104
+ """"Internal model used for solutions searched using gemini"""
105
+ requirement_ids: List[int] = Field(...,
106
+ description="List of each requirement ID addressed by the solution")
107
+ problem_description: str = Field(...,
108
+ description="Description of the problem being solved.")
109
+ solution_description: str = Field(...,
110
+ description="Detailed description of the solution.")
111
+
112
+
113
+ class SolutionSearchResponse(BaseModel):
114
+ """Response model for solution search"""
115
+ solutions: list[SolutionModel]