File size: 16,939 Bytes
594f2fe
d4b51a7
41c1aed
 
 
 
03acc5b
36dc4ec
72683de
594f2fe
41c1aed
9e95c26
36dc4ec
41c1aed
 
 
 
 
 
 
 
 
 
594f2fe
41c1aed
594f2fe
41c1aed
 
594f2fe
 
 
 
 
 
 
32239ae
 
 
 
 
594f2fe
 
32239ae
41c1aed
36dc4ec
 
 
 
 
41c1aed
72683de
 
41c1aed
e0c1af3
 
 
03acc5b
 
e0c1af3
03acc5b
 
41c1aed
 
36dc4ec
 
 
 
 
03acc5b
41c1aed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594f2fe
41c1aed
 
 
72683de
 
41c1aed
 
f6a7399
 
 
 
 
 
41c1aed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03acc5b
41c1aed
8cc6fe4
03acc5b
72683de
c1e5d8a
72683de
9e95c26
 
72683de
 
 
 
 
 
 
 
32239ae
72683de
 
32239ae
72683de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0c1af3
 
 
 
 
 
 
72683de
 
 
 
 
 
 
 
 
 
 
 
 
9e95c26
72683de
 
 
 
 
 
8cc6fe4
2334311
 
 
 
ae51a9d
2334311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03acc5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36dc4ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4b51a7
 
 
 
36dc4ec
d4b51a7
 
36dc4ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4b51a7
36dc4ec
 
 
 
 
03acc5b
 
594f2fe
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
import asyncio
import json
import logging
import os
import sys
import uvicorn
from fastapi import APIRouter, FastAPI
from schemas import _RefinedSolutionModel, _SearchedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, InsightFinderConstraintsList, RequirementInfo, ReqGroupingCategory, ReqGroupingResponse, ReqGroupingRequest, _ReqGroupingCategory, _ReqGroupingOutput, SolutionCriticism, SolutionModel, SolutionSearchResponse, SolutionSearchV2Request, TechnologyData
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from litellm.router import Router
from dotenv import load_dotenv
from util import retry_until
from httpx import AsyncClient

