File size: 4,028 Bytes
594f2fe
41c1aed
 
 
 
 
594f2fe
41c1aed
 
594f2fe
41c1aed
 
 
 
 
 
 
 
 
 
 
594f2fe
41c1aed
594f2fe
41c1aed
 
594f2fe
 
 
 
 
 
 
 
 
 
 
41c1aed
 
 
 
594f2fe
41c1aed
 
594f2fe
41c1aed
 
 
 
 
 
 
 
 
 
 
 
 
 
594f2fe
 
41c1aed
 
 
 
 
594f2fe
41c1aed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594f2fe
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import asyncio
import logging
import os
import sys
import uvicorn
from fastapi import FastAPI
from schemas import CriticizeSolutionsRequest, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput
from jinja2 import Environment, FileSystemLoader
from litellm import acompletion
from litellm.router import Router
from dotenv import load_dotenv

logging.basicConfig(
    level=logging.INFO,
    format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

# Load .env files
load_dotenv()

if "LLM_MODEL" not in os.environ or "LLM_API_KEY" not in os.environ:
    logging.error(
        "No LLM token (`LLM_API_KEY`) and/or LLM model (`LLM_MODEL`) were provided in the env vars. Exiting")
    sys.exit(-1)

# LiteLLM router
llm_router = Router(model_list=[
    {
        "model_name": "chat",
        "litellm_params": {
            "model": os.environ.get("LLM_MODEL"),
            "api_key": os.environ.get("LLM_API_KEY"),
            "max_retries": 5
        }
    }
])

# Jinja2 environment to load prompt templates
prompt_env = Environment(loader=FileSystemLoader('prompts'), enable_async=True)

api = FastAPI()


@api.post("/categorize_requirements")
async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
    """Categorize the given service requirements into categories"""

    MAX_ATTEMPTS = 5

    categories: list[_ReqGroupingCategory] = []
    messages = []

    # categorize the requirements using their indices
    req_prompt = await prompt_env.get_template("classify.txt").render_async(**{
        "requirements": [rq.model_dump() for rq in params.requirements],
        "max_n_categories": params.max_n_categories,
        "response_schema": _ReqGroupingOutput.model_json_schema()})

    logging.info(req_prompt)

    # add system prompt with requirements
    messages.append({"role": "user", "content": req_prompt})

    # ensure all requirements items are processed
    for attempt in range(MAX_ATTEMPTS):
        req_completion = await llm_router.acompletion(model="chat", messages=messages, response_format=_ReqGroupingOutput)
        output = _ReqGroupingOutput.model_validate_json(
            req_completion.choices[0].message.content)

        # quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
        assigned_ids = {
            req_id for cat in output.categories for req_id in cat.items}
        unassigned_ids = set(range(1, len(params.requirements))) - assigned_ids

        if len(unassigned_ids) == 0:
            categories.extend(output.categories)
            break
        else:
            messages.append(req_completion.choices[0].message)
            messages.append(
                {"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."})

            if attempt == MAX_ATTEMPTS - 1:
                raise Exception("Failed to classify all requirements")

    # build the final category objects
    # remove the invalid (likely hallucinated) requirement IDs
    final_categories = []
    for idx, cat in enumerate(output.categories):
        final_categories.append(ReqGroupingCategory(
            id=idx,
            title=cat.title,
            requirements=[params.requirements[i]
                          for i in cat.items if i < len(params.requirements)]
        ))

    return ReqGroupingResponse(categories=final_categories)


@api.post("/criticize_solution")
async def criticize_solution(params: CriticizeSolutionsRequest) -> str:
    req_prompt = await prompt_env.get_template("criticize.txt").render_async(solutions=[sol.model_dump() for sol in params.solutions])
    req_completion = await llm_router.acompletion(model="chat", messages=[{"role": "user", "content": req_prompt}])

    return req_completion.choices[0].message.content


uvicorn.run(api, host="0.0.0.0", port=8000)