File size: 7,947 Bytes
594f2fe
41c1aed
 
 
 
 
72683de
 
594f2fe
41c1aed
 
 
 
 
 
 
 
 
 
 
594f2fe
41c1aed
594f2fe
41c1aed
 
594f2fe
 
 
 
 
 
 
32239ae
 
 
 
 
594f2fe
 
32239ae
41c1aed
 
72683de
 
41c1aed
72683de
41c1aed
 
594f2fe
41c1aed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594f2fe
41c1aed
 
 
72683de
 
41c1aed
 
f6a7399
 
 
 
 
 
41c1aed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72683de
f6a7399
 
72683de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32239ae
72683de
 
32239ae
72683de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594f2fe
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import asyncio
import logging
import os
import sys
import uvicorn
from fastapi import FastAPI
from schemas import _SearchedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput, SolutionCriticism, SolutionModel, SolutionSearchResponse
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from litellm.router import Router
from dotenv import load_dotenv

logging.basicConfig(
    level=logging.INFO,
    format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

# Load .env files
load_dotenv()

if "LLM_MODEL" not in os.environ or "LLM_API_KEY" not in os.environ:
    logging.error(
        "No LLM token (`LLM_API_KEY`) and/or LLM model (`LLM_MODEL`) were provided in the env vars. Exiting")
    sys.exit(-1)

# LiteLLM router
llm_router = Router(model_list=[
    {
        "model_name": "chat",
        "litellm_params": {
            "model": os.environ.get("LLM_MODEL"),
            "api_key": os.environ.get("LLM_API_KEY"),
            "rpm": 15,
            "max_parallel_requests": 4,
            "allowed_fails": 1,
            "cooldown_time": 60,
            "max_retries": 10,
        }
    }
], num_retries=10, retry_after=30)

# Jinja2 environment to load prompt templates
prompt_env = Environment(loader=FileSystemLoader(
    'prompts'), enable_async=True, undefined=StrictUndefined)

api = FastAPI(docs_url="/")


@api.post("/categorize_requirements")
async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
    """Categorize the given service requirements into categories"""

    MAX_ATTEMPTS = 5

    categories: list[_ReqGroupingCategory] = []
    messages = []

    # categorize the requirements using their indices
    req_prompt = await prompt_env.get_template("classify.txt").render_async(**{
        "requirements": [rq.model_dump() for rq in params.requirements],
        "max_n_categories": params.max_n_categories,
        "response_schema": _ReqGroupingOutput.model_json_schema()})

    # add system prompt with requirements
    messages.append({"role": "user", "content": req_prompt})

    # ensure all requirements items are processed
    for attempt in range(MAX_ATTEMPTS):
        req_completion = await llm_router.acompletion(model="chat", messages=messages, response_format=_ReqGroupingOutput)
        output = _ReqGroupingOutput.model_validate_json(
            req_completion.choices[0].message.content)

        # quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
        valid_ids_universe = set(range(0, len(params.requirements)))
        assigned_ids = {
            req_id for cat in output.categories for req_id in cat.items}

        # keep only non-hallucinated, valid assigned ids
        valid_assigned_ids = assigned_ids.intersection(valid_ids_universe)

        # check for remaining requirements assigned to none of the categories
        unassigned_ids = valid_ids_universe - valid_assigned_ids

        if len(unassigned_ids) == 0:
            categories.extend(output.categories)
            break
        else:
            messages.append(req_completion.choices[0].message)
            messages.append(
                {"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."})

            if attempt == MAX_ATTEMPTS - 1:
                raise Exception("Failed to classify all requirements")

    # build the final category objects
    # remove the invalid (likely hallucinated) requirement IDs
    final_categories = []
    for idx, cat in enumerate(output.categories):
        final_categories.append(ReqGroupingCategory(
            id=idx,
            title=cat.title,
            requirements=[params.requirements[i]
                          for i in cat.items if i < len(params.requirements)]
        ))

    return ReqGroupingResponse(categories=final_categories)


@api.post("/criticize_solution", response_model=CritiqueResponse)
async def criticize_solution(params: CriticizeSolutionsRequest) -> CritiqueResponse:
    """Criticize the challenges, weaknesses and limitations of the provided solutions."""

    async def __criticize_single(solution: SolutionModel):
        req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{
            "solutions": [solution.model_dump()],
            "response_schema": _SolutionCriticismOutput.model_json_schema()
        })

        req_completion = await llm_router.acompletion(
            model="chat",
            messages=[{"role": "user", "content": req_prompt}],
            response_format=_SolutionCriticismOutput
        )

        criticism_out = _SolutionCriticismOutput.model_validate_json(
            req_completion.choices[0].message.content
        )

        return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0])

    critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False)
    return CritiqueResponse(critiques=critiques)


@api.post("/search_solutions_gemini", response_model=SolutionSearchResponse)
async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchResponse:
    """Searches solutions using Gemini and grounded on google search"""

    async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel:
        # ================== generate the solution with web grounding
        req_prompt = await prompt_env.get_template("search_solution.txt").render_async(**{
            "category": cat.model_dump(),
        })

        # generate the completion in non-structured mode.
        # the googleSearch tool enables grounding gemini with google search
        # this also forces gemini to perform a tool call
        req_completion = await llm_router.acompletion(model="chat", messages=[
            {"role": "user", "content": req_prompt}
        ], tools=[{"googleSearch": {}}], tool_choice="required")

        # ==================== structure the solution as a json ===================================

        structured_prompt = await prompt_env.get_template("structure_solution.txt").render_async(**{
            "solution": req_completion.choices[0].message.content,
            "response_schema": _SearchedSolutionModel.model_json_schema()
        })

        structured_completion = await llm_router.acompletion(model="chat", messages=[
            {"role": "user", "content": structured_prompt}
        ], response_format=_SearchedSolutionModel)
        solution_model = _SearchedSolutionModel.model_validate_json(
            structured_completion.choices[0].message.content)

        # ======================== build the final solution object ================================

        # extract the source metadata from the search items
        sources_metadata = [
            f'{a["web"]["title"]} - {a["web"]["uri"]}' for a in req_completion["vertex_ai_grounding_metadata"][0]['groundingChunks']]

        final_sol = SolutionModel(
            Context="",
            Requirements=[
                cat.requirements[i].requirement for i in solution_model.requirement_ids
            ],
            Problem_Description=solution_model.problem_description,
            Solution_Description=solution_model.solution_description,
            References=sources_metadata,
            Category_Id=cat.id,
        )
        return final_sol

    solutions = await asyncio.gather(*[_search_inner(cat) for cat in params.categories], return_exceptions=True)
    logging.info(solutions)
    final_solutions = [
        sol for sol in solutions if not isinstance(sol, Exception)]

    return SolutionSearchResponse(solutions=final_solutions)


uvicorn.run(api, host="0.0.0.0", port=8000)