[BE] Remove dependency on `six` and `future` (#94709)
Remove the Python 2 and 3 compatibility library [six](https://pypi.org/project/six) and [future](https://pypi.org/project/future) and `torch._six`. We only support Python 3.8+ now. It's time to retire them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94709
Approved by: https://github.com/malfet, https://github.com/Skylion007
diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt
index 36c0604..f3b5a0a 100644
--- a/.ci/docker/requirements-ci.txt
+++ b/.ci/docker/requirements-ci.txt
@@ -36,11 +36,6 @@
#Pinned versions: 2.0
#test that import:
-#future #this breaks linux-bionic-rocm4.5-py3.7
-#Description: compatibility layer between python 2 and python 3
-#Pinned versions:
-#test that import:
-
hypothesis==5.35.1
# Pin hypothesis to avoid flakiness: https://github.com/pytorch/pytorch/issues/31136
#Description: advanced library for generating parametrized tests
diff --git a/.circleci/config.yml b/.circleci/config.yml
index 30178d9..5cb89ac 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -1101,7 +1101,7 @@
cd ${PROJ_ROOT}/ios/TestApp/benchmark
mkdir -p ../models
if [ ${USE_COREML_DELEGATE} == 1 ]; then
- pip install coremltools==5.0b5 protobuf==3.20.1 six==1.16.0
+ pip install coremltools==5.0b5 protobuf==3.20.1
python coreml_backend.py
else
cd "${PROJ_ROOT}"
diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh
index 323d461..f273816 100755
--- a/.circleci/scripts/binary_linux_test.sh
+++ b/.circleci/scripts/binary_linux_test.sh
@@ -82,8 +82,7 @@
mkl>=2018 \
ninja \
typing-extensions \
- ${PROTOBUF_PACKAGE} \
- six
+ ${PROTOBUF_PACKAGE}
if [[ "$DESIRED_CUDA" == 'cpu' ]]; then
retry conda install -c pytorch -y cpuonly
else
@@ -100,7 +99,7 @@
)
elif [[ "$PACKAGE_TYPE" != libtorch ]]; then
pip install "\$pkg" --extra-index-url "https://download.pytorch.org/whl/nightly/${DESIRED_CUDA}"
- retry pip install -q future numpy protobuf typing-extensions six
+ retry pip install -q numpy protobuf typing-extensions
fi
if [[ "$PACKAGE_TYPE" == libtorch ]]; then
pkg="\$(ls /final_pkgs/*-latest.zip)"
diff --git a/.circleci/verbatim-sources/job-specs/job-specs-custom.yml b/.circleci/verbatim-sources/job-specs/job-specs-custom.yml
index 6050ea0..f03e173 100644
--- a/.circleci/verbatim-sources/job-specs/job-specs-custom.yml
+++ b/.circleci/verbatim-sources/job-specs/job-specs-custom.yml
@@ -626,7 +626,7 @@
cd ${PROJ_ROOT}/ios/TestApp/benchmark
mkdir -p ../models
if [ ${USE_COREML_DELEGATE} == 1 ]; then
- pip install coremltools==5.0b5 protobuf==3.20.1 six==1.16.0
+ pip install coremltools==5.0b5 protobuf==3.20.1
python coreml_backend.py
else
cd "${PROJ_ROOT}"
diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt
index 494b72a..1ad7074 100644
--- a/.github/ci_commit_pins/xla.txt
+++ b/.github/ci_commit_pins/xla.txt
@@ -1 +1 @@
-9cbcdb4008c14ad8251c5d4d7723aa616f659edb
+d29eb67c27af0f18d4f487d76b86f43b0a69aade
diff --git a/.github/requirements/conda-env-macOS-ARM64 b/.github/requirements/conda-env-macOS-ARM64
index 05dede3..b467a7b 100644
--- a/.github/requirements/conda-env-macOS-ARM64
+++ b/.github/requirements/conda-env-macOS-ARM64
@@ -5,7 +5,6 @@
typing-extensions=4.3.0
dataclasses=0.8
pip=22.2.2
-six=1.16.0
pillow=9.2.0
pkg-config=0.29.2
wheel=0.37.1
diff --git a/.github/requirements/conda-env-macOS-X64 b/.github/requirements/conda-env-macOS-X64
index 18e6b06..a22e6c4 100644
--- a/.github/requirements/conda-env-macOS-X64
+++ b/.github/requirements/conda-env-macOS-X64
@@ -7,7 +7,6 @@
typing-extensions=4.3.0
dataclasses=0.8
pip=22.2.2
-six=1.16.0
pillow=9.2.0
libuv=1.40.0
pkg-config=0.29.2
diff --git a/.github/requirements/pip-requirements-iOS.txt b/.github/requirements/pip-requirements-iOS.txt
index 773be0e..0befad8 100644
--- a/.github/requirements/pip-requirements-iOS.txt
+++ b/.github/requirements/pip-requirements-iOS.txt
@@ -1,4 +1,3 @@
# iOS simulator requirements
coremltools==5.0b5
protobuf==3.20.2
-six==1.16.0
diff --git a/.github/workflows/run_torchbench.yml b/.github/workflows/run_torchbench.yml
index 676379e..8d55f6a 100644
--- a/.github/workflows/run_torchbench.yml
+++ b/.github/workflows/run_torchbench.yml
@@ -41,7 +41,7 @@
conda activate pr-ci
conda install -y numpy="${NUMPY_VERSION}" requests ninja pyyaml mkl mkl-include \
setuptools cmake=3.22.* typing-extensions boto3 \
- six pillow pytest tabulate gitpython git-lfs tqdm psutil
+ pillow pytest tabulate gitpython git-lfs tqdm psutil
pip install --pre torch torchvision torchtext -f https://download.pytorch.org/whl/nightly/cu116/torch_nightly.html
- name: Setup TorchBench branch
run: |
diff --git a/.lintrunner.toml b/.lintrunner.toml
index c76a07c..8782a8c 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -145,7 +145,6 @@
'expecttest==0.1.3',
'mypy==0.960',
'types-requests==2.27.25',
- 'types-six==1.16.15',
'types-PyYAML==6.0.7',
'types-tabulate==0.8.8',
'types-protobuf==3.19.18',
diff --git a/benchmarks/dynamo/Makefile b/benchmarks/dynamo/Makefile
index 90f7899..6dc0bf1 100644
--- a/benchmarks/dynamo/Makefile
+++ b/benchmarks/dynamo/Makefile
@@ -28,7 +28,7 @@
# conda create --name torchdynamo -y python=3.8
# conda activate torchdynamo
conda install -y astunparse numpy scipy ninja pyyaml mkl mkl-include setuptools cmake \
- typing-extensions six requests protobuf numba cython scikit-learn
+ typing-extensions requests protobuf numba cython scikit-learn
conda install -y -c pytorch magma-cuda116
conda install -y -c conda-forge librosa
(cd ../../../torchvision && python setup.py clean && python setup.py develop)
diff --git a/caffe2/experiments/python/device_reduce_sum_bench.py b/caffe2/experiments/python/device_reduce_sum_bench.py
index ce9364c..c57bff5 100644
--- a/caffe2/experiments/python/device_reduce_sum_bench.py
+++ b/caffe2/experiments/python/device_reduce_sum_bench.py
@@ -25,7 +25,6 @@
import logging
import os
-from six import add_metaclass
import numpy as np
from caffe2.python import workspace, core
@@ -46,8 +45,7 @@
return cls
-@add_metaclass(BenchmarkMeta)
-class Benchmark:
+class Benchmark(metaclass=BenchmarkMeta):
def __init__(self):
self.results = []
diff --git a/docs/caffe2/installation.md b/docs/caffe2/installation.md
index 6abc67f..6c8ac2f 100644
--- a/docs/caffe2/installation.md
+++ b/docs/caffe2/installation.md
@@ -58,10 +58,6 @@
## Python support
-To use Caffe2 in Python, you need two libraries, future and six.
-
- pip install future six
-
To run the tutorials, download additional source from GitHub.
git clone --recursive https://github.com/caffe2/tutorials.git caffe2_tutorials
diff --git a/docs/cpp/requirements.txt b/docs/cpp/requirements.txt
index ca3eb7d..da401f2 100644
--- a/docs/cpp/requirements.txt
+++ b/docs/cpp/requirements.txt
@@ -6,4 +6,3 @@
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
bs4
lxml
-six
diff --git a/pyproject.toml b/pyproject.toml
index 4570800..338bdc9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,7 +9,6 @@
"setuptools",
"cmake",
"typing-extensions",
- "six",
"requests",
]
# Use legacy backend to import local packages in setup.py
diff --git a/scripts/build_tegra_x1.sh b/scripts/build_tegra_x1.sh
index 49c559a..b1121ff 100755
--- a/scripts/build_tegra_x1.sh
+++ b/scripts/build_tegra_x1.sh
@@ -41,10 +41,6 @@
# the one provided by apt-get is quite old so we install it via pip
sudo pip install hypothesis
-# Install the six module, which includes Python 2 and 3 compatibility utilities,
-# and is required for Caffe2
-sudo pip install six
-
# Now, actually build the android target.
echo "Building caffe2"
cd $BUILD_ROOT
diff --git a/scripts/build_tizen.sh b/scripts/build_tizen.sh
index c9d26ce..33fc65c 100755
--- a/scripts/build_tizen.sh
+++ b/scripts/build_tizen.sh
@@ -95,10 +95,6 @@
# Obtain python hypothesis, which Caffe2 uses for unit testing. Note that
# the one provided by zypper is quite old so we install it via pip
sudo pip install hypothesis
-
-# Install the six module, which includes Python 2 and 3 compatibility utilities,
-# and is required for Caffe2
-sudo pip install six
}
caffe2_full_build(){
diff --git a/scripts/model_zoo/update-caffe2-models.py b/scripts/model_zoo/update-caffe2-models.py
index e9a5f28..7f9c8e9 100755
--- a/scripts/model_zoo/update-caffe2-models.py
+++ b/scripts/model_zoo/update-caffe2-models.py
@@ -6,7 +6,7 @@
import tarfile
import tempfile
-from six.moves.urllib.request import urlretrieve
+from urllib.request import urlretrieve
from caffe2.python.models.download import downloadFromURLToFile, getURLFromName, deleteDirectory
diff --git a/scripts/model_zoo/update-models-from-caffe2.py b/scripts/model_zoo/update-models-from-caffe2.py
index fb58871..9e408d6 100644
--- a/scripts/model_zoo/update-models-from-caffe2.py
+++ b/scripts/model_zoo/update-models-from-caffe2.py
@@ -17,7 +17,7 @@
import boto3
-from six.moves.urllib.request import urlretrieve
+from urllib.request import urlretrieve
from caffe2.python.models.download import downloadFromURLToFile, getURLFromName, deleteDirectory
from caffe2.proto import caffe2_pb2
diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py
index eb7afae..bd26fca 100644
--- a/test/distributed/test_store.py
+++ b/test/distributed/test_store.py
@@ -16,7 +16,6 @@
sys.exit(0)
import torch.testing._internal.common_utils as common
-from torch._six import string_classes
from torch.testing._internal.common_distributed import (
skip_if_win32,
create_tcp_store
@@ -336,7 +335,7 @@
self.store = {}
def set(self, key, value):
- if not isinstance(key, string_classes):
+ if not isinstance(key, str):
raise AssertionError("Expected set to be called with string key")
if type(value) is not bytes:
raise AssertionError("Expected set to be called with bytes value")
diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py
index 836b595..db36429 100644
--- a/test/distributions/test_distributions.py
+++ b/test/distributions/test_distributions.py
@@ -42,7 +42,7 @@
# Distributions tests use double as the default dtype
torch.set_default_dtype(torch.double)
-from torch._six import inf, nan
+from torch import inf, nan
from torch.testing._internal.common_utils import \
(TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN, load_tests,
gradcheck, skipIfTorchDynamo)
diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py
index e795d6b..9a9124a 100644
--- a/test/nn/test_pooling.py
+++ b/test/nn/test_pooling.py
@@ -10,7 +10,7 @@
import itertools
import math
-from torch._six import inf, nan
+from torch import inf, nan
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, set_default_dtype, \
diff --git a/test/test_autograd.py b/test/test_autograd.py
index efacfc0..9fecbab 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -23,7 +23,7 @@
import torch
from torch import nn
-from torch._six import inf, nan
+from torch import inf, nan
from torch.autograd.function import once_differentiable
from torch.autograd.profiler import (profile, record_function, emit_nvtx, emit_itt)
from torch.autograd.profiler_util import (_format_time, EventList, FunctionEvent, FunctionEventAvg)
diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py
index 82113ef..3f23be1 100644
--- a/test/test_binary_ufuncs.py
+++ b/test/test_binary_ufuncs.py
@@ -14,7 +14,7 @@
from functools import partial
import torch.autograd.forward_ad as fwAD
-from torch._six import inf, nan
+from torch import inf, nan
from torch.testing._internal.common_utils import (
TestCase,
slowTest,
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 9bb601c..344e66d 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -22,9 +22,9 @@
import torch
import torch.cuda
import torch.cuda.comm as comm
+from torch import inf, nan
from torch.nn.parallel import scatter_gather
from torch.utils.checkpoint import checkpoint_sequential
-from torch._six import inf, nan
from torch.testing._internal.common_utils import TestCase, freeze_rng_state, run_tests, \
NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_REMOTE_GPU, IS_SANDCASTLE, IS_WINDOWS, \
slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, TEST_WITH_ROCM, TEST_NUMPY, \
@@ -1595,7 +1595,7 @@
p = subprocess.Popen([sys.executable, '-c', f"""\
import sys
import torch
-from torch._six import inf, nan
+from torch import inf, nan
try:
with torch.random.fork_rng(devices=[0]):
torch.multinomial(torch.tensor({probs}).to('cuda'), 2, replacement=True)
diff --git a/test/test_mps.py b/test/test_mps.py
index c03e4e3..f45601fa 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -17,7 +17,7 @@
import torch.nn.functional as F
import itertools
from collections import defaultdict
-from torch._six import inf
+from torch import inf
from torch.nn import Parameter
from torch.testing._internal import opinfo
from torch.testing._internal.common_utils import \
diff --git a/test/test_nn.py b/test/test_nn.py
index fc1d623..be5ca93 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -21,7 +21,7 @@
# NN tests use double as the default dtype
torch.set_default_dtype(torch.double)
-from torch._six import inf, nan
+from torch import inf, nan
import torch.autograd.forward_ad as fwAD
import torch.backends.cudnn as cudnn
import torch.nn as nn
diff --git a/test/test_reductions.py b/test/test_reductions.py
index e14225d..29fc72e 100644
--- a/test/test_reductions.py
+++ b/test/test_reductions.py
@@ -11,7 +11,7 @@
from itertools import product, combinations, permutations
import warnings
-from torch._six import inf, nan
+from torch import inf, nan
from torch.testing import make_tensor
from torch.testing._internal.common_dtype import (
all_types_and_complex_and, get_all_math_dtypes, integral_types, complex_types, floating_types_and,
diff --git a/test/test_shape_ops.py b/test/test_shape_ops.py
index a43d632..d3fefca 100644
--- a/test/test_shape_ops.py
+++ b/test/test_shape_ops.py
@@ -8,7 +8,7 @@
import random
import warnings
-from torch._six import nan
+from torch import nan
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
TestCase, run_tests, skipIfTorchDynamo, torch_to_numpy_dtype_dict)
diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py
index 1343e1a..540df06 100644
--- a/test/test_sort_and_select.py
+++ b/test/test_sort_and_select.py
@@ -4,7 +4,7 @@
import numpy as np
import random
-from torch._six import nan
+from torch import nan
from itertools import permutations, product
from torch.testing import make_tensor
diff --git a/test/test_torch.py b/test/test_torch.py
index 205328f..7069ccc 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -24,7 +24,7 @@
import subprocess
import weakref
import sys
-from torch._six import inf, nan, string_classes
+from torch import inf, nan
from itertools import product, combinations, permutations
from functools import partial
from torch import multiprocessing as mp
@@ -8288,7 +8288,7 @@
ns_name = ns.__name__
skip_regexes = []
for r in skips:
- if isinstance(r, string_classes):
+ if isinstance(r, str):
skip_regexes.append(re.compile('^{}$'.format(re.escape(r))))
else:
skip_regexes.append(r)
diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py
index 77a1940..bb9107b 100644
--- a/test/test_unary_ufuncs.py
+++ b/test/test_unary_ufuncs.py
@@ -8,7 +8,7 @@
import random
import unittest
-from torch._six import inf, nan
+from torch import inf, nan
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
diff --git a/torch/_C/_VariableFunctions.pyi.in b/torch/_C/_VariableFunctions.pyi.in
index c3b167d..8a5a638 100644
--- a/torch/_C/_VariableFunctions.pyi.in
+++ b/torch/_C/_VariableFunctions.pyi.in
@@ -1,8 +1,7 @@
# ${generated_comment}
-from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided
+from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided, inf
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, Literal, TypeVar
-from torch._six import inf
from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout, SymInt, Device
import torch
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 3b565fb..1bd547c 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -2,7 +2,7 @@
import torch
from torch.package import PackageExporter
-from torch import Tensor
+from torch import Tensor, inf
from torch.autograd.graph import Node as _Node
from enum import Enum
from pathlib import Path
@@ -10,7 +10,6 @@
Any, BinaryIO, Callable, ContextManager, Dict, Iterable, Iterator, List,
NamedTuple, Optional, overload, Sequence, Tuple, TypeVar, Type, Union,
Literal, Generic, Set, AnyStr)
-from torch._six import inf
from torch.types import (
_int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage, SymInt, _dispatchkey
@@ -150,11 +149,11 @@
per_channel_affine_float_qparams: qscheme = ...
# Defined in torch/csrc/autograd/python_function.cpp
-class _FunctionBase(object):
+class _FunctionBase:
...
# Defined in torch/csrc/autograd/python_legacy_variable.cpp
-class _LegacyVariableBase(object):
+class _LegacyVariableBase(Tensor): # inherits from Tensor to appease mypy
def __init__(
self,
data: Optional[Tensor]=...,
@@ -168,7 +167,7 @@
class JITException: ...
-class Future(object):
+class Future:
def __init__(self, devices: List[device]) -> None: ...
def done(self) -> _bool: ...
def value(self) -> Any: ...
@@ -178,7 +177,7 @@
def set_result(self, result: Any) -> None: ...
def _set_unwrap_func(self, callback: Callable) -> None: ...
-class _Await(object):
+class _Await:
def __init__(self) -> None: ...
def fn(self) -> Callable: ...
def args(self) -> Tuple[Any, ...]: ...
@@ -700,7 +699,7 @@
def _test_only_remove_entry_to_op_version(op_name: str) -> None: ...
# Defined in torch/csrc/jit/python/script_init.cpp
-class ScriptModuleSerializer(object):
+class ScriptModuleSerializer:
def __init__(self, export_writer: PyTorchFileWriter) -> None: ...
def serialize(self, model: ScriptModule, script_module_id: _int) -> None: ...
def write_files(self) -> None: ...
@@ -708,14 +707,14 @@
...
# Defined in torch/csrc/jit/python/script_init.cpp
-class SerializationStorageContext(object):
+class SerializationStorageContext:
def __init__(self) -> None: ...
def has_storage(self, storage: Storage) -> _bool: ...
def get_or_add_storage(self, storage: Storage) -> _int: ...
...
# Defined in torch/csrc/jit/python/script_init.cpp
-class DeserializationStorageContext(object):
+class DeserializationStorageContext:
def __init__(self) -> None: ...
def get_storage(self, name: str, dtype: _dtype) -> Tensor: ...
def has_storage(self, name: str) -> _bool: ...
@@ -971,7 +970,7 @@
def _get_dispatch_stack_at(idx: _int) -> Any: ...
def _len_torch_dispatch_stack() -> _int: ...
-class _InferenceMode(object):
+class _InferenceMode:
def __init__(self, mode: _bool) -> None: ...
class _DisableFuncTorch:
@@ -987,7 +986,7 @@
def __init__(self, mode: _bool) -> None: ...
# Defined in torch/csrc/jit/python/script_init.cpp
-class LoggerBase(object):
+class LoggerBase:
...
class NoopLogger(LoggerBase):
@@ -1000,7 +999,7 @@
SUM = 0
AVG = 1
-class FileCheck(object):
+class FileCheck:
def run(self, test_string: str) -> None: ...
def check(self, test_string: str) -> 'FileCheck': ...
def check_not(self, test_string: str) -> 'FileCheck': ...
@@ -1012,7 +1011,7 @@
...
# Defined in torch/csrc/jit/python/init.cpp
-class PyTorchFileReader(object):
+class PyTorchFileReader:
@overload
def __init__(self, name: str) -> None: ...
@overload
@@ -1020,7 +1019,7 @@
def get_record(self, name: str) -> bytes: ...
...
-class PyTorchFileWriter(object):
+class PyTorchFileWriter:
@overload
def __init__(self, name: str) -> None: ...
@overload
@@ -1048,7 +1047,7 @@
def _rename_privateuse1_backend(backend: str) -> None: ...
# Defined in torch/csrc/Generator.cpp
-class Generator(object):
+class Generator:
device: _device
def __init__(self, device: Union[_device, str, None] = None) -> None: ...
def get_state(self) -> Tensor: ...
@@ -1127,28 +1126,28 @@
def _are_functorch_transforms_active() -> _bool: ...
# Define in torch/csrc/autograd/init.cpp
-class _DisablePythonDispatcher(object):
+class _DisablePythonDispatcher:
pass
-class _EnablePythonDispatcher(object):
+class _EnablePythonDispatcher:
pass
def _set_python_dispatcher(dispatcher: object) -> None: ...
# Defined in torch/csrc/utils/init.cpp
-class BenchmarkConfig(object):
+class BenchmarkConfig:
num_calling_threads: _int
num_worker_threads: _int
num_warmup_iters: _int
num_iters: _int
profiler_output_path: str
-class BenchmarkExecutionStats(object):
+class BenchmarkExecutionStats:
latency_avg_ms: _float
num_iters: _int
-class ThroughputBenchmark(object):
+class ThroughputBenchmark:
def __init__(self, module: Any) -> None: ...
def add_input(self, *args: Any, **kwargs: Any) -> None: ...
def run_once(self, *args: Any, **kwargs: Any) -> Any: ...
@@ -1162,7 +1161,9 @@
# Defined in torch/csrc/autograd/python_engine.cpp
class _ImperativeEngine:
- ...
+ def queue_callback(self, callback: Callable[[], None]) -> None: ...
+ def run_backward(self, *args: Any, **kwargs: Any) -> Tuple[Tensor, ...]: ...
+ def is_checkpoint_valid(self) -> _bool: ...
# Defined in torch/csrc/autograd/python_variable.cpp
class _TensorMeta(type):
diff --git a/torch/_C/return_types.pyi.in b/torch/_C/return_types.pyi.in
index 299f2d9..ca5e3f8 100644
--- a/torch/_C/return_types.pyi.in
+++ b/torch/_C/return_types.pyi.in
@@ -1,8 +1,7 @@
# ${generated_comment}
-from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided
+from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided, inf
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, Literal, TypeVar
-from torch._six import inf
from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout
diff --git a/torch/__init__.py b/torch/__init__.py
index 1e7850b..61062bf 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -28,8 +28,6 @@
else:
from .torch_version import __version__ as __version__
-from ._six import string_classes as _string_classes
-
from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Union
import builtins
@@ -593,7 +591,7 @@
torch.float64
"""
- if isinstance(t, _string_classes):
+ if isinstance(t, str):
t = _import_dotted_name(t)
_C._set_default_tensor_type(t)
diff --git a/torch/_six.py b/torch/_six.py
deleted file mode 100644
index 7ccc12f..0000000
--- a/torch/_six.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# Copyright (c) 2010-2017 Benjamin Peterson
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in all
-# copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-
-import math
-
-inf = math.inf
-nan = math.nan
-string_classes = (str, bytes)
-
-
-def with_metaclass(meta: type, *bases) -> type:
- """Create a base class with a metaclass."""
- # This requires a bit of explanation: the basic idea is to make a dummy
- # metaclass for one level of class instantiation that replaces itself with
- # the actual metaclass.
- class metaclass(meta): # type: ignore[misc, valid-type]
- def __new__(cls, name, this_bases, d):
- return meta(name, bases, d)
-
- @classmethod
- def __prepare__(cls, name, this_bases):
- return meta.__prepare__(name, bases)
-
- return type.__new__(metaclass, "temporary_class", (), {})
diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py
index 13d85f6..adea080 100644
--- a/torch/_tensor_str.py
+++ b/torch/_tensor_str.py
@@ -3,7 +3,7 @@
from typing import Optional
import torch
-from torch._six import inf
+from torch import inf
class __PrinterOptions:
diff --git a/torch/autograd/function.py b/torch/autograd/function.py
index b6100c6..880ef80 100644
--- a/torch/autograd/function.py
+++ b/torch/autograd/function.py
@@ -3,7 +3,6 @@
from torch._C import _functions
import torch._functorch as _functorch
import torch.utils.hooks as hooks
-from torch._six import with_metaclass
import functools
import warnings
from collections import OrderedDict
@@ -294,8 +293,7 @@
super(FunctionMeta, cls).__init__(name, bases, attrs)
-# mypy doesn't understand `with_metaclass` from torch._six
-class _SingleLevelFunction(with_metaclass(FunctionMeta, _C._FunctionBase, FunctionCtx, _HookMixin)): # type: ignore[misc]
+class _SingleLevelFunction(_C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta):
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
r"""
@@ -505,7 +503,7 @@
if not torch._C._are_functorch_transforms_active():
# See NOTE: [functorch vjp and autograd interaction]
args = _functorch.utils.unwrap_dead_wrappers(args)
- return super().apply(*args, **kwargs)
+ return super().apply(*args, **kwargs) # type: ignore[misc]
if cls.setup_context == _SingleLevelFunction.setup_context:
raise RuntimeError(
@@ -680,14 +678,14 @@
def _do_forward(self, *input):
self._nested_input = input
flat_input = tuple(_iter_tensors(input))
- flat_output = super()._do_forward(*flat_input)
+ flat_output = super()._do_forward(*flat_input) # type: ignore[misc]
nested_output = self._nested_output
nested_tensors = _unflatten(flat_output, self._nested_output)
return nested_tensors
def _do_backward(self, gradients, retain_variables):
self.retain_variables = retain_variables
- result = super()._do_backward(gradients, retain_variables)
+ result = super()._do_backward(gradients, retain_variables) # type: ignore[misc]
if not retain_variables:
del self._nested_output
del self._to_save_nested
@@ -713,7 +711,7 @@
@property
def saved_tensors(self):
- flat_tensors = super().saved_tensors
+ flat_tensors = super().saved_tensors # type: ignore[misc]
return _unflatten(flat_tensors, self._to_save_nested)
def mark_dirty(self, *args: Any, **kwargs: Any) -> None:
diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py
index 57b210e..ed841d4 100644
--- a/torch/autograd/variable.py
+++ b/torch/autograd/variable.py
@@ -1,15 +1,14 @@
import torch
-from torch._six import with_metaclass
+from torch._C import _ImperativeEngine as ImperativeEngine
+
__all__ = ["VariableMeta", "Variable"]
+
class VariableMeta(type):
def __instancecheck__(cls, other):
return isinstance(other, torch.Tensor)
-# mypy doesn't understand torch._six.with_metaclass
-class Variable(with_metaclass(VariableMeta, torch._C._LegacyVariableBase)): # type: ignore[misc]
- pass
-from torch._C import _ImperativeEngine as ImperativeEngine
-Variable._execution_engine = ImperativeEngine()
+class Variable(torch._C._LegacyVariableBase, metaclass=VariableMeta): # type: ignore[misc]
+ _execution_engine = ImperativeEngine()
diff --git a/torch/cuda/amp/autocast_mode.py b/torch/cuda/amp/autocast_mode.py
index cd3b7f4..d9347ec 100644
--- a/torch/cuda/amp/autocast_mode.py
+++ b/torch/cuda/amp/autocast_mode.py
@@ -6,7 +6,6 @@
HAS_NUMPY = True
except ModuleNotFoundError:
np = None # type: ignore[assignment]
-from torch._six import string_classes
from typing import Any
__all__ = ["autocast", "custom_fwd", "custom_bwd"]
@@ -48,7 +47,7 @@
if isinstance(value, torch.Tensor):
is_eligible = (value.is_floating_point() and value.is_cuda and (value.dtype is not torch.float64))
return value.to(dtype) if is_eligible else value
- elif isinstance(value, string_classes):
+ elif isinstance(value, str):
return value
elif HAS_NUMPY and isinstance(value, np.ndarray):
return value
diff --git a/torch/distributed/_composable/_ddp.py b/torch/distributed/_composable/_ddp.py
index 4a20665..a2a4cb3 100644
--- a/torch/distributed/_composable/_ddp.py
+++ b/torch/distributed/_composable/_ddp.py
@@ -81,7 +81,7 @@
# Enqueue delay allreduce for static graph training on the first
# iteration.
if state_dict["static_graph"] and state_dict["num_iterations"] == 1:
- Variable._execution_engine.queue_callback(ctx.reducer._delay_all_reduce)
+ Variable._execution_engine.queue_callback(ctx.reducer._delay_all_reduce) # type: ignore[call-arg,misc]
return (None, None, *grad_outputs)
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 00fa7ea..be0006d 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -32,7 +32,6 @@
get_debug_level,
Work
)
-from torch._six import string_classes
from torch.autograd.profiler import record_function
from .constants import default_pg_timeout
from .c10d_error_logger import _get_or_create_logger
@@ -178,7 +177,7 @@
backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI]
def __new__(cls, name: str):
- if not isinstance(name, string_classes):
+ if not isinstance(name, str):
raise ValueError("Backend name must be a string, but got: {}".format(name))
value = getattr(Backend, name.upper(), Backend.UNDEFINED)
diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py
index 5a4d6ce..4a6d132 100644
--- a/torch/distributed/rendezvous.py
+++ b/torch/distributed/rendezvous.py
@@ -11,7 +11,6 @@
from datetime import timedelta
from typing import Dict, Optional
-import torch._six as six
from torch.distributed import FileStore, PrefixStore, Store, TCPStore
from .constants import default_pg_timeout
@@ -91,7 +90,7 @@
def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
- if not isinstance(url, six.string_classes):
+ if not isinstance(url, str):
raise RuntimeError("`url` must be a string. {}: {}".format(type(url), url))
if not isinstance(rank, numbers.Integral):
diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py
index 9557484..9d9b0fd 100644
--- a/torch/distributions/bernoulli.py
+++ b/torch/distributions/bernoulli.py
@@ -1,7 +1,7 @@
from numbers import Number
import torch
-from torch._six import nan
+from torch import nan
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property
diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py
index 06372a3..7cff0e4 100644
--- a/torch/distributions/categorical.py
+++ b/torch/distributions/categorical.py
@@ -1,5 +1,5 @@
import torch
-from torch._six import nan
+from torch import nan
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import probs_to_logits, logits_to_probs, lazy_property
diff --git a/torch/distributions/cauchy.py b/torch/distributions/cauchy.py
index 8e45131..2ef0fb9 100644
--- a/torch/distributions/cauchy.py
+++ b/torch/distributions/cauchy.py
@@ -1,5 +1,5 @@
import math
-from torch._six import inf, nan
+from torch import inf, nan
from numbers import Number
import torch
diff --git a/torch/distributions/fishersnedecor.py b/torch/distributions/fishersnedecor.py
index fe9e2c4..26511ab 100644
--- a/torch/distributions/fishersnedecor.py
+++ b/torch/distributions/fishersnedecor.py
@@ -1,6 +1,6 @@
from numbers import Number
import torch
-from torch._six import nan
+from torch import nan
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.gamma import Gamma
diff --git a/torch/distributions/half_cauchy.py b/torch/distributions/half_cauchy.py
index fac77fc..c501076 100644
--- a/torch/distributions/half_cauchy.py
+++ b/torch/distributions/half_cauchy.py
@@ -1,7 +1,7 @@
import math
import torch
-from torch._six import inf
+from torch import inf
from torch.distributions import constraints
from torch.distributions.transforms import AbsTransform
from torch.distributions.cauchy import Cauchy
diff --git a/torch/distributions/half_normal.py b/torch/distributions/half_normal.py
index 3fa1e7e..184d6f1 100644
--- a/torch/distributions/half_normal.py
+++ b/torch/distributions/half_normal.py
@@ -1,7 +1,7 @@
import math
import torch
-from torch._six import inf
+from torch import inf
from torch.distributions import constraints
from torch.distributions.transforms import AbsTransform
from torch.distributions.normal import Normal
diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py
index 57eaade..26d7b47 100644
--- a/torch/distributions/kl.py
+++ b/torch/distributions/kl.py
@@ -4,7 +4,7 @@
from typing import Type, Dict, Callable, Tuple
import torch
-from torch._six import inf
+from torch import inf
from .bernoulli import Bernoulli
from .beta import Beta
diff --git a/torch/distributions/kumaraswamy.py b/torch/distributions/kumaraswamy.py
index b781490..249cdf0 100644
--- a/torch/distributions/kumaraswamy.py
+++ b/torch/distributions/kumaraswamy.py
@@ -1,5 +1,5 @@
import torch
-from torch._six import nan
+from torch import nan
from torch.distributions import constraints
from torch.distributions.uniform import Uniform
from torch.distributions.transformed_distribution import TransformedDistribution
diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py
index 4befced..579febb 100644
--- a/torch/distributions/multinomial.py
+++ b/torch/distributions/multinomial.py
@@ -1,5 +1,5 @@
import torch
-from torch._six import inf
+from torch import inf
from torch.distributions.binomial import Binomial
from torch.distributions.distribution import Distribution
from torch.distributions import Categorical
diff --git a/torch/distributions/studentT.py b/torch/distributions/studentT.py
index 674af46..83b06c6 100644
--- a/torch/distributions/studentT.py
+++ b/torch/distributions/studentT.py
@@ -1,7 +1,7 @@
import math
import torch
-from torch._six import inf, nan
+from torch import inf, nan
from torch.distributions import Chi2, constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import _standard_normal, broadcast_all
diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py
index b73bfc2..cbbd8d1 100644
--- a/torch/distributions/uniform.py
+++ b/torch/distributions/uniform.py
@@ -1,7 +1,7 @@
from numbers import Number
import torch
-from torch._six import nan
+from torch import nan
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
diff --git a/torch/distributions/wishart.py b/torch/distributions/wishart.py
index 3bc6ad4..0c9c541 100644
--- a/torch/distributions/wishart.py
+++ b/torch/distributions/wishart.py
@@ -4,7 +4,7 @@
from typing import Union
import torch
-from torch._six import nan
+from torch import nan
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import lazy_property
diff --git a/torch/fx/experimental/unification/multipledispatch/variadic.py b/torch/fx/experimental/unification/multipledispatch/variadic.py
index d9280e9..6d50ff6 100644
--- a/torch/fx/experimental/unification/multipledispatch/variadic.py
+++ b/torch/fx/experimental/unification/multipledispatch/variadic.py
@@ -1,5 +1,3 @@
-import six
-
from .utils import typename
__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"]
@@ -72,7 +70,7 @@
)
-class Variadic(six.with_metaclass(VariadicSignatureMeta)):
+class Variadic(metaclass=VariadicSignatureMeta):
"""A class whose getitem method can be used to generate a new type
representing a specific variadic signature.
Examples
diff --git a/torch/jit/_script.py b/torch/jit/_script.py
index 553a702..cee7a24 100644
--- a/torch/jit/_script.py
+++ b/torch/jit/_script.py
@@ -23,7 +23,6 @@
from torch.nn import Module
from torch.jit._state import _enabled
from torch.jit._builtins import _register_builtin
-from torch._six import with_metaclass
from torch.jit.frontend import get_jit_def, get_default_args, get_jit_class_def
from torch._jit_internal import _qualified_name
from torch.jit._fuser import _graph_for, _script_method_graph_for
@@ -484,7 +483,7 @@
# did nothing, __getattr__ would not be called. Instead we'd get nn.Module.forward
# which always throws an exception.
- class ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignore[misc]
+ class ScriptModule(Module, metaclass=ScriptMeta):
r"""
A wrapper around C++ ``torch::jit::Module``. ``ScriptModule``\s
contain methods, attributes, parameters, and
@@ -495,7 +494,7 @@
def __init__(self):
super().__init__()
- forward = _CachedForward()
+ forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment]
def __getattr__(self, attr):
if "_actual_script_module" not in self.__dict__:
@@ -650,11 +649,11 @@
modules = {}
for name, cpp_module in torch._C.ModuleDict(self._c).items():
modules[name] = wrap_cpp_module(cpp_module)
- self._modules = OrderedModuleDict(self._c, modules)
+ self._modules = OrderedModuleDict(self._c, modules) # type: ignore[assignment]
# Copy parameters and buffers.
- self._parameters = OrderedDictWrapper(torch._C.ParameterDict(self._c))
- self._buffers = OrderedDictWrapper(torch._C.BufferDict(self._c))
+ self._parameters = OrderedDictWrapper(torch._C.ParameterDict(self._c)) # type: ignore[assignment]
+ self._buffers = OrderedDictWrapper(torch._C.BufferDict(self._c)) # type: ignore[assignment]
# Get rid of the functions from the old C++ module.
self.__dict__ = {
@@ -679,7 +678,7 @@
``forward`` method. This graph will be preprocessed to inline all function and method calls.
See :ref:`interpreting-graphs` for details.
"""
- return self.forward.inlined_graph
+ return self.forward.inlined_graph # type: ignore[attr-defined]
@property
def code(self):
@@ -688,7 +687,7 @@
the internal graph for the ``forward`` method. See
:ref:`inspecting-code` for details.
"""
- return self.forward.code
+ return self.forward.code # type: ignore[attr-defined]
@property
def code_with_constants(self):
@@ -702,7 +701,7 @@
See :ref:`inspecting-code` for details.
"""
- r = self.forward.code_with_constants
+ r = self.forward.code_with_constants # type: ignore[attr-defined]
return (r[0], ConstMap(r[1]))
def save(self, f, **kwargs):
@@ -740,7 +739,7 @@
return "original_name={}".format(self.original_name)
def graph_for(self, *args, **kwargs):
- return self.forward.graph_for(self, *args, **kwargs)
+ return self.forward.graph_for(self, *args, **kwargs) # type: ignore[attr-defined]
@property
def original_name(self):
diff --git a/torch/jit/_serialization.py b/torch/jit/_serialization.py
index b3762b3..c8c2975 100644
--- a/torch/jit/_serialization.py
+++ b/torch/jit/_serialization.py
@@ -11,7 +11,6 @@
import pathlib
import torch
-from torch._six import string_classes
from torch.jit._recursive import wrap_cpp_module
from torch.serialization import validate_cuda_device
@@ -148,7 +147,7 @@
os.remove("scriptmodule.pt")
"""
- if isinstance(f, string_classes):
+ if isinstance(f, str):
if not os.path.exists(f): # type: ignore[type-var]
raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]
if os.path.isdir(f):
@@ -197,7 +196,7 @@
def jit_module_from_flatbuffer(f):
ff = get_ff_module()
- if isinstance(f, string_classes):
+ if isinstance(f, str):
if not os.path.exists(f): # type: ignore[type-var]
raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe]
if os.path.isdir(f):
diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py
index 1e2c61f..f0da4a1 100644
--- a/torch/jit/_trace.py
+++ b/torch/jit/_trace.py
@@ -16,7 +16,7 @@
import warnings
import inspect
import re
-from typing import Any, Dict, List, Optional, Set
+from typing import Any, Callable, Dict, List, Optional, Set
from torch.jit._state import _python_cu, _enabled
from torch.jit._script import ScriptModule, _CachedForward, script
@@ -1198,7 +1198,7 @@
class TopLevelTracedModule(TracedModule):
- forward = _CachedForward()
+ forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment]
def _reconstruct(self, cpp_module):
"""
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 0287960..0c8837f 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -512,7 +512,7 @@
if '_buffers' not in self.__dict__:
raise AttributeError(
"cannot assign buffer before Module.__init__() call")
- elif not isinstance(name, torch._six.string_classes):
+ elif not isinstance(name, str):
raise TypeError("buffer name should be a string. "
"Got {}".format(torch.typename(name)))
elif '.' in name:
@@ -553,7 +553,7 @@
raise AttributeError(
"cannot assign parameter before Module.__init__() call")
- elif not isinstance(name, torch._six.string_classes):
+ elif not isinstance(name, str):
raise TypeError("parameter name should be a string. "
"Got {}".format(torch.typename(name)))
elif '.' in name:
@@ -595,7 +595,7 @@
if not isinstance(module, Module) and module is not None:
raise TypeError("{} is not a Module subclass".format(
torch.typename(module)))
- elif not isinstance(name, torch._six.string_classes):
+ elif not isinstance(name, str):
raise TypeError("module name should be a string. Got {}".format(
torch.typename(name)))
elif hasattr(self, name) and name not in self._modules:
diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py
index 742b3bb..99aca62 100644
--- a/torch/nn/parallel/distributed.py
+++ b/torch/nn/parallel/distributed.py
@@ -184,7 +184,7 @@
ctx.state_dict["static_graph"]
and ctx.state_dict["num_iterations"] == 1
):
- Variable._execution_engine.queue_callback(
+ Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc]
ctx.reducer._delay_all_reduce
)
diff --git a/torch/nn/utils/clip_grad.py b/torch/nn/utils/clip_grad.py
index 8cc8b58..900d042 100644
--- a/torch/nn/utils/clip_grad.py
+++ b/torch/nn/utils/clip_grad.py
@@ -2,8 +2,7 @@
from typing import Union, Iterable, List, Dict, Tuple, Optional
import torch
-from torch import Tensor
-from torch._six import inf
+from torch import Tensor, inf
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype, _has_foreach_support
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
diff --git a/torch/onnx/_internal/jit_utils.py b/torch/onnx/_internal/jit_utils.py
index e8d37b2..90326a3 100644
--- a/torch/onnx/_internal/jit_utils.py
+++ b/torch/onnx/_internal/jit_utils.py
@@ -310,7 +310,7 @@
@_beartype.beartype
def _is_onnx_list(value):
return (
- not isinstance(value, torch._six.string_classes)
+ not isinstance(value, str)
and not isinstance(value, torch.Tensor)
and isinstance(value, Iterable)
)
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index 6c015b4..f882729 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -959,7 +959,7 @@
if isinstance(model, torch.jit.ScriptModule):
try:
- graph = model.forward.graph
+ graph = model.forward.graph # type: ignore[attr-defined]
except AttributeError as e:
raise RuntimeError("'forward' method must be a script method") from e
_C._jit_pass_onnx_function_substitution(graph)
diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py
index f82fd8a..273fe4a 100644
--- a/torch/optim/lr_scheduler.py
+++ b/torch/optim/lr_scheduler.py
@@ -1,6 +1,6 @@
import types
import math
-from torch._six import inf
+from torch import inf
from functools import wraps
import warnings
import weakref
diff --git a/torch/serialization.py b/torch/serialization.py
index af3b3c3..83f6fa2 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -10,7 +10,6 @@
import warnings
from contextlib import closing, contextmanager
from ._utils import _import_dotted_name
-from ._six import string_classes as _string_classes
from torch._sources import get_source_lines_and_file
from torch.types import Storage
from torch.storage import _get_dtype_from_pickle_storage_type
@@ -1079,7 +1078,7 @@
def restore_location(storage, location):
location = map_location.get(location, location)
return default_restore_location(storage, location)
- elif isinstance(map_location, _string_classes):
+ elif isinstance(map_location, str):
def restore_location(storage, location):
return default_restore_location(storage, map_location)
elif isinstance(map_location, torch.device):
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index bfc9607..8460741 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -11,7 +11,7 @@
import torch
import numpy as np
-from torch._six import inf, nan
+from torch import inf, nan
from typing import Any, Dict, List, Tuple, Union, Sequence
from torch.testing import make_tensor
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index b8cca44..03193f5 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -65,7 +65,6 @@
import torch.cuda
from torch import Tensor
from torch._C import ScriptDict, ScriptList # type: ignore[attr-defined]
-from torch._six import string_classes
from torch._utils_internal import get_writable_path
from torch.nn import (
ModuleDict,
@@ -589,7 +588,7 @@
# `p.wait()` in a `final` block for the code to be portable.
#
# https://github.com/python/cpython/blob/71b6c1af727fbe13525fb734568057d78cea33f3/Lib/subprocess.py#L309-L323
- assert not isinstance(command, torch._six.string_classes), "Command to shell should be a list or tuple of tokens"
+ assert not isinstance(command, str), "Command to shell should be a list or tuple of tokens"
p = subprocess.Popen(command, universal_newlines=True, cwd=cwd, env=env, stdout=stdout, stderr=stderr)
return wait_for_process(p)
@@ -1924,7 +1923,7 @@
class StringPair(UnittestPair):
- CLS = string_classes
+ CLS = str
TYPE_NAME = "string"
diff --git a/torch/testing/_internal/jit_metaprogramming_utils.py b/torch/testing/_internal/jit_metaprogramming_utils.py
index cd09ee0..ec82aa2 100644
--- a/torch/testing/_internal/jit_metaprogramming_utils.py
+++ b/torch/testing/_internal/jit_metaprogramming_utils.py
@@ -15,7 +15,7 @@
import math # noqa: F401
# Testing utils
-from torch._six import inf
+from torch import inf
# TODO: include files like this should not set the default dtype
torch.set_default_dtype(torch.double)
diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py
index 72479e0..839cbbe 100644
--- a/torch/utils/data/_utils/collate.py
+++ b/torch/utils/data/_utils/collate.py
@@ -13,7 +13,6 @@
import torch
from typing import Callable, Dict, Optional, Tuple, Type, Union
-from torch._six import string_classes
np_str_obj_array_pattern = re.compile(r'[SaUO]')
@@ -70,7 +69,7 @@
return elem_type(*(default_convert(d) for d in data))
elif isinstance(data, tuple):
return [default_convert(d) for d in data] # Backwards compatibility.
- elif isinstance(data, collections.abc.Sequence) and not isinstance(data, string_classes):
+ elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
try:
return elem_type([default_convert(d) for d in data])
except TypeError:
@@ -198,7 +197,7 @@
default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn
default_collate_fn_map[float] = collate_float_fn
default_collate_fn_map[int] = collate_int_fn
-default_collate_fn_map[string_classes] = collate_str_fn
+default_collate_fn_map[str] = collate_str_fn
def default_collate(batch):
diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py
index 466cf0c..7d2b745 100644
--- a/torch/utils/data/_utils/pin_memory.py
+++ b/torch/utils/data/_utils/pin_memory.py
@@ -9,7 +9,6 @@
import queue
import torch
-from torch._six import string_classes
from . import MP_STATUS_CHECK_INTERVAL
from torch._utils import ExceptionWrapper
@@ -54,7 +53,7 @@
def pin_memory(data, device=None):
if isinstance(data, torch.Tensor):
return data.pin_memory(device)
- elif isinstance(data, string_classes):
+ elif isinstance(data, str):
return data
elif isinstance(data, collections.abc.Mapping):
try:
diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py
index 9796d1f..85098ae 100644
--- a/torch/utils/data/dataloader.py
+++ b/torch/utils/data/dataloader.py
@@ -22,7 +22,6 @@
import torch.utils.data.graph_settings
from torch._utils import ExceptionWrapper
-from torch._six import string_classes
from . import (
IterDataPipe,
@@ -396,7 +395,7 @@
def multiprocessing_context(self, multiprocessing_context):
if multiprocessing_context is not None:
if self.num_workers > 0:
- if isinstance(multiprocessing_context, string_classes):
+ if isinstance(multiprocessing_context, str):
valid_start_methods = multiprocessing.get_all_start_methods()
if multiprocessing_context not in valid_start_methods:
raise ValueError(