Lucas ARRIESSE
commited on
Commit
·
f6a7399
1
Parent(s):
23cca30
Finish criticize_solutions endpoint
Browse files- .gitignore +2 -1
- app.py +22 -12
- prompts/criticize.txt +6 -1
- schemas.py +26 -2
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
.venv
|
2 |
-
__pycache__
|
|
|
|
1 |
.venv
|
2 |
+
__pycache__
|
3 |
+
.env
|
app.py
CHANGED
@@ -4,9 +4,8 @@ import os
|
|
4 |
import sys
|
5 |
import uvicorn
|
6 |
from fastapi import FastAPI
|
7 |
-
from schemas import CriticizeSolutionsRequest, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput
|
8 |
from jinja2 import Environment, FileSystemLoader
|
9 |
-
from litellm import acompletion
|
10 |
from litellm.router import Router
|
11 |
from dotenv import load_dotenv
|
12 |
|
@@ -57,8 +56,6 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
|
|
57 |
"max_n_categories": params.max_n_categories,
|
58 |
"response_schema": _ReqGroupingOutput.model_json_schema()})
|
59 |
|
60 |
-
logging.info(req_prompt)
|
61 |
-
|
62 |
# add system prompt with requirements
|
63 |
messages.append({"role": "user", "content": req_prompt})
|
64 |
|
@@ -68,10 +65,16 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
|
|
68 |
output = _ReqGroupingOutput.model_validate_json(
|
69 |
req_completion.choices[0].message.content)
|
70 |
|
71 |
-
# quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
|
|
|
72 |
assigned_ids = {
|
73 |
req_id for cat in output.categories for req_id in cat.items}
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
if len(unassigned_ids) == 0:
|
77 |
categories.extend(output.categories)
|
@@ -99,11 +102,18 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
|
|
99 |
|
100 |
|
101 |
@api.post("/criticize_solution")
|
102 |
-
async def criticize_solution(params: CriticizeSolutionsRequest) ->
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
uvicorn.run(api, host="0.0.0.0", port=8000)
|
|
|
4 |
import sys
|
5 |
import uvicorn
|
6 |
from fastapi import FastAPI
|
7 |
+
from schemas import _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput, SolutionCriticism
|
8 |
from jinja2 import Environment, FileSystemLoader
|
|
|
9 |
from litellm.router import Router
|
10 |
from dotenv import load_dotenv
|
11 |
|
|
|
56 |
"max_n_categories": params.max_n_categories,
|
57 |
"response_schema": _ReqGroupingOutput.model_json_schema()})
|
58 |
|
|
|
|
|
59 |
# add system prompt with requirements
|
60 |
messages.append({"role": "user", "content": req_prompt})
|
61 |
|
|
|
65 |
output = _ReqGroupingOutput.model_validate_json(
|
66 |
req_completion.choices[0].message.content)
|
67 |
|
68 |
+
# # quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
|
69 |
+
valid_ids_universe = set(range(1, len(params.requirements)))
|
70 |
assigned_ids = {
|
71 |
req_id for cat in output.categories for req_id in cat.items}
|
72 |
+
|
73 |
+
# keep only non-hallucinated, valid assigned ids
|
74 |
+
valid_assigned_ids = assigned_ids.intersection(valid_ids_universe)
|
75 |
+
|
76 |
+
# check for remaining requirements assigned to none of the categories
|
77 |
+
unassigned_ids = valid_ids_universe - valid_assigned_ids
|
78 |
|
79 |
if len(unassigned_ids) == 0:
|
80 |
categories.extend(output.categories)
|
|
|
102 |
|
103 |
|
104 |
@api.post("/criticize_solution")
|
105 |
+
async def criticize_solution(params: CriticizeSolutionsRequest) -> CritiqueResponse:
|
106 |
+
"""Criticize the challenges, weaknesses and limitations of the provided solutions."""
|
107 |
+
req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{
|
108 |
+
"solutions": [sol.model_dump() for sol in params.solutions],
|
109 |
+
"response_schema": _SolutionCriticismOutput.model_json_schema()
|
110 |
+
})
|
111 |
+
req_completion = await llm_router.acompletion(model="chat", messages=[{"role": "user", "content": req_prompt}], response_format=_SolutionCriticismOutput)
|
112 |
+
criticism_out = _SolutionCriticismOutput.model_validate_json(
|
113 |
+
req_completion.choices[0].message.content)
|
114 |
+
|
115 |
+
return CritiqueResponse(critiques=[
|
116 |
+
SolutionCriticism(solution=sol, criticism=crit) for (sol, crit) in zip(params.solutions, criticism_out.criticisms)
|
117 |
+
])
|
118 |
|
119 |
uvicorn.run(api, host="0.0.0.0", port=8000)
|
prompts/criticize.txt
CHANGED
@@ -13,4 +13,9 @@ Here are the solutions:
|
|
13 |
- Solution description: {{solution["Solution Description"]}}
|
14 |
---
|
15 |
{% endfor -%}
|
16 |
-
</solutions>
|
|
|
|
|
|
|
|
|
|
|
|
13 |
- Solution description: {{solution["Solution Description"]}}
|
14 |
---
|
15 |
{% endfor -%}
|
16 |
+
</solutions>
|
17 |
+
|
18 |
+
<response_format>
|
19 |
+
Reply in JSON using the following format:
|
20 |
+
{{response_schema}}
|
21 |
+
</response_format>
|
schemas.py
CHANGED
@@ -20,7 +20,7 @@ class ReqGroupingCategory(BaseModel):
|
|
20 |
..., description="List of grouped requirements")
|
21 |
|
22 |
|
23 |
-
class
|
24 |
Context: str
|
25 |
Requirements: List[str]
|
26 |
ProblemDescription: str
|
@@ -57,4 +57,28 @@ class _ReqGroupingOutput(BaseModel):
|
|
57 |
# Criticize solution endpoint
|
58 |
|
59 |
class CriticizeSolutionsRequest(BaseModel):
|
60 |
-
solutions: list[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
..., description="List of grouped requirements")
|
21 |
|
22 |
|
23 |
+
class SolutionModel(BaseModel):
|
24 |
Context: str
|
25 |
Requirements: List[str]
|
26 |
ProblemDescription: str
|
|
|
57 |
# Criticize solution endpoint
|
58 |
|
59 |
class CriticizeSolutionsRequest(BaseModel):
|
60 |
+
solutions: list[SolutionModel]
|
61 |
+
|
62 |
+
|
63 |
+
class _SolutionCriticism(BaseModel):
|
64 |
+
technical_challenges: List[str] = Field(
|
65 |
+
..., description="Technical challenges encountered by the solution")
|
66 |
+
weaknesses: List[str] = Field(...,
|
67 |
+
description="Identified weaknesses of the solution")
|
68 |
+
limitations: List[str] = Field(...,
|
69 |
+
description="Identified limitations of the solution")
|
70 |
+
|
71 |
+
|
72 |
+
class _SolutionCriticismOutput(BaseModel):
|
73 |
+
criticisms: List[_SolutionCriticism]
|
74 |
+
|
75 |
+
# response format
|
76 |
+
|
77 |
+
|
78 |
+
class SolutionCriticism(BaseModel):
|
79 |
+
solution: SolutionModel
|
80 |
+
criticism: _SolutionCriticism
|
81 |
+
|
82 |
+
|
83 |
+
class CritiqueResponse(BaseModel):
|
84 |
+
critiques: List[SolutionCriticism]
|