Lucas ARRIESSE commited on
Commit
41c1aed
·
0 Parent(s):

Initial commit

Browse files
Files changed (7) hide show
  1. .gitignore +2 -0
  2. Dockerfile +10 -0
  3. README.md +9 -0
  4. app.py +88 -0
  5. prompts/classify.txt +23 -0
  6. requirements.txt +51 -0
  7. schemas.py +44 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv
2
+ __pycache__
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY . .
6
+
7
+ EXPOSE 8000
8
+ RUN pip3 install -r requirements.txt
9
+
10
+ ENTRYPOINT ["python", "app.py"]
README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Reqroup
3
+ emoji: 🤖
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: docker
7
+ pinned: false
8
+ short_description: Categorize service requirements into groups (using AI)
9
+ ---
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ import uvicorn
5
+ from fastapi import FastAPI
6
+ from schemas import RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput
7
+ from jinja2 import Environment, FileSystemLoader
8
+ from litellm import acompletion
9
+ from dotenv import load_dotenv
10
+
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s',
14
+ datefmt='%Y-%m-%d %H:%M:%S'
15
+ )
16
+
17
+ # Load .env files
18
+ load_dotenv()
19
+
20
+ if "LLM_API_MODEL" not in os.environ or "LLM_API_KEY" not in os.environ:
21
+ logging.error(
22
+ "No LLM token (`LLM_API_TOKEN`) and/or LLM model (`LLM_API_KEY`) were provided in the env vars. Exiting")
23
+ sys.exit(-1)
24
+
25
+ LLM_API_MODEL = os.environ.get("LLM_API_MODEL")
26
+ LLM_API_KEY = os.environ.get("LLM_API_KEY")
27
+
28
+ # Jinja2 environment to load prompt templates
29
+ prompt_env = Environment(loader=FileSystemLoader('prompts'), enable_async=True)
30
+
31
+ fastapi = FastAPI()
32
+
33
+
34
+ @fastapi.post("/categorize_requirements")
35
+ async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
36
+ """Categorize the given service requirements into categories"""
37
+
38
+ MAX_ATTEMPTS = 5
39
+
40
+ categories: list[_ReqGroupingCategory] = []
41
+ messages = []
42
+
43
+ # categorize the requirements using their indices
44
+ req_prompt = await prompt_env.get_template("classify.txt").render_async(**{
45
+ "requirements": [rq.model_dump() for rq in params.requirements],
46
+ "max_n_categories": params.max_n_categories,
47
+ "response_schema": _ReqGroupingOutput.model_json_schema()})
48
+
49
+ # add system prompt with requirements
50
+ messages.append({"role": "user", "content": req_prompt})
51
+
52
+ # ensure all requirements items are processed
53
+ for attempt in range(MAX_ATTEMPTS):
54
+ req_completion = await acompletion(model=LLM_API_MODEL, api_key=LLM_API_KEY, messages=messages, response_format=_ReqGroupingOutput)
55
+ output = _ReqGroupingOutput.model_validate_json(
56
+ req_completion.choices[0].message.content)
57
+
58
+ # quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
59
+ assigned_ids = {
60
+ req_id for cat in output.categories for req_id in cat.items}
61
+ unassigned_ids = set(range(1, len(params.requirements))) - assigned_ids
62
+
63
+ if len(unassigned_ids) == 0:
64
+ categories.extend(output.categories)
65
+ break
66
+ else:
67
+ messages.append(req_completion.choices[0].message)
68
+ messages.append(
69
+ {"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."})
70
+
71
+ if attempt == MAX_ATTEMPTS - 1:
72
+ raise Exception("Failed to classify all requirements")
73
+
74
+ # build the final category objects
75
+ # remove the invalid (likely hallucinated) requirement IDs
76
+ final_categories = []
77
+ for idx, cat in enumerate(output.categories):
78
+ final_categories.append(ReqGroupingCategory(
79
+ id=idx,
80
+ title=cat.title,
81
+ requirements=[params.requirements[i]
82
+ for i in cat.items if i < len(params.requirements)]
83
+ ))
84
+
85
+ return ReqGroupingResponse(categories=final_categories)
86
+
87
+
88
+ uvicorn.run(fastapi, host="0.0.0.0", port=8000)
prompts/classify.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <role>You are an useful assistant who excels at categorizing technical extracted requirements</role>
2
+ <task>You are tasked with classifying each element of a list of technical requirements into categories which you may arbitrarily define.
3
+ For each category indicate which requirements belong in that category using their ID. An item may appear in one category at a time.
4
+ Please make each category title indicative of whats in it.
5
+ </task>
6
+ {% if max_n_categories is none %}
7
+ <number_of_categories>You may have at most as much categories as you think is needed</number_of_categories>
8
+ {% else %}
9
+ <number_of_categories>You may have at most {{max_n_categories}} categories</number_of_categories>
10
+ {%endif%}
11
+
12
+ Here are the requirements:
13
+ <requirements>
14
+ {% for req in requirements %}
15
+ - {{ loop.index }}. {{ req["requirement"] }}
16
+ {% endfor %}
17
+ </requirements>
18
+
19
+
20
+ <response_format>
21
+ Reply in JSON using the following format:
22
+ {{response_schema}}
23
+ </response_format>
requirements.txt ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.6.1
2
+ aiohttp==3.12.13
3
+ aiosignal==1.3.2
4
+ annotated-types==0.7.0
5
+ anyio==4.9.0
6
+ attrs==25.3.0
7
+ certifi==2025.6.15
8
+ charset-normalizer==3.4.2
9
+ click==8.2.1
10
+ distro==1.9.0
11
+ dotenv==0.9.9
12
+ fastapi==0.115.12
13
+ filelock==3.18.0
14
+ frozenlist==1.7.0
15
+ fsspec==2025.5.1
16
+ h11==0.16.0
17
+ hf-xet==1.1.4
18
+ httpcore==1.0.9
19
+ httpx==0.28.1
20
+ huggingface-hub==0.33.0
21
+ idna==3.10
22
+ importlib_metadata==8.7.0
23
+ Jinja2==3.1.6
24
+ jiter==0.10.0
25
+ jsonschema==4.24.0
26
+ jsonschema-specifications==2025.4.1
27
+ litellm==1.72.6
28
+ MarkupSafe==3.0.2
29
+ multidict==6.4.4
30
+ openai==1.88.0
31
+ packaging==25.0
32
+ propcache==0.3.2
33
+ pydantic==2.11.7
34
+ pydantic_core==2.33.2
35
+ python-dotenv==1.1.0
36
+ PyYAML==6.0.2
37
+ referencing==0.36.2
38
+ regex==2024.11.6
39
+ requests==2.32.4
40
+ rpds-py==0.25.1
41
+ sniffio==1.3.1
42
+ starlette==0.46.2
43
+ tiktoken==0.9.0
44
+ tokenizers==0.21.1
45
+ tqdm==4.67.1
46
+ typing-inspection==0.4.1
47
+ typing_extensions==4.14.0
48
+ urllib3==2.4.0
49
+ uvicorn==0.34.3
50
+ yarl==1.20.1
51
+ zipp==3.23.0
schemas.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List, Optional
3
+
4
+ # Shared model schemas
5
+
6
+ class RequirementInfo(BaseModel):
7
+ """Represents an extracted requirement info"""
8
+ context: str = Field(..., description="Context for the requirement.")
9
+ requirement: str = Field(..., description="The requirement itself.")
10
+ document: str = Field(...,
11
+ description="The document the requirement is extracted from.")
12
+
13
+
14
+ class ReqGroupingCategory(BaseModel):
15
+ """Represents the category of requirements grouped together"""
16
+ id: int = Field(..., description="ID of the grouping category")
17
+ title: str = Field(..., description="Title given to the grouping category")
18
+ requirements: List[RequirementInfo] = Field(
19
+ ..., description="List of grouped requirements")
20
+
21
+ # Endpoint model schemas
22
+
23
+ class ReqGroupingRequest(BaseModel):
24
+ """Request schema of a requirement grouping call."""
25
+ requirements: list[RequirementInfo]
26
+ max_n_categories: Optional[int] = Field(
27
+ default=None, description="Max number of categories to construct. Defaults to None")
28
+
29
+
30
+ class ReqGroupingResponse(BaseModel):
31
+ """Response of a requirement grouping call."""
32
+ categories: List[ReqGroupingCategory]
33
+
34
+
35
+ # INFO: keep in sync with prompt
36
+ class _ReqGroupingCategory(BaseModel):
37
+ title: str = Field(..., description="Title given to the grouping category")
38
+ items: list[int] = Field(
39
+ ..., description="List of the IDs of the requirements belonging to the category.")
40
+
41
+
42
+ class _ReqGroupingOutput(BaseModel):
43
+ categories: list[_ReqGroupingCategory] = Field(
44
+ ..., description="List of grouping categories")