|
import ast
|
|
import math
|
|
import random
|
|
import operator as op
|
|
|
|
|
|
class AnyType(str):
|
|
def __ne__(self, __value: object) -> bool:
|
|
return False
|
|
|
|
|
|
|
|
any = AnyType("*")
|
|
|
|
operators = {
|
|
ast.Add: op.add,
|
|
ast.Sub: op.sub,
|
|
ast.Mult: op.mul,
|
|
ast.Div: op.truediv,
|
|
ast.FloorDiv: op.floordiv,
|
|
ast.Pow: op.pow,
|
|
ast.BitXor: op.xor,
|
|
ast.USub: op.neg,
|
|
ast.Mod: op.mod,
|
|
ast.BitAnd: op.and_,
|
|
ast.BitOr: op.or_,
|
|
ast.Invert: op.invert,
|
|
ast.And: lambda a, b: 1 if a and b else 0,
|
|
ast.Or: lambda a, b: 1 if a or b else 0,
|
|
ast.Not: lambda a: 0 if a else 1,
|
|
ast.RShift: op.rshift,
|
|
ast.LShift: op.lshift
|
|
}
|
|
|
|
|
|
functions = {
|
|
"round": {
|
|
"args": (1, 2),
|
|
"call": lambda a, b = None: round(a, b),
|
|
"hint": "number, dp? = 0"
|
|
},
|
|
"ceil": {
|
|
"args": (1, 1),
|
|
"call": lambda a: math.ceil(a),
|
|
"hint": "number"
|
|
},
|
|
"floor": {
|
|
"args": (1, 1),
|
|
"call": lambda a: math.floor(a),
|
|
"hint": "number"
|
|
},
|
|
"min": {
|
|
"args": (2, None),
|
|
"call": lambda *args: min(*args),
|
|
"hint": "...numbers"
|
|
},
|
|
"max": {
|
|
"args": (2, None),
|
|
"call": lambda *args: max(*args),
|
|
"hint": "...numbers"
|
|
},
|
|
"randomint": {
|
|
"args": (2, 2),
|
|
"call": lambda a, b: random.randint(a, b),
|
|
"hint": "min, max"
|
|
},
|
|
"randomchoice": {
|
|
"args": (2, None),
|
|
"call": lambda *args: random.choice(args),
|
|
"hint": "...numbers"
|
|
},
|
|
"sqrt": {
|
|
"args": (1, 1),
|
|
"call": lambda a: math.sqrt(a),
|
|
"hint": "number"
|
|
},
|
|
"int": {
|
|
"args": (1, 1),
|
|
"call": lambda a = None: int(a),
|
|
"hint": "number"
|
|
},
|
|
"iif": {
|
|
"args": (3, 3),
|
|
"call": lambda a, b, c = None: b if a else c,
|
|
"hint": "value, truepart, falsepart"
|
|
},
|
|
}
|
|
|
|
autocompleteWords = list({
|
|
"text": x,
|
|
"value": f"{x}()",
|
|
"showValue": False,
|
|
"hint": f"{functions[x]['hint']}",
|
|
"caretOffset": -1
|
|
} for x in functions.keys())
|
|
|
|
|
|
class MathExpression:
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"expression": ("STRING", {"multiline": True, "dynamicPrompts": False, "pysssss.autocomplete": {
|
|
"words": autocompleteWords,
|
|
"separator": ""
|
|
}}),
|
|
},
|
|
"optional": {
|
|
"a": (any, ),
|
|
"b": (any,),
|
|
"c": (any, ),
|
|
},
|
|
"hidden": {"extra_pnginfo": "EXTRA_PNGINFO",
|
|
"prompt": "PROMPT"},
|
|
}
|
|
|
|
RETURN_TYPES = ("INT", "FLOAT", )
|
|
FUNCTION = "evaluate"
|
|
CATEGORY = "utils"
|
|
OUTPUT_NODE = True
|
|
|
|
@classmethod
|
|
def IS_CHANGED(s, expression, **kwargs):
|
|
if "random" in expression:
|
|
return float("nan")
|
|
return expression
|
|
|
|
def get_widget_value(self, extra_pnginfo, prompt, node_name, widget_name):
|
|
workflow = extra_pnginfo["workflow"] if "workflow" in extra_pnginfo else { "nodes": [] }
|
|
node_id = None
|
|
for node in workflow["nodes"]:
|
|
name = node["type"]
|
|
if "properties" in node:
|
|
if "Node name for S&R" in node["properties"]:
|
|
name = node["properties"]["Node name for S&R"]
|
|
if name == node_name:
|
|
node_id = node["id"]
|
|
break
|
|
if "title" in node:
|
|
name = node["title"]
|
|
if name == node_name:
|
|
node_id = node["id"]
|
|
break
|
|
if node_id is not None:
|
|
values = prompt[str(node_id)]
|
|
if "inputs" in values:
|
|
if widget_name in values["inputs"]:
|
|
value = values["inputs"][widget_name]
|
|
if isinstance(value, list):
|
|
raise ValueError("Converted widgets are not supported via named reference, use the inputs instead.")
|
|
return value
|
|
raise NameError(f"Widget not found: {node_name}.{widget_name}")
|
|
raise NameError(f"Node not found: {node_name}.{widget_name}")
|
|
|
|
def get_size(self, target, property):
|
|
if isinstance(target, dict) and "samples" in target:
|
|
|
|
if property == "width":
|
|
return target["samples"].shape[3] * 8
|
|
return target["samples"].shape[2] * 8
|
|
else:
|
|
|
|
if property == "width":
|
|
return target.shape[2]
|
|
return target.shape[1]
|
|
|
|
def evaluate(self, expression, prompt, extra_pnginfo={}, a=None, b=None, c=None):
|
|
expression = expression.replace('\n', ' ').replace('\r', '')
|
|
node = ast.parse(expression, mode='eval').body
|
|
|
|
lookup = {"a": a, "b": b, "c": c}
|
|
|
|
def eval_op(node, l, r):
|
|
l = eval_expr(l)
|
|
r = eval_expr(r)
|
|
l = l if isinstance(l, int) else float(l)
|
|
r = r if isinstance(r, int) else float(r)
|
|
return operators[type(node.op)](l, r)
|
|
|
|
def eval_expr(node):
|
|
if isinstance(node, ast.Constant) or isinstance(node, ast.Num):
|
|
return node.n
|
|
elif isinstance(node, ast.BinOp):
|
|
return eval_op(node, node.left, node.right)
|
|
elif isinstance(node, ast.BoolOp):
|
|
return eval_op(node, node.values[0], node.values[1])
|
|
elif isinstance(node, ast.UnaryOp):
|
|
return operators[type(node.op)](eval_expr(node.operand))
|
|
elif isinstance(node, ast.Attribute):
|
|
if node.value.id in lookup:
|
|
if node.attr == "width" or node.attr == "height":
|
|
return self.get_size(lookup[node.value.id], node.attr)
|
|
|
|
return self.get_widget_value(extra_pnginfo, prompt, node.value.id, node.attr)
|
|
elif isinstance(node, ast.Name):
|
|
if node.id in lookup:
|
|
val = lookup[node.id]
|
|
if isinstance(val, (int, float, complex)):
|
|
return val
|
|
else:
|
|
raise TypeError(
|
|
f"Compex types (LATENT/IMAGE) need to reference their width/height, e.g. {node.id}.width")
|
|
raise NameError(f"Name not found: {node.id}")
|
|
elif isinstance(node, ast.Call):
|
|
if node.func.id in functions:
|
|
fn = functions[node.func.id]
|
|
l = len(node.args)
|
|
if l < fn["args"][0] or (fn["args"][1] is not None and l > fn["args"][1]):
|
|
if fn["args"][1] is None:
|
|
toErr = " or more"
|
|
else:
|
|
toErr = f" to {fn['args'][1]}"
|
|
raise SyntaxError(
|
|
f"Invalid function call: {node.func.id} requires {fn['args'][0]}{toErr} arguments")
|
|
args = []
|
|
for arg in node.args:
|
|
args.append(eval_expr(arg))
|
|
return fn["call"](*args)
|
|
raise NameError(f"Invalid function call: {node.func.id}")
|
|
elif isinstance(node, ast.Compare):
|
|
l = eval_expr(node.left)
|
|
r = eval_expr(node.comparators[0])
|
|
if isinstance(node.ops[0], ast.Eq):
|
|
return 1 if l == r else 0
|
|
if isinstance(node.ops[0], ast.NotEq):
|
|
return 1 if l != r else 0
|
|
if isinstance(node.ops[0], ast.Gt):
|
|
return 1 if l > r else 0
|
|
if isinstance(node.ops[0], ast.GtE):
|
|
return 1 if l >= r else 0
|
|
if isinstance(node.ops[0], ast.Lt):
|
|
return 1 if l < r else 0
|
|
if isinstance(node.ops[0], ast.LtE):
|
|
return 1 if l <= r else 0
|
|
raise NotImplementedError(
|
|
"Operator " + node.ops[0].__class__.__name__ + " not supported.")
|
|
else:
|
|
raise TypeError(node)
|
|
|
|
r = eval_expr(node)
|
|
return {"ui": {"value": [r]}, "result": (int(r), float(r),)}
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"MathExpression|pysssss": MathExpression,
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"MathExpression|pysssss": "Math Expression 🐍",
|
|
}
|
|
|
|
|