Spaces:
Sleeping
Sleeping
File size: 2,651 Bytes
1b7e88c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import json
import re
from abc import ABC, abstractmethod
from typing import Generic, Optional, TypeVar
from ....utils.error import VQLError
from ..base import BotBase
T = TypeVar("T")
class BaseOutputParser(BotBase, ABC, Generic[T]):
"""Class to parse the output of an LLM call.
Output parsers help structure language model responses.
"""
regex: Optional[str] = None
regex_group: Optional[int] = 0
class Config:
"""Configuration for this pydantic object."""
extra = "forbid"
@abstractmethod
def _parse(self, text: str) -> T:
"""Parse the output of an LLM call.
A method which takes in a string (assumed output of language model )
and parses it into some structure.
Args:
text: output of language model
Returns:
structured output
"""
def parse(self, text: str) -> T:
if self.regex:
regex_res = re.search(self.regex, text)
if regex_res is not None:
text = regex_res.group(self.regex_group)
else:
raise VQLError(
800, detail=f"Not valid json [{text}] for regex [{self.regex}]"
)
return self._parse(text)
@property
def _type(self) -> str:
"""Return the type key."""
raise self.type
class DictParser(BaseOutputParser):
def _fix_json_input(self, input_str: str) -> str:
# Replace single backslashes with double backslashes,
# while leaving already escaped ones intact
corrected_str = re.sub(
r'(?<!\\)\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', r"\\\\", input_str
)
corrected_str.replace("'", '"')
return corrected_str
def _find_json(self, input_str: str) -> dict:
match = input_str.find("{")
result, index = json.JSONDecoder().raw_decode(input_str[match:])
return result
def _parse(self, text: str) -> dict:
try:
parsed = self._find_json(text)
except json.JSONDecodeError:
preprocessed_text = self._fix_json_input(text)
try:
parsed = self._find_json(preprocessed_text)
except Exception:
raise VQLError(800, detail=f"Not valid json [{text}]")
return parsed
class ListParser(BaseOutputParser):
separator: str = ","
def _parse(self, text: str) -> list:
res_list = text.split(self.separator)
res_list = [x.strip() for x in res_list]
return res_list
class StrParser(BaseOutputParser):
def _parse(self, text: str) -> str:
return text
|