File size: 2,116 Bytes
7e16d4f
78a1bf0
 
7e16d4f
 
 
 
 
 
 
 
 
 
 
1f626ee
 
 
 
 
 
 
7e16d4f
 
 
 
 
 
 
 
 
3caf047
7e16d4f
 
78a1bf0
7e16d4f
 
78a1bf0
7e16d4f
 
 
3caf047
 
78a1bf0
3caf047
 
 
 
 
78a1bf0
 
 
3caf047
 
 
78a1bf0
3caf047
 
7e16d4f
 
78a1bf0
7e16d4f
3caf047
 
78a1bf0
7e16d4f
 
 
 
 
 
 
78a1bf0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import re
from typing import Dict, List

import weave
from pydantic import BaseModel


class RegexResult(BaseModel):
    passed: bool
    matched_patterns: Dict[str, List[str]]
    failed_patterns: List[str]


class RegexModel(weave.Model):
    """
    Initialize RegexModel with a dictionary of patterns.

    Args:
        patterns (Dict[str, str]): Dictionary where key is pattern name and value is regex pattern.
    """

    patterns: Dict[str, str]

    def __init__(self, patterns: Dict[str, str]) -> None:
        super().__init__(patterns=patterns)
        self._compiled_patterns = {
            name: re.compile(pattern) for name, pattern in patterns.items()
        }

    @weave.op()
    def check(self, prompt: str) -> RegexResult:
        """
        Check text against all patterns and return detailed results.

        Args:
            text: Input text to check against patterns

        Returns:
            RegexResult containing pass/fail status and details about matches
        """
        matched_patterns = {}
        failed_patterns = []

        for pattern_name, pattern in self.patterns.items():
            matches = []
            for match in re.finditer(pattern, prompt):
                if match.groups():
                    # If there are capture groups, join them with a separator
                    matches.append(
                        "-".join(str(g) for g in match.groups() if g is not None)
                    )
                else:
                    # If no capture groups, use the full match
                    matches.append(match.group(0))

            if matches:
                matched_patterns[pattern_name] = matches
            else:
                failed_patterns.append(pattern_name)

        return RegexResult(
            matched_patterns=matched_patterns,
            failed_patterns=failed_patterns,
            passed=len(matched_patterns) == 0,
        )

    @weave.op()
    def predict(self, text: str) -> RegexResult:
        """
        Alias for check() to maintain consistency with other models.
        """
        return self.check(text)