Spaces:
Build error
Build error
import collections | |
import re | |
from warnings import warn | |
import yaml | |
def yaml_parser(message: str) -> tuple[dict, bool, str]: | |
"""Parse a yaml message for the retry function.""" | |
# saves gpt-3.5 from some yaml parsing errors | |
message = re.sub(r':\s*\n(?=\S|\n)', ': ', message) | |
try: | |
value = yaml.safe_load(message) | |
valid = True | |
retry_message = '' | |
except yaml.YAMLError as e: | |
warn(str(e), stacklevel=2) | |
value = {} | |
valid = False | |
retry_message = "Your response is not a valid yaml. Please try again and be careful to the format. Don't add any apology or comment, just the answer." | |
return value, valid, retry_message | |
def _compress_chunks( | |
text: str, identifier: str, skip_list: list[str], split_regex: str = '\n\n+' | |
) -> tuple[dict[str, str], str]: | |
"""Compress a string by replacing redundant chunks by identifiers. Chunks are defined by the split_regex.""" | |
text_list = re.split(split_regex, text) | |
text_list = [chunk.strip() for chunk in text_list] | |
counter = collections.Counter(text_list) | |
def_dict = {} | |
id = 0 | |
# Store items that occur more than once in a dictionary | |
for item, count in counter.items(): | |
if count > 1 and item not in skip_list and len(item) > 10: | |
def_dict[f'{identifier}-{id}'] = item | |
id += 1 | |
# Replace redundant items with their identifiers in the text | |
compressed_text = '\n'.join(text_list) | |
for key, value in def_dict.items(): | |
compressed_text = compressed_text.replace(value, key) | |
return def_dict, compressed_text | |
def compress_string(text: str) -> str: | |
"""Compress a string by replacing redundant paragraphs and lines with identifiers.""" | |
# Perform paragraph-level compression | |
def_dict, compressed_text = _compress_chunks( | |
text, identifier='§', skip_list=[], split_regex='\n\n+' | |
) | |
# Perform line-level compression, skipping any paragraph identifiers | |
line_dict, compressed_text = _compress_chunks( | |
compressed_text, '¶', list(def_dict.keys()), split_regex='\n+' | |
) | |
def_dict.update(line_dict) | |
# Create a definitions section | |
def_lines = ['<definitions>'] | |
for key, value in def_dict.items(): | |
def_lines.append(f'{key}:\n{value}') | |
def_lines.append('</definitions>') | |
definitions = '\n'.join(def_lines) | |
return definitions + '\n' + compressed_text | |
def extract_html_tags(text: str, keys: list[str]) -> dict[str, list[str]]: | |
"""Extract the content within HTML tags for a list of keys. | |
Parameters | |
---------- | |
text : str | |
The input string containing the HTML tags. | |
keys : list of str | |
The HTML tags to extract the content from. | |
Returns: | |
------- | |
dict | |
A dictionary mapping each key to a list of subset in `text` that match the key. | |
Notes: | |
----- | |
All text and keys will be converted to lowercase before matching. | |
""" | |
content_dict = {} | |
# text = text.lower() | |
# keys = set([k.lower() for k in keys]) | |
for key in keys: | |
pattern = f'<{key}>(.*?)</{key}>' | |
matches = re.findall(pattern, text, re.DOTALL) | |
if matches: | |
content_dict[key] = [match.strip() for match in matches] | |
return content_dict | |
class ParseError(Exception): | |
pass | |
def parse_html_tags_raise( | |
text: str, | |
keys: list[str] | None = None, | |
optional_keys: list[str] | None = None, | |
merge_multiple: bool = False, | |
) -> dict[str, str]: | |
"""A version of parse_html_tags that raises an exception if the parsing is not successful.""" | |
content_dict, valid, retry_message = parse_html_tags( | |
text, keys, optional_keys, merge_multiple=merge_multiple | |
) | |
if not valid: | |
raise ParseError(retry_message) | |
return content_dict | |
def parse_html_tags( | |
text: str, | |
keys: list[str] | None = None, | |
optional_keys: list[str] | None = None, | |
merge_multiple: bool = False, | |
) -> tuple[dict[str, str], bool, str]: | |
"""Satisfy the parse api, extracts 1 match per key and validates that all keys are present | |
Parameters | |
---------- | |
text : str | |
The input string containing the HTML tags. | |
keys : list of str | |
The HTML tags to extract the content from. | |
optional_keys : list of str | |
The HTML tags to extract the content from, but are optional. | |
Returns: | |
------- | |
dict | |
A dictionary mapping each key to subset of `text` that match the key. | |
bool | |
Whether the parsing was successful. | |
str | |
A message to be displayed to the agent if the parsing was not successful. | |
""" | |
keys = keys or [] | |
optional_keys = optional_keys or [] | |
all_keys = list(keys) + list(optional_keys) | |
content_dict = extract_html_tags(text, all_keys) | |
retry_messages = [] | |
result_dict: dict[str, str] = {} | |
for key in all_keys: | |
if key not in content_dict: | |
if key not in optional_keys: | |
retry_messages.append(f'Missing the key <{key}> in the answer.') | |
else: | |
val = content_dict[key] | |
if len(val) > 1: | |
if not merge_multiple: | |
retry_messages.append( | |
f'Found multiple instances of the key {key}. You should have only one of them.' | |
) | |
else: | |
# merge the multiple instances | |
result_dict[key] = '\n'.join(val) | |
else: | |
result_dict[key] = val[0] | |
valid = len(retry_messages) == 0 | |
retry_message = '\n'.join(retry_messages) | |
return result_dict, valid, retry_message | |