jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import re
import math
from typing import Union
import torch
from torch import Tensor
import torch.nn.functional as F
from dataclasses import dataclass, replace
from comfy.sd import CLIP
from comfy.utils import ProgressBar
import comfy.model_management
from .utils_model import InterpolationMethod
from .utils_motion import extend_list_to_batch_size
from .utils_scheduling import SelectError, TensorInterp, convert_str_to_indexes, lerp_tensors, slerp_tensors
from .logger import logger
###############################################
#----------------------------------------------
# JSON prompt format is as follows:
# "idxs": "prompt", ...
_regex_prompt_json = re.compile(r'"([\d:\-\.]+)\s*"\s*:\s*"([^"]*)"(?:\s*,\s*|$)')
# NOTE: I used ChatGPT to generate this regex and summary, as I couldn't be bothered.
# ([\d:\-\.]+): Matches idxs, which can be any combination of digits, colons, and periods.
# \s*: Matches optional whitespace.
# ":\s*": Matches the ":" separator with optional spaces.
# "([^"]*)": Captures the prompt, which can be any character except for double quotation marks.
# (?:\s*,\s*|$): This non-capturing group (?: ... ) matches either a comma (with optional spaces before or after) or the end of the string ($).
# pythonic prompt format is as follows:
# idxs = "prompt", ...
_regex_prompt_pyth = re.compile(r'([\d:\-\.]+)\s*=\s*"([^"]*)"(?:\s*,\s*|$)')
# NOTE: I used ChatGPT to generate this regex and summary, as I couldn't be bothered.
# ([\d:\-\.]+): Matches idx, which can be any combination of digits, colons, and periods.
# \s*=\s*: Matches the equal sign (=) with optional spaces on both sides.
# "([^"]*)": Captures the prompt, which can be any character except for double quotation marks.
# (?:\s*,\s*|$): Matches either a comma (with optional spaces before or after) or the end of the string ($).
# JSON value format is as follows:
# "idxs": value, ...
_regex_value_json = re.compile(r'"([\d:\-\.]+)\s*"\s*:\s*([^,]+)(?:\s*,\s*|$)')
# NOTE: I used ChatGPT to generate this regex and summary, as I couldn't be bothered.
# ([\d:\-\.]+): Matches idxs, which can be any combination of digits, colons, and periods.
# \s*: Matches optional whitespace.
# ":\s*: Matches the ":" separator with optional spaces.
# ([^,]+): Captures the value, which can be any character except for commas (this ensures that values are correctly separated).
# (?:\s*,\s*|$): Matches either a comma (with optional spaces before or after) or the end of the string ($).
# pythonic value format is as follows:
# idxs = value, ...
_regex_value_pyth = re.compile(r'([\d:\-\.]+)\s*=\s*([^,]+)(?:\s*,\s*|$)')
# NOTE: I used ChatGPT to generate this regex and summary, as I couldn't be bothered.
# ([\d:\-\.]+): Matches idx, which can be any combination of digits, colons, and periods.
# \s*=\s*: Matches the equal sign (=) with optional spaces on both sides.
# ([^,]+): Captures the value, which can be any character except for commas (this ensures that values are correctly separated).
# (?:\s*,\s*|$): Matches either a comma (with optional spaces before or after) or the end of the string ($).
#----------------------------------------------
###############################################
# verify that string only contains a-z, A-Z, 0-9, or _
_regex_key_value = re.compile(r'^[a-zA-Z0-9_]+$')
def verify_key_value(key: str, raise_error=True):
match = re.match(_regex_key_value, key)
if not match and raise_error:
raise Exception(f"Value key may only contain 'a-z', 'A-Z', '0-9', or '_', but was: '{key}'.")
return match is not None
class SFormat:
JSON = "json"
PYTH = "pythonic"
@dataclass
class RegexErrorReport:
start: int
end: int
text: str
reason: str = None
@dataclass
class InputPair:
idx: int
val: Union[int, str, Tensor]
hold: bool = False
end: bool = False
@dataclass
class CondHolder:
idx: int
prompt: str
raw_prompt: str
cond: Tensor
pooled: Tensor
hold: bool = False
interp_weight: float = None
interp_prompt: str = None
@dataclass
class ParseErrorReport:
idx_str: str
val_str: str
reason: str
@dataclass
class PromptOptions:
interp: str = TensorInterp.LERP
prepend_text: str = ''
append_text: str = ''
values_replace: dict[str, list[float]] = None
print_schedule: bool = False
add_dict: dict[str] = None
def evaluate_prompt_schedule(text: str, length: int, clip: CLIP, options: PromptOptions):
text = strip_input(text)
if len(text) == 0:
raise Exception("No text provided to Prompt Scheduling.")
# prioritize formats based on best guess to minimize redo's
if text.startswith('"'):
formats = [SFormat.JSON, SFormat.PYTH]
else:
formats = [SFormat.PYTH, SFormat.JSON]
for format in formats:
if format is SFormat.JSON:
# check JSON format
# if no errors found, assume this is the right format and pass on to parsing individual values
json_matches, json_errors = get_matches_and_errors(text, _regex_prompt_json)
if len(json_errors) == 0:
return parse_prompt_groups(json_matches, length, clip, options)
elif format is SFormat.PYTH:
# check pythonic format
# if no errors found, assume this is the right format and pass on to parsing individual values
pyth_matches, pyth_errors = get_matches_and_errors(text, _regex_prompt_pyth)
if len(pyth_errors) == 0:
return parse_prompt_groups(pyth_matches, length, clip, options)
# since both formats have errors, check which format is more 'correct' for the input
# priority:
# 1 - most matches
# 2 - least errors
if len(json_matches) > len(pyth_matches):
real_errors = json_errors
assumed = SFormat.JSON
elif len(json_matches) < len(pyth_matches):
real_errors = pyth_errors
assumed = SFormat.PYTH
elif len(json_errors) < len(pyth_errors):
real_errors = json_errors
assumed = SFormat.JSON
else:
logger.warn("same amount of matches+errors for prompt!")
real_errors = pyth_errors
assumed = SFormat.PYTH
# TODO: make separate case for when format is unknown, so that both are displayed to the user
error_msg_list = []
if len(real_errors) == 1:
error_msg_list.append(f"Found 1 issue in prompt schedule (assumed {assumed} format):")
else:
error_msg_list.append(f"Found {len(real_errors)} issues in prompt schedule (assumed {assumed} format):")
for error in real_errors:
error_msg_list.append(f"Position {error.start} to {error.end}: '{error.text}'")
error_msg = "\n".join(error_msg_list)
raise Exception(error_msg)
def parse_prompt_groups(groups: list[tuple], length: int, clip: CLIP, options: PromptOptions):
pairs: list[InputPair]
errors: list[ParseErrorReport]
# turn group tuples into InputPairs
pairs = [InputPair(x[0], x[1]) for x in groups]
# perform first parse, to get idea of indexes to handle
pairs, errors = handle_group_idxs(pairs, length)
if len(errors) > 0:
error_msg_list = []
issues_formatted = f"{len(errors)} issue{'s' if len(errors)> 1 else ''}"
error_msg_list.append(f"Found {issues_formatted} with idxs:")
for error in errors:
error_msg_list.append(f"{error.idx_str}: {error.reason}")
error_msg = "\n".join(error_msg_list)
raise Exception(error_msg)
prepare_prompts(pairs, options)
final_vals = handle_prompt_interpolation(pairs, length, clip, options)
return final_vals
def prepare_prompts(pairs: list[InputPair], options: PromptOptions):
for pair in pairs:
prepend_text = options.prepend_text.strip()
append_text = options.append_text.strip()
prompt = pair.val.strip()
# when adding prepend and append text, handle commas properly
# prepend text
if len(prepend_text) > 0:
while prepend_text.endswith(','):
prepend_text = prepend_text[:-1].strip()
if prompt.startswith(','):
prepend_text = f"{prepend_text}"
else:
prepend_text = f"{prepend_text}, "
prompt = prepend_text + prompt
# append text
if len(append_text) > 0:
while append_text.startswith(','):
append_text = append_text[1:].strip()
if prompt.endswith(','):
append_text = f" {append_text}"
else:
append_text = f", {append_text}"
prompt = prompt + append_text
# update value w/ prompt
pair.val = prompt
def apply_values_replace_to_prompt(prompt: str, idx: int, values_replace: Union[None, dict[str, list[float]]]):
# if no values to replace, do nothing
if values_replace is None:
return prompt
for key, value in values_replace.items():
# use FizzNodes `` notation
match_str = '`' + key + '`'
value_str = f"{value[idx]}"
prompt = prompt.replace(match_str, value_str)
return prompt
def handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP, options: PromptOptions):
if length == 0:
length = max(pairs, key=lambda x: x.idx).idx+1
# prepare values_replace (should match length)
values_replace = options.values_replace
if values_replace is not None:
values_replace.copy()
for key, value in values_replace.items():
if len(value) < length:
values_replace[key] = extend_list_to_batch_size(value, length)
scheduled_keyframes = []
if clip.use_clip_schedule:
clip = clip.clone()
scheduled_keyframes = clip.patcher.forced_hooks.get_hooks_for_clip_schedule()
pairs_lengths = len(pairs) * max(1, len(scheduled_keyframes))
pbar_total = length + pairs_lengths
pbar = ProgressBar(pbar_total)
# for now, use FizzNodes approach of calculating max size of tokens beforehand;
# this can up to double total encoding time, as this will be done again.
# TODO: do this dynamically to save encoding time
max_size = 0
for pair in pairs:
prepared_prompt = apply_values_replace_to_prompt(pair.val, 0, values_replace=values_replace)
cond: Tensor = clip.encode_from_tokens(clip.tokenize(prepared_prompt))
max_size = max(max_size, cond.shape[1])
pbar.update(1)
# if do not need to schedule clip with hooks, do nothing special
if not clip.use_clip_schedule:
return _handle_prompt_interpolation(pairs, length, clip, options, values_replace, max_size, pbar)
# otherwise, need to account for keyframes on forced_hooks
full_output = []
for i, scheduled_opts in enumerate(scheduled_keyframes):
clip.patcher.forced_hooks.reset()
clip.patcher.unpatch_hooks()
t_range = scheduled_opts[0]
hooks_keyframes = scheduled_opts[1]
for hook, keyframe in hooks_keyframes:
hook.hook_keyframe._current_keyframe = keyframe
try:
# don't print_schedule on non-first iteration
orig_print_schedule = options.print_schedule
if orig_print_schedule and i != 0:
options.print_schedule = False
schedule_output = _handle_prompt_interpolation(pairs, length, clip, options, values_replace, max_size, pbar)
finally:
options.print_schedule = orig_print_schedule
for cond, pooled_dict in schedule_output:
pooled_dict: dict[str]
# add clip_start_percent and clip_end_percent in pooled
pooled_dict["clip_start_percent"] = t_range[0]
pooled_dict["clip_end_percent"] = t_range[1]
full_output.extend(schedule_output)
return full_output
def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP, options: PromptOptions,
values_replace: dict[str, list[float]], max_size: int, pbar: ProgressBar):
real_holders: list[CondHolder] = [None] * length
real_cond = [None] * length
real_pooled = [None] * length
prev_holder: Union[CondHolder, None] = None
for idx, pair in enumerate(pairs):
holder = None
is_over_length = False
# if no last pair is set, then use first provided val up to the idx
if prev_holder is None:
for i in range(idx, pair.idx+1):
if i >= length:
is_over_length = True
continue
real_prompt = apply_values_replace_to_prompt(pair.val, i, values_replace=values_replace)
if holder is None or holder.prompt != real_prompt:
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond = pad_cond(cond, target_length=max_size)
holder = CondHolder(idx=i, prompt=real_prompt, raw_prompt=pair.val, cond=cond, pooled=pooled, hold=pair.hold)
else:
holder = replace(holder)
holder.idx = i
real_cond[i] = cond
real_pooled[i] = pooled
real_holders[i] = holder
pbar.update(1)
comfy.model_management.throw_exception_if_processing_interrupted()
# if idx is exactly one greater than the one before, nothing special
elif prev_holder.idx == pair.idx-1:
comfy.model_management.throw_exception_if_processing_interrupted()
holder = prev_holder
if pair.idx < length:
real_prompt = apply_values_replace_to_prompt(pair.val, pair.idx, values_replace=values_replace)
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond = pad_cond(cond, target_length=max_size)
holder = CondHolder(idx=pair.idx, prompt=real_prompt, raw_prompt=pair.val, cond=cond, pooled=pooled, hold=pair.hold)
real_cond[pair.idx] = cond
real_pooled[pair.idx] = pooled
real_holders[pair.idx] = holder
pbar.update(1)
else:
# if holding value, no interpolation
if prev_holder.hold:
# keep same value as last_holder, then calculate current index cond;
# however, need to check if real_prompt remains the same
for i in range(prev_holder.idx+1, pair.idx):
if i >= length:
is_over_length = True
continue
if holder is None:
holder = prev_holder
real_prompt = apply_values_replace_to_prompt(pair.val, i, values_replace=values_replace)
if holder.prompt != real_prompt:
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond = pad_cond(cond, target_length=max_size)
holder = CondHolder(idx=i, prompt=real_prompt, raw_prompt=pair.val, cond=cond, pooled=pooled, hold=pair.hold)
else:
holder = replace(holder)
holder.idx = i
real_cond[i] = holder.cond
real_pooled[i] = holder.pooled
real_holders[i] = holder
pbar.update(1)
comfy.model_management.throw_exception_if_processing_interrupted()
if pair.idx < length:
real_prompt = apply_values_replace_to_prompt(pair.val, pair.idx, values_replace=values_replace)
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond = pad_cond(cond, target_length=max_size)
holder = CondHolder(idx=pair.idx, prompt=real_prompt, raw_prompt=pair.val, cond=cond, pooled=pooled, hold=pair.hold)
real_cond[pair.idx] = cond
real_pooled[pair.idx] = pooled
real_holders[pair.idx] = holder
pbar.update(1)
comfy.model_management.throw_exception_if_processing_interrupted()
# otherwise, interpolate
else:
diff_len = abs(pair.idx-prev_holder.idx)+1
interp_idxs = InterpolationMethod.get_weights(num_from=prev_holder.idx, num_to=pair.idx, length=diff_len,
method=InterpolationMethod.LINEAR)
interp_weights = InterpolationMethod.get_weights(num_from=0.0, num_to=1.0, length=diff_len,
method=InterpolationMethod.LINEAR)
cond_to = None
pooled_to = None
cond_from = None
holder = None
interm_holder = prev_holder
for raw_idx, weight in zip(interp_idxs, interp_weights):
if raw_idx >= length:
is_over_length = True
continue
idx_int = round(float(raw_idx))
# calculate cond_to stuff if not done yet
real_prompt = apply_values_replace_to_prompt(pair.val, idx_int, values_replace=values_replace)
if holder is None or holder.prompt != real_prompt:
cond_to, pooled_to = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond_to = pad_cond(cond_to, target_length=max_size)
holder = CondHolder(idx=idx_int, prompt=real_prompt, raw_prompt=pair.val, cond=cond_to, pooled=pooled_to, hold=pair.hold)
# calculate interm_holder stuff if needed
real_prompt = apply_values_replace_to_prompt(interm_holder.raw_prompt, idx_int, values_replace=values_replace)
if interm_holder.prompt != real_prompt:
cond_from, pooled_from = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond_from = pad_cond(cond_from, target_length=max_size)
interm_holder = CondHolder(idx=idx_int, prompt=real_prompt, raw_prompt=interm_holder.raw_prompt, cond=cond_from, pooled=pooled_from, hold=holder.hold)
else:
interm_holder = CondHolder(idx=interm_holder.idx, prompt=interm_holder.prompt, raw_prompt=interm_holder.raw_prompt, cond=interm_holder.cond, pooled=interm_holder.pooled, hold=interm_holder.hold)
# interpolate conds
if options.interp == TensorInterp.LERP:
cond_interp = lerp_tensors(tensor_from=interm_holder.cond, tensor_to=cond_to, strength_to=weight)
elif options.interp == TensorInterp.SLERP:
cond_interp = slerp_tensors(tensor_from=interm_holder.cond, tensor_to=cond_to, strength_to=weight)
pooled_interp = pooled_to
if math.isclose(weight, 0.0):
pooled_interp = interm_holder.pooled
interm_holder = CondHolder(idx=idx_int, prompt=interm_holder.prompt, raw_prompt=interm_holder.raw_prompt, cond=cond_interp, pooled=pooled_interp, hold=holder.hold,
interp_weight=weight, interp_prompt=holder.prompt)
real_cond[idx_int] = cond_interp
real_pooled[idx_int] = pooled_interp
real_holders[idx_int] = interm_holder
pbar.update(1)
comfy.model_management.throw_exception_if_processing_interrupted()
if is_over_length:
break
assert holder is not None
prev_holder = holder
# fill in None gaps with last used values
# TODO: review if this works as intended, or if needs to be a bit more thorough
prev_holder = None
for i in range(len(real_holders)):
if real_holders[i] is None:
# check if any value replacement needs to be accounted for
real_prompt = apply_values_replace_to_prompt(prev_holder.raw_prompt, i, values_replace=values_replace)
if prev_holder.prompt != real_prompt:
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond = pad_cond(cond, target_length=max_size)
prev_holder = CondHolder(idx=i, prompt=real_prompt, raw_prompt=prev_holder.raw_prompt, cond=cond, pooled=pooled, hold=prev_holder.hold)
real_cond[i] = prev_holder.cond
real_pooled[i] = prev_holder.pooled
real_holders[i] = prev_holder
pbar.update(1)
else:
prev_holder = real_holders[i]
final_cond = torch.cat(real_cond, dim=0)
final_pooled = torch.cat(real_pooled, dim=0)
if options.print_schedule:
logger.info(f"PromptScheduling ({len(real_holders)} prompts)")
for i, holder in enumerate(real_holders):
if holder.interp_prompt is None:
logger.info(f'{i} = "{holder.prompt}"')
else:
logger.info(f'{i} = ({1.-holder.interp_weight:.2f})"{holder.prompt}" -> ({holder.interp_weight:.2f})"{holder.interp_prompt}"')
# cond is a list[list[Tensor, dict[str: Any]]] format
final_pooled_dict = {"pooled_output": final_pooled}
if options.add_dict is not None:
final_pooled_dict.update(options.add_dict)
# add hooks, if needed
clip.add_hooks_to_dict(final_pooled_dict)
return [[final_cond, final_pooled_dict]]
def pad_cond(cond: Tensor, target_length: int):
# FizzNodes-style cond padding
# TODO: test out other methods of padding
curr_length = cond.shape[1]
if curr_length < target_length:
pad_length = target_length - curr_length
# FizzNodes pads the tensor on both ends
left_pad = pad_length // 2
right_pad = pad_length - left_pad
# perform padding
cond = F.pad(cond, (0, 0, left_pad, right_pad))
return cond
def evaluate_value_schedule(text: str, length: int):
text = strip_input(text)
if len(text) == 0:
raise Exception("No text provided to Value Scheduling.")
# prioritize formats based on best guess to minimize redo's
if text.startswith('"'):
formats = [SFormat.JSON, SFormat.PYTH]
else:
formats = [SFormat.PYTH, SFormat.JSON]
for format in formats:
if format is SFormat.JSON:
# check JSON format
# if no errors found, assume this is the right format and pass on to parsing individual values
json_matches, json_errors = get_matches_and_errors(text, _regex_value_json)
if len(json_errors) == 0:
return parse_value_groups(json_matches, length)
elif format is SFormat.PYTH:
# check pythonic format
# if no errors found, assume this is the right format and pass on to parsing individual values
pyth_matches, pyth_errors = get_matches_and_errors(text, _regex_value_pyth)
if len(pyth_errors) == 0:
return parse_value_groups(pyth_matches, length)
# since both formats have errors, check which format is more 'correct' for the input
# priority:
# 1 - most matches
# 2 - least errors
if len(json_matches) > len(pyth_matches):
real_errors = json_errors
assumed = SFormat.JSON
elif len(json_matches) < len(pyth_matches):
real_errors = pyth_errors
assumed = SFormat.PYTH
elif len(json_errors) < len(pyth_errors):
real_errors = json_errors
assumed = SFormat.JSON
else:
#logger.info("same amount of matches+errors for value!")
if text.startswith('"'):
real_errors = json_errors
assumed = SFormat.JSON
else:
real_errors = pyth_errors
assumed = SFormat.PYTH
# TODO: make separate case for when format is unknown, so that both are displayed to the user
error_msg_list = []
if len(real_errors) == 1:
error_msg_list.append(f"Found 1 issue in value schedule (assumed {assumed} format):")
else:
error_msg_list.append(f"Found {len(real_errors)} issues in value schedule (assumed {assumed} format):")
for error in real_errors:
error_msg_list.append(f"Position {error.start} to {error.end}: '{error.text}'")
error_msg = "\n".join(error_msg_list)
raise Exception(error_msg)
def parse_value_groups(groups: list[tuple], length: int):
#logger.info(groups)
pairs: list[InputPair]
errors: list[ParseErrorReport]
# perform first parse, where we convert vals to floats
pairs, errors = handle_float_vals(groups)
if len(errors) == 0:
# perform second parse, to get idea of indexes to handle
pairs, errors = handle_group_idxs(pairs, length)
if len(pairs) == 0:
errors.append(ParseErrorReport(idx_str='No valid idxs provided', val_str='', reason='Provided ranges might not be selecting anything.'))
if len(errors) > 0:
error_msg_list = []
issues_formatted = f"{len(errors)} issue{'s' if len(errors)> 1 else ''}"
error_msg_list.append(f"Found {issues_formatted} with idxs/vals:")
for error in errors:
error_msg_list.append(f"{error.idx_str}: {error.reason}")
error_msg = "\n".join(error_msg_list)
raise Exception(error_msg)
# perform third parse, where hold and interpolation is used to fill in any in-between values
final_vals = handle_val_interpolation(pairs, length)
return final_vals
def handle_float_vals(groups: list[tuple]):
actual_pairs: list[InputPair] = []
errors: list[ParseErrorReport] = []
for idx_str, val_str in groups:
val_str = strip_value(val_str)
try:
val = float(val_str)
except ValueError:
errors.append(ParseErrorReport(idx_str, val_str, f"Value '{val_str}' is not a valid number"))
continue
actual_pairs.append(InputPair(idx_str, val))
return actual_pairs, errors
def handle_val_interpolation(pairs: list[InputPair], length: int):
if length == 0:
length = max(pairs, key=lambda x: x.idx).idx+1
real_vals = [None] * length
prev_pair = None
for pair in pairs:
# if no last pair is set, then use first provided val up to the idx
if prev_pair is None:
for i in range(0, pair.idx+1):
if i >= length:
break
real_vals[i] = pair.val
# if idx is exactly one greater than the one before, nothing special
elif prev_pair.idx == pair.idx-1:
if pair.idx < length:
real_vals[pair.idx] = pair.val
else:
# if holding value, no interpolation
if prev_pair.hold:
# keep same value as last_pair, then assign current index value
for i in range(prev_pair.idx+1, pair.idx):
if i >= length:
continue
real_vals[i] = prev_pair.val
if pair.idx < length:
real_vals[pair.idx] = pair.val
# otherwise, interpolate
else:
diff_len = abs(pair.idx-prev_pair.idx)+1
interp_idxs = InterpolationMethod.get_weights(num_from=prev_pair.idx, num_to=pair.idx, length=diff_len,
method=InterpolationMethod.LINEAR)
interp_vals = InterpolationMethod.get_weights(num_from=prev_pair.val, num_to=pair.val, length=diff_len,
method=InterpolationMethod.LINEAR)
for idx, val in zip(interp_idxs, interp_vals):
if idx >= length:
continue
real_vals[round(float(idx))] = float(val)
prev_pair = pair
# fill in None gaps with last used value
# TODO: review if this works as intended, or if needs to be a bit more thorough
prev_val = None
for i in range(len(real_vals)):
if real_vals[i] is None:
real_vals[i] = prev_val
else:
prev_val = real_vals[i]
return real_vals
def handle_group_idxs(pairs: list[InputPair], length: int):
actual_pairs: list[InputPair] = []
errors: list[ParseErrorReport] = []
for pair in pairs:
idx_str, val_str = pair.idx, pair.val
idx_str: str = idx_str.strip()
hold = False
# if starts with :, wrong
if idx_str.startswith(':'):
errors.append(ParseErrorReport(idx_str, val_str, "Idx can't begin with ':'"))
continue
# if has more than one :, wrong
if idx_str.count(':') > 1:
errors.append(ParseErrorReport(idx_str, val_str, "Idx can't have more than one ':'"))
if idx_str.endswith(':'):
hold = True
idx_str = idx_str[:-1]
try:
idxs = convert_str_to_indexes(idx_str, length, allow_range=True, allow_missing=True, fix_reverse=True, same_is_one=True, allow_decimal=True)
except SelectError as e:
errors.append(ParseErrorReport(idx_str, val_str, f"Couldn't convert idxs; {str(e)}"))
continue
for idx in idxs:
actual_pairs.append(InputPair(idx, val_str, hold))
return actual_pairs, errors
def get_matches_and_errors(text: str, pattern: re.Pattern) -> tuple[list, list[RegexErrorReport]]:
last_match_end = 0
matches = []
errors: list[RegexErrorReport] = []
for match in re.finditer(pattern, text):
start, end = match.span()
# if there is any text between last match and current, consider as error
if start != last_match_end:
errors.append(RegexErrorReport(last_match_end, start, text[last_match_end:start].replace('\n','\t')))
# update match
last_match_end = end
# store match
matches.append(match.groups())
# check for any trailing unmatched text
if last_match_end != len(text):
errors.append(RegexErrorReport(last_match_end, len(text), text[last_match_end:].replace('\n','\t')))
return matches, errors
def is_surrounded(text: str, pair):
return text.startswith(pair[0]) and text.endswith(pair[1])
def is_surrounded_pairs(text: str, pairs):
for pair in pairs:
if is_surrounded(text, pair):
return True
return False
def strip_value(text: str, limit=-1):
text = text.strip()
# strip common paired symbols
symbol_pairs = [
("(", ")"),
("[", "]"),
("{", "}"),
('"', '"'),
("'", "'"),
]
while limit != 0 and is_surrounded_pairs(text, symbol_pairs):
text = text[1:-1].strip()
limit -= 1
return text
def strip_input(text: str):
text = text.strip()
# strip JSON brackets, if needed
if text.startswith('{') and text.endswith('}'):
return text[1:-1].strip()
return text