Lucas ARRIESSE commited on
Commit
594f2fe
·
1 Parent(s): d9eeeac
Files changed (4) hide show
  1. app.py +30 -9
  2. prompts/classify.txt +5 -6
  3. prompts/criticize.txt +16 -0
  4. schemas.py +17 -1
app.py CHANGED
@@ -1,11 +1,13 @@
 
1
  import logging
2
  import os
3
  import sys
4
  import uvicorn
5
  from fastapi import FastAPI
6
- from schemas import RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput
7
  from jinja2 import Environment, FileSystemLoader
8
  from litellm import acompletion
 
9
  from dotenv import load_dotenv
10
 
11
  logging.basicConfig(
@@ -17,21 +19,30 @@ logging.basicConfig(
17
  # Load .env files
18
  load_dotenv()
19
 
20
- if "LLM_API_MODEL" not in os.environ or "LLM_API_KEY" not in os.environ:
21
  logging.error(
22
- "No LLM token (`LLM_API_TOKEN`) and/or LLM model (`LLM_API_KEY`) were provided in the env vars. Exiting")
23
  sys.exit(-1)
24
 
25
- LLM_API_MODEL = os.environ.get("LLM_API_MODEL")
26
- LLM_API_KEY = os.environ.get("LLM_API_KEY")
 
 
 
 
 
 
 
 
 
27
 
28
  # Jinja2 environment to load prompt templates
29
  prompt_env = Environment(loader=FileSystemLoader('prompts'), enable_async=True)
30
 
31
- fastapi = FastAPI()
32
 
33
 
34
- @fastapi.post("/categorize_requirements")
35
  async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
36
  """Categorize the given service requirements into categories"""
37
 
@@ -46,12 +57,14 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
46
  "max_n_categories": params.max_n_categories,
47
  "response_schema": _ReqGroupingOutput.model_json_schema()})
48
 
 
 
49
  # add system prompt with requirements
50
  messages.append({"role": "user", "content": req_prompt})
51
 
52
  # ensure all requirements items are processed
53
  for attempt in range(MAX_ATTEMPTS):
54
- req_completion = await acompletion(model=LLM_API_MODEL, api_key=LLM_API_KEY, messages=messages, response_format=_ReqGroupingOutput)
55
  output = _ReqGroupingOutput.model_validate_json(
56
  req_completion.choices[0].message.content)
57
 
@@ -85,4 +98,12 @@ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
85
  return ReqGroupingResponse(categories=final_categories)
86
 
87
 
88
- uvicorn.run(fastapi, host="0.0.0.0", port=8000)
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
  import logging
3
  import os
4
  import sys
5
  import uvicorn
6
  from fastapi import FastAPI
7
+ from schemas import CriticizeSolutionsRequest, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput
8
  from jinja2 import Environment, FileSystemLoader
9
  from litellm import acompletion
10
+ from litellm.router import Router
11
  from dotenv import load_dotenv
12
 
13
  logging.basicConfig(
 
19
  # Load .env files
20
  load_dotenv()
21
 
22
+ if "LLM_MODEL" not in os.environ or "LLM_API_KEY" not in os.environ:
23
  logging.error(
24
+ "No LLM token (`LLM_API_KEY`) and/or LLM model (`LLM_MODEL`) were provided in the env vars. Exiting")
25
  sys.exit(-1)
26
 
27
+ # LiteLLM router
28
+ llm_router = Router(model_list=[
29
+ {
30
+ "model_name": "chat",
31
+ "litellm_params": {
32
+ "model": os.environ.get("LLM_MODEL"),
33
+ "api_key": os.environ.get("LLM_API_KEY"),
34
+ "max_retries": 5
35
+ }
36
+ }
37
+ ])
38
 
39
  # Jinja2 environment to load prompt templates
40
  prompt_env = Environment(loader=FileSystemLoader('prompts'), enable_async=True)
41
 
42
+ api = FastAPI()
43
 
44
 
45
+ @api.post("/categorize_requirements")
46
  async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
47
  """Categorize the given service requirements into categories"""
48
 
 
57
  "max_n_categories": params.max_n_categories,
58
  "response_schema": _ReqGroupingOutput.model_json_schema()})
59
 
60
+ logging.info(req_prompt)
61
+
62
  # add system prompt with requirements
63
  messages.append({"role": "user", "content": req_prompt})
64
 
65
  # ensure all requirements items are processed
66
  for attempt in range(MAX_ATTEMPTS):
67
+ req_completion = await llm_router.acompletion(model="chat", messages=messages, response_format=_ReqGroupingOutput)
68
  output = _ReqGroupingOutput.model_validate_json(
69
  req_completion.choices[0].message.content)
70
 
 
98
  return ReqGroupingResponse(categories=final_categories)
99
 
100
 
