Spaces:
Running
Running
File size: 20,809 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 |
import torch
import inspect
import numbers
import types
import typing
import enum
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING
from torch._jit_internal import boolean_dispatched
from ._compatibility import compatibility
from torch._ops import OpOverloadPacket, OpOverload
if TYPE_CHECKING:
from .node import Argument
__all__ = ["ArgsKwargsPair", "check_for_mutable_operation", "get_signature_for_torch_op", "create_type_hint",
"type_matches", "normalize_function", "normalize_module"]
@compatibility(is_backward_compatible=False)
class ArgsKwargsPair(NamedTuple):
"""
Simple named tuple for wrapping args/kwargs pairs.
"""
args: Tuple[Any, ...]
kwargs: Dict[str, Any]
_manual_overrides : Dict[Callable, List[inspect.Signature]] = {}
def _nonzero_schemas():
signatures = []
def nonzero(self):
pass
signatures.append(inspect.signature(nonzero))
def nonzero(self, *, as_tuple : bool): # type: ignore[no-redef]
pass
signatures.append(inspect.signature(nonzero))
return signatures
_manual_overrides[torch.nonzero] = _nonzero_schemas()
class _FakeGlobalNamespace:
def __getattr__(self, name):
if name == 'torch':
return torch
raise RuntimeError('Expected a torch namespace lookup')
_type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout,
'number' : numbers.Number, 'Future' : torch.jit.Future,
'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme,
'__torch__': _FakeGlobalNamespace(), 'NoneType': type(None),
'Storage': torch.UntypedStorage,
't': typing.TypeVar('t')}
for k in dir(typing):
_type_eval_globals[k] = getattr(typing, k)
def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any:
"""
Convert a TorchScript type to a Python type (including subtypes) via
eval'ing the annotation_str. _type_eval_globals sets up expressions
like "List" and "Future" to map to actual types (typing.List and jit.Future)
"""
return eval(ts_type.annotation_str, _type_eval_globals)
def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -> inspect.Signature:
from inspect import Parameter
parameters : List[Parameter] = []
for arg in ts_schema.arguments:
arg_type = _torchscript_type_to_python_type(arg.type)
default = arg.default_value if arg.has_default_value() else Parameter.empty
# TODO: Figure out if this is safe. It seems like when generating the type signatures for
# PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor
# argument name. Downstream, if someone converts that positional argument to a keyword
# argument, the name mismatch will break things, so here we're going to normalize the
# name to "input"
name = arg.name if arg.name != 'self' else 'input'
kind = Parameter.KEYWORD_ONLY if arg.kwarg_only else Parameter.POSITIONAL_OR_KEYWORD
# "from" is a keyword therefore it must be a POSITIONAL_ONLY argument
if name == "from":
assert kind == Parameter.POSITIONAL_OR_KEYWORD
# ParameterKind type is internal implementation detail to inspec package
# which makes it hard to do type annotation
kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment]
# This renders all previous arguments to positional only
for idx, p in enumerate(parameters):
assert p.kind == Parameter.POSITIONAL_OR_KEYWORD
parameters[idx] = Parameter(name=p.name, kind=Parameter.POSITIONAL_ONLY, default=p.default, annotation=p.annotation)
parameters.append(Parameter(name=name, kind=kind, default=default, annotation=arg_type))
return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns]
if len(return_types) == 0:
return_type = None
elif len(return_types) == 1:
return_type = return_types[0]
else:
return_type = tuple(return_types)
return inspect.Signature(parameters, return_annotation=return_type)
_SCHEMA_TO_SIGNATURE_CACHE : Dict[Tuple[str, str], inspect.Signature] = {}
def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature:
# Cached as it's called in the hot path of FakeTensor dispatch
cache_key = ts_schema.name, ts_schema.overload_name
cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key)
if cache_val is not None:
return cache_val
res = _torchscript_schema_to_signature_impl(ts_schema)
_SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res
return res
@compatibility(is_backward_compatible=False)
def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']):
signatures, schemas = get_signature_for_torch_op(target, return_schemas=True)
if signatures and schemas:
matched_schemas = []
# Iterate through all of the schema until we find one that matches
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
# values. If none matches, `new_args_and_kwargs` will be None
for candidate_signature, schema in zip(signatures, schemas):
try:
candidate_signature.bind(*args, **kwargs)
matched_schemas.append((candidate_signature, schema))
except TypeError as e:
continue
def throw_if_mutable(schema):
if schema.is_mutable:
raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional '
f'code, so operations that mutate operands in-place (e.g. via `out` arguments) '
f'are not supported')
if len(matched_schemas) == 0:
# Did not match any schema. Cannot check for mutation
pass
elif len(matched_schemas) == 1:
# Matched exactly one schema, unambiguous
_, schema_to_check = matched_schemas[0]
throw_if_mutable(schema_to_check)
pass
else:
# Ambiguous schema match. Since mutability checking is best effort,
# do nothing.
pass
@compatibility(is_backward_compatible=False)
def get_signature_for_torch_op(op : Callable, return_schemas : bool = False):
"""
Given an operator on the `torch` namespace, return a list of `inspect.Signature`
objects corresponding to the overloads of that op.. May return `None` if a signature
could not be retrieved.
Args:
op (Callable): An operator on the `torch` namespace to look up a signature for
Returns:
Optional[List[inspect.Signature]]: A list of signatures for the overloads of this
operator, or None if the operator signatures could not be retrieved. If
return_schemas=True, returns a tuple containing the optional Python signatures
and the optional TorchScript Function signature
"""
if isinstance(op, OpOverload):
schemas = [op._schema]
elif isinstance(op, OpOverloadPacket):
schemas = [getattr(op, overload)._schema for overload in op.overloads()]
else:
override = _manual_overrides.get(op)
if override:
return (override, None) if return_schemas else None
aten_fn = torch.jit._builtins._find_builtin(op)
if aten_fn is None:
return (None, None) if return_schemas else None
schemas = torch._C._jit_get_schemas_for_operator(aten_fn)
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
return (signatures, schemas) if return_schemas else signatures
@compatibility(is_backward_compatible=False)
def create_type_hint(x):
try:
if isinstance(x, (list, tuple)):
# todo(chilli): Figure out the right way for mypy to handle this
if isinstance(x, list):
def ret_type(x):
return List[x] # type: ignore[valid-type]
else:
def ret_type(x):
return Tuple[x, ...]
if len(x) == 0:
return ret_type(Any)
base_type = x[0]
for t in x:
if issubclass(t, base_type):
continue
elif issubclass(base_type, t):
base_type = t
else:
return ret_type(Any)
return ret_type(base_type)
except Exception as e:
# We tried to create a type hint for list but failed.
warnings.warn(f"We were not able to successfully create type hint from the type {x}")
pass
return x
@compatibility(is_backward_compatible=False)
def type_matches(signature_type : Any, argument_type : Any):
sig_origin_type = getattr(signature_type, '__origin__', signature_type)
if signature_type is argument_type:
return True
# Union types in signature. Given type needs to match one of the
# contained types in the Union
if sig_origin_type is typing.Union and signature_type != argument_type:
sig_contained = signature_type.__args__
return any(type_matches(c, argument_type) for c in sig_contained)
if signature_type is List[int] and argument_type is int:
# int can be promoted to List[int]
return True
if getattr(signature_type, '__origin__', None) in {list, List}:
sig_el_type = signature_type.__args__[0]
if not inspect.isclass(sig_el_type):
warnings.warn(
f"Does not support nested parametric types, got {signature_type}. Please file a bug.")
return False
if getattr(argument_type, '__origin__', None) in {list, List}:
return issubclass(argument_type.__args__[0], sig_el_type)
def is_homogeneous_tuple(t):
if getattr(t, "__origin__", None) not in {tuple, Tuple}:
return False
contained = t.__args__
if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason
return True
return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained)
# Tuple[T] is accepted for List[T] parameters
return is_homogeneous_tuple(argument_type)
# Dtype is an int in schemas
if signature_type is int and argument_type is torch.dtype:
return True
if signature_type is numbers.Number and argument_type in {int, float}:
return True
if inspect.isclass(argument_type) and inspect.isclass(signature_type):
return issubclass(argument_type, signature_type)
return False
@compatibility(is_backward_compatible=False)
def normalize_function(
target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None,
kwarg_types : Optional[Dict[str, Any]] = None,
normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
"""
Returns normalized arguments to PyTorch functions. This means that
`args/kwargs` will be matched up to the functional's
signature and return exclusively kwargs in positional order if
`normalize_to_only_use_kwargs` is True.
Also populates default values. Does not support positional-only
parameters or varargs parameters (*args, **kwargs). Does not support modules.
May require `arg_types` and `kwarg_types` in order to disambiguate overloads.
Args:
target (Callable): Function that we are normalizing
args (Tuple[Any]): Tuple of args to the function
kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
Returns:
Returns normalized_args_and_kwargs, or `None` if not successful.
"""
if kwargs is None:
kwargs = {}
new_args_and_kwargs = None
if not isinstance(target, types.BuiltinFunctionType) and not (
isinstance(target, (OpOverloadPacket, OpOverload))
):
target_for_analysis = target
if target in boolean_dispatched:
# HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
# a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
# branches of the dispatch have exactly the same signature. If they do, use the `true`
# branch signature for analysis. Otherwise, leave this un-normalized
assert not isinstance(target, str)
dispatched = boolean_dispatched[target]
if_true, if_false = dispatched['if_true'], dispatched['if_false']
if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters:
return None
target_for_analysis = if_true
assert callable(target_for_analysis)
sig = inspect.signature(inspect.unwrap(target_for_analysis))
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs)
else:
assert callable(target)
torch_op_schemas = get_signature_for_torch_op(target)
matched_schemas = []
if torch_op_schemas:
# Iterate through all of the schema until we find one that matches
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
# values. If none matches, `new_args_and_kwargs` will be None
for candidate_signature in torch_op_schemas:
try:
candidate_signature.bind(*args, **kwargs)
matched_schemas.append(candidate_signature)
except TypeError as e:
continue
if len(matched_schemas) == 0:
# Did not match any schema. Cannot normalize
pass
elif len(matched_schemas) == 1:
# Matched exactly one schema, unambiguous
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs,
normalize_to_only_use_kwargs)
else:
if arg_types is not None or kwarg_types is not None:
arg_types = arg_types if arg_types else cast(Tuple[Any], ())
kwarg_types = kwarg_types if kwarg_types else {}
for candidate_signature in torch_op_schemas:
sig_matches = True
try:
bound_types = candidate_signature.bind(*arg_types, **kwarg_types)
for arg_name, arg_type in bound_types.arguments.items():
param = candidate_signature.parameters[arg_name]
sig_matches = sig_matches and type_matches(param.annotation, arg_type)
except TypeError as e:
sig_matches = False
if sig_matches:
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs,
normalize_to_only_use_kwargs)
break
else:
# Matched more than one schema. In this situation, the caller must provide the types of
# the arguments of the overload they expect.
schema_printouts = '\n'.join(str(schema) for schema in matched_schemas)
raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but '
f'the schema match was ambiguous! Please provide argument types to '
f'the normalize_arguments() call. Available schemas:\n{schema_printouts}')
return new_args_and_kwargs
@compatibility(is_backward_compatible=False)
def normalize_module(
root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None,
normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
"""
Returns normalized arguments to PyTorch modules. This means that
`args/kwargs` will be matched up to the functional's
signature and return exclusively kwargs in positional order if
`normalize_to_only_use_kwargs` is True.
Also populates default values. Does not support positional-only
parameters or varargs parameters (*args, **kwargs).
Args:
root (nn.Module): root module upon which we query modules
target (Callable): Function that we are normalizing
args (Tuple[Any]): Tuple of args to the function
kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
Returns:
Returns normalized_args_and_kwargs, or `None` if not successful.
"""
try:
submod = root.get_submodule(target)
except AttributeError as e:
raise RuntimeError(f"Tried to normalize node with target {target} but root did not "
f"have that target!") from e
if hasattr(submod.__class__, '__name__'):
classname = submod.__class__.__name__
if getattr(torch.nn, classname, None) == submod.__class__:
sig = inspect.signature(inspect.unwrap(submod.forward))
if kwargs is None:
kwargs = {}
new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs,
normalize_to_only_use_kwargs)
return new_args_and_kwargs
return None
def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...],
kwargs : Dict[str, Any],
normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]:
"""
Given a call target, args, and kwargs, return the arguments normalized into
an ArgsKwargsPair, or None if the type signature is not supported by
this normalization.
Args:
sig (inspect.Signature): Signature object for the target
args (Tuple): Arguments that appear at the callsite for `target`
kwargs (Dict): Keyword arguments that appear at the callsite for `target`
normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
Returns:
Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if
this target is not supported.
"""
# Don't currently support positional-only
# or varargs (*args, **kwargs) signatures
supported_parameter_types = {
inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY}
if any(p.kind not in supported_parameter_types for p in sig.parameters.values()):
# Add an exception for one signature, which is common for random/uniform, i.e.:
# Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None
# `from` is Python keyword and as such functions with that signature should have
# positional-only args, but at the same time they could be dispatched as kwargs
if list(sig.parameters.keys()) != ['input', 'from', 'to', 'generator']:
return None
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
new_kwargs : Dict[str, Any] = {}
new_args : List[Any] = []
for i, param in enumerate(sig.parameters):
if not normalize_to_only_use_kwargs and i < len(args):
new_args.append(bound_args.arguments[param])
else:
new_kwargs[param] = bound_args.arguments[param]
return ArgsKwargsPair(tuple(new_args), new_kwargs)
|