File size: 31,356 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 |
import re
import math
from typing import Union
import torch
from torch import Tensor
import torch.nn.functional as F
from dataclasses import dataclass, replace
from comfy.sd import CLIP
from comfy.utils import ProgressBar
import comfy.model_management
from .utils_model import InterpolationMethod
from .utils_motion import extend_list_to_batch_size
from .utils_scheduling import SelectError, TensorInterp, convert_str_to_indexes, lerp_tensors, slerp_tensors
from .logger import logger
###############################################
#----------------------------------------------
# JSON prompt format is as follows:
# "idxs": "prompt", ...
_regex_prompt_json = re.compile(r'"([\d:\-\.]+)\s*"\s*:\s*"([^"]*)"(?:\s*,\s*|$)')
# NOTE: I used ChatGPT to generate this regex and summary, as I couldn't be bothered.
# ([\d:\-\.]+): Matches idxs, which can be any combination of digits, colons, and periods.
# \s*: Matches optional whitespace.
# ":\s*": Matches the ":" separator with optional spaces.
# "([^"]*)": Captures the prompt, which can be any character except for double quotation marks.
# (?:\s*,\s*|$): This non-capturing group (?: ... ) matches either a comma (with optional spaces before or after) or the end of the string ($).
# pythonic prompt format is as follows:
# idxs = "prompt", ...
_regex_prompt_pyth = re.compile(r'([\d:\-\.]+)\s*=\s*"([^"]*)"(?:\s*,\s*|$)')
# NOTE: I used ChatGPT to generate this regex and summary, as I couldn't be bothered.
# ([\d:\-\.]+): Matches idx, which can be any combination of digits, colons, and periods.
# \s*=\s*: Matches the equal sign (=) with optional spaces on both sides.
# "([^"]*)": Captures the prompt, which can be any character except for double quotation marks.
# (?:\s*,\s*|$): Matches either a comma (with optional spaces before or after) or the end of the string ($).
# JSON value format is as follows:
# "idxs": value, ...
_regex_value_json = re.compile(r'"([\d:\-\.]+)\s*"\s*:\s*([^,]+)(?:\s*,\s*|$)')
# NOTE: I used ChatGPT to generate this regex and summary, as I couldn't be bothered.
# ([\d:\-\.]+): Matches idxs, which can be any combination of digits, colons, and periods.
# \s*: Matches optional whitespace.
# ":\s*: Matches the ":" separator with optional spaces.
# ([^,]+): Captures the value, which can be any character except for commas (this ensures that values are correctly separated).
# (?:\s*,\s*|$): Matches either a comma (with optional spaces before or after) or the end of the string ($).
# pythonic value format is as follows:
# idxs = value, ...
_regex_value_pyth = re.compile(r'([\d:\-\.]+)\s*=\s*([^,]+)(?:\s*,\s*|$)')
# NOTE: I used ChatGPT to generate this regex and summary, as I couldn't be bothered.
# ([\d:\-\.]+): Matches idx, which can be any combination of digits, colons, and periods.
# \s*=\s*: Matches the equal sign (=) with optional spaces on both sides.
# ([^,]+): Captures the value, which can be any character except for commas (this ensures that values are correctly separated).
# (?:\s*,\s*|$): Matches either a comma (with optional spaces before or after) or the end of the string ($).
#----------------------------------------------
###############################################
# verify that string only contains a-z, A-Z, 0-9, or _
_regex_key_value = re.compile(r'^[a-zA-Z0-9_]+$')
def verify_key_value(key: str, raise_error=True):
match = re.match(_regex_key_value, key)
if not match and raise_error:
raise Exception(f"Value key may only contain 'a-z', 'A-Z', '0-9', or '_', but was: '{key}'.")
return match is not None
class SFormat:
JSON = "json"
PYTH = "pythonic"
@dataclass
class RegexErrorReport:
start: int
end: int
text: str
reason: str = None
@dataclass
class InputPair:
idx: int
val: Union[int, str, Tensor]
hold: bool = False
end: bool = False
@dataclass
class CondHolder:
idx: int
prompt: str
raw_prompt: str
cond: Tensor
pooled: Tensor
hold: bool = False
interp_weight: float = None
interp_prompt: str = None
@dataclass
class ParseErrorReport:
idx_str: str
val_str: str
reason: str
@dataclass
class PromptOptions:
interp: str = TensorInterp.LERP
prepend_text: str = ''
append_text: str = ''
values_replace: dict[str, list[float]] = None
print_schedule: bool = False
add_dict: dict[str] = None
def evaluate_prompt_schedule(text: str, length: int, clip: CLIP, options: PromptOptions):
text = strip_input(text)
if len(text) == 0:
raise Exception("No text provided to Prompt Scheduling.")
# prioritize formats based on best guess to minimize redo's
if text.startswith('"'):
formats = [SFormat.JSON, SFormat.PYTH]
else:
formats = [SFormat.PYTH, SFormat.JSON]
for format in formats:
if format is SFormat.JSON:
# check JSON format
# if no errors found, assume this is the right format and pass on to parsing individual values
json_matches, json_errors = get_matches_and_errors(text, _regex_prompt_json)
if len(json_errors) == 0:
return parse_prompt_groups(json_matches, length, clip, options)
elif format is SFormat.PYTH:
# check pythonic format
# if no errors found, assume this is the right format and pass on to parsing individual values
pyth_matches, pyth_errors = get_matches_and_errors(text, _regex_prompt_pyth)
if len(pyth_errors) == 0:
return parse_prompt_groups(pyth_matches, length, clip, options)
# since both formats have errors, check which format is more 'correct' for the input
# priority:
# 1 - most matches
# 2 - least errors
if len(json_matches) > len(pyth_matches):
real_errors = json_errors
assumed = SFormat.JSON
elif len(json_matches) < len(pyth_matches):
real_errors = pyth_errors
assumed = SFormat.PYTH
elif len(json_errors) < len(pyth_errors):
real_errors = json_errors
assumed = SFormat.JSON
else:
logger.warn("same amount of matches+errors for prompt!")
real_errors = pyth_errors
assumed = SFormat.PYTH
# TODO: make separate case for when format is unknown, so that both are displayed to the user
error_msg_list = []
if len(real_errors) == 1:
error_msg_list.append(f"Found 1 issue in prompt schedule (assumed {assumed} format):")
else:
error_msg_list.append(f"Found {len(real_errors)} issues in prompt schedule (assumed {assumed} format):")
for error in real_errors:
error_msg_list.append(f"Position {error.start} to {error.end}: '{error.text}'")
error_msg = "\n".join(error_msg_list)
raise Exception(error_msg)
def parse_prompt_groups(groups: list[tuple], length: int, clip: CLIP, options: PromptOptions):
pairs: list[InputPair]
errors: list[ParseErrorReport]
# turn group tuples into InputPairs
pairs = [InputPair(x[0], x[1]) for x in groups]
# perform first parse, to get idea of indexes to handle
pairs, errors = handle_group_idxs(pairs, length)
if len(errors) > 0:
error_msg_list = []
issues_formatted = f"{len(errors)} issue{'s' if len(errors)> 1 else ''}"
error_msg_list.append(f"Found {issues_formatted} with idxs:")
for error in errors:
error_msg_list.append(f"{error.idx_str}: {error.reason}")
error_msg = "\n".join(error_msg_list)
raise Exception(error_msg)
prepare_prompts(pairs, options)
final_vals = handle_prompt_interpolation(pairs, length, clip, options)
return final_vals
def prepare_prompts(pairs: list[InputPair], options: PromptOptions):
for pair in pairs:
prepend_text = options.prepend_text.strip()
append_text = options.append_text.strip()
prompt = pair.val.strip()
# when adding prepend and append text, handle commas properly
# prepend text
if len(prepend_text) > 0:
while prepend_text.endswith(','):
prepend_text = prepend_text[:-1].strip()
if prompt.startswith(','):
prepend_text = f"{prepend_text}"
else:
prepend_text = f"{prepend_text}, "
prompt = prepend_text + prompt
# append text
if len(append_text) > 0:
while append_text.startswith(','):
append_text = append_text[1:].strip()
if prompt.endswith(','):
append_text = f" {append_text}"
else:
append_text = f", {append_text}"
prompt = prompt + append_text
# update value w/ prompt
pair.val = prompt
def apply_values_replace_to_prompt(prompt: str, idx: int, values_replace: Union[None, dict[str, list[float]]]):
# if no values to replace, do nothing
if values_replace is None:
return prompt
for key, value in values_replace.items():
# use FizzNodes `` notation
match_str = '`' + key + '`'
value_str = f"{value[idx]}"
prompt = prompt.replace(match_str, value_str)
return prompt
def handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP, options: PromptOptions):
if length == 0:
length = max(pairs, key=lambda x: x.idx).idx+1
# prepare values_replace (should match length)
values_replace = options.values_replace
if values_replace is not None:
values_replace.copy()
for key, value in values_replace.items():
if len(value) < length:
values_replace[key] = extend_list_to_batch_size(value, length)
scheduled_keyframes = []
if clip.use_clip_schedule:
clip = clip.clone()
scheduled_keyframes = clip.patcher.forced_hooks.get_hooks_for_clip_schedule()
pairs_lengths = len(pairs) * max(1, len(scheduled_keyframes))
pbar_total = length + pairs_lengths
pbar = ProgressBar(pbar_total)
# for now, use FizzNodes approach of calculating max size of tokens beforehand;
# this can up to double total encoding time, as this will be done again.
# TODO: do this dynamically to save encoding time
max_size = 0
for pair in pairs:
prepared_prompt = apply_values_replace_to_prompt(pair.val, 0, values_replace=values_replace)
cond: Tensor = clip.encode_from_tokens(clip.tokenize(prepared_prompt))
max_size = max(max_size, cond.shape[1])
pbar.update(1)
# if do not need to schedule clip with hooks, do nothing special
if not clip.use_clip_schedule:
return _handle_prompt_interpolation(pairs, length, clip, options, values_replace, max_size, pbar)
# otherwise, need to account for keyframes on forced_hooks
full_output = []
for i, scheduled_opts in enumerate(scheduled_keyframes):
clip.patcher.forced_hooks.reset()
clip.patcher.unpatch_hooks()
t_range = scheduled_opts[0]
hooks_keyframes = scheduled_opts[1]
for hook, keyframe in hooks_keyframes:
hook.hook_keyframe._current_keyframe = keyframe
try:
# don't print_schedule on non-first iteration
orig_print_schedule = options.print_schedule
if orig_print_schedule and i != 0:
options.print_schedule = False
schedule_output = _handle_prompt_interpolation(pairs, length, clip, options, values_replace, max_size, pbar)
finally:
options.print_schedule = orig_print_schedule
for cond, pooled_dict in schedule_output:
pooled_dict: dict[str]
# add clip_start_percent and clip_end_percent in pooled
pooled_dict["clip_start_percent"] = t_range[0]
pooled_dict["clip_end_percent"] = t_range[1]
full_output.extend(schedule_output)
return full_output
def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP, options: PromptOptions,
values_replace: dict[str, list[float]], max_size: int, pbar: ProgressBar):
real_holders: list[CondHolder] = [None] * length
real_cond = [None] * length
real_pooled = [None] * length
prev_holder: Union[CondHolder, None] = None
for idx, pair in enumerate(pairs):
holder = None
is_over_length = False
# if no last pair is set, then use first provided val up to the idx
if prev_holder is None:
for i in range(idx, pair.idx+1):
if i >= length:
is_over_length = True
continue
real_prompt = apply_values_replace_to_prompt(pair.val, i, values_replace=values_replace)
if holder is None or holder.prompt != real_prompt:
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond = pad_cond(cond, target_length=max_size)
holder = CondHolder(idx=i, prompt=real_prompt, raw_prompt=pair.val, cond=cond, pooled=pooled, hold=pair.hold)
else:
holder = replace(holder)
holder.idx = i
real_cond[i] = cond
real_pooled[i] = pooled
real_holders[i] = holder
pbar.update(1)
comfy.model_management.throw_exception_if_processing_interrupted()
# if idx is exactly one greater than the one before, nothing special
elif prev_holder.idx == pair.idx-1:
comfy.model_management.throw_exception_if_processing_interrupted()
holder = prev_holder
if pair.idx < length:
real_prompt = apply_values_replace_to_prompt(pair.val, pair.idx, values_replace=values_replace)
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond = pad_cond(cond, target_length=max_size)
holder = CondHolder(idx=pair.idx, prompt=real_prompt, raw_prompt=pair.val, cond=cond, pooled=pooled, hold=pair.hold)
real_cond[pair.idx] = cond
real_pooled[pair.idx] = pooled
real_holders[pair.idx] = holder
pbar.update(1)
else:
# if holding value, no interpolation
if prev_holder.hold:
# keep same value as last_holder, then calculate current index cond;
# however, need to check if real_prompt remains the same
for i in range(prev_holder.idx+1, pair.idx):
if i >= length:
is_over_length = True
continue
if holder is None:
holder = prev_holder
real_prompt = apply_values_replace_to_prompt(pair.val, i, values_replace=values_replace)
if holder.prompt != real_prompt:
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond = pad_cond(cond, target_length=max_size)
holder = CondHolder(idx=i, prompt=real_prompt, raw_prompt=pair.val, cond=cond, pooled=pooled, hold=pair.hold)
else:
holder = replace(holder)
holder.idx = i
real_cond[i] = holder.cond
real_pooled[i] = holder.pooled
real_holders[i] = holder
pbar.update(1)
comfy.model_management.throw_exception_if_processing_interrupted()
if pair.idx < length:
real_prompt = apply_values_replace_to_prompt(pair.val, pair.idx, values_replace=values_replace)
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond = pad_cond(cond, target_length=max_size)
holder = CondHolder(idx=pair.idx, prompt=real_prompt, raw_prompt=pair.val, cond=cond, pooled=pooled, hold=pair.hold)
real_cond[pair.idx] = cond
real_pooled[pair.idx] = pooled
real_holders[pair.idx] = holder
pbar.update(1)
comfy.model_management.throw_exception_if_processing_interrupted()
# otherwise, interpolate
else:
diff_len = abs(pair.idx-prev_holder.idx)+1
interp_idxs = InterpolationMethod.get_weights(num_from=prev_holder.idx, num_to=pair.idx, length=diff_len,
method=InterpolationMethod.LINEAR)
interp_weights = InterpolationMethod.get_weights(num_from=0.0, num_to=1.0, length=diff_len,
method=InterpolationMethod.LINEAR)
cond_to = None
pooled_to = None
cond_from = None
holder = None
interm_holder = prev_holder
for raw_idx, weight in zip(interp_idxs, interp_weights):
if raw_idx >= length:
is_over_length = True
continue
idx_int = round(float(raw_idx))
# calculate cond_to stuff if not done yet
real_prompt = apply_values_replace_to_prompt(pair.val, idx_int, values_replace=values_replace)
if holder is None or holder.prompt != real_prompt:
cond_to, pooled_to = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond_to = pad_cond(cond_to, target_length=max_size)
holder = CondHolder(idx=idx_int, prompt=real_prompt, raw_prompt=pair.val, cond=cond_to, pooled=pooled_to, hold=pair.hold)
# calculate interm_holder stuff if needed
real_prompt = apply_values_replace_to_prompt(interm_holder.raw_prompt, idx_int, values_replace=values_replace)
if interm_holder.prompt != real_prompt:
cond_from, pooled_from = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond_from = pad_cond(cond_from, target_length=max_size)
interm_holder = CondHolder(idx=idx_int, prompt=real_prompt, raw_prompt=interm_holder.raw_prompt, cond=cond_from, pooled=pooled_from, hold=holder.hold)
else:
interm_holder = CondHolder(idx=interm_holder.idx, prompt=interm_holder.prompt, raw_prompt=interm_holder.raw_prompt, cond=interm_holder.cond, pooled=interm_holder.pooled, hold=interm_holder.hold)
# interpolate conds
if options.interp == TensorInterp.LERP:
cond_interp = lerp_tensors(tensor_from=interm_holder.cond, tensor_to=cond_to, strength_to=weight)
elif options.interp == TensorInterp.SLERP:
cond_interp = slerp_tensors(tensor_from=interm_holder.cond, tensor_to=cond_to, strength_to=weight)
pooled_interp = pooled_to
if math.isclose(weight, 0.0):
pooled_interp = interm_holder.pooled
interm_holder = CondHolder(idx=idx_int, prompt=interm_holder.prompt, raw_prompt=interm_holder.raw_prompt, cond=cond_interp, pooled=pooled_interp, hold=holder.hold,
interp_weight=weight, interp_prompt=holder.prompt)
real_cond[idx_int] = cond_interp
real_pooled[idx_int] = pooled_interp
real_holders[idx_int] = interm_holder
pbar.update(1)
comfy.model_management.throw_exception_if_processing_interrupted()
if is_over_length:
break
assert holder is not None
prev_holder = holder
# fill in None gaps with last used values
# TODO: review if this works as intended, or if needs to be a bit more thorough
prev_holder = None
for i in range(len(real_holders)):
if real_holders[i] is None:
# check if any value replacement needs to be accounted for
real_prompt = apply_values_replace_to_prompt(prev_holder.raw_prompt, i, values_replace=values_replace)
if prev_holder.prompt != real_prompt:
cond, pooled = clip.encode_from_tokens(clip.tokenize(real_prompt), return_pooled=True)
cond = pad_cond(cond, target_length=max_size)
prev_holder = CondHolder(idx=i, prompt=real_prompt, raw_prompt=prev_holder.raw_prompt, cond=cond, pooled=pooled, hold=prev_holder.hold)
real_cond[i] = prev_holder.cond
real_pooled[i] = prev_holder.pooled
real_holders[i] = prev_holder
pbar.update(1)
else:
prev_holder = real_holders[i]
final_cond = torch.cat(real_cond, dim=0)
final_pooled = torch.cat(real_pooled, dim=0)
if options.print_schedule:
logger.info(f"PromptScheduling ({len(real_holders)} prompts)")
for i, holder in enumerate(real_holders):
if holder.interp_prompt is None:
logger.info(f'{i} = "{holder.prompt}"')
else:
logger.info(f'{i} = ({1.-holder.interp_weight:.2f})"{holder.prompt}" -> ({holder.interp_weight:.2f})"{holder.interp_prompt}"')
# cond is a list[list[Tensor, dict[str: Any]]] format
final_pooled_dict = {"pooled_output": final_pooled}
if options.add_dict is not None:
final_pooled_dict.update(options.add_dict)
# add hooks, if needed
clip.add_hooks_to_dict(final_pooled_dict)
return [[final_cond, final_pooled_dict]]
def pad_cond(cond: Tensor, target_length: int):
# FizzNodes-style cond padding
# TODO: test out other methods of padding
curr_length = cond.shape[1]
if curr_length < target_length:
pad_length = target_length - curr_length
# FizzNodes pads the tensor on both ends
left_pad = pad_length // 2
right_pad = pad_length - left_pad
# perform padding
cond = F.pad(cond, (0, 0, left_pad, right_pad))
return cond
def evaluate_value_schedule(text: str, length: int):
text = strip_input(text)
if len(text) == 0:
raise Exception("No text provided to Value Scheduling.")
# prioritize formats based on best guess to minimize redo's
if text.startswith('"'):
formats = [SFormat.JSON, SFormat.PYTH]
else:
formats = [SFormat.PYTH, SFormat.JSON]
for format in formats:
if format is SFormat.JSON:
# check JSON format
# if no errors found, assume this is the right format and pass on to parsing individual values
json_matches, json_errors = get_matches_and_errors(text, _regex_value_json)
if len(json_errors) == 0:
return parse_value_groups(json_matches, length)
elif format is SFormat.PYTH:
# check pythonic format
# if no errors found, assume this is the right format and pass on to parsing individual values
pyth_matches, pyth_errors = get_matches_and_errors(text, _regex_value_pyth)
if len(pyth_errors) == 0:
return parse_value_groups(pyth_matches, length)
# since both formats have errors, check which format is more 'correct' for the input
# priority:
# 1 - most matches
# 2 - least errors
if len(json_matches) > len(pyth_matches):
real_errors = json_errors
assumed = SFormat.JSON
elif len(json_matches) < len(pyth_matches):
real_errors = pyth_errors
assumed = SFormat.PYTH
elif len(json_errors) < len(pyth_errors):
real_errors = json_errors
assumed = SFormat.JSON
else:
#logger.info("same amount of matches+errors for value!")
if text.startswith('"'):
real_errors = json_errors
assumed = SFormat.JSON
else:
real_errors = pyth_errors
assumed = SFormat.PYTH
# TODO: make separate case for when format is unknown, so that both are displayed to the user
error_msg_list = []
if len(real_errors) == 1:
error_msg_list.append(f"Found 1 issue in value schedule (assumed {assumed} format):")
else:
error_msg_list.append(f"Found {len(real_errors)} issues in value schedule (assumed {assumed} format):")
for error in real_errors:
error_msg_list.append(f"Position {error.start} to {error.end}: '{error.text}'")
error_msg = "\n".join(error_msg_list)
raise Exception(error_msg)
def parse_value_groups(groups: list[tuple], length: int):
#logger.info(groups)
pairs: list[InputPair]
errors: list[ParseErrorReport]
# perform first parse, where we convert vals to floats
pairs, errors = handle_float_vals(groups)
if len(errors) == 0:
# perform second parse, to get idea of indexes to handle
pairs, errors = handle_group_idxs(pairs, length)
if len(pairs) == 0:
errors.append(ParseErrorReport(idx_str='No valid idxs provided', val_str='', reason='Provided ranges might not be selecting anything.'))
if len(errors) > 0:
error_msg_list = []
issues_formatted = f"{len(errors)} issue{'s' if len(errors)> 1 else ''}"
error_msg_list.append(f"Found {issues_formatted} with idxs/vals:")
for error in errors:
error_msg_list.append(f"{error.idx_str}: {error.reason}")
error_msg = "\n".join(error_msg_list)
raise Exception(error_msg)
# perform third parse, where hold and interpolation is used to fill in any in-between values
final_vals = handle_val_interpolation(pairs, length)
return final_vals
def handle_float_vals(groups: list[tuple]):
actual_pairs: list[InputPair] = []
errors: list[ParseErrorReport] = []
for idx_str, val_str in groups:
val_str = strip_value(val_str)
try:
val = float(val_str)
except ValueError:
errors.append(ParseErrorReport(idx_str, val_str, f"Value '{val_str}' is not a valid number"))
continue
actual_pairs.append(InputPair(idx_str, val))
return actual_pairs, errors
def handle_val_interpolation(pairs: list[InputPair], length: int):
if length == 0:
length = max(pairs, key=lambda x: x.idx).idx+1
real_vals = [None] * length
prev_pair = None
for pair in pairs:
# if no last pair is set, then use first provided val up to the idx
if prev_pair is None:
for i in range(0, pair.idx+1):
if i >= length:
break
real_vals[i] = pair.val
# if idx is exactly one greater than the one before, nothing special
elif prev_pair.idx == pair.idx-1:
if pair.idx < length:
real_vals[pair.idx] = pair.val
else:
# if holding value, no interpolation
if prev_pair.hold:
# keep same value as last_pair, then assign current index value
for i in range(prev_pair.idx+1, pair.idx):
if i >= length:
continue
real_vals[i] = prev_pair.val
if pair.idx < length:
real_vals[pair.idx] = pair.val
# otherwise, interpolate
else:
diff_len = abs(pair.idx-prev_pair.idx)+1
interp_idxs = InterpolationMethod.get_weights(num_from=prev_pair.idx, num_to=pair.idx, length=diff_len,
method=InterpolationMethod.LINEAR)
interp_vals = InterpolationMethod.get_weights(num_from=prev_pair.val, num_to=pair.val, length=diff_len,
method=InterpolationMethod.LINEAR)
for idx, val in zip(interp_idxs, interp_vals):
if idx >= length:
continue
real_vals[round(float(idx))] = float(val)
prev_pair = pair
# fill in None gaps with last used value
# TODO: review if this works as intended, or if needs to be a bit more thorough
prev_val = None
for i in range(len(real_vals)):
if real_vals[i] is None:
real_vals[i] = prev_val
else:
prev_val = real_vals[i]
return real_vals
def handle_group_idxs(pairs: list[InputPair], length: int):
actual_pairs: list[InputPair] = []
errors: list[ParseErrorReport] = []
for pair in pairs:
idx_str, val_str = pair.idx, pair.val
idx_str: str = idx_str.strip()
hold = False
# if starts with :, wrong
if idx_str.startswith(':'):
errors.append(ParseErrorReport(idx_str, val_str, "Idx can't begin with ':'"))
continue
# if has more than one :, wrong
if idx_str.count(':') > 1:
errors.append(ParseErrorReport(idx_str, val_str, "Idx can't have more than one ':'"))
if idx_str.endswith(':'):
hold = True
idx_str = idx_str[:-1]
try:
idxs = convert_str_to_indexes(idx_str, length, allow_range=True, allow_missing=True, fix_reverse=True, same_is_one=True, allow_decimal=True)
except SelectError as e:
errors.append(ParseErrorReport(idx_str, val_str, f"Couldn't convert idxs; {str(e)}"))
continue
for idx in idxs:
actual_pairs.append(InputPair(idx, val_str, hold))
return actual_pairs, errors
def get_matches_and_errors(text: str, pattern: re.Pattern) -> tuple[list, list[RegexErrorReport]]:
last_match_end = 0
matches = []
errors: list[RegexErrorReport] = []
for match in re.finditer(pattern, text):
start, end = match.span()
# if there is any text between last match and current, consider as error
if start != last_match_end:
errors.append(RegexErrorReport(last_match_end, start, text[last_match_end:start].replace('\n','\t')))
# update match
last_match_end = end
# store match
matches.append(match.groups())
# check for any trailing unmatched text
if last_match_end != len(text):
errors.append(RegexErrorReport(last_match_end, len(text), text[last_match_end:].replace('\n','\t')))
return matches, errors
def is_surrounded(text: str, pair):
return text.startswith(pair[0]) and text.endswith(pair[1])
def is_surrounded_pairs(text: str, pairs):
for pair in pairs:
if is_surrounded(text, pair):
return True
return False
def strip_value(text: str, limit=-1):
text = text.strip()
# strip common paired symbols
symbol_pairs = [
("(", ")"),
("[", "]"),
("{", "}"),
('"', '"'),
("'", "'"),
]
while limit != 0 and is_surrounded_pairs(text, symbol_pairs):
text = text[1:-1].strip()
limit -= 1
return text
def strip_input(text: str):
text = text.strip()
# strip JSON brackets, if needed
if text.startswith('{') and text.endswith('}'):
return text[1:-1].strip()
return text
|