Lucas ARRIESSE
commited on
Commit
·
41c1aed
0
Parent(s):
Initial commit
Browse files- .gitignore +2 -0
- Dockerfile +10 -0
- README.md +9 -0
- app.py +88 -0
- prompts/classify.txt +23 -0
- requirements.txt +51 -0
- schemas.py +44 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.venv
|
2 |
+
__pycache__
|
Dockerfile
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY . .
|
6 |
+
|
7 |
+
EXPOSE 8000
|
8 |
+
RUN pip3 install -r requirements.txt
|
9 |
+
|
10 |
+
ENTRYPOINT ["python", "app.py"]
|
README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Reqroup
|
3 |
+
emoji: 🤖
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: blue
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
short_description: Categorize service requirements into groups (using AI)
|
9 |
+
---
|
app.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import uvicorn
|
5 |
+
from fastapi import FastAPI
|
6 |
+
from schemas import RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput
|
7 |
+
from jinja2 import Environment, FileSystemLoader
|
8 |
+
from litellm import acompletion
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
|
11 |
+
logging.basicConfig(
|
12 |
+
level=logging.INFO,
|
13 |
+
format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s',
|
14 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
15 |
+
)
|
16 |
+
|
17 |
+
# Load .env files
|
18 |
+
load_dotenv()
|
19 |
+
|
20 |
+
if "LLM_API_MODEL" not in os.environ or "LLM_API_KEY" not in os.environ:
|
21 |
+
logging.error(
|
22 |
+
"No LLM token (`LLM_API_TOKEN`) and/or LLM model (`LLM_API_KEY`) were provided in the env vars. Exiting")
|
23 |
+
sys.exit(-1)
|
24 |
+
|
25 |
+
LLM_API_MODEL = os.environ.get("LLM_API_MODEL")
|
26 |
+
LLM_API_KEY = os.environ.get("LLM_API_KEY")
|
27 |
+
|
28 |
+
# Jinja2 environment to load prompt templates
|
29 |
+
prompt_env = Environment(loader=FileSystemLoader('prompts'), enable_async=True)
|
30 |
+
|
31 |
+
fastapi = FastAPI()
|
32 |
+
|
33 |
+
|
34 |
+
@fastapi.post("/categorize_requirements")
|
35 |
+
async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
|
36 |
+
"""Categorize the given service requirements into categories"""
|
37 |
+
|
38 |
+
MAX_ATTEMPTS = 5
|
39 |
+
|
40 |
+
categories: list[_ReqGroupingCategory] = []
|
41 |
+
messages = []
|
42 |
+
|
43 |
+
# categorize the requirements using their indices
|
44 |
+
req_prompt = await prompt_env.get_template("classify.txt").render_async(**{
|
45 |
+
"requirements": [rq.model_dump() for rq in params.requirements],
|
46 |
+
"max_n_categories": params.max_n_categories,
|
47 |
+
"response_schema": _ReqGroupingOutput.model_json_schema()})
|
48 |
+
|
49 |
+
# add system prompt with requirements
|
50 |
+
messages.append({"role": "user", "content": req_prompt})
|
51 |
+
|
52 |
+
# ensure all requirements items are processed
|
53 |
+
for attempt in range(MAX_ATTEMPTS):
|
54 |
+
req_completion = await acompletion(model=LLM_API_MODEL, api_key=LLM_API_KEY, messages=messages, response_format=_ReqGroupingOutput)
|
55 |
+
output = _ReqGroupingOutput.model_validate_json(
|
56 |
+
req_completion.choices[0].message.content)
|
57 |
+
|
58 |
+
# quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
|
59 |
+
assigned_ids = {
|
60 |
+
req_id for cat in output.categories for req_id in cat.items}
|
61 |
+
unassigned_ids = set(range(1, len(params.requirements))) - assigned_ids
|
62 |
+
|
63 |
+
if len(unassigned_ids) == 0:
|
64 |
+
categories.extend(output.categories)
|
65 |
+
break
|
66 |
+
else:
|
67 |
+
messages.append(req_completion.choices[0].message)
|
68 |
+
messages.append(
|
69 |
+
{"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."})
|
70 |
+
|
71 |
+
if attempt == MAX_ATTEMPTS - 1:
|
72 |
+
raise Exception("Failed to classify all requirements")
|
73 |
+
|
74 |
+
# build the final category objects
|
75 |
+
# remove the invalid (likely hallucinated) requirement IDs
|
76 |
+
final_categories = []
|
77 |
+
for idx, cat in enumerate(output.categories):
|
78 |
+
final_categories.append(ReqGroupingCategory(
|
79 |
+
id=idx,
|
80 |
+
title=cat.title,
|
81 |
+
requirements=[params.requirements[i]
|
82 |
+
for i in cat.items if i < len(params.requirements)]
|
83 |
+
))
|
84 |
+
|
85 |
+
return ReqGroupingResponse(categories=final_categories)
|
86 |
+
|
87 |
+
|
88 |
+
uvicorn.run(fastapi, host="0.0.0.0", port=8000)
|
prompts/classify.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<role>You are an useful assistant who excels at categorizing technical extracted requirements</role>
|
2 |
+
<task>You are tasked with classifying each element of a list of technical requirements into categories which you may arbitrarily define.
|
3 |
+
For each category indicate which requirements belong in that category using their ID. An item may appear in one category at a time.
|
4 |
+
Please make each category title indicative of whats in it.
|
5 |
+
</task>
|
6 |
+
{% if max_n_categories is none %}
|
7 |
+
<number_of_categories>You may have at most as much categories as you think is needed</number_of_categories>
|
8 |
+
{% else %}
|
9 |
+
<number_of_categories>You may have at most {{max_n_categories}} categories</number_of_categories>
|
10 |
+
{%endif%}
|
11 |
+
|
12 |
+
Here are the requirements:
|
13 |
+
<requirements>
|
14 |
+
{% for req in requirements %}
|
15 |
+
- {{ loop.index }}. {{ req["requirement"] }}
|
16 |
+
{% endfor %}
|
17 |
+
</requirements>
|
18 |
+
|
19 |
+
|
20 |
+
<response_format>
|
21 |
+
Reply in JSON using the following format:
|
22 |
+
{{response_schema}}
|
23 |
+
</response_format>
|
requirements.txt
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohappyeyeballs==2.6.1
|
2 |
+
aiohttp==3.12.13
|
3 |
+
aiosignal==1.3.2
|
4 |
+
annotated-types==0.7.0
|
5 |
+
anyio==4.9.0
|
6 |
+
attrs==25.3.0
|
7 |
+
certifi==2025.6.15
|
8 |
+
charset-normalizer==3.4.2
|
9 |
+
click==8.2.1
|
10 |
+
distro==1.9.0
|
11 |
+
dotenv==0.9.9
|
12 |
+
fastapi==0.115.12
|
13 |
+
filelock==3.18.0
|
14 |
+
frozenlist==1.7.0
|
15 |
+
fsspec==2025.5.1
|
16 |
+
h11==0.16.0
|
17 |
+
hf-xet==1.1.4
|
18 |
+
httpcore==1.0.9
|
19 |
+
httpx==0.28.1
|
20 |
+
huggingface-hub==0.33.0
|
21 |
+
idna==3.10
|
22 |
+
importlib_metadata==8.7.0
|
23 |
+
Jinja2==3.1.6
|
24 |
+
jiter==0.10.0
|
25 |
+
jsonschema==4.24.0
|
26 |
+
jsonschema-specifications==2025.4.1
|
27 |
+
litellm==1.72.6
|
28 |
+
MarkupSafe==3.0.2
|
29 |
+
multidict==6.4.4
|
30 |
+
openai==1.88.0
|
31 |
+
packaging==25.0
|
32 |
+
propcache==0.3.2
|
33 |
+
pydantic==2.11.7
|
34 |
+
pydantic_core==2.33.2
|
35 |
+
python-dotenv==1.1.0
|
36 |
+
PyYAML==6.0.2
|
37 |
+
referencing==0.36.2
|
38 |
+
regex==2024.11.6
|
39 |
+
requests==2.32.4
|
40 |
+
rpds-py==0.25.1
|
41 |
+
sniffio==1.3.1
|
42 |
+
starlette==0.46.2
|
43 |
+
tiktoken==0.9.0
|
44 |
+
tokenizers==0.21.1
|
45 |
+
tqdm==4.67.1
|
46 |
+
typing-inspection==0.4.1
|
47 |
+
typing_extensions==4.14.0
|
48 |
+
urllib3==2.4.0
|
49 |
+
uvicorn==0.34.3
|
50 |
+
yarl==1.20.1
|
51 |
+
zipp==3.23.0
|
schemas.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
# Shared model schemas
|
5 |
+
|
6 |
+
class RequirementInfo(BaseModel):
|
7 |
+
"""Represents an extracted requirement info"""
|
8 |
+
context: str = Field(..., description="Context for the requirement.")
|
9 |
+
requirement: str = Field(..., description="The requirement itself.")
|
10 |
+
document: str = Field(...,
|
11 |
+
description="The document the requirement is extracted from.")
|
12 |
+
|
13 |
+
|
14 |
+
class ReqGroupingCategory(BaseModel):
|
15 |
+
"""Represents the category of requirements grouped together"""
|
16 |
+
id: int = Field(..., description="ID of the grouping category")
|
17 |
+
title: str = Field(..., description="Title given to the grouping category")
|
18 |
+
requirements: List[RequirementInfo] = Field(
|
19 |
+
..., description="List of grouped requirements")
|
20 |
+
|
21 |
+
# Endpoint model schemas
|
22 |
+
|
23 |
+
class ReqGroupingRequest(BaseModel):
|
24 |
+
"""Request schema of a requirement grouping call."""
|
25 |
+
requirements: list[RequirementInfo]
|
26 |
+
max_n_categories: Optional[int] = Field(
|
27 |
+
default=None, description="Max number of categories to construct. Defaults to None")
|
28 |
+
|
29 |
+
|
30 |
+
class ReqGroupingResponse(BaseModel):
|
31 |
+
"""Response of a requirement grouping call."""
|
32 |
+
categories: List[ReqGroupingCategory]
|
33 |
+
|
34 |
+
|
35 |
+
# INFO: keep in sync with prompt
|
36 |
+
class _ReqGroupingCategory(BaseModel):
|
37 |
+
title: str = Field(..., description="Title given to the grouping category")
|
38 |
+
items: list[int] = Field(
|
39 |
+
..., description="List of the IDs of the requirements belonging to the category.")
|
40 |
+
|
41 |
+
|
42 |
+
class _ReqGroupingOutput(BaseModel):
|
43 |
+
categories: list[_ReqGroupingCategory] = Field(
|
44 |
+
..., description="List of grouping categories")
|