Lucas ARRIESSE commited on
Commit
03acc5b
·
1 Parent(s): 32239ae

Reorganize endpoints + add refine_solution endpoint

Browse files
app.py CHANGED
@@ -3,8 +3,8 @@ import logging
3
  import os
4
  import sys
5
  import uvicorn
6
- from fastapi import FastAPI
7
- from schemas import _SearchedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput, SolutionCriticism, SolutionModel, SolutionSearchResponse
8
  from jinja2 import Environment, FileSystemLoader, StrictUndefined
9
  from litellm.router import Router
10
  from dotenv import load_dotenv
@@ -44,9 +44,13 @@ prompt_env = Environment(loader=FileSystemLoader(
44
  'prompts'), enable_async=True, undefined=StrictUndefined)
45
 
46
  api = FastAPI(docs_url="/")
 
 
 
 
47
 
48
 
49
- @api.post("/categorize_requirements")
50
  async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
51
  """Categorize the given service requirements into categories"""
52
 
@@ -105,34 +109,9 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
105
 
106
  return ReqGroupingResponse(categories=final_categories)
107
 
 
108
 
109
- @api.post("/criticize_solution", response_model=CritiqueResponse)
110
- async def criticize_solution(params: CriticizeSolutionsRequest) -> CritiqueResponse:
111
- """Criticize the challenges, weaknesses and limitations of the provided solutions."""
112
-
113
- async def __criticize_single(solution: SolutionModel):
114
- req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{
115
- "solutions": [solution.model_dump()],
116
- "response_schema": _SolutionCriticismOutput.model_json_schema()
117
- })
118
-
119
- req_completion = await llm_router.acompletion(
120
- model="chat",
121
- messages=[{"role": "user", "content": req_prompt}],
122
- response_format=_SolutionCriticismOutput
123
- )
124
-
125
- criticism_out = _SolutionCriticismOutput.model_validate_json(
126
- req_completion.choices[0].message.content
127
- )
128
-
129
- return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0])
130
-
131
- critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False)
132
- return CritiqueResponse(critiques=critiques)
133
-
134
-
135
- @api.post("/search_solutions_gemini", response_model=SolutionSearchResponse)
136
  async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchResponse:
137
  """Searches solutions using Gemini and grounded on google search"""
138
 
@@ -187,5 +166,63 @@ async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchRespons
187
 
188
  return SolutionSearchResponse(solutions=final_solutions)
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  uvicorn.run(api, host="0.0.0.0", port=8000)
 
3
  import os
4
  import sys
5
  import uvicorn
6
+ from fastapi import APIRouter, FastAPI
7
+ from schemas import _RefinedSolutionModel, _SearchedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput, SolutionCriticism, SolutionModel, SolutionSearchResponse
8
  from jinja2 import Environment, FileSystemLoader, StrictUndefined
9
  from litellm.router import Router
10
  from dotenv import load_dotenv
 
44
  'prompts'), enable_async=True, undefined=StrictUndefined)
45
 
46
  api = FastAPI(docs_url="/")
47
+ # requirements routes
48
+ requirements_router = APIRouter(prefix="/reqs", tags=["requirements"])
49
+ # solution routes
50
+ solution_router = APIRouter(prefix="/solution", tags=["solution"])
51
 
52
 
53
+ @requirements_router.post("/categorize_requirements")
54
  async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
55
  """Categorize the given service requirements into categories"""
56
 
 
109
 
110
  return ReqGroupingResponse(categories=final_categories)
111
 
112
+ # ========================================================= Solution Endpoints ===========================================================
113
 
114
+ @solution_router.post("/search_solutions_gemini", response_model=SolutionSearchResponse)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchResponse:
116
  """Searches solutions using Gemini and grounded on google search"""
117
 
 
166
 
167
  return SolutionSearchResponse(solutions=final_solutions)
168
 
