|
|
|
|
|
MAX_CYCLE = 3000 |
|
|
|
import itertools |
|
import operator |
|
|
|
from typing import Dict, List, Optional |
|
|
|
from .. import polyfill, variables |
|
from ..exc import unimplemented |
|
|
|
from .base import MutableLocal, VariableTracker |
|
from .constant import ConstantVariable |
|
|
|
|
|
class ItertoolsVariable(VariableTracker): |
|
def __init__(self, value, **kwargs): |
|
super().__init__(**kwargs) |
|
self.value = value |
|
|
|
def __repr__(self): |
|
return f"ItertoolsVariable({self.value})" |
|
|
|
def python_type(self): |
|
return type(self.value) |
|
|
|
def as_python_constant(self): |
|
return self.value |
|
|
|
def call_function( |
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" |
|
) -> "VariableTracker": |
|
if ( |
|
self.value is itertools.product |
|
and not kwargs |
|
and all(arg.has_unpack_var_sequence(tx) for arg in args) |
|
): |
|
seqs = [arg.unpack_var_sequence(tx) for arg in args] |
|
items = [] |
|
for item in itertools.product(*seqs): |
|
items.append(variables.TupleVariable(list(item))) |
|
return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) |
|
elif ( |
|
self.value is itertools.chain |
|
and not kwargs |
|
and all(arg.has_unpack_var_sequence(tx) for arg in args) |
|
): |
|
seqs = [arg.unpack_var_sequence(tx) for arg in args] |
|
items = list(itertools.chain.from_iterable(seqs)) |
|
return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) |
|
elif self.value is itertools.accumulate: |
|
from .builtin import BuiltinVariable |
|
|
|
if any(key not in ["initial", "func"] for key in kwargs.keys()): |
|
unimplemented( |
|
"Unsupported kwargs for itertools.accumulate: " |
|
f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}" |
|
) |
|
|
|
acc = kwargs.get("initial") |
|
|
|
if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx): |
|
seq = args[0].unpack_var_sequence(tx) |
|
|
|
if "func" in kwargs and len(args) == 1: |
|
func = kwargs["func"].call_function |
|
elif len(args) == 2: |
|
func = args[1].call_function |
|
elif len(args) == 1: |
|
|
|
func = BuiltinVariable(operator.add).call_function |
|
else: |
|
unimplemented( |
|
"itertools.accumulate can only accept one of: `func` kwarg, pos 2 arg" |
|
) |
|
else: |
|
unimplemented("Unsupported arguments for itertools.accumulate") |
|
|
|
items = [] |
|
if acc is not None: |
|
items.append(acc) |
|
for item in seq: |
|
if acc is None: |
|
acc = item |
|
else: |
|
try: |
|
acc = func(tx, [acc, item], {}) |
|
except Exception as e: |
|
unimplemented( |
|
f"Unexpected failure in invoking function during accumulate. Failed running func {func}({item}{acc})", |
|
from_exc=e, |
|
) |
|
items.append(acc) |
|
|
|
return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) |
|
elif ( |
|
self.value is itertools.combinations |
|
and not kwargs |
|
and len(args) == 2 |
|
and args[0].has_unpack_var_sequence(tx) |
|
and args[1].is_python_constant() |
|
): |
|
iterable = args[0].unpack_var_sequence(tx) |
|
r = args[1].as_python_constant() |
|
|
|
items = [] |
|
for item in itertools.combinations(iterable, r): |
|
items.append(variables.TupleVariable(list(item))) |
|
return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) |
|
elif self.value is itertools.groupby: |
|
if any(kw != "key" for kw in kwargs.keys()): |
|
unimplemented( |
|
"Unsupported kwargs for itertools.groupby: " |
|
f"{','.join(set(kwargs.keys()) - {'key'})}" |
|
) |
|
|
|
def retrieve_const_key(key): |
|
if isinstance(key, variables.SymNodeVariable): |
|
return key.evaluate_expr() |
|
elif isinstance(key, variables.ConstantVariable): |
|
return key.as_python_constant() |
|
else: |
|
unimplemented( |
|
"Unsupported key type for itertools.groupby: " + str(type(key)) |
|
) |
|
|
|
if len(args) == 1 and args[0].has_unpack_var_sequence(tx): |
|
seq = args[0].unpack_var_sequence(tx) |
|
keyfunc = ( |
|
( |
|
lambda x: ( |
|
retrieve_const_key( |
|
kwargs.get("key").call_function(tx, [x], {}) |
|
) |
|
) |
|
) |
|
if "key" in kwargs |
|
else None |
|
) |
|
else: |
|
unimplemented("Unsupported arguments for itertools.groupby") |
|
|
|
result = [] |
|
try: |
|
for k, v in itertools.groupby(seq, key=keyfunc): |
|
result.append( |
|
variables.TupleVariable( |
|
[ |
|
variables.ConstantVariable.create(k) |
|
if variables.ConstantVariable.is_literal(k) |
|
else k, |
|
variables.ListIteratorVariable( |
|
list(v), mutable_local=MutableLocal() |
|
), |
|
], |
|
mutable_local=MutableLocal(), |
|
) |
|
) |
|
except Exception as e: |
|
unimplemented( |
|
"Unexpected failure when calling itertools.groupby", |
|
from_exc=e, |
|
) |
|
return variables.ListIteratorVariable(result, mutable_local=MutableLocal()) |
|
elif self.value is itertools.repeat: |
|
if len(args) < 2: |
|
return variables.RepeatIteratorVariable( |
|
*args, mutable_local=MutableLocal() |
|
) |
|
|
|
from .builder import SourcelessBuilder |
|
|
|
return tx.inline_user_function_return( |
|
SourcelessBuilder.create(tx, polyfill.repeat), args, kwargs |
|
) |
|
elif self.value is itertools.count: |
|
return variables.CountIteratorVariable(*args, mutable_local=MutableLocal()) |
|
elif self.value is itertools.cycle: |
|
return variables.CycleIteratorVariable(*args, mutable_local=MutableLocal()) |
|
elif self.value is itertools.dropwhile: |
|
return variables.UserFunctionVariable(polyfill.dropwhile).call_function( |
|
tx, args, kwargs |
|
) |
|
else: |
|
return super().call_function(tx, args, kwargs) |
|
|
|
|
|
class IteratorVariable(VariableTracker): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
def next_variable(self, tx): |
|
unimplemented("abstract method, must implement") |
|
|
|
|
|
class RepeatIteratorVariable(IteratorVariable): |
|
def __init__(self, item: VariableTracker, **kwargs): |
|
super().__init__(**kwargs) |
|
self.item = item |
|
|
|
|
|
def next_variable(self, tx): |
|
return self.item |
|
|
|
|
|
class CountIteratorVariable(IteratorVariable): |
|
def __init__(self, item: int = 0, step: int = 1, **kwargs): |
|
super().__init__(**kwargs) |
|
if not isinstance(item, VariableTracker): |
|
item = ConstantVariable.create(item) |
|
if not isinstance(step, VariableTracker): |
|
step = ConstantVariable.create(step) |
|
self.item = item |
|
self.step = step |
|
|
|
def next_variable(self, tx): |
|
assert self.mutable_local |
|
tx.output.side_effects.mutation(self) |
|
next_item = self.item.call_method(tx, "__add__", [self.step], {}) |
|
self.item = next_item |
|
return self.item |
|
|
|
|
|
class CycleIteratorVariable(IteratorVariable): |
|
def __init__( |
|
self, |
|
iterator: IteratorVariable, |
|
saved: List[VariableTracker] = None, |
|
saved_index: int = 0, |
|
item: Optional[VariableTracker] = None, |
|
**kwargs, |
|
): |
|
if saved is None: |
|
saved = [] |
|
super().__init__(**kwargs) |
|
self.iterator = iterator |
|
self.saved = saved |
|
self.saved_index = saved_index |
|
self.item = item |
|
|
|
def next_variable(self, tx): |
|
assert self.mutable_local |
|
|
|
if self.iterator is not None: |
|
try: |
|
new_item = self.iterator.next_variable(tx) |
|
if len(self.saved) > MAX_CYCLE: |
|
unimplemented( |
|
"input iterator to itertools.cycle has too many items" |
|
) |
|
tx.output.side_effects.mutation(self) |
|
self.saved.append(new_item) |
|
self.item = new_item |
|
if self.item is None: |
|
return self.next_variable(tx) |
|
return self.item |
|
except StopIteration: |
|
self.iterator = None |
|
return self.next_variable(tx) |
|
elif len(self.saved) > 0: |
|
tx.output.side_effects.mutation(self) |
|
self.saved_index = (self.saved_index + 1) % len(self.saved) |
|
return self.item |
|
else: |
|
raise StopIteration |
|
|