Lucas ARRIESSE commited on
Commit
f6a7399
·
1 Parent(s): 23cca30

Finish criticize_solutions endpoint

Browse files
Files changed (4) hide show
  1. .gitignore +2 -1
  2. app.py +22 -12
  3. prompts/criticize.txt +6 -1
  4. schemas.py +26 -2
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .venv
2
- __pycache__
 
 
1
  .venv
2
+ __pycache__
3
+ .env
app.py CHANGED
@@ -4,9 +4,8 @@ import os
4
  import sys
5
  import uvicorn
6
  from fastapi import FastAPI
7
- from schemas import CriticizeSolutionsRequest, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput
8
  from jinja2 import Environment, FileSystemLoader
9
- from litellm import acompletion
10
  from litellm.router import Router
11
  from dotenv import load_dotenv
12
 
@@ -57,8 +56,6 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
57
  "max_n_categories": params.max_n_categories,
58
  "response_schema": _ReqGroupingOutput.model_json_schema()})
59
 
60
- logging.info(req_prompt)
61
-
62
  # add system prompt with requirements
63
  messages.append({"role": "user", "content": req_prompt})
64
 
@@ -68,10 +65,16 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
68
  output = _ReqGroupingOutput.model_validate_json(
69
  req_completion.choices[0].message.content)
70
 
71
- # quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
 
72
  assigned_ids = {
73
  req_id for cat in output.categories for req_id in cat.items}
74
- unassigned_ids = set(range(1, len(params.requirements))) - assigned_ids
 
 
 
 
 
75
 
76
  if len(unassigned_ids) == 0:
77
  categories.extend(output.categories)
@@ -99,11 +102,18 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
99
 
100
 
101
  @api.post("/criticize_solution")
102
- async def criticize_solution(params: CriticizeSolutionsRequest) -> str:
103
- req_prompt = await prompt_env.get_template("criticize.txt").render_async(solutions=[sol.model_dump() for sol in params.solutions])
104
- req_completion = await llm_router.acompletion(model="chat", messages=[{"role": "user", "content": req_prompt}])
105
-
106
- return req_completion.choices[0].message.content
107
-
 
 
 
 
 
 
 
108
 
109
  uvicorn.run(api, host="0.0.0.0", port=8000)
 
4
  import sys
5
  import uvicorn
6
  from fastapi import FastAPI
7
+ from schemas import _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput, SolutionCriticism
8
  from jinja2 import Environment, FileSystemLoader
 
9
  from litellm.router import Router
10
  from dotenv import load_dotenv
11
 
 
56
  "max_n_categories": params.max_n_categories,
57
  "response_schema": _ReqGroupingOutput.model_json_schema()})
58
 
 
 
59
  # add system prompt with requirements
60
  messages.append({"role": "user", "content": req_prompt})
61
 
 
65
  output = _ReqGroupingOutput.model_validate_json(
66
  req_completion.choices[0].message.content)
67
 
68
+ # # quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
69
+ valid_ids_universe = set(range(1, len(params.requirements)))
70
  assigned_ids = {
71
  req_id for cat in output.categories for req_id in cat.items}
72
+
73
+ # keep only non-hallucinated, valid assigned ids
74
+ valid_assigned_ids = assigned_ids.intersection(valid_ids_universe)
75
+
76
+ # check for remaining requirements assigned to none of the categories
77
+ unassigned_ids = valid_ids_universe - valid_assigned_ids
78
 
79
  if len(unassigned_ids) == 0:
80
  categories.extend(output.categories)
 
102
 
103
 
104
  @api.post("/criticize_solution")
105
+ async def criticize_solution(params: CriticizeSolutionsRequest) -> CritiqueResponse:
106
+ """Criticize the challenges, weaknesses and limitations of the provided solutions."""
107
+ req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{
108
+ "solutions": [sol.model_dump() for sol in params.solutions],
109
+ "response_schema": _SolutionCriticismOutput.model_json_schema()
110
+ })
111
+ req_completion = await llm_router.acompletion(model="chat", messages=[{"role": "user", "content": req_prompt}], response_format=_SolutionCriticismOutput)
112
+ criticism_out = _SolutionCriticismOutput.model_validate_json(
113
+ req_completion.choices[0].message.content)
114
+
115
+ return CritiqueResponse(critiques=[
116
+ SolutionCriticism(solution=sol, criticism=crit) for (sol, crit) in zip(params.solutions, criticism_out.criticisms)
117
+ ])
118
 
119
  uvicorn.run(api, host="0.0.0.0", port=8000)
prompts/criticize.txt CHANGED
@@ -13,4 +13,9 @@ Here are the solutions:
13
  - Solution description: {{solution["Solution Description"]}}
14
  ---
15
  {% endfor -%}
16
- </solutions>
 
 
 
 
 
 
13
  - Solution description: {{solution["Solution Description"]}}
14
  ---
15
  {% endfor -%}
16
+ </solutions>
17
+
18
+ <response_format>
19
+ Reply in JSON using the following format:
20
+ {{response_schema}}
21
+ </response_format>
schemas.py CHANGED
@@ -20,7 +20,7 @@ class ReqGroupingCategory(BaseModel):
20
  ..., description="List of grouped requirements")
21
 
22
 
23
- class SolutionSearchResult(BaseModel):
24
  Context: str
25
  Requirements: List[str]
26
  ProblemDescription: str
@@ -57,4 +57,28 @@ class _ReqGroupingOutput(BaseModel):
57
  # Criticize solution endpoint
58
 
59
  class CriticizeSolutionsRequest(BaseModel):
60
- solutions: list[SolutionSearchResult]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  ..., description="List of grouped requirements")
21
 
22
 
23
+ class SolutionModel(BaseModel):
24
  Context: str
25
  Requirements: List[str]
26
  ProblemDescription: str
 
57
  # Criticize solution endpoint
58
 
59
  class CriticizeSolutionsRequest(BaseModel):
60
+ solutions: list[SolutionModel]
61
+
62
+
63
+ class _SolutionCriticism(BaseModel):
64
+ technical_challenges: List[str] = Field(
65
+ ..., description="Technical challenges encountered by the solution")
66
+ weaknesses: List[str] = Field(...,
67
+ description="Identified weaknesses of the solution")
68
+ limitations: List[str] = Field(...,
69
+ description="Identified limitations of the solution")
70
+
71
+
72
+ class _SolutionCriticismOutput(BaseModel):
73
+ criticisms: List[_SolutionCriticism]
74
+
75
+ # response format
76
+
77
+
78
+ class SolutionCriticism(BaseModel):
79
+ solution: SolutionModel
80
+ criticism: _SolutionCriticism
81
+
82
+
83
+ class CritiqueResponse(BaseModel):
84
+ critiques: List[SolutionCriticism]