File size: 1,973 Bytes
41c1aed
 
 
 
 
594f2fe
41c1aed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594f2fe
 
 
 
 
 
 
 
 
 
41c1aed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594f2fe
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from pydantic import BaseModel, Field
from typing import List, Optional

# Shared model schemas


class RequirementInfo(BaseModel):
    """Represents an extracted requirement info"""
    context: str = Field(..., description="Context for the requirement.")
    requirement: str = Field(..., description="The requirement itself.")
    document: str = Field(...,
                          description="The document the requirement is extracted from.")


class ReqGroupingCategory(BaseModel):
    """Represents the category of requirements grouped together"""
    id: int = Field(..., description="ID of the grouping category")
    title: str = Field(..., description="Title given to the grouping category")
    requirements: List[RequirementInfo] = Field(
        ..., description="List of grouped requirements")


class SolutionSearchResult(BaseModel):
    Context: str
    Requirements: List[str]
    ProblemDescription: str
    SolutionDescription: str
    References: Optional[str] = ""

# Categorize requirements endpoint


class ReqGroupingRequest(BaseModel):
    """Request schema of a requirement grouping call."""
    requirements: list[RequirementInfo]
    max_n_categories: Optional[int] = Field(
        default=None, description="Max number of categories to construct. Defaults to None")


class ReqGroupingResponse(BaseModel):
    """Response of a requirement grouping call."""
    categories: List[ReqGroupingCategory]


# INFO: keep in sync with prompt
class _ReqGroupingCategory(BaseModel):
    title: str = Field(..., description="Title given to the grouping category")
    items: list[int] = Field(
        ..., description="List of the IDs of the requirements belonging to the category.")


class _ReqGroupingOutput(BaseModel):
    categories: list[_ReqGroupingCategory] = Field(
        ..., description="List of grouping categories")


# Criticize solution endpoint

class CriticizeSolutionsRequest(BaseModel):
    solutions: list[SolutionSearchResult]