101
+ @api.post("/criticize_solution")
102
+ async def criticize_solution(params: CriticizeSolutionsRequest) -> str:
103
+ req_prompt = await prompt_env.get_template("criticize.txt").render_async(solutions=[sol.model_dump() for sol in params.solutions])
104
+ req_completion = await llm_router.acompletion(model="chat", messages=[{"role": "user", "content": req_prompt}])
105
+
106
+ return req_completion.choices[0].message.content
107
+
108
+
109
+ uvicorn.run(api, host="0.0.0.0", port=8000)
prompts/classify.txt CHANGED
@@ -3,20 +3,19 @@
3
  For each category indicate which requirements belong in that category using their ID. An item may appear in one category at a time.
4
  Please make each category title indicative of whats in it.
5
  </task>
6
- {% if max_n_categories is none %}
7
  <number_of_categories>You may have at most as much categories as you think is needed</number_of_categories>
8
- {% else %}
9
  <number_of_categories>You may have at most {{max_n_categories}} categories</number_of_categories>
10
- {%endif%}
11
 
12
  Here are the requirements:
13
  <requirements>
14
- {% for req in requirements %}
15
  - {{ loop.index }}. {{ req["requirement"] }}
16
- {% endfor %}
17
  </requirements>
18
 
19
-
20
  <response_format>
21
  Reply in JSON using the following format:
22
  {{response_schema}}
 
3
  For each category indicate which requirements belong in that category using their ID. An item may appear in one category at a time.
4
  Please make each category title indicative of whats in it.
5
  </task>
6
+ {% if max_n_categories is none -%}
7
  <number_of_categories>You may have at most as much categories as you think is needed</number_of_categories>
8
+ {% else -%}
9
  <number_of_categories>You may have at most {{max_n_categories}} categories</number_of_categories>
10
+ {% endif-%}
11
 
12
  Here are the requirements:
13
  <requirements>
14
+ {% for req in requirements -%}
15
  - {{ loop.index }}. {{ req["requirement"] }}
16
+ {% endfor -%}
17
  </requirements>
18
 
 
19
  <response_format>
20
  Reply in JSON using the following format:
21
  {{response_schema}}
prompts/criticize.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <role>You are an useful engineering assistant for innovation</role>
2
+ <task>
3
+ You are tasked with criticizing multiple solutions solving a set of requirements on different points,
4
+ namely the technical challenges of the solutions, the weaknesses and limitations of the proposed solutions.
5
+ </task>
6
+
7
+ Here are the solutions:
8
+ <solutions>
9
+ {% for solution in solutions %}
10
+ ## Solution
11
+ - Context: {{solution["Context"]}}
12
+ - Problem description: {{solution["Problem Description"]}}
13
+ - Solution description: {{solution["Solution Description"]}}
14
+ ---
15
+ {% endfor -%}
16
+ </solutions>
schemas.py CHANGED
@@ -3,6 +3,7 @@ from typing import List, Optional
3
 
4
  # Shared model schemas
5
 
 
6
  class RequirementInfo(BaseModel):
7
  """Represents an extracted requirement info"""
8
  context: str = Field(..., description="Context for the requirement.")
@@ -18,7 +19,16 @@ class ReqGroupingCategory(BaseModel):
18
  requirements: List[RequirementInfo] = Field(
19
  ..., description="List of grouped requirements")
20
 
21
- # Endpoint model schemas
 
 
 
 
 
 
 
 
 
22
 
23
  class ReqGroupingRequest(BaseModel):
24
  """Request schema of a requirement grouping call."""
@@ -42,3 +52,9 @@ class _ReqGroupingCategory(BaseModel):
42
  class _ReqGroupingOutput(BaseModel):
43
  categories: list[_ReqGroupingCategory] = Field(
44
  ..., description="List of grouping categories")
 
 
 
 
 
 
 
3
 
4
  # Shared model schemas
5
 
6
+
7
  class RequirementInfo(BaseModel):
8
  """Represents an extracted requirement info"""
9
  context: str = Field(..., description="Context for the requirement.")
 
19
  requirements: List[RequirementInfo] = Field(
20
  ..., description="List of grouped requirements")
21
 
22
+
23
+ class SolutionSearchResult(BaseModel):
24
+ Context: str
25
+ Requirements: List[str]
26
+ ProblemDescription: str
27
+ SolutionDescription: str
28
+ References: Optional[str] = ""
29
+
30
+ # Categorize requirements endpoint
31
+
32
 
33
  class ReqGroupingRequest(BaseModel):
34
  """Request schema of a requirement grouping call."""
 
52
  class _ReqGroupingOutput(BaseModel):
53
  categories: list[_ReqGroupingCategory] = Field(
54
  ..., description="List of grouping categories")
55
+
56
+
57
+ # Criticize solution endpoint
58
+
59
+ class CriticizeSolutionsRequest(BaseModel):
60
+ solutions: list[SolutionSearchResult]