[BE]: Apply FURB118 (prev): replaces unnecessary lambdas with operator. (#116027)
This replaces a bunch of unnecessary lambdas with the operator package. This is semantically equivalent, but the operator package is faster, and arguably more readable. When the FURB rules are taken out of preview, I will enable it as a ruff check.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116027
Approved by: https://github.com/malfet
diff --git a/benchmarks/distributed/rpc/rl/agent.py b/benchmarks/distributed/rpc/rl/agent.py
index 14bfbaf..a82be1c 100644
--- a/benchmarks/distributed/rpc/rl/agent.py
+++ b/benchmarks/distributed/rpc/rl/agent.py
@@ -1,3 +1,4 @@
+import operator
import threading
import time
from functools import reduce
@@ -75,9 +76,7 @@
batch (bool): Whether to process and respond to observer requests as a batch or 1 at a time
"""
self.batch = batch
- self.policy = Policy(
- reduce((lambda x, y: x * y), state_size), nlayers, out_features
- )
+ self.policy = Policy(reduce(operator.mul, state_size), nlayers, out_features)
self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
self.batch_size = batch_size
diff --git a/benchmarks/tensorexpr/broadcast.py b/benchmarks/tensorexpr/broadcast.py
index fc68822..2401416 100644
--- a/benchmarks/tensorexpr/broadcast.py
+++ b/benchmarks/tensorexpr/broadcast.py
@@ -1,4 +1,5 @@
import itertools
+import operator
import numpy as np
import torch
@@ -262,9 +263,9 @@
def register_broadcast_ops():
binary_op_list = [
- ["mul", lambda a, b: a * b],
- ["add", lambda a, b: a + b],
- ["sub", lambda a, b: a - b],
+ ["mul", operator.mul],
+ ["add", operator.add],
+ ["sub", operator.sub],
["div", lambda a, b: a / (b + 1e-4)],
[
"pow",
diff --git a/benchmarks/tensorexpr/elementwise.py b/benchmarks/tensorexpr/elementwise.py
index f9893d2..eee6ef6 100644
--- a/benchmarks/tensorexpr/elementwise.py
+++ b/benchmarks/tensorexpr/elementwise.py
@@ -1,4 +1,5 @@
import itertools
+import operator
import numpy as np
import scipy.special
@@ -116,9 +117,9 @@
def register_element_ops():
binary_op_list = [
- ["mul", lambda a, b: a * b],
- ["add", lambda a, b: a + b],
- ["sub", lambda a, b: a - b],
+ ["mul", operator.mul],
+ ["add", operator.add],
+ ["sub", operator.sub],
["div", lambda a, b: a / (b + 1e-4)],
[
"pow",
diff --git a/benchmarks/tensorexpr/microbenchmarks.py b/benchmarks/tensorexpr/microbenchmarks.py
index cd083b4..6177502 100644
--- a/benchmarks/tensorexpr/microbenchmarks.py
+++ b/benchmarks/tensorexpr/microbenchmarks.py
@@ -1,4 +1,5 @@
import argparse
+import operator
import time
import matplotlib.pyplot as plt
@@ -105,10 +106,10 @@
te_bool = te.Dtype.Bool
binary_ops = [
- ("add", (lambda a, b: a + b), torch.add),
- ("mul", (lambda a, b: a * b), torch.mul),
- ("sub", (lambda a, b: a - b), torch.sub),
- ("div", (lambda a, b: a / b), torch.div),
+ ("add", operator.add, torch.add),
+ ("mul", operator.mul, torch.mul),
+ ("sub", operator.sub, torch.sub),
+ ("div", operator.truediv, torch.div),
(
"eq",
(lambda a, b: te.Cast.make(te_bool, a == b)),
diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py
index 2c14b03..21e9125 100644
--- a/test/nn/test_pooling.py
+++ b/test/nn/test_pooling.py
@@ -25,6 +25,7 @@
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import gradcheck, gradgradcheck
+import operator
class TestAvgPool(TestCase):
@@ -42,11 +43,11 @@
return joined_x.view(1, joined_x.numel())
def _avg_pool2d(self, x, kernel_size):
- size = reduce((lambda x, y: x * y), kernel_size)
+ size = reduce(operator.mul, kernel_size)
return self._sum_pool2d(x, kernel_size) / size
def _avg_pool3d(self, x, kernel_size):
- size = reduce((lambda x, y: x * y), kernel_size)
+ size = reduce(operator.mul, kernel_size)
return self._sum_pool3d(x, kernel_size) / size
def test_doubletensor_avg_pool2d(self):
diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py
index a12b9f3..ed72229 100644
--- a/test/onnx/test_operators.py
+++ b/test/onnx/test_operators.py
@@ -10,6 +10,7 @@
import inspect
import io
import itertools
+import operator
import os
import shutil
import tempfile
@@ -172,27 +173,27 @@
def test_add_broadcast(self):
x = torch.randn(2, 3, requires_grad=True).double()
y = torch.randn(3, requires_grad=True).double()
- self.assertONNX(lambda x, y: x + y, (x, y))
+ self.assertONNX(operator.add, (x, y))
def test_add_left_broadcast(self):
x = torch.randn(3, requires_grad=True).double()
y = torch.randn(2, 3, requires_grad=True).double()
- self.assertONNX(lambda x, y: x + y, (x, y))
+ self.assertONNX(operator.add, (x, y))
def test_add_size1_broadcast(self):
x = torch.randn(2, 3, requires_grad=True).double()
y = torch.randn(2, 1, requires_grad=True).double()
- self.assertONNX(lambda x, y: x + y, (x, y))
+ self.assertONNX(operator.add, (x, y))
def test_add_size1_right_broadcast(self):
x = torch.randn(2, 3, requires_grad=True).double()
y = torch.randn(3, requires_grad=True).double()
- self.assertONNX(lambda x, y: x + y, (x, y))
+ self.assertONNX(operator.add, (x, y))
def test_add_size1_singleton_broadcast(self):
x = torch.randn(2, 3, requires_grad=True).double()
y = torch.randn(1, 3, requires_grad=True).double()
- self.assertONNX(lambda x, y: x + y, (x, y))
+ self.assertONNX(operator.add, (x, y))
def test_rsub(self):
x = torch.randn(2, 3, requires_grad=True).double()
@@ -541,27 +542,27 @@
def test_equal(self):
x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
y = torch.randn(1, 4, requires_grad=False).int()
- self.assertONNX(lambda x, y: x == y, (x, y))
+ self.assertONNX(operator.eq, (x, y))
def test_lt(self):
x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
y = torch.randn(1, 4, requires_grad=False).int()
- self.assertONNX(lambda x, y: x < y, (x, y))
+ self.assertONNX(operator.lt, (x, y))
def test_gt(self):
x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
y = torch.randn(1, 4, requires_grad=False).int()
- self.assertONNX(lambda x, y: x > y, (x, y))
+ self.assertONNX(operator.gt, (x, y))
def test_le(self):
x = torch.randn(3, 4, requires_grad=False).int()
y = torch.randn(3, 4, requires_grad=False).int()
- self.assertONNX(lambda x, y: x <= y, (x, y))
+ self.assertONNX(operator.le, (x, y))
def test_ge(self):
x = torch.randn(3, 4, requires_grad=False).int()
y = torch.randn(3, 4, requires_grad=False).int()
- self.assertONNX(lambda x, y: x >= y, (x, y))
+ self.assertONNX(operator.ge, (x, y))
def test_exp(self):
x = torch.randn(3, 4, requires_grad=True)
@@ -862,7 +863,7 @@
def test_master_opset(self):
x = torch.randn(2, 3).float()
y = torch.randn(2, 3).float()
- self.assertONNX(lambda x, y: x + y, (x, y), opset_version=10)
+ self.assertONNX(operator.add, (x, y), opset_version=10)
def test_std(self):
x = torch.randn(2, 3, 4).float()
diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py
index 9fcc8b4..d5fcd55 100644
--- a/test/test_binary_ufuncs.py
+++ b/test/test_binary_ufuncs.py
@@ -2980,17 +2980,17 @@
@onlyCPU
@dtypes(torch.float)
def test_cdiv(self, device, dtype):
- self._test_cop(torch.div, lambda x, y: x / y, dtype, device)
+ self._test_cop(torch.div, operator.truediv, dtype, device)
@onlyCPU
@dtypes(torch.float)
def test_cremainder(self, device, dtype):
- self._test_cop(torch.remainder, lambda x, y: x % y, dtype, device)
+ self._test_cop(torch.remainder, operator.mod, dtype, device)
@onlyCPU
@dtypes(torch.float)
def test_cmul(self, device, dtype):
- self._test_cop(torch.mul, lambda x, y: x * y, dtype, device)
+ self._test_cop(torch.mul, operator.mul, dtype, device)
@onlyCPU
@dtypes(torch.float)
diff --git a/test/test_datapipe.py b/test/test_datapipe.py
index 61f086b..469994e 100644
--- a/test/test_datapipe.py
+++ b/test/test_datapipe.py
@@ -63,6 +63,7 @@
from torch.utils.data.datapipes.dataframe import CaptureDataFrame
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
+import operator
try:
import dill
@@ -1361,8 +1362,8 @@
# Unmatched input columns with fn arguments
_helper(None, fn_n1, 1, error=ValueError)
_helper(None, fn_n1, [0, 1, 2], error=ValueError)
- _helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError)
- _helper(None, lambda d0, d1: d0 + d1, [0, 1, 2], error=ValueError)
+ _helper(None, operator.add, 0, error=ValueError)
+ _helper(None, operator.add, [0, 1, 2], error=ValueError)
_helper(None, fn_cmplx, 0, 1, ValueError)
_helper(None, fn_n1_pos, 1, error=ValueError)
_helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError)
diff --git a/test/test_indexing.py b/test/test_indexing.py
index cfaedd5..e5e854d 100644
--- a/test/test_indexing.py
+++ b/test/test_indexing.py
@@ -16,6 +16,7 @@
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA,
onlyNativeDeviceTypes, skipXLA)
+import operator
class TestIndexing(TestCase):
@@ -138,7 +139,7 @@
def consec(size, start=1):
# Creates the sequence in float since CPU half doesn't support the
# needed operations. Converts to dtype before returning.
- numel = reduce(lambda x, y: x * y, size, 1)
+ numel = reduce(operator.mul, size, 1)
sequence = torch.ones(numel, dtype=torch.float, device=device).cumsum(0)
sequence.add_(start - 1)
return sequence.view(*size).to(dtype=dtype)
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 3f00da5..8d60d66 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -33,6 +33,7 @@
_get_torch_cuda_version
from torch.distributions.binomial import Binomial
import torch.backends.opt_einsum as opt_einsum
+import operator
# Protects against includes accidentally setting the default dtype
assert torch.get_default_dtype() is torch.float32
@@ -7008,7 +7009,7 @@
# mat_chars denotes matrix characteristics
# possible values are: sym, sym_psd, sym_pd, sing, non_sym
def run_test(matsize, batchdims, mat_chars):
- num_matrices = reduce(lambda x, y: x * y, batchdims, 1)
+ num_matrices = reduce(operator.mul, batchdims, 1)
list_of_matrices = []
for idx in range(num_matrices):
diff --git a/test/test_mps.py b/test/test_mps.py
index ae43dd5..cf628a7 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -44,6 +44,7 @@
import torch
import torch.utils._pytree as pytree
from itertools import product
+import operator
test_consistency_op_db = copy.deepcopy(op_db)
test_error_inputs_op_db = copy.deepcopy(op_db)
@@ -1388,11 +1389,11 @@
return joined_x.view(1, joined_x.numel())
def _avg_pool2d(self, x, kernel_size):
- size = reduce((lambda x, y: x * y), kernel_size)
+ size = reduce(operator.mul, kernel_size)
return self._sum_pool2d(x, kernel_size) / size
def _avg_pool3d(self, x, kernel_size):
- size = reduce((lambda x, y: x * y), kernel_size)
+ size = reduce(operator.mul, kernel_size)
return self._sum_pool3d(x, kernel_size) / size
def test_avg_pool2d_with_zero_divisor(self):
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index 12061aa..b83933d 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -22,6 +22,7 @@
all_types_and_complex, floating_and_complex_types_and)
from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED
+import operator
if TEST_SCIPY:
import scipy.sparse as sp
@@ -3310,7 +3311,7 @@
# random bool vector w/ length equal to max possible nnz for the sparse_shape
mask_source = make_tensor(batch_mask_shape, dtype=torch.bool, device=device).flatten()
- n_batch = functools.reduce(lambda x, y: x * y, batch_shape, 1)
+ n_batch = functools.reduce(operator.mul, batch_shape, 1)
# stack random permutations of the source for each batch
mask = torch.stack([mask_source[torch.randperm(mask_source.numel())]
diff --git a/test/test_testing.py b/test/test_testing.py
index 542601d..12ef27c 100644
--- a/test/test_testing.py
+++ b/test/test_testing.py
@@ -29,6 +29,7 @@
from torch.testing._internal.common_dtype import all_types_and_complex_and, floating_types
from torch.testing._internal.common_modules import modules, module_db, ModuleInfo
from torch.testing._internal.opinfo.core import SampleInput, DecorateInfo, OpInfo
+import operator
# For testing TestCase methods and torch.testing functions
class TestTesting(TestCase):
@@ -1427,7 +1428,7 @@
@parametrize("noncontiguous", [False, True])
@parametrize("shape", [tuple(), (0,), (1,), (1, 1), (2,), (2, 3), (8, 16, 32)])
def test_noncontiguous(self, dtype, device, noncontiguous, shape):
- numel = functools.reduce(lambda a, b: a * b, shape, 1)
+ numel = functools.reduce(operator.mul, shape, 1)
t = torch.testing.make_tensor(shape, dtype=dtype, device=device, noncontiguous=noncontiguous)
self.assertEqual(t.is_contiguous(), not noncontiguous or numel < 2)
diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py
index db808c7..ceef13b 100644
--- a/test/test_type_promotion.py
+++ b/test/test_type_promotion.py
@@ -18,6 +18,7 @@
import numpy as np
+import operator
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
@@ -550,37 +551,37 @@
name="lt",
out_op=lambda x, y, d: torch.lt(x, y, out=torch.empty(0, dtype=torch.bool, device=d)),
ret_op=lambda x, y: torch.lt(x, y),
- compare_op=lambda x, y: x < y,
+ compare_op=operator.lt,
),
dict(
name="le",
out_op=lambda x, y, d: torch.le(x, y, out=torch.empty(0, dtype=torch.bool, device=d)),
ret_op=lambda x, y: torch.le(x, y),
- compare_op=lambda x, y: x <= y,
+ compare_op=operator.le,
),
dict(
name="gt",
out_op=lambda x, y, d: torch.gt(x, y, out=torch.empty(0, dtype=torch.bool, device=d)),
ret_op=lambda x, y: torch.gt(x, y),
- compare_op=lambda x, y: x > y,
+ compare_op=operator.gt,
),
dict(
name="ge",
out_op=lambda x, y, d: torch.ge(x, y, out=torch.empty(0, dtype=torch.bool, device=d)),
ret_op=lambda x, y: torch.ge(x, y),
- compare_op=lambda x, y: x >= y,
+ compare_op=operator.ge,
),
dict(
name="eq",
out_op=lambda x, y, d: torch.eq(x, y, out=torch.empty(0, dtype=torch.bool, device=d)),
ret_op=lambda x, y: torch.eq(x, y),
- compare_op=lambda x, y: x == y,
+ compare_op=operator.eq,
),
dict(
name="ne",
out_op=lambda x, y, d: torch.ne(x, y, out=torch.empty(0, dtype=torch.bool, device=d)),
ret_op=lambda x, y: torch.ne(x, y),
- compare_op=lambda x, y: x != y,
+ compare_op=operator.ne,
),
]
for op in comparison_ops:
@@ -627,12 +628,12 @@
@onlyNativeDeviceTypes
def test_complex_assertraises(self, device):
comparison_ops = [
- dict(name="lt", compare_op=lambda x, y: x < y, ),
- dict(name="le", compare_op=lambda x, y: x <= y, ),
- dict(name="gt", compare_op=lambda x, y: x > y, ),
- dict(name="ge", compare_op=lambda x, y: x >= y, ),
- dict(name="eq", compare_op=lambda x, y: x == y, ),
- dict(name="ne", compare_op=lambda x, y: x != y, ),
+ dict(name="lt", compare_op=operator.lt, ),
+ dict(name="le", compare_op=operator.le, ),
+ dict(name="gt", compare_op=operator.gt, ),
+ dict(name="ge", compare_op=operator.ge, ),
+ dict(name="eq", compare_op=operator.eq, ),
+ dict(name="ne", compare_op=operator.ne, ),
]
for op in comparison_ops:
is_cuda = torch.device(device).type == 'cuda'
diff --git a/test/torch_np/numpy_tests/core/test_numeric.py b/test/torch_np/numpy_tests/core/test_numeric.py
index 7d9343b..131ef7b 100644
--- a/test/torch_np/numpy_tests/core/test_numeric.py
+++ b/test/torch_np/numpy_tests/core/test_numeric.py
@@ -15,6 +15,7 @@
IS_WASM = False
HAS_REFCOUNT = True
+import operator
from unittest import expectedFailure as xfail, skipIf as skipif, SkipTest
from hypothesis import given, strategies as st
@@ -701,25 +702,19 @@
# The value of tiny for double double is NaN, so we need to
# pass the assert
if not np.isnan(ft_tiny):
- self.assert_raises_fpe(underflow, lambda a, b: a / b, ft_tiny, ft_max)
- self.assert_raises_fpe(underflow, lambda a, b: a * b, ft_tiny, ft_tiny)
- self.assert_raises_fpe(overflow, lambda a, b: a * b, ft_max, ftype(2))
- self.assert_raises_fpe(overflow, lambda a, b: a / b, ft_max, ftype(0.5))
- self.assert_raises_fpe(overflow, lambda a, b: a + b, ft_max, ft_max * ft_eps)
- self.assert_raises_fpe(overflow, lambda a, b: a - b, -ft_max, ft_max * ft_eps)
+ self.assert_raises_fpe(underflow, operator.truediv, ft_tiny, ft_max)
+ self.assert_raises_fpe(underflow, operator.mul, ft_tiny, ft_tiny)
+ self.assert_raises_fpe(overflow, operator.mul, ft_max, ftype(2))
+ self.assert_raises_fpe(overflow, operator.truediv, ft_max, ftype(0.5))
+ self.assert_raises_fpe(overflow, operator.add, ft_max, ft_max * ft_eps)
+ self.assert_raises_fpe(overflow, operator.sub, -ft_max, ft_max * ft_eps)
self.assert_raises_fpe(overflow, np.power, ftype(2), ftype(2**fi.nexp))
- self.assert_raises_fpe(divbyzero, lambda a, b: a / b, ftype(1), ftype(0))
- self.assert_raises_fpe(
- invalid, lambda a, b: a / b, ftype(np.inf), ftype(np.inf)
- )
- self.assert_raises_fpe(invalid, lambda a, b: a / b, ftype(0), ftype(0))
- self.assert_raises_fpe(
- invalid, lambda a, b: a - b, ftype(np.inf), ftype(np.inf)
- )
- self.assert_raises_fpe(
- invalid, lambda a, b: a + b, ftype(np.inf), ftype(-np.inf)
- )
- self.assert_raises_fpe(invalid, lambda a, b: a * b, ftype(0), ftype(np.inf))
+ self.assert_raises_fpe(divbyzero, operator.truediv, ftype(1), ftype(0))
+ self.assert_raises_fpe(invalid, operator.truediv, ftype(np.inf), ftype(np.inf))
+ self.assert_raises_fpe(invalid, operator.truediv, ftype(0), ftype(0))
+ self.assert_raises_fpe(invalid, operator.sub, ftype(np.inf), ftype(np.inf))
+ self.assert_raises_fpe(invalid, operator.add, ftype(np.inf), ftype(-np.inf))
+ self.assert_raises_fpe(invalid, operator.mul, ftype(0), ftype(np.inf))
@skipif(IS_WASM, reason="no wasm fp exception support")
def test_warnings(self):
diff --git a/tools/testing/target_determination/heuristics/interface.py b/tools/testing/target_determination/heuristics/interface.py
index 4846165..68ab45b 100644
--- a/tools/testing/target_determination/heuristics/interface.py
+++ b/tools/testing/target_determination/heuristics/interface.py
@@ -1,3 +1,4 @@
+import operator
import sys
from abc import abstractmethod
from copy import copy
@@ -179,14 +180,10 @@
self._test_priorities[new_relevance.value].append(upgraded_tests)
def set_test_relevance(self, test_run: TestRun, new_relevance: Relevance) -> None:
- return self._update_test_relevance(
- test_run, new_relevance, lambda curr, new: curr == new
- )
+ return self._update_test_relevance(test_run, new_relevance, operator.eq)
def raise_test_relevance(self, test_run: TestRun, new_relevance: Relevance) -> None:
- return self._update_test_relevance(
- test_run, new_relevance, lambda curr, new: curr >= new
- )
+ return self._update_test_relevance(test_run, new_relevance, operator.ge)
def validate_test_priorities(self) -> None:
# Union all TestRuns that contain include/exclude pairs
diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py
index 81044c0..03aa55e 100644
--- a/torch/_inductor/fx_passes/mkldnn_fusion.py
+++ b/torch/_inductor/fx_passes/mkldnn_fusion.py
@@ -680,7 +680,7 @@
:-1
] == torch.Size(reshape_2[:-1])
can_remove_reshape = can_remove_reshape and (
- reduce(lambda x, y: x * y, reshape_2[:-1]) == reshape_1[0]
+ reduce(operator.mul, reshape_2[:-1]) == reshape_1[0]
)
if can_remove_reshape:
diff --git a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py
index 28f32bb..448c937 100644
--- a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py
+++ b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py
@@ -4,6 +4,7 @@
from typing import Any, List, Optional, Tuple
from .base_data_sparsifier import BaseDataSparsifier
+import operator
__all__ = ['DataNormSparsifier']
@@ -35,7 +36,7 @@
sparse_block_shape: Tuple[int, int] = (1, 4),
zeros_per_block: Optional[int] = None, norm: str = 'L1'):
if zeros_per_block is None:
- zeros_per_block = reduce((lambda x, y: x * y), sparse_block_shape)
+ zeros_per_block = reduce(operator.mul, sparse_block_shape)
assert norm in ['L1', 'L2'], "only L1 and L2 norm supported at the moment"
@@ -95,7 +96,7 @@
data_norm = F.avg_pool2d(data[None, None, :], kernel_size=sparse_block_shape,
stride=sparse_block_shape, ceil_mode=True)
- values_per_block = reduce((lambda x, y: x * y), sparse_block_shape)
+ values_per_block = reduce(operator.mul, sparse_block_shape)
data_norm = data_norm.flatten()
num_blocks = len(data_norm)
@@ -116,7 +117,7 @@
def update_mask(self, name, data, sparsity_level,
sparse_block_shape, zeros_per_block, **kwargs):
- values_per_block = reduce((lambda x, y: x * y), sparse_block_shape)
+ values_per_block = reduce(operator.mul, sparse_block_shape)
if zeros_per_block > values_per_block:
raise ValueError("Number of zeros per block cannot be more than "
"the total number of elements in that block.")
diff --git a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py
index 36903dd..2b24ca3 100644
--- a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py
+++ b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py
@@ -5,6 +5,7 @@
import torch.nn.functional as F
from .base_sparsifier import BaseSparsifier
+import operator
__all__ = ["WeightNormSparsifier"]
@@ -56,7 +57,7 @@
zeros_per_block: Optional[int] = None,
norm: Optional[Union[Callable, int]] = None):
if zeros_per_block is None:
- zeros_per_block = reduce((lambda x, y: x * y), sparse_block_shape)
+ zeros_per_block = reduce(operator.mul, sparse_block_shape)
defaults = {
"sparsity_level": sparsity_level,
"sparse_block_shape": sparse_block_shape,
@@ -108,7 +109,7 @@
mask.data = torch.ones_like(mask)
return mask
- values_per_block = reduce((lambda x, y: x * y), sparse_block_shape)
+ values_per_block = reduce(operator.mul, sparse_block_shape)
if values_per_block > 1:
# Reduce the data
data = F.avg_pool2d(
@@ -145,7 +146,7 @@
block_h, block_w = sparse_block_shape
dh = (block_h - h % block_h) % block_h
dw = (block_w - w % block_w) % block_w
- values_per_block = reduce((lambda x, y: x * y), sparse_block_shape)
+ values_per_block = reduce(operator.mul, sparse_block_shape)
if mask is None:
mask = torch.ones((h + dh, w + dw), device=data.device)
@@ -174,7 +175,7 @@
def update_mask(self, module, tensor_name, sparsity_level, sparse_block_shape,
zeros_per_block, **kwargs):
- values_per_block = reduce((lambda x, y: x * y), sparse_block_shape)
+ values_per_block = reduce(operator.mul, sparse_block_shape)
if zeros_per_block > values_per_block:
raise ValueError(
"Number of zeros per block cannot be more than the total number of elements in that block."
diff --git a/torch/autograd/_functions/tensor.py b/torch/autograd/_functions/tensor.py
index 6f3f45c..f091d38 100644
--- a/torch/autograd/_functions/tensor.py
+++ b/torch/autograd/_functions/tensor.py
@@ -1,3 +1,4 @@
+import operator
import warnings
from functools import reduce
@@ -31,7 +32,7 @@
@staticmethod
def forward(ctx, tensor, sizes):
ctx.sizes = sizes
- ctx.numel = reduce(lambda x, y: x * y, sizes, 1)
+ ctx.numel = reduce(operator.mul, sizes, 1)
if tensor.numel() != ctx.numel:
raise RuntimeError(
(
diff --git a/torch/autograd/_functions/utils.py b/torch/autograd/_functions/utils.py
index 735b624..7111d89 100644
--- a/torch/autograd/_functions/utils.py
+++ b/torch/autograd/_functions/utils.py
@@ -1,3 +1,4 @@
+import operator
from functools import reduce
@@ -38,8 +39,8 @@
supported = True
len1 = len(dims1)
len2 = len(dims2)
- numel1 = reduce(lambda x, y: x * y, dims1)
- numel2 = reduce(lambda x, y: x * y, dims2)
+ numel1 = reduce(operator.mul, dims1)
+ numel2 = reduce(operator.mul, dims2)
if len1 < len2:
broadcast = True
if numel2 != 1:
diff --git a/torch/backends/_nnapi/serializer.py b/torch/backends/_nnapi/serializer.py
index b636498..748132e 100644
--- a/torch/backends/_nnapi/serializer.py
+++ b/torch/backends/_nnapi/serializer.py
@@ -2,6 +2,7 @@
import enum
import functools
import logging
+import operator
import struct
import sys
from typing import List, NamedTuple, Optional, Tuple
@@ -1032,11 +1033,7 @@
out_shape = (
in_oper.shape[:start_dim]
- + (
- functools.reduce(
- lambda x, y: x * y, in_oper.shape[start_dim : end_dim + 1]
- ),
- )
+ + (functools.reduce(operator.mul, in_oper.shape[start_dim : end_dim + 1]),)
+ in_oper.shape[end_dim + 1 :]
)
diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py
index 1fc3114..06141fd 100644
--- a/torch/distributed/_shard/sharded_tensor/api.py
+++ b/torch/distributed/_shard/sharded_tensor/api.py
@@ -45,6 +45,7 @@
)
from torch.distributed.remote_device import _remote_device
from torch.utils import _pytree as pytree
+import operator
# Tracking for sharded tensor objects.
_sharded_tensor_lock = threading.Lock()
@@ -394,7 +395,7 @@
Default: ``None``
"""
def shard_size(shard_md):
- return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
+ return reduce(operator.mul, shard_md.shard_sizes) # type: ignore[attr-defined]
if enforce_dtype:
warnings.warn("enforce_dtype is deprecated. Please use dtype instead.")
diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py
index 8c35e0f..e44a75d 100644
--- a/torch/fx/experimental/graph_gradual_typechecker.py
+++ b/torch/fx/experimental/graph_gradual_typechecker.py
@@ -244,8 +244,8 @@
elif isinstance(t1, TensorType):
assert isinstance(t1, TensorType)
a = [e if e != Dyn else 1 for e in t1.__args__]
- p1 = reduce(lambda x, y: x * y, a)
- p2 = reduce(lambda x, y: x * y, t2)
+ p1 = reduce(operator.mul, a)
+ p2 = reduce(operator.mul, t2)
if p1 % p2 == 0 or p2 % p1 == 0:
n.type = t2_type
return t2_type
@@ -498,7 +498,7 @@
if Dyn in mid:
mid = [Dyn]
else:
- mid = [reduce(lambda x, y: x * y, my_args[start_dim:end_dim])]
+ mid = [reduce(operator.mul, my_args[start_dim:end_dim])]
new_type_list = lhs + mid + rhs
return TensorType(tuple(new_type_list))
else:
diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py
index b3aa187f..8f2dff4 100644
--- a/torch/fx/experimental/sym_node.py
+++ b/torch/fx/experimental/sym_node.py
@@ -541,9 +541,9 @@
reflectable_magic_methods = {
- "add": lambda a, b: a + b,
- "sub": lambda a, b: a - b,
- "mul": lambda a, b: a * b,
+ "add": operator.add,
+ "sub": operator.sub,
+ "mul": operator.mul,
"mod": _sympy_mod,
"pow": _sympy_pow,
"and": _sympy_and,
@@ -676,7 +676,7 @@
magic_methods = {
**reflectable_magic_methods,
- "sym_not": lambda a: ~a,
+ "sym_not": operator.invert,
"eq": _sympy_eq,
"ne": _sympy_ne,
"gt": _sympy_gt,
@@ -686,7 +686,7 @@
"floor": _sympy_floor,
"sym_float": _sympy_sym_float,
"ceil": _sympy_ceil,
- "neg": lambda a: -a,
+ "neg": operator.neg,
"sym_min": _sympy_min,
"sym_max": _sympy_max,
"sym_ite": _sympy_ite,
diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py
index 04dc803..053606e 100644
--- a/torch/testing/_internal/common_distributed.py
+++ b/torch/testing/_internal/common_distributed.py
@@ -41,6 +41,7 @@
_uninstall_threaded_pg,
ProcessLocalGroup,
)
+import operator
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -426,7 +427,7 @@
def compute_sum(fn, world_size: int):
return reduce(
- lambda a, b: a + b, [fn(rank, world_size) for rank in range(world_size)]
+ operator.add, [fn(rank, world_size) for rank in range(world_size)]
)
return [
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index cafcd84..9f0939a 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -18219,7 +18219,7 @@
}
""",
num_outputs=1),
- ref=lambda i0, i1: i0 + i1,
+ ref=operator.add,
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.bool),
sample_inputs_func=partial(sample_inputs_jiterator, num_inputs=2, alpha=-0.42),
supports_out=False,
diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py
index e5f05b7..bde3f10 100644
--- a/torch/testing/_internal/distributed/distributed_test.py
+++ b/torch/testing/_internal/distributed/distributed_test.py
@@ -84,6 +84,7 @@
import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer
from torch.utils.data.distributed import DistributedSampler
+import operator
try:
import torchvision
@@ -2167,7 +2168,7 @@
dist.ReduceOp.PRODUCT,
2,
10,
- reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2),
+ reduce(operator.mul, [10] * (len(group) - 1), 2),
)
@skip_but_pass_in_sandcastle_if(
@@ -2233,7 +2234,7 @@
dist.ReduceOp.PRODUCT,
2,
10,
- reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2),
+ reduce(operator.mul, [10] * (len(group) - 1), 2),
)
@skip_but_pass_in_sandcastle_if(
@@ -2299,7 +2300,7 @@
dist.ReduceOp.PRODUCT,
2,
10,
- reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2),
+ reduce(operator.mul, [10] * (len(group) - 1), 2),
)
@skip_but_pass_in_sandcastle_if(
@@ -2821,7 +2822,7 @@
dist.ReduceOp.PRODUCT,
2,
10,
- reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2),
+ reduce(operator.mul, [10] * (len(group) - 1), 2),
)
@skip_but_pass_in_sandcastle_if(
@@ -2871,7 +2872,7 @@
dist.ReduceOp.PRODUCT,
2,
10,
- reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2),
+ reduce(operator.mul, [10] * (len(group) - 1), 2),
)
@skip_if_small_worldsize
@@ -2921,7 +2922,7 @@
dist.ReduceOp.PRODUCT,
2,
10,
- reduce((lambda x, y: x * y), [10] * (len(group) - 1), 2),
+ reduce(operator.mul, [10] * (len(group) - 1), 2),
)
@skip_but_pass_in_sandcastle_if(
diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
index 524ee9e..408d07a 100644
--- a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
+++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
@@ -17,6 +17,7 @@
import torch
from torch.utils.benchmark.utils import common, cpp_jit
from torch.utils.benchmark.utils._stubs import CallgrindModuleType
+import operator
__all__ = ["FunctionCount", "FunctionCounts", "CallgrindStats", "CopyIfCallgrind"]
@@ -100,7 +101,7 @@
self,
other: "FunctionCounts",
) -> "FunctionCounts":
- return self._merge(other, lambda c: -c)
+ return self._merge(other, operator.neg)
def __mul__(self, other: Union[int, float]) -> "FunctionCounts":
return self._from_dict({