169
+ @solution_router.post("/criticize_solution", response_model=CritiqueResponse)
170
+ async def criticize_solution(params: CriticizeSolutionsRequest) -> CritiqueResponse:
171
+ """Criticize the challenges, weaknesses and limitations of the provided solutions."""
172
+
173
+ async def __criticize_single(solution: SolutionModel):
174
+ req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{
175
+ "solutions": [solution.model_dump()],
176
+ "response_schema": _SolutionCriticismOutput.model_json_schema()
177
+ })
178
+
179
+ req_completion = await llm_router.acompletion(
180
+ model="chat",
181
+ messages=[{"role": "user", "content": req_prompt}],
182
+ response_format=_SolutionCriticismOutput
183
+ )
184
+
185
+ criticism_out = _SolutionCriticismOutput.model_validate_json(
186
+ req_completion.choices[0].message.content
187
+ )
188
+
189
+ return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0])
190
+
191
+ critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False)
192
+ return CritiqueResponse(critiques=critiques)
193
+
194
+
195
+ @solution_router.post("/refine_solutions", response_model=SolutionSearchResponse)
196
+ async def refine_solutions(params: CritiqueResponse) -> SolutionSearchResponse:
197
+ """Refines the previously critiqued solutions."""
198
+
199
+ async def __refine_solution(crit: SolutionCriticism):
200
+ req_prompt = await prompt_env.get_template("refine_solution.txt").render_async(**{
201
+ "solution": crit.solution.model_dump(),
202
+ "criticism": crit.criticism,
203
+ "response_schema": _RefinedSolutionModel.model_json_schema(),
204
+ })
205
+
206
+ req_completion = await llm_router.acompletion(model="chat", messages=[
207
+ {"role": "user", "content": req_prompt}
208
+ ], response_format=_RefinedSolutionModel)
209
+
210
+ req_model = _RefinedSolutionModel.model_validate_json(
211
+ req_completion.choices[0].message.content)
212
+
213
+ # copy previous solution model
214
+ refined_solution = crit.solution.model_copy(deep=True)
215
+ refined_solution.Problem_Description = req_model.problem_description
216
+ refined_solution.Solution_Description = req_model.solution_description
217
+
218
+ return refined_solution
219
+
220
+ refined_solutions = await asyncio.gather(*[__refine_solution(crit) for crit in params.critiques], return_exceptions=False)
221
+
222
+ return SolutionSearchResponse(solutions=refined_solutions)
223
+
224
+
225
+ api.include_router(requirements_router)
226
+ api.include_router(solution_router)
227
 
228
  uvicorn.run(api, host="0.0.0.0", port=8000)
prompts/classify.txt CHANGED
@@ -16,6 +16,6 @@ Here are the requirements:
16
  </requirements>
17
 
18
  <response_format>
19
- Reply in JSON using the following format:
20
  {{response_schema}}
21
  </response_format>
 
16
  </requirements>
17
 
18
  <response_format>
19
+ Reply in JSON using the following schema:
20
  {{response_schema}}
21
  </response_format>
prompts/criticize.txt CHANGED
@@ -16,6 +16,6 @@ Here are the solutions:
16
  </solutions>
17
 
18
  <response_format>
19
- Reply in JSON using the following format:
20
  {{response_schema}}
21
  </response_format>
 
16
  </solutions>
17
 
18
  <response_format>
19
+ Reply in JSON using the following schema:
20
  {{response_schema}}
21
  </response_format>
prompts/refine_solution.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <role>You are an expert system designer</role>
2
+ <task>
3
+ Your task is to refine a solution to account for the technical challenges, weaknesses and limitations that were critiqued.
4
+ No need to include that the solution is refined.
5
+ </task>
6
+
7
+ <solution>
8
+ Here is the solution:
9
+
10
+ # Solution Context:
11
+ {{solution['Context']}}
12
+
13
+ # Requirements solved by the solution
14
+ {% for req in solution['Requirements'] -%}
15
+ - {{req}}
16
+ {% endfor %}
17
+
18
+ # Problem description associated to the solution
19
+ {{solution['Problem_Description']}}
20
+
21
+ # Description of the solution
22
+ {{solution['Solution_Description']}}
23
+ </solution>
24
+
25
+ <criticisim>
26
+ Here is the criticism:
27
+ {{criticism}}
28
+ </criticism>
29
+
30
+ <response_format>
31
+ Reply in JSON using the following response schema:
32
+ {{response_schema}}
33
+ </response_format>
prompts/structure_solution.txt CHANGED
@@ -7,6 +7,6 @@ Here is the solution
7
  </solution>
8
 
9
  <response_format>
10
- Reply in JSON using the following format:
11
  {{response_schema}}
12
  </response_format>
 
7
  </solution>
8
 
9
  <response_format>
10
+ Reply in JSON using the following schema:
11
  {{response_schema}}
12
  </response_format>
schemas.py CHANGED
@@ -113,3 +113,15 @@ class _SearchedSolutionModel(BaseModel):
113
  class SolutionSearchResponse(BaseModel):
114
  """Response model for solution search"""
115
  solutions: list[SolutionModel]
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  class SolutionSearchResponse(BaseModel):
114
  """Response model for solution search"""
115
  solutions: list[SolutionModel]
116
+
117
+
118
+ # ================================================================= refine solution endpoints
119
+
120
+
121
+ class _RefinedSolutionModel(BaseModel):
122
+ """Internal model used for solution refining"""
123
+
124
+ problem_description: str = Field(...,
125
+ description="New description of the problem being solved.")
126
+ solution_description: str = Field(...,
127
+ description="New detailed description of the solution.")