Lucas ARRIESSE
commited on
Commit
·
594f2fe
1
Parent(s):
d9eeeac
WIP
Browse files- app.py +30 -9
- prompts/classify.txt +5 -6
- prompts/criticize.txt +16 -0
- schemas.py +17 -1
app.py
CHANGED
@@ -1,11 +1,13 @@
|
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
import sys
|
4 |
import uvicorn
|
5 |
from fastapi import FastAPI
|
6 |
-
from schemas import RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput
|
7 |
from jinja2 import Environment, FileSystemLoader
|
8 |
from litellm import acompletion
|
|
|
9 |
from dotenv import load_dotenv
|
10 |
|
11 |
logging.basicConfig(
|
@@ -17,21 +19,30 @@ logging.basicConfig(
|
|
17 |
# Load .env files
|
18 |
load_dotenv()
|
19 |
|
20 |
-
if "
|
21 |
logging.error(
|
22 |
-
"No LLM token (`
|
23 |
sys.exit(-1)
|
24 |
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
# Jinja2 environment to load prompt templates
|
29 |
prompt_env = Environment(loader=FileSystemLoader('prompts'), enable_async=True)
|
30 |
|
31 |
-
|
32 |
|
33 |
|
34 |
-
@
|
35 |
async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
|
36 |
"""Categorize the given service requirements into categories"""
|
37 |
|
@@ -46,12 +57,14 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
|
|
46 |
"max_n_categories": params.max_n_categories,
|
47 |
"response_schema": _ReqGroupingOutput.model_json_schema()})
|
48 |
|
|
|
|
|
49 |
# add system prompt with requirements
|
50 |
messages.append({"role": "user", "content": req_prompt})
|
51 |
|
52 |
# ensure all requirements items are processed
|
53 |
for attempt in range(MAX_ATTEMPTS):
|
54 |
-
req_completion = await acompletion(model=
|
55 |
output = _ReqGroupingOutput.model_validate_json(
|
56 |
req_completion.choices[0].message.content)
|
57 |
|
@@ -85,4 +98,12 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
|
|
85 |
return ReqGroupingResponse(categories=final_categories)
|
86 |
|
87 |
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
import logging
|
3 |
import os
|
4 |
import sys
|
5 |
import uvicorn
|
6 |
from fastapi import FastAPI
|
7 |
+
from schemas import CriticizeSolutionsRequest, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput
|
8 |
from jinja2 import Environment, FileSystemLoader
|
9 |
from litellm import acompletion
|
10 |
+
from litellm.router import Router
|
11 |
from dotenv import load_dotenv
|
12 |
|
13 |
logging.basicConfig(
|
|
|
19 |
# Load .env files
|
20 |
load_dotenv()
|
21 |
|
22 |
+
if "LLM_MODEL" not in os.environ or "LLM_API_KEY" not in os.environ:
|
23 |
logging.error(
|
24 |
+
"No LLM token (`LLM_API_KEY`) and/or LLM model (`LLM_MODEL`) were provided in the env vars. Exiting")
|
25 |
sys.exit(-1)
|
26 |
|
27 |
+
# LiteLLM router
|
28 |
+
llm_router = Router(model_list=[
|
29 |
+
{
|
30 |
+
"model_name": "chat",
|
31 |
+
"litellm_params": {
|
32 |
+
"model": os.environ.get("LLM_MODEL"),
|
33 |
+
"api_key": os.environ.get("LLM_API_KEY"),
|
34 |
+
"max_retries": 5
|
35 |
+
}
|
36 |
+
}
|
37 |
+
])
|
38 |
|
39 |
# Jinja2 environment to load prompt templates
|
40 |
prompt_env = Environment(loader=FileSystemLoader('prompts'), enable_async=True)
|
41 |
|
42 |
+
api = FastAPI()
|
43 |
|
44 |
|
45 |
+
@api.post("/categorize_requirements")
|
46 |
async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
|
47 |
"""Categorize the given service requirements into categories"""
|
48 |
|
|
|
57 |
"max_n_categories": params.max_n_categories,
|
58 |
"response_schema": _ReqGroupingOutput.model_json_schema()})
|
59 |
|
60 |
+
logging.info(req_prompt)
|
61 |
+
|
62 |
# add system prompt with requirements
|
63 |
messages.append({"role": "user", "content": req_prompt})
|
64 |
|
65 |
# ensure all requirements items are processed
|
66 |
for attempt in range(MAX_ATTEMPTS):
|
67 |
+
req_completion = await llm_router.acompletion(model="chat", messages=messages, response_format=_ReqGroupingOutput)
|
68 |
output = _ReqGroupingOutput.model_validate_json(
|
69 |
req_completion.choices[0].message.content)
|
70 |
|
|
|
98 |
return ReqGroupingResponse(categories=final_categories)
|
99 |
|
100 |
|
101 |
+
@api.post("/criticize_solution")
|
102 |
+
async def criticize_solution(params: CriticizeSolutionsRequest) -> str:
|
103 |
+
req_prompt = await prompt_env.get_template("criticize.txt").render_async(solutions=[sol.model_dump() for sol in params.solutions])
|
104 |
+
req_completion = await llm_router.acompletion(model="chat", messages=[{"role": "user", "content": req_prompt}])
|
105 |
+
|
106 |
+
return req_completion.choices[0].message.content
|
107 |
+
|
108 |
+
|
109 |
+
uvicorn.run(api, host="0.0.0.0", port=8000)
|
prompts/classify.txt
CHANGED
@@ -3,20 +3,19 @@
|
|
3 |
For each category indicate which requirements belong in that category using their ID. An item may appear in one category at a time.
|
4 |
Please make each category title indicative of whats in it.
|
5 |
</task>
|
6 |
-
{% if max_n_categories is none
|
7 |
<number_of_categories>You may have at most as much categories as you think is needed</number_of_categories>
|
8 |
-
{% else
|
9 |
<number_of_categories>You may have at most {{max_n_categories}} categories</number_of_categories>
|
10 |
-
{%endif
|
11 |
|
12 |
Here are the requirements:
|
13 |
<requirements>
|
14 |
-
{% for req in requirements
|
15 |
- {{ loop.index }}. {{ req["requirement"] }}
|
16 |
-
{% endfor
|
17 |
</requirements>
|
18 |
|
19 |
-
|
20 |
<response_format>
|
21 |
Reply in JSON using the following format:
|
22 |
{{response_schema}}
|
|
|
3 |
For each category indicate which requirements belong in that category using their ID. An item may appear in one category at a time.
|
4 |
Please make each category title indicative of whats in it.
|
5 |
</task>
|
6 |
+
{% if max_n_categories is none -%}
|
7 |
<number_of_categories>You may have at most as much categories as you think is needed</number_of_categories>
|
8 |
+
{% else -%}
|
9 |
<number_of_categories>You may have at most {{max_n_categories}} categories</number_of_categories>
|
10 |
+
{% endif-%}
|
11 |
|
12 |
Here are the requirements:
|
13 |
<requirements>
|
14 |
+
{% for req in requirements -%}
|
15 |
- {{ loop.index }}. {{ req["requirement"] }}
|
16 |
+
{% endfor -%}
|
17 |
</requirements>
|
18 |
|
|
|
19 |
<response_format>
|
20 |
Reply in JSON using the following format:
|
21 |
{{response_schema}}
|
prompts/criticize.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<role>You are an useful engineering assistant for innovation</role>
|
2 |
+
<task>
|
3 |
+
You are tasked with criticizing multiple solutions solving a set of requirements on different points,
|
4 |
+
namely the technical challenges of the solutions, the weaknesses and limitations of the proposed solutions.
|
5 |
+
</task>
|
6 |
+
|
7 |
+
Here are the solutions:
|
8 |
+
<solutions>
|
9 |
+
{% for solution in solutions %}
|
10 |
+
## Solution
|
11 |
+
- Context: {{solution["Context"]}}
|
12 |
+
- Problem description: {{solution["Problem Description"]}}
|
13 |
+
- Solution description: {{solution["Solution Description"]}}
|
14 |
+
---
|
15 |
+
{% endfor -%}
|
16 |
+
</solutions>
|
schemas.py
CHANGED
@@ -3,6 +3,7 @@ from typing import List, Optional
|
|
3 |
|
4 |
# Shared model schemas
|
5 |
|
|
|
6 |
class RequirementInfo(BaseModel):
|
7 |
"""Represents an extracted requirement info"""
|
8 |
context: str = Field(..., description="Context for the requirement.")
|
@@ -18,7 +19,16 @@ class ReqGroupingCategory(BaseModel):
|
|
18 |
requirements: List[RequirementInfo] = Field(
|
19 |
..., description="List of grouped requirements")
|
20 |
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
class ReqGroupingRequest(BaseModel):
|
24 |
"""Request schema of a requirement grouping call."""
|
@@ -42,3 +52,9 @@ class _ReqGroupingCategory(BaseModel):
|
|
42 |
class _ReqGroupingOutput(BaseModel):
|
43 |
categories: list[_ReqGroupingCategory] = Field(
|
44 |
..., description="List of grouping categories")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
# Shared model schemas
|
5 |
|
6 |
+
|
7 |
class RequirementInfo(BaseModel):
|
8 |
"""Represents an extracted requirement info"""
|
9 |
context: str = Field(..., description="Context for the requirement.")
|
|
|
19 |
requirements: List[RequirementInfo] = Field(
|
20 |
..., description="List of grouped requirements")
|
21 |
|
22 |
+
|
23 |
+
class SolutionSearchResult(BaseModel):
|
24 |
+
Context: str
|
25 |
+
Requirements: List[str]
|
26 |
+
ProblemDescription: str
|
27 |
+
SolutionDescription: str
|
28 |
+
References: Optional[str] = ""
|
29 |
+
|
30 |
+
# Categorize requirements endpoint
|
31 |
+
|
32 |
|
33 |
class ReqGroupingRequest(BaseModel):
|
34 |
"""Request schema of a requirement grouping call."""
|
|
|
52 |
class _ReqGroupingOutput(BaseModel):
|
53 |
categories: list[_ReqGroupingCategory] = Field(
|
54 |
..., description="List of grouping categories")
|
55 |
+
|
56 |
+
|
57 |
+
# Criticize solution endpoint
|
58 |
+
|
59 |
+
class CriticizeSolutionsRequest(BaseModel):
|
60 |
+
solutions: list[SolutionSearchResult]
|