Adding all project files
This commit is contained in:
parent
6c9e127bdc
commit
cd4316ad0f
42289 changed files with 8009643 additions and 0 deletions
537
venv/Lib/site-packages/torch/ao/ns/fx/utils.py
Normal file
537
venv/Lib/site-packages/torch/ao/ns/fx/utils.py
Normal file
|
@ -0,0 +1,537 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import enum
|
||||
import operator
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.ao.nn.intrinsic.quantized as nniq
|
||||
import torch.ao.nn.quantized as nnq
|
||||
import torch.nn as nn
|
||||
from torch.ao.quantization import FakeQuantizeBase, ObserverBase
|
||||
from torch.ao.quantization.observer import _is_activation_post_process
|
||||
from torch.ao.quantization.utils import getattr_from_fqn
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.graph import Node
|
||||
|
||||
from .ns_types import NSNodeTargetType, NSResultsType
|
||||
|
||||
|
||||
toq = torch.ops.quantized
|
||||
|
||||
|
||||
# TODO(future PR): consider deleting this enum and using the torch types
|
||||
# directly. This might be tricky because it is not a one to one mapping.
|
||||
class NodeInputOrOutputType(enum.Enum):
|
||||
FP32 = enum.auto() # torch.float
|
||||
INT8 = enum.auto() # torch.qint8 or torch.quint8
|
||||
FP16 = enum.auto() # torch.float16
|
||||
UNKNOWN = enum.auto() # we cannot determine input/output dtype
|
||||
# TODO(future PR): while these functions can support multiple dtypes,
|
||||
# for the purposes of numerical debugging we want to get the actual
|
||||
# dtype used in the model. We will likely need some kind of dtype
|
||||
# propagation to estimate this.
|
||||
FP32_OR_INT8 = enum.auto() # either torch.float or torch.quint8 or torch.qint8
|
||||
# TODO(future PRs): dynamic quant, fake quant, etc
|
||||
|
||||
|
||||
def get_node_first_input_and_output_type(
|
||||
node: Node,
|
||||
gm: GraphModule,
|
||||
logger_cls: Callable,
|
||||
node_type_to_io_type_map: dict[str, set[NSNodeTargetType]],
|
||||
) -> tuple[NodeInputOrOutputType, NodeInputOrOutputType]:
|
||||
# TODO(future PR): clean this up
|
||||
FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"]
|
||||
FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"]
|
||||
FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"]
|
||||
FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"]
|
||||
MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"]
|
||||
MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"]
|
||||
MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
|
||||
METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"]
|
||||
|
||||
if node.op == "call_function":
|
||||
if node.target in FUNS_IO_TYPE_FP32:
|
||||
return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
|
||||
if node.target in FUNS_IO_TYPE_FP16:
|
||||
return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16)
|
||||
elif node.target in FUNS_IO_TYPE_INT8:
|
||||
return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
|
||||
elif node.target in FUNS_IO_TYPE_FP32_OR_INT8:
|
||||
first_arg = get_normalized_nth_input(node, gm, 0)
|
||||
assert isinstance(first_arg, Node)
|
||||
(
|
||||
_prev_node_input_type,
|
||||
prev_node_output_type,
|
||||
) = get_node_first_input_and_output_type(
|
||||
first_arg, gm, logger_cls, node_type_to_io_type_map
|
||||
)
|
||||
return (prev_node_output_type, prev_node_output_type)
|
||||
else:
|
||||
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
|
||||
|
||||
elif node.op == "call_module":
|
||||
assert node.op == "call_module"
|
||||
assert isinstance(node.target, str)
|
||||
mod = getattr_from_fqn(gm, node.target)
|
||||
is_known_fp32_or_int8_input_module = any(
|
||||
isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type]
|
||||
)
|
||||
if (
|
||||
isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase)) # type: ignore[arg-type]
|
||||
or is_known_fp32_or_int8_input_module
|
||||
):
|
||||
# A logger or observer's input and output type is the output
|
||||
# type of the preceding node.
|
||||
first_arg = get_normalized_nth_input(node, gm, 0)
|
||||
assert isinstance(first_arg, Node)
|
||||
(
|
||||
_prev_node_input_type,
|
||||
prev_node_output_type,
|
||||
) = get_node_first_input_and_output_type(
|
||||
first_arg, gm, logger_cls, node_type_to_io_type_map
|
||||
)
|
||||
return (prev_node_output_type, prev_node_output_type)
|
||||
is_known_fp32_input_module = any(
|
||||
isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32 # type: ignore[arg-type]
|
||||
)
|
||||
is_known_int8_input_module = any(
|
||||
isinstance(mod, target_type) for target_type in MODS_IO_TYPE_INT8 # type: ignore[arg-type]
|
||||
)
|
||||
if is_known_fp32_input_module:
|
||||
return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
|
||||
elif is_known_int8_input_module:
|
||||
return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
|
||||
else:
|
||||
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
|
||||
|
||||
elif node.op == "call_method":
|
||||
if node.target == "dequantize":
|
||||
# Dequantize is a special node because it allows multiple input types.
|
||||
# So, we look up the output type of the previous node and return that
|
||||
# as the input type of this node instance.
|
||||
prev_node = get_normalized_nth_input(node, gm, 0)
|
||||
assert isinstance(prev_node, Node)
|
||||
(
|
||||
_prev_node_input_type,
|
||||
prev_node_output_type,
|
||||
) = get_node_first_input_and_output_type(
|
||||
prev_node, gm, logger_cls, node_type_to_io_type_map
|
||||
)
|
||||
return (prev_node_output_type, NodeInputOrOutputType.FP32)
|
||||
|
||||
elif node.target == "to":
|
||||
# to is a special node because it allows multiple input types.
|
||||
# So, we look up the output type of the previous node and return that
|
||||
# as the input type of this node instance. We also look up the target
|
||||
# of to and return the correct output type.
|
||||
prev_node = get_normalized_nth_input(node, gm, 0)
|
||||
assert isinstance(prev_node, Node)
|
||||
(
|
||||
_prev_node_input_type,
|
||||
prev_node_output_type,
|
||||
) = get_node_first_input_and_output_type(
|
||||
prev_node, gm, logger_cls, node_type_to_io_type_map
|
||||
)
|
||||
|
||||
cur_node_dtype_target = get_normalized_nth_input(node, gm, 1)
|
||||
assert (
|
||||
cur_node_dtype_target is torch.float16
|
||||
), f"{cur_node_dtype_target} handling needs to be added"
|
||||
|
||||
return (prev_node_output_type, NodeInputOrOutputType.FP16)
|
||||
|
||||
elif node.target in METHS_IO_TYPE_FP32_OR_INT8:
|
||||
first_arg = get_normalized_nth_input(node, gm, 0)
|
||||
assert isinstance(first_arg, Node)
|
||||
(
|
||||
_prev_node_input_type,
|
||||
prev_node_output_type,
|
||||
) = get_node_first_input_and_output_type(
|
||||
first_arg, gm, logger_cls, node_type_to_io_type_map
|
||||
)
|
||||
return (prev_node_output_type, prev_node_output_type)
|
||||
|
||||
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
|
||||
else:
|
||||
return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
|
||||
|
||||
|
||||
def get_node_input_qparams(
|
||||
node: Node,
|
||||
gm: GraphModule,
|
||||
node_type_to_io_type_map: dict[str, set[NSNodeTargetType]],
|
||||
) -> Optional[tuple[Union[torch.Tensor, float], Union[torch.Tensor, int]]]:
|
||||
"""
|
||||
Returns the qparams (scale, zero_point) of the first input to `node`,
|
||||
if they can be inferred from the graph.
|
||||
"""
|
||||
prev_node = get_normalized_nth_input(node, gm, 0)
|
||||
|
||||
if not isinstance(prev_node, Node):
|
||||
return None
|
||||
|
||||
MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
|
||||
|
||||
def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx):
|
||||
scale_node = get_normalized_nth_input(node, gm, scale_arg_idx)
|
||||
zp_node = get_normalized_nth_input(node, gm, zp_arg_idx)
|
||||
assert isinstance(scale_node, Node) and isinstance(scale_node.target, str)
|
||||
assert isinstance(zp_node, Node) and isinstance(zp_node.target, str)
|
||||
scale_obj = getattr_from_fqn(gm, scale_node.target)
|
||||
zp_obj = getattr_from_fqn(gm, zp_node.target)
|
||||
return (scale_obj, zp_obj)
|
||||
|
||||
if prev_node.op == "call_function":
|
||||
# quantize - read the args directly
|
||||
if prev_node.target == torch.quantize_per_tensor:
|
||||
return _get_scale_zp_from_function_args(prev_node, gm, 1, 2)
|
||||
elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu):
|
||||
return _get_scale_zp_from_function_args(prev_node, gm, 2, 3)
|
||||
|
||||
return None
|
||||
# TODO(future PR): handle more functionals
|
||||
# TODO(future PR): handle functional ops which inherit qparams from input
|
||||
|
||||
elif prev_node.op == "call_module":
|
||||
# get type of the module
|
||||
assert isinstance(prev_node.target, str)
|
||||
module_obj = getattr_from_fqn(gm, prev_node.target)
|
||||
if isinstance(
|
||||
module_obj,
|
||||
(
|
||||
nnq.Linear,
|
||||
nnq.Conv1d,
|
||||
nnq.Conv2d,
|
||||
nniq.ConvReLU2d,
|
||||
nnq.Conv3d,
|
||||
nnq.BatchNorm2d,
|
||||
nnq.BatchNorm3d,
|
||||
nnq.ConvTranspose1d,
|
||||
nnq.ConvTranspose2d,
|
||||
nnq.ELU,
|
||||
nnq.GroupNorm,
|
||||
nnq.InstanceNorm1d,
|
||||
nnq.InstanceNorm2d,
|
||||
nnq.InstanceNorm3d,
|
||||
nnq.LayerNorm,
|
||||
nnq.Hardswish,
|
||||
nnq.LeakyReLU,
|
||||
nnq.ReLU6,
|
||||
nniq.BNReLU2d,
|
||||
nniq.BNReLU3d,
|
||||
nniq.ConvReLU1d,
|
||||
nniq.ConvReLU2d,
|
||||
nniq.ConvReLU3d,
|
||||
nniq.LinearReLU,
|
||||
),
|
||||
):
|
||||
return (module_obj.scale, module_obj.zero_point) # type: ignore[return-value]
|
||||
|
||||
is_known_fp32_or_int8_input_module = any(
|
||||
isinstance(module_obj, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type]
|
||||
)
|
||||
if is_known_fp32_or_int8_input_module:
|
||||
return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def return_first_non_observer_node(
|
||||
node: Node,
|
||||
gm: GraphModule,
|
||||
) -> Node:
|
||||
"""
|
||||
If node is not an observer, returns it. If node is an observer,
|
||||
navigates up the graph and returns the first parent which is not an
|
||||
observer. For example,
|
||||
|
||||
graph: (node_non_obs), node = node_non_obs : returns node_non_obs
|
||||
graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs
|
||||
graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs
|
||||
"""
|
||||
if node.op == "call_module":
|
||||
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
|
||||
if _is_activation_post_process(node_obj):
|
||||
assert len(node.args) == 1
|
||||
assert isinstance(node.args[0], Node)
|
||||
node = node.args[0]
|
||||
# code duplication intended, not worth refactoring
|
||||
assert isinstance(node.target, str)
|
||||
node_obj = getattr_from_fqn(gm, node.target)
|
||||
if _is_activation_post_process(node_obj):
|
||||
assert len(node.args) == 1
|
||||
assert isinstance(node.args[0], Node)
|
||||
node = node.args[0]
|
||||
return node
|
||||
|
||||
|
||||
def get_number_of_non_param_args(
|
||||
node: Node,
|
||||
gm: GraphModule,
|
||||
) -> int:
|
||||
"""
|
||||
Assumes that all non-param args occur first. Returns the number of
|
||||
non-param args expected for a node. For example, for
|
||||
|
||||
F.linear(x, weight, bias)
|
||||
|
||||
Returns 1, because x is a non-param arg and weight and bias are params.
|
||||
For
|
||||
|
||||
lstm_mod(x, hid)
|
||||
|
||||
Returns 2, because both x and hid are non-param args.
|
||||
"""
|
||||
if node.op == "call_module":
|
||||
node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type]
|
||||
if isinstance(node_obj, nn.LSTM):
|
||||
return 2
|
||||
|
||||
# default is 1
|
||||
return 1
|
||||
|
||||
|
||||
def get_arg_indices_of_inputs_to_log(node: Node) -> list[int]:
|
||||
"""
|
||||
Returns the indices of args of the node which we should attach
|
||||
loggers to, if input logging is enabled.
|
||||
|
||||
For example,
|
||||
* for (x + y), returns [0, 1]
|
||||
* for (1 + y), returns [1]
|
||||
* for (x + 1), returns [0]
|
||||
* for (linear(x, w, b)) returns [0]
|
||||
* by default, returns [0]
|
||||
"""
|
||||
if len(node.args) == 0:
|
||||
return []
|
||||
if node.op == "call_function" and (
|
||||
# TODO(future PR): use relationship map instead of hardcoding
|
||||
node.target in (torch.add, torch.ops.quantized.add, operator.add)
|
||||
or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
|
||||
):
|
||||
result = [i for i in range(2) if type(node.args[i]) == Node]
|
||||
return result
|
||||
return [0]
|
||||
|
||||
|
||||
def get_target_type_str(node: Node, gm: GraphModule) -> str:
|
||||
"""
|
||||
Returns a string representation of the type of the function or module
|
||||
pointed to by this node, or '' for other node types.
|
||||
"""
|
||||
target_type = ""
|
||||
if node.op in ("call_function", "call_method"):
|
||||
target_type = torch.typename(node.target)
|
||||
elif node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
target_mod = getattr_from_fqn(gm, node.target)
|
||||
target_type = torch.typename(target_mod)
|
||||
return target_type
|
||||
|
||||
|
||||
def rekey_logger_info_on_node_name_of_model(
|
||||
results: NSResultsType,
|
||||
model_name: str,
|
||||
) -> NSResultsType:
|
||||
"""
|
||||
Rekeys the layer name of a results dictionary to use node names
|
||||
from `model_name`.
|
||||
|
||||
For example, transforms
|
||||
|
||||
{'base_op_1_0': {'node_output': {'model_a':
|
||||
[{'ref_node_name': 'linear1', ...}]}}}
|
||||
|
||||
into
|
||||
|
||||
{'linear1': {'node_output': {'model_a':
|
||||
[{'ref_node_name': 'linear1', ...}]}}}
|
||||
|
||||
Note: we cannot use these node names directly because they are not
|
||||
guaranteed to be consistent across models. This is why we extract
|
||||
the results first and rekey afterwards.
|
||||
"""
|
||||
new_results = {}
|
||||
for old_layer_name, result_type_to_results in results.items():
|
||||
new_layer_name = None
|
||||
for model_name_to_results in result_type_to_results.values():
|
||||
for cur_model_name, list_of_results in model_name_to_results.items():
|
||||
if cur_model_name == model_name:
|
||||
assert len(list_of_results)
|
||||
new_layer_name = list_of_results[0]["ref_node_name"]
|
||||
else:
|
||||
continue
|
||||
if new_layer_name is not None:
|
||||
new_results[new_layer_name] = result_type_to_results
|
||||
else:
|
||||
new_results[old_layer_name] = result_type_to_results
|
||||
return new_results
|
||||
|
||||
|
||||
def maybe_add_missing_fqns(results: NSResultsType) -> None:
|
||||
"""
|
||||
If `fqn` entries are filled in for one of the models in `results`, copies
|
||||
them over to any models which do not have them filled out.
|
||||
|
||||
A common use case benefitting from this is comparing a model prepared by
|
||||
quantization to a quantized model. In this case, the model prepared by
|
||||
quantization would have `fqn` entries, and the quantized model would not.
|
||||
"""
|
||||
|
||||
# Check in the first result to find any model with fqn entries defined.
|
||||
model_name_with_fqns = None
|
||||
for result_type_to_results in results.values():
|
||||
for model_name_to_results in result_type_to_results.values():
|
||||
for model_name, model_results in model_name_to_results.items():
|
||||
if len(model_results) > 0:
|
||||
if model_results[0]["fqn"] is not None:
|
||||
model_name_with_fqns = model_name
|
||||
break
|
||||
break
|
||||
break
|
||||
|
||||
if model_name_with_fqns:
|
||||
for result_type_to_results in results.values():
|
||||
for model_name_to_results in result_type_to_results.values():
|
||||
ref_model_results = model_name_to_results[model_name_with_fqns]
|
||||
for model_name, model_results in model_name_to_results.items():
|
||||
if model_name == model_name_with_fqns:
|
||||
continue
|
||||
for i in range(len(model_results)):
|
||||
fqn = ref_model_results[i]["fqn"]
|
||||
model_results[i]["fqn"] = fqn
|
||||
|
||||
|
||||
def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f):
|
||||
def inner(*args, **kwargs):
|
||||
a0, a1, *a_other = args
|
||||
|
||||
if (isinstance(a0, tuple) and isinstance(a1, tuple)) or (
|
||||
isinstance(a0, list) and isinstance(a1, list)
|
||||
):
|
||||
results = []
|
||||
for el0, el1 in zip(a0, a1):
|
||||
new_args = (el0, el1, *a_other)
|
||||
results.append(inner(*new_args, **kwargs))
|
||||
return results
|
||||
|
||||
elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor):
|
||||
if a0.is_quantized:
|
||||
a0 = a0.dequantize()
|
||||
if a1.is_quantized:
|
||||
a1 = a1.dequantize()
|
||||
|
||||
# for the purposes of this util, only handle floats
|
||||
if a0.dtype != torch.float or a1.dtype != torch.float:
|
||||
return None
|
||||
|
||||
new_args = (a0, a1, *a_other)
|
||||
return f(*new_args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@maybe_dequantize_first_two_tensor_args_and_handle_tuples
|
||||
def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Computes the SQNR between `x` and `y`.
|
||||
|
||||
Args:
|
||||
x: Tensor or tuple of tensors
|
||||
y: Tensor or tuple of tensors
|
||||
|
||||
Return:
|
||||
float or tuple of floats
|
||||
"""
|
||||
Ps = torch.norm(x)
|
||||
Pn = torch.norm(x - y)
|
||||
return 20 * torch.log10(Ps / Pn)
|
||||
|
||||
|
||||
@maybe_dequantize_first_two_tensor_args_and_handle_tuples
|
||||
def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Computes the normalized L2 error between `x` and `y`.
|
||||
|
||||
Args:
|
||||
x: Tensor or tuple of tensors
|
||||
y: Tensor or tuple of tensors
|
||||
|
||||
Return:
|
||||
float or tuple of floats
|
||||
"""
|
||||
return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum())
|
||||
|
||||
|
||||
@maybe_dequantize_first_two_tensor_args_and_handle_tuples
|
||||
def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Computes the cosine similarity between `x` and `y`.
|
||||
|
||||
Args:
|
||||
x: Tensor or tuple of tensors
|
||||
y: Tensor or tuple of tensors
|
||||
|
||||
Return:
|
||||
float or tuple of floats
|
||||
"""
|
||||
# For convolutions, the shape of the quantized weight has one additional
|
||||
# dimension compared to the shape of the fp32 weight. Match the shapes
|
||||
# to enable cosine similarity comparison.
|
||||
x = x.reshape(1, -1)
|
||||
y = y.reshape(1, -1)
|
||||
return torch.nn.functional.cosine_similarity(x, y)
|
||||
|
||||
|
||||
def op_type_supports_shadowing(node: Node) -> bool:
|
||||
if node.op == "call_function":
|
||||
if node.target in (
|
||||
torch.add,
|
||||
torch.mul,
|
||||
operator.add,
|
||||
operator.mul,
|
||||
torch.cat,
|
||||
torch.stack,
|
||||
):
|
||||
# shadowing for ops with multiple tensor inputs is not implemented yet
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node:
|
||||
"""
|
||||
Given a node, gets the n'th input to that node, normalizing
|
||||
args and kwargs to the best of its ability.
|
||||
"""
|
||||
try:
|
||||
norm_args_and_kwargs = node.normalized_arguments(
|
||||
gm, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
if norm_args_and_kwargs is not None:
|
||||
norm_args, norm_kwargs = norm_args_and_kwargs
|
||||
assert len(norm_args) + len(norm_kwargs) > idx
|
||||
if idx < len(norm_args):
|
||||
return norm_args[idx]
|
||||
else:
|
||||
# note: in Python 3.7+ dicts are ordered
|
||||
return list(norm_kwargs.values())[idx]
|
||||
else:
|
||||
assert len(node.args) + len(node.kwargs) > idx
|
||||
if idx < len(node.args):
|
||||
return node.args[idx] # type: ignore[return-value]
|
||||
else:
|
||||
kwargs_idx = idx + len(node.args)
|
||||
return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value]
|
||||
except RuntimeError:
|
||||
# this RuntimeError happens when node argument normalization
|
||||
# requires typehints to proceed, such as for torch.add where
|
||||
# either the first, second or both arguments could be tensors
|
||||
assert len(node.args) + len(node.kwargs) > idx
|
||||
if idx < len(node.args):
|
||||
return node.args[idx] # type: ignore[return-value]
|
||||
else:
|
||||
kwargs_idx = idx + len(node.args)
|
||||
return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value]
|
Loading…
Add table
Add a link
Reference in a new issue