| import argparse |
| import gzip |
| import os |
| import sys |
| from urllib.error import URLError |
| from urllib.request import urlretrieve |
| |
| |
| MIRRORS = [ |
| "http://yann.lecun.com/exdb/mnist/", |
| "https://ossci-datasets.s3.amazonaws.com/mnist/", |
| ] |
| |
| RESOURCES = [ |
| "train-images-idx3-ubyte.gz", |
| "train-labels-idx1-ubyte.gz", |
| "t10k-images-idx3-ubyte.gz", |
| "t10k-labels-idx1-ubyte.gz", |
| ] |
| |
| |
| def report_download_progress( |
| chunk_number: int, |
| chunk_size: int, |
| file_size: int, |
| ) -> None: |
| if file_size != -1: |
| percent = min(1, (chunk_number * chunk_size) / file_size) |
| bar = "#" * int(64 * percent) |
| sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%") |
| |
| |
| def download(destination_path: str, resource: str, quiet: bool) -> None: |
| if os.path.exists(destination_path): |
| if not quiet: |
| print(f"{destination_path} already exists, skipping ...") |
| else: |
| for mirror in MIRRORS: |
| url = mirror + resource |
| print(f"Downloading {url} ...") |
| try: |
| hook = None if quiet else report_download_progress |
| urlretrieve(url, destination_path, reporthook=hook) |
| except (URLError, ConnectionError) as e: |
| print(f"Failed to download (trying next):\n{e}") |
| continue |
| finally: |
| if not quiet: |
| # Just a newline. |
| print() |
| break |
| else: |
| raise RuntimeError("Error downloading resource!") |
| |
| |
| def unzip(zipped_path: str, quiet: bool) -> None: |
| unzipped_path = os.path.splitext(zipped_path)[0] |
| if os.path.exists(unzipped_path): |
| if not quiet: |
| print(f"{unzipped_path} already exists, skipping ... ") |
| return |
| with gzip.open(zipped_path, "rb") as zipped_file: |
| with open(unzipped_path, "wb") as unzipped_file: |
| unzipped_file.write(zipped_file.read()) |
| if not quiet: |
| print(f"Unzipped {zipped_path} ...") |
| |
| |
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description="Download the MNIST dataset from the internet" |
| ) |
| parser.add_argument( |
| "-d", "--destination", default=".", help="Destination directory" |
| ) |
| parser.add_argument( |
| "-q", "--quiet", action="store_true", help="Don't report about progress" |
| ) |
| options = parser.parse_args() |
| |
| if not os.path.exists(options.destination): |
| os.makedirs(options.destination) |
| |
| try: |
| for resource in RESOURCES: |
| path = os.path.join(options.destination, resource) |
| download(path, resource, options.quiet) |
| unzip(path, options.quiet) |
| except KeyboardInterrupt: |
| print("Interrupted") |
| |
| |
| if __name__ == "__main__": |
| main() |