logging.basicConfig(
    level=logging.INFO,
    format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

# Load .env files
load_dotenv()

if "LLM_MODEL" not in os.environ or "LLM_API_KEY" not in os.environ:
    logging.error(
        "No LLM token (`LLM_API_KEY`) and/or LLM model (`LLM_MODEL`) were provided in the env vars. Exiting")
    sys.exit(-1)

# LiteLLM router
llm_router = Router(model_list=[
    {
        "model_name": "chat",
        "litellm_params": {
            "model": os.environ.get("LLM_MODEL"),
            "api_key": os.environ.get("LLM_API_KEY"),
            "rpm": 15,
            "max_parallel_requests": 4,
            "allowed_fails": 1,
            "cooldown_time": 60,
            "max_retries": 10,
        }
    }
], num_retries=10, retry_after=30)

# HTTP client
INSIGHT_FINDER_BASE_URL = "https://organizedprogrammers-insight-finder.hf.space/"
http_client = AsyncClient(verify=os.environ.get(
    "NO_SSL", "0") == "1", timeout=None)

# Jinja2 environment to load prompt templates
prompt_env = Environment(loader=FileSystemLoader(
    'prompts'), enable_async=True, undefined=StrictUndefined)

api = FastAPI(docs_url="/", title="Reqxtract-API",
              description=open("docs/docs.md").read())

# requirements routes
requirements_router = APIRouter(prefix="/reqs", tags=["requirements"])

# solution routes
solution_router = APIRouter(prefix="/solution", tags=["solution"])


async def format_prompt(prompt_name: str, **args) -> str:
    """Helper to format a prompt"""
    return await prompt_env.get_template(prompt_name).render_async(args)


@requirements_router.post("/categorize_requirements")
async def categorize_reqs(params: ReqGroupingRequest) -> ReqGroupingResponse:
    """Categorize the given service requirements into categories"""

    MAX_ATTEMPTS = 5

    categories: list[_ReqGroupingCategory] = []
    messages = []

    # categorize the requirements using their indices
    req_prompt = await prompt_env.get_template("classify.txt").render_async(**{
        "requirements": [rq.model_dump() for rq in params.requirements],
        "max_n_categories": params.max_n_categories,
        "response_schema": _ReqGroupingOutput.model_json_schema()})

    # add system prompt with requirements
    messages.append({"role": "user", "content": req_prompt})

    # ensure all requirements items are processed
    for attempt in range(MAX_ATTEMPTS):
        req_completion = await llm_router.acompletion(model="chat", messages=messages, response_format=_ReqGroupingOutput)
        output = _ReqGroupingOutput.model_validate_json(
            req_completion.choices[0].message.content)

        # quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
        valid_ids_universe = set(range(0, len(params.requirements)))
        assigned_ids = {
            req_id for cat in output.categories for req_id in cat.items}

        # keep only non-hallucinated, valid assigned ids
        valid_assigned_ids = assigned_ids.intersection(valid_ids_universe)

        # check for remaining requirements assigned to none of the categories
        unassigned_ids = valid_ids_universe - valid_assigned_ids

        if len(unassigned_ids) == 0:
            categories.extend(output.categories)
            break
        else:
            messages.append(req_completion.choices[0].message)
            messages.append(
                {"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."})

            if attempt == MAX_ATTEMPTS - 1:
                raise Exception("Failed to classify all requirements")

    # build the final category objects
    # remove the invalid (likely hallucinated) requirement IDs
    final_categories = []
    for idx, cat in enumerate(output.categories):
        final_categories.append(ReqGroupingCategory(
            id=idx,
            title=cat.title,
            requirements=[params.requirements[i]
                          for i in cat.items if i < len(params.requirements)]
        ))

    return ReqGroupingResponse(categories=final_categories)

# ========================================================= Solution Endpoints ===========================================================


@solution_router.post("/search_solutions_gemini", response_model=SolutionSearchResponse)
async def search_solutions(params: ReqGroupingResponse) -> SolutionSearchResponse:
    """Searches solutions solving the given grouping params using Gemini and grounded on google search"""

    logging.info(f"Searching solutions for categories: {params.categories}")

    async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel:
        # ================== generate the solution with web grounding
        req_prompt = await prompt_env.get_template("search_solution.txt").render_async(**{
            "category": cat.model_dump(),
        })

        # generate the completion in non-structured mode.
        # the googleSearch tool enables grounding gemini with google search
        # this also forces gemini to perform a tool call
        req_completion = await llm_router.acompletion(model="chat", messages=[
            {"role": "user", "content": req_prompt}
        ], tools=[{"googleSearch": {}}], tool_choice="required")

        # ==================== structure the solution as a json ===================================

        structured_prompt = await prompt_env.get_template("structure_solution.txt").render_async(**{
            "solution": req_completion.choices[0].message.content,
            "response_schema": _SearchedSolutionModel.model_json_schema()
        })

        structured_completion = await llm_router.acompletion(model="chat", messages=[
            {"role": "user", "content": structured_prompt}
        ], response_format=_SearchedSolutionModel)
        solution_model = _SearchedSolutionModel.model_validate_json(
            structured_completion.choices[0].message.content)

        # ======================== build the final solution object ================================

        sources_metadata = []
        # extract the source metadata from the search items, if gemini actually called the tools to search .... and didn't hallucinated
        try:
            sources_metadata.extend([{"name": a["web"]["title"], "url": a["web"]["uri"]}
                                     for a in req_completion["vertex_ai_grounding_metadata"][0]['groundingChunks']])
        except KeyError as ke:
            pass

        final_sol = SolutionModel(
            Context="",
            Requirements=[
                cat.requirements[i].requirement for i in solution_model.requirement_ids
            ],
            Problem_Description=solution_model.problem_description,
            Solution_Description=solution_model.solution_description,
            References=sources_metadata,
            Category_Id=cat.id,
        )
        return final_sol

    solutions = await asyncio.gather(*[retry_until(_search_inner, cat, lambda v: len(v.References) > 0, 2) for cat in params.categories], return_exceptions=True)
    logging.info(solutions)
    final_solutions = [
        sol for sol in solutions if not isinstance(sol, Exception)]

    return SolutionSearchResponse(solutions=final_solutions)


@solution_router.post("/search_solutions_gemini/v2", response_model=SolutionSearchResponse)
async def search_solutions(params: SolutionSearchV2Request) -> SolutionSearchResponse:
    """Searches solutions solving the given grouping params and respecting the user constraints using Gemini and grounded on google search"""

    logging.info(f"Searching solutions for categories: {params}")

    async def _search_inner(cat: ReqGroupingCategory) -> SolutionModel:
        # ================== generate the solution with web grounding
        req_prompt = await prompt_env.get_template("search_solution_v2.txt").render_async(**{
            "category": cat.model_dump(),
            "user_constraints": params.user_constraints,
        })

        # generate the completion in non-structured mode.
        # the googleSearch tool enables grounding gemini with google search
        # this also forces gemini to perform a tool call
        req_completion = await llm_router.acompletion(model="chat", messages=[
            {"role": "user", "content": req_prompt}
        ], tools=[{"googleSearch": {}}], tool_choice="required")

        # ==================== structure the solution as a json ===================================

        structured_prompt = await prompt_env.get_template("structure_solution.txt").render_async(**{
            "solution": req_completion.choices[0].message.content,
            "response_schema": _SearchedSolutionModel.model_json_schema()
        })

        structured_completion = await llm_router.acompletion(model="chat", messages=[
            {"role": "user", "content": structured_prompt}
        ], response_format=_SearchedSolutionModel)
        solution_model = _SearchedSolutionModel.model_validate_json(
            structured_completion.choices[0].message.content)

        # ======================== build the final solution object ================================

        sources_metadata = []
        # extract the source metadata from the search items, if gemini actually called the tools to search .... and didn't hallucinated
        try:
            sources_metadata.extend([{"name": a["web"]["title"], "url": a["web"]["uri"]}
                                     for a in req_completion["vertex_ai_grounding_metadata"][0]['groundingChunks']])
        except KeyError as ke:
            pass

        final_sol = SolutionModel(
            Context="",
            Requirements=[
                cat.requirements[i].requirement for i in solution_model.requirement_ids
            ],
            Problem_Description=solution_model.problem_description,
            Solution_Description=solution_model.solution_description,
            References=sources_metadata,
            Category_Id=cat.id,
        )
        return final_sol

    solutions = await asyncio.gather(*[retry_until(_search_inner, cat, lambda v: len(v.References) > 0, 2) for cat in params.categories], return_exceptions=True)
    logging.info(solutions)
    final_solutions = [
        sol for sol in solutions if not isinstance(sol, Exception)]

    return SolutionSearchResponse(solutions=final_solutions)


# =================================================================================================================


@solution_router.post("/criticize_solution", response_model=CritiqueResponse)
async def criticize_solution(params: CriticizeSolutionsRequest) -> CritiqueResponse:
    """Criticize the challenges, weaknesses and limitations of the provided solutions."""

    async def __criticize_single(solution: SolutionModel):
        req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{
            "solutions": [solution.model_dump()],
            "response_schema": _SolutionCriticismOutput.model_json_schema()
        })

        req_completion = await llm_router.acompletion(
            model="chat",
            messages=[{"role": "user", "content": req_prompt}],
            response_format=_SolutionCriticismOutput
        )

        criticism_out = _SolutionCriticismOutput.model_validate_json(
            req_completion.choices[0].message.content
        )

        return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0])

    critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False)
    return CritiqueResponse(critiques=critiques)


