|
import re |
|
import math |
|
from typing import Union |
|
import torch |
|
from torch import Tensor |
|
import torch.nn.functional as F |
|
from dataclasses import dataclass, replace |
|
|
|
from comfy.sd import CLIP |
|
from comfy.utils import ProgressBar |
|
import comfy.model_management |
|
|
|
from .utils_model import InterpolationMethod |
|
from .utils_motion import extend_list_to_batch_size |
|
from .utils_scheduling import SelectError, TensorInterp, convert_str_to_indexes, lerp_tensors, slerp_tensors |
|
from .logger import logger |
|
|
|
|
|
|
|
|
|
|
|
_regex_prompt_json = re.compile(r'"([\d:\-\.]+)\s*"\s*:\s*"([^"]*)"(?:\s*,\s*|$)') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_regex_prompt_pyth = re.compile(r'([\d:\-\.]+)\s*=\s*"([^"]*)"(?:\s*,\s*|$)') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_regex_value_json = re.compile(r'"([\d:\-\.]+)\s*"\s*:\s*([^,]+)(?:\s*,\s*|$)') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_regex_value_pyth = re.compile(r'([\d:\-\.]+)\s*=\s*([^,]+)(?:\s*,\s*|$)') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_regex_key_value = re.compile(r'^[a-zA-Z0-9_]+$') |
|
def verify_key_value(key: str, raise_error=True): |
|
match = re.match(_regex_key_value, key) |
|
if not match and raise_error: |
|
raise Exception(f"Value key may only contain 'a-z', 'A-Z', '0-9', or '_', but was: '{key}'.") |
|
return match is not None |
|
|
|
|
|
class SFormat: |
|
JSON = "json" |
|
PYTH = "pythonic" |
|
|
|
@dataclass |
|
class RegexErrorReport: |
|
start: int |
|
end: int |
|
text: str |
|
reason: str = None |
|
|
|
@dataclass |
|
class InputPair: |
|
idx: int |
|
val: Union[int, str, Tensor] |
|
hold: bool = False |
|
end: bool = False |
|
|
|
@dataclass |
|
class CondHolder: |
|
idx: int |
|
prompt: str |
|
raw_prompt: str |
|
cond: Tensor |
|
pooled: Tensor |
|
hold: bool = False |
|
interp_weight: float = None |
|
interp_prompt: str = None |
|
|
|
@dataclass |
|
class ParseErrorReport: |
|
idx_str: str |
|
val_str: str |
|
reason: str |
|
|
|
@dataclass |
|
class PromptOptions: |
|
interp: str = TensorInterp.LERP |
|
prepend_text: str = '' |
|
append_text: str = '' |
|
values_replace: dict[str, list[float]] = None |
|
print_schedule: bool = False |
|
add_dict: dict[str] = None |
|
|
|
|
|
def evaluate_prompt_schedule(text: str, length: int, clip: CLIP, options: PromptOptions): |
|
text = strip_input(text) |
|
if len(text) == 0: |
|
raise Exception("No text provided to Prompt Scheduling.") |
|
|
|
if text.startswith('"'): |
|
formats = [SFormat.JSON, SFormat.PYTH] |
|
else: |
|
formats = [SFormat.PYTH, SFormat.JSON] |
|
for format in formats: |
|
if format is SFormat.JSON: |
|
|
|
|
|
json_matches, json_errors = get_matches_and_errors(text, _regex_prompt_json) |
|
if len(json_errors) == 0: |
|
return parse_prompt_groups(json_matches, length, clip, options) |
|
elif format is SFormat.PYTH: |
|
|
|
|
|
pyth_matches, pyth_errors = get_matches_and_errors(text, _regex_prompt_pyth) |
|
if len(pyth_errors) == 0: |
|
return parse_prompt_groups(pyth_matches, length, clip, options) |
|
|
|
|
|
|
|
|
|
if len(json_matches) > len(pyth_matches): |
|
real_errors = json_errors |
|
assumed = SFormat.JSON |
|
elif len(json_matches) < len(pyth_matches): |
|
real_errors = pyth_errors |
|
assumed = SFormat.PYTH |
|
elif len(json_errors) < len(pyth_errors): |
|
real_errors = json_errors |
|
assumed = SFormat.JSON |
|
else: |
|
logger.warn("same amount of matches+errors for prompt!") |
|
real_errors = pyth_errors |
|
assumed = SFormat.PYTH |
|
|
|
error_msg_list = [] |
|
if len(real_errors) == 1: |
|
error_msg_list.append(f"Found 1 issue in prompt schedule (assumed {assumed} format):") |
|
else: |
|
error_msg_list.append(f"Found {len(real_errors)} issues in prompt schedule (assumed {assumed} format):") |
|
for error in real_errors: |
|
error_msg_list.append(f"Position {error.start} to {error.end}: '{error.text}'") |
|
error_msg = "\n".join(error_msg_list) |
|
raise Exception(error_msg) |
|
|
|
|
|
def parse_prompt_groups(groups: list[tuple], length: int, clip: CLIP, options: PromptOptions): |
|
pairs: list[InputPair] |
|
errors: list[ParseErrorReport] |
|
|
|
pairs = [InputPair(x[0], x[1]) for x in groups] |
|
|
|
pairs, errors = handle_group_idxs(pairs, length) |
|
if len(errors) > 0: |
|
error_msg_list = [] |
|
issues_formatted = f"{len(errors)} issue{'s' if len(errors)> 1 else ''}" |
|
error_msg_list.append(f"Found {issues_formatted} with idxs:") |
|
for error in errors: |
|
error_msg_list.append(f"{error.idx_str}: {error.reason}") |
|
error_msg = "\n".join(error_msg_list) |
|
raise Exception(error_msg) |
|
prepare_prompts(pairs, options) |
|
final_vals = handle_prompt_interpolation(pairs, length, clip, options) |
|
return final_vals |
|
|
|
|
|
def prepare_prompts(pairs: list[InputPair], options: PromptOptions): |
|
for pair in pairs: |
|
prepend_text = options.prepend_text.strip() |
|
append_text = options.append_text.strip() |
|
prompt = pair.val.strip() |
|
|
|
|
|
if len(prepend_text) > 0: |
|
while prepend_text.endswith(','): |
|
prepend_text = prepend_text[:-1].strip() |
|
if prompt.startswith(','): |
|
prepend_text = f"{prepend_text}" |
|
else: |
|
prepend_text = f"{prepend_text}, " |
|
prompt = prepend_text + prompt |
|
|
|
if len(append_text) > 0: |
|
while append_text.startswith(','): |
|
append_text = append_text[1:].strip() |
|
if prompt.endswith(','): |
|
append_text = f" {append_text}" |
|
else: |
|
append_text = f", {append_text}" |
|
prompt = prompt + append_text |
|
|
|
pair.val = prompt |
|
|
|
|
|
def apply_values_replace_to_prompt(prompt: str, idx: int, values_replace: Union[None, dict[str, list[float]]]): |
|
|
|
if values_replace is None: |
|
return prompt |
|
for key, value in values_replace.items(): |
|
|
|
match_str = '`' + key + '`' |
|
value_str = f"{value[idx]}" |
|
prompt = prompt.replace(match_str, value_str) |
|
return prompt |
|
|
|
|
|
def handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP, options: PromptOptions): |
|
if length == 0: |
|
length = max(pairs, key=lambda x: x.idx).idx+1 |
|
|
|
values_replace = options.values_replace |
|
if values_replace is not None: |
|
values_replace.copy() |
|
for key, value in values_replace.items(): |
|
if len(value) < length: |
|
values_replace[key] = extend_list_to_batch_size(value, length) |
|
|
|
scheduled_keyframes = [] |
|
if clip.use_clip_schedule: |
|
clip = clip.clone() |
|
scheduled_keyframes = clip.patcher.forced_hooks.get_hooks_for_clip_schedule() |
|
|
|
pairs_lengths = len(pairs) * max(1, len(scheduled_keyframes)) |
|
pbar_total = length + pairs_lengths |
|
pbar = ProgressBar(pbar_total) |
|
|
|
|
|
|
|
max_size = 0 |
|
for pair in pairs: |
|
prepared_prompt = apply_values_replace_to_prompt(pair.val, 0, values_replace=values_replace) |
|
cond: Tensor = clip.encode_from_tokens(clip.tokenize(prepared_prompt)) |
|
max_size = max(max_size, cond.shape[1]) |
|
pbar.update(1) |
|
|
|
|
|
if not clip.use_clip_schedule: |
|
return _handle_prompt_interpolation(pairs, length, clip, options, values_replace, max_size, pbar) |
|
|
|
full_output = [] |
|
for i, scheduled_opts in enumerate(scheduled_keyframes): |
|
clip.patcher.forced_hooks.reset() |
|
clip.patcher.unpatch_hooks() |
|
|
|
t_range = scheduled_opts[0] |
|
hooks_keyframes = scheduled_opts[1] |
|
for hook, keyframe in hooks_keyframes: |
|
hook.hook_keyframe._current_keyframe = keyframe |
|
try: |
|
|
|
orig_print_schedule = options.print_schedule |
|
if orig_print_schedule and i != 0: |
|
options.print_schedule = False |
|
schedule_output = _handle_prompt_interpolation(pairs, length, clip, options, values_replace, max_size, pbar) |
|
finally: |
|
options.print_schedule = orig_print_schedule |
|
for cond, pooled_dict in schedule_output: |
|
pooled_dict: dict[str] |
|
|
|
pooled_dict["clip_start_percent"] = t_range[0] |
|
pooled_dict["clip_end_percent"] = t_range[1] |
|
full_output.extend(schedule_output) |
|
return full_output |
|
|
|
|
|
def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP, options: PromptOptions, |
|
values_replace: dict[str, list[float]], max_size: int, pbar: ProgressBar): |
|
real_holders: list[CondHolder] = [None] * length |
|
real_cond = [None] * length |
|
real_pooled = [None] * length |
|
prev_holder: Union[CondHolder, None] = None |
|
for idx, pair in enumerate(pairs): |
|
holder = None |
|
is_over_length = False |
|
|
|
if prev_holder is None: |
|
for i in range(idx, pair.idx+1): |
|
if i >= length: |
|
is_over_length = True |
|
continue |
|
real_prompt = apply_values_replace_to_prompt(pair.val, i, values_replace=values_replace) |
|
if holder is None or holder.prompt != real_prompt: |
|
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True) |
|
cond = pad_cond(cond, target_length=max_size) |
|
holder = CondHolder(idx=i, prompt=real_prompt, raw_prompt=pair.val, cond=cond, pooled=pooled, hold=pair.hold) |
|
else: |
|
holder = replace(holder) |
|
holder.idx = i |
|
real_cond[i] = cond |
|
real_pooled[i] = pooled |
|
real_holders[i] = holder |
|
pbar.update(1) |
|
comfy.model_management.throw_exception_if_processing_interrupted() |
|
|
|
elif prev_holder.idx == pair.idx-1: |
|
comfy.model_management.throw_exception_if_processing_interrupted() |
|
holder = prev_holder |
|
if pair.idx < length: |
|
real_prompt = apply_values_replace_to_prompt(pair.val, pair.idx, values_replace=values_replace) |
|
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True) |
|
cond = pad_cond(cond, target_length=max_size) |
|
holder = CondHolder(idx=pair.idx, prompt=real_prompt, raw_prompt=pair.val, cond=cond, pooled=pooled, hold=pair.hold) |
|
real_cond[pair.idx] = cond |
|
real_pooled[pair.idx] = pooled |
|
real_holders[pair.idx] = holder |
|
pbar.update(1) |
|
else: |
|
|
|
if prev_holder.hold: |
|
|
|
|
|
for i in range(prev_holder.idx+1, pair.idx): |
|
if i >= length: |
|
is_over_length = True |
|
continue |
|
if holder is None: |
|
holder = prev_holder |
|
real_prompt = apply_values_replace_to_prompt(pair.val, i, values_replace=values_replace) |
|
if holder.prompt != real_prompt: |
|
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True) |
|
cond = pad_cond(cond, target_length=max_size) |
|
holder = CondHolder(idx=i, prompt=real_prompt, raw_prompt=pair.val, cond=cond, pooled=pooled, hold=pair.hold) |
|
else: |
|
holder = replace(holder) |
|
holder.idx = i |
|
real_cond[i] = holder.cond |
|
real_pooled[i] = holder.pooled |
|
real_holders[i] = holder |
|
pbar.update(1) |
|
comfy.model_management.throw_exception_if_processing_interrupted() |
|
if pair.idx < length: |
|
real_prompt = apply_values_replace_to_prompt(pair.val, pair.idx, values_replace=values_replace) |
|
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True) |
|
cond = pad_cond(cond, target_length=max_size) |
|
holder = CondHolder(idx=pair.idx, prompt=real_prompt, raw_prompt=pair.val, cond=cond, pooled=pooled, hold=pair.hold) |
|
real_cond[pair.idx] = cond |
|
real_pooled[pair.idx] = pooled |
|
real_holders[pair.idx] = holder |
|
pbar.update(1) |
|
comfy.model_management.throw_exception_if_processing_interrupted() |
|
|
|
else: |
|
diff_len = abs(pair.idx-prev_holder.idx)+1 |
|
interp_idxs = InterpolationMethod.get_weights(num_from=prev_holder.idx, num_to=pair.idx, length=diff_len, |
|
method=InterpolationMethod.LINEAR) |
|
interp_weights = InterpolationMethod.get_weights(num_from=0.0, num_to=1.0, length=diff_len, |
|
method=InterpolationMethod.LINEAR) |
|
cond_to = None |
|
pooled_to = None |
|
cond_from = None |
|
holder = None |
|
interm_holder = prev_holder |
|
for raw_idx, weight in zip(interp_idxs, interp_weights): |
|
if raw_idx >= length: |
|
is_over_length = True |
|
continue |
|
idx_int = round(float(raw_idx)) |
|
|
|
real_prompt = apply_values_replace_to_prompt(pair.val, idx_int, values_replace=values_replace) |
|
if holder is None or holder.prompt != real_prompt: |
|
cond_to, pooled_to = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True) |
|
cond_to = pad_cond(cond_to, target_length=max_size) |
|
holder = CondHolder(idx=idx_int, prompt=real_prompt, raw_prompt=pair.val, cond=cond_to, pooled=pooled_to, hold=pair.hold) |
|
|
|
real_prompt = apply_values_replace_to_prompt(interm_holder.raw_prompt, idx_int, values_replace=values_replace) |
|
if interm_holder.prompt != real_prompt: |
|
cond_from, pooled_from = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True) |
|
cond_from = pad_cond(cond_from, target_length=max_size) |
|
interm_holder = CondHolder(idx=idx_int, prompt=real_prompt, raw_prompt=interm_holder.raw_prompt, cond=cond_from, pooled=pooled_from, hold=holder.hold) |
|
else: |
|
interm_holder = CondHolder(idx=interm_holder.idx, prompt=interm_holder.prompt, raw_prompt=interm_holder.raw_prompt, cond=interm_holder.cond, pooled=interm_holder.pooled, hold=interm_holder.hold) |
|
|
|
if options.interp == TensorInterp.LERP: |
|
cond_interp = lerp_tensors(tensor_from=interm_holder.cond, tensor_to=cond_to, strength_to=weight) |
|
elif options.interp == TensorInterp.SLERP: |
|
cond_interp = slerp_tensors(tensor_from=interm_holder.cond, tensor_to=cond_to, strength_to=weight) |
|
pooled_interp = pooled_to |
|
if math.isclose(weight, 0.0): |
|
pooled_interp = interm_holder.pooled |
|
interm_holder = CondHolder(idx=idx_int, prompt=interm_holder.prompt, raw_prompt=interm_holder.raw_prompt, cond=cond_interp, pooled=pooled_interp, hold=holder.hold, |
|
interp_weight=weight, interp_prompt=holder.prompt) |
|
real_cond[idx_int] = cond_interp |
|
real_pooled[idx_int] = pooled_interp |
|
real_holders[idx_int] = interm_holder |
|
pbar.update(1) |
|
comfy.model_management.throw_exception_if_processing_interrupted() |
|
if is_over_length: |
|
break |
|
assert holder is not None |
|
prev_holder = holder |
|
|
|
|
|
|
|
prev_holder = None |
|
for i in range(len(real_holders)): |
|
if real_holders[i] is None: |
|
|
|
real_prompt = apply_values_replace_to_prompt(prev_holder.raw_prompt, i, values_replace=values_replace) |
|
if prev_holder.prompt != real_prompt: |
|
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True) |
|
cond = pad_cond(cond, target_length=max_size) |
|
prev_holder = CondHolder(idx=i, prompt=real_prompt, raw_prompt=prev_holder.raw_prompt, cond=cond, pooled=pooled, hold=prev_holder.hold) |
|
real_cond[i] = prev_holder.cond |
|
real_pooled[i] = prev_holder.pooled |
|
real_holders[i] = prev_holder |
|
pbar.update(1) |
|
else: |
|
prev_holder = real_holders[i] |
|
|
|
final_cond = torch.cat(real_cond, dim=0) |
|
final_pooled = torch.cat(real_pooled, dim=0) |
|
|
|
if options.print_schedule: |
|
logger.info(f"PromptScheduling ({len(real_holders)} prompts)") |
|
for i, holder in enumerate(real_holders): |
|
if holder.interp_prompt is None: |
|
logger.info(f'{i} = "{holder.prompt}"') |
|
else: |
|
logger.info(f'{i} = ({1.-holder.interp_weight:.2f})"{holder.prompt}" -> ({holder.interp_weight:.2f})"{holder.interp_prompt}"') |
|
|
|
final_pooled_dict = {"pooled_output": final_pooled} |
|
if options.add_dict is not None: |
|
final_pooled_dict.update(options.add_dict) |
|
|
|
clip.add_hooks_to_dict(final_pooled_dict) |
|
return [[final_cond, final_pooled_dict]] |
|
|
|
|
|
def pad_cond(cond: Tensor, target_length: int): |
|
|
|
|
|
curr_length = cond.shape[1] |
|
if curr_length < target_length: |
|
pad_length = target_length - curr_length |
|
|
|
left_pad = pad_length // 2 |
|
right_pad = pad_length - left_pad |
|
|
|
cond = F.pad(cond, (0, 0, left_pad, right_pad)) |
|
return cond |
|
|
|
|
|
def evaluate_value_schedule(text: str, length: int): |
|
text = strip_input(text) |
|
if len(text) == 0: |
|
raise Exception("No text provided to Value Scheduling.") |
|
|
|
if text.startswith('"'): |
|
formats = [SFormat.JSON, SFormat.PYTH] |
|
else: |
|
formats = [SFormat.PYTH, SFormat.JSON] |
|
for format in formats: |
|
if format is SFormat.JSON: |
|
|
|
|
|
json_matches, json_errors = get_matches_and_errors(text, _regex_value_json) |
|
if len(json_errors) == 0: |
|
return parse_value_groups(json_matches, length) |
|
elif format is SFormat.PYTH: |
|
|
|
|
|
pyth_matches, pyth_errors = get_matches_and_errors(text, _regex_value_pyth) |
|
if len(pyth_errors) == 0: |
|
return parse_value_groups(pyth_matches, length) |
|
|
|
|
|
|
|
|
|
if len(json_matches) > len(pyth_matches): |
|
real_errors = json_errors |
|
assumed = SFormat.JSON |
|
elif len(json_matches) < len(pyth_matches): |
|
real_errors = pyth_errors |
|
assumed = SFormat.PYTH |
|
elif len(json_errors) < len(pyth_errors): |
|
real_errors = json_errors |
|
assumed = SFormat.JSON |
|
else: |
|
|
|
if text.startswith('"'): |
|
real_errors = json_errors |
|
assumed = SFormat.JSON |
|
else: |
|
real_errors = pyth_errors |
|
assumed = SFormat.PYTH |
|
|
|
error_msg_list = [] |
|
if len(real_errors) == 1: |
|
error_msg_list.append(f"Found 1 issue in value schedule (assumed {assumed} format):") |
|
else: |
|
error_msg_list.append(f"Found {len(real_errors)} issues in value schedule (assumed {assumed} format):") |
|
for error in real_errors: |
|
error_msg_list.append(f"Position {error.start} to {error.end}: '{error.text}'") |
|
error_msg = "\n".join(error_msg_list) |
|
raise Exception(error_msg) |
|
|
|
|
|
def parse_value_groups(groups: list[tuple], length: int): |
|
|
|
pairs: list[InputPair] |
|
errors: list[ParseErrorReport] |
|
|
|
pairs, errors = handle_float_vals(groups) |
|
if len(errors) == 0: |
|
|
|
pairs, errors = handle_group_idxs(pairs, length) |
|
if len(pairs) == 0: |
|
errors.append(ParseErrorReport(idx_str='No valid idxs provided', val_str='', reason='Provided ranges might not be selecting anything.')) |
|
if len(errors) > 0: |
|
error_msg_list = [] |
|
issues_formatted = f"{len(errors)} issue{'s' if len(errors)> 1 else ''}" |
|
error_msg_list.append(f"Found {issues_formatted} with idxs/vals:") |
|
for error in errors: |
|
error_msg_list.append(f"{error.idx_str}: {error.reason}") |
|
error_msg = "\n".join(error_msg_list) |
|
raise Exception(error_msg) |
|
|
|
final_vals = handle_val_interpolation(pairs, length) |
|
return final_vals |
|
|
|
|
|
def handle_float_vals(groups: list[tuple]): |
|
actual_pairs: list[InputPair] = [] |
|
errors: list[ParseErrorReport] = [] |
|
for idx_str, val_str in groups: |
|
val_str = strip_value(val_str) |
|
try: |
|
val = float(val_str) |
|
except ValueError: |
|
errors.append(ParseErrorReport(idx_str, val_str, f"Value '{val_str}' is not a valid number")) |
|
continue |
|
actual_pairs.append(InputPair(idx_str, val)) |
|
return actual_pairs, errors |
|
|
|
|
|
def handle_val_interpolation(pairs: list[InputPair], length: int): |
|
if length == 0: |
|
length = max(pairs, key=lambda x: x.idx).idx+1 |
|
real_vals = [None] * length |
|
|
|
prev_pair = None |
|
for pair in pairs: |
|
|
|
if prev_pair is None: |
|
for i in range(0, pair.idx+1): |
|
if i >= length: |
|
break |
|
real_vals[i] = pair.val |
|
|
|
elif prev_pair.idx == pair.idx-1: |
|
if pair.idx < length: |
|
real_vals[pair.idx] = pair.val |
|
else: |
|
|
|
if prev_pair.hold: |
|
|
|
for i in range(prev_pair.idx+1, pair.idx): |
|
if i >= length: |
|
continue |
|
real_vals[i] = prev_pair.val |
|
if pair.idx < length: |
|
real_vals[pair.idx] = pair.val |
|
|
|
else: |
|
diff_len = abs(pair.idx-prev_pair.idx)+1 |
|
interp_idxs = InterpolationMethod.get_weights(num_from=prev_pair.idx, num_to=pair.idx, length=diff_len, |
|
method=InterpolationMethod.LINEAR) |
|
interp_vals = InterpolationMethod.get_weights(num_from=prev_pair.val, num_to=pair.val, length=diff_len, |
|
method=InterpolationMethod.LINEAR) |
|
for idx, val in zip(interp_idxs, interp_vals): |
|
if idx >= length: |
|
continue |
|
real_vals[round(float(idx))] = float(val) |
|
prev_pair = pair |
|
|
|
|
|
prev_val = None |
|
for i in range(len(real_vals)): |
|
if real_vals[i] is None: |
|
real_vals[i] = prev_val |
|
else: |
|
prev_val = real_vals[i] |
|
return real_vals |
|
|
|
|
|
def handle_group_idxs(pairs: list[InputPair], length: int): |
|
actual_pairs: list[InputPair] = [] |
|
errors: list[ParseErrorReport] = [] |
|
for pair in pairs: |
|
idx_str, val_str = pair.idx, pair.val |
|
idx_str: str = idx_str.strip() |
|
hold = False |
|
|
|
if idx_str.startswith(':'): |
|
errors.append(ParseErrorReport(idx_str, val_str, "Idx can't begin with ':'")) |
|
continue |
|
|
|
if idx_str.count(':') > 1: |
|
errors.append(ParseErrorReport(idx_str, val_str, "Idx can't have more than one ':'")) |
|
if idx_str.endswith(':'): |
|
hold = True |
|
idx_str = idx_str[:-1] |
|
try: |
|
idxs = convert_str_to_indexes(idx_str, length, allow_range=True, allow_missing=True, fix_reverse=True, same_is_one=True, allow_decimal=True) |
|
except SelectError as e: |
|
errors.append(ParseErrorReport(idx_str, val_str, f"Couldn't convert idxs; {str(e)}")) |
|
continue |
|
for idx in idxs: |
|
actual_pairs.append(InputPair(idx, val_str, hold)) |
|
return actual_pairs, errors |
|
|
|
|
|
def get_matches_and_errors(text: str, pattern: re.Pattern) -> tuple[list, list[RegexErrorReport]]: |
|
last_match_end = 0 |
|
matches = [] |
|
errors: list[RegexErrorReport] = [] |
|
|
|
for match in re.finditer(pattern, text): |
|
start, end = match.span() |
|
|
|
if start != last_match_end: |
|
errors.append(RegexErrorReport(last_match_end, start, text[last_match_end:start].replace('\n','\t'))) |
|
|
|
last_match_end = end |
|
|
|
matches.append(match.groups()) |
|
|
|
|
|
if last_match_end != len(text): |
|
errors.append(RegexErrorReport(last_match_end, len(text), text[last_match_end:].replace('\n','\t'))) |
|
|
|
return matches, errors |
|
|
|
|
|
def is_surrounded(text: str, pair): |
|
return text.startswith(pair[0]) and text.endswith(pair[1]) |
|
|
|
|
|
def is_surrounded_pairs(text: str, pairs): |
|
for pair in pairs: |
|
if is_surrounded(text, pair): |
|
return True |
|
return False |
|
|
|
|
|
def strip_value(text: str, limit=-1): |
|
text = text.strip() |
|
|
|
symbol_pairs = [ |
|
("(", ")"), |
|
("[", "]"), |
|
("{", "}"), |
|
('"', '"'), |
|
("'", "'"), |
|
] |
|
while limit != 0 and is_surrounded_pairs(text, symbol_pairs): |
|
text = text[1:-1].strip() |
|
limit -= 1 |
|
return text |
|
|
|
|
|
def strip_input(text: str): |
|
text = text.strip() |
|
|
|
if text.startswith('{') and text.endswith('}'): |
|
return text[1:-1].strip() |
|
return text |
|
|