@solution_router.post("/refine_solutions", response_model=SolutionSearchResponse)
async def refine_solutions(params: CritiqueResponse) -> SolutionSearchResponse:
    """Refines the previously critiqued solutions."""

    async def __refine_solution(crit: SolutionCriticism):
        req_prompt = await prompt_env.get_template("refine_solution.txt").render_async(**{
            "solution": crit.solution.model_dump(),
            "criticism": crit.criticism,
            "response_schema": _RefinedSolutionModel.model_json_schema(),
        })

        req_completion = await llm_router.acompletion(model="chat", messages=[
            {"role": "user", "content": req_prompt}
        ], response_format=_RefinedSolutionModel)

        req_model = _RefinedSolutionModel.model_validate_json(
            req_completion.choices[0].message.content)

        # copy previous solution model
        refined_solution = crit.solution.model_copy(deep=True)
        refined_solution.Problem_Description = req_model.problem_description
        refined_solution.Solution_Description = req_model.solution_description

        return refined_solution

    refined_solutions = await asyncio.gather(*[__refine_solution(crit) for crit in params.critiques], return_exceptions=False)

    return SolutionSearchResponse(solutions=refined_solutions)


# ======================================== Solution generation using Insights Finder ==================

@solution_router.post("/search_solutions_if")
async def search_solutions_if(req: SolutionSearchV2Request) -> SolutionSearchResponse:

    async def _search_solution_inner(cat: ReqGroupingCategory):
        # process requirements into insight finder format
        fmt_completion = await llm_router.acompletion("chat", messages=[
            {
                "role": "user",
                "content": await format_prompt("if/format_requirements.txt", **{
                    "category": cat.model_dump(),
                    "response_schema": InsightFinderConstraintsList.model_json_schema()
                })
            }], response_format=InsightFinderConstraintsList)

        fmt_model = InsightFinderConstraintsList.model_validate_json(
            fmt_completion.choices[0].message.content)

        out = {'constraints': {
            cons.title: cons.description for cons in fmt_model.constraints}}
        # logging.info(out)

        # fetch technologies from insight finder
        # translate from a structured output to a dict for insights finder
        technologies_req = await http_client.post(INSIGHT_FINDER_BASE_URL + "process-constraints", content=json.dumps(out))
        technologies = TechnologyData.model_validate(technologies_req.json())

        # =============================================================== synthesize solution using LLM =========================================

        format_solution = await llm_router.acompletion("chat", messages=[{
            "role": "user",
            "content": await format_prompt("if/synthesize_solution.txt", **{
                "category": cat.model_dump(),
                "technologies": technologies.model_dump()["technologies"],
                "user_constraints": None,
                "response_schema": _SearchedSolutionModel.model_json_schema()
            })}
        ], response_format=_SearchedSolutionModel)

        format_solution_model = _SearchedSolutionModel.model_validate_json(
            format_solution.choices[0].message.content)

        final_solution = SolutionModel(
            Context="",
            Requirements=[
                cat.requirements[i].requirement for i in format_solution_model.requirement_ids
            ],
            Problem_Description=format_solution_model.problem_description,
            Solution_Description=format_solution_model.solution_description,
            References=[],
            Category_Id=cat.id,
        )

        # ========================================================================================================================================

        return final_solution

    tasks = await asyncio.gather(*[_search_solution_inner(cat) for cat in req.categories], return_exceptions=True)
    final_solutions = [sol for sol in tasks if not isinstance(sol, Exception)]

    return SolutionSearchResponse(solutions=final_solutions)


api.include_router(requirements_router)
api.include_router(solution_router)

uvicorn.run(api, host="0.0.0.0", port=8000)