blob: 6e824f674c862866144778652980aba2aaf5c652 [file] [log] [blame]
# Copyright 2024 The Bazel Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A pip_parse implementation for version aware toolchains in WORKSPACE."""
load("//python/private:text_util.bzl", "render")
load(":pip_repository.bzl", pip_parse = "pip_repository")
def _multi_pip_parse_impl(rctx):
rules_python = rctx.attr._rules_python_workspace.workspace_name
load_statements = []
install_deps_calls = []
process_requirements_calls = []
for python_version, pypi_repository in rctx.attr.pip_parses.items():
sanitized_python_version = python_version.replace(".", "_")
load_statement = """\
load(
"@{pypi_repository}//:requirements.bzl",
_{sanitized_python_version}_install_deps = "install_deps",
_{sanitized_python_version}_all_requirements = "all_requirements",
)""".format(
pypi_repository = pypi_repository,
sanitized_python_version = sanitized_python_version,
)
load_statements.append(load_statement)
process_requirements_call = """\
_process_requirements(
pkg_labels = _{sanitized_python_version}_all_requirements,
python_version = "{python_version}",
repo_prefix = "{pypi_repository}_",
)""".format(
pypi_repository = pypi_repository,
python_version = python_version,
sanitized_python_version = sanitized_python_version,
)
process_requirements_calls.append(process_requirements_call)
install_deps_call = """ _{sanitized_python_version}_install_deps(**whl_library_kwargs)""".format(
sanitized_python_version = sanitized_python_version,
)
install_deps_calls.append(install_deps_call)
# NOTE @aignas 2023-10-31: I am not sure it is possible to render aliases
# for all of the packages using the `render_pkg_aliases` function because
# we need to know what the list of packages for each version is and then
# we would be creating directories for each.
macro_tmpl = "@%s_{}//:{}" % rctx.attr.name
requirements_bzl = """\
# Generated by python/pip.bzl
load("@{rules_python}//python:pip.bzl", "whl_library_alias", "pip_utils")
{load_statements}
_wheel_names = []
_version_map = dict()
def _process_requirements(pkg_labels, python_version, repo_prefix):
for pkg_label in pkg_labels:
wheel_name = Label(pkg_label).package
if not wheel_name:
# We are dealing with the cases where we don't have aliases.
workspace_name = Label(pkg_label).workspace_name
wheel_name = workspace_name[len(repo_prefix):]
_wheel_names.append(wheel_name)
if not wheel_name in _version_map:
_version_map[wheel_name] = dict()
_version_map[wheel_name][python_version] = repo_prefix
{process_requirements_calls}
def requirement(name):
return "{macro_tmpl}".format(pip_utils.normalize_name(name), "pkg")
def whl_requirement(name):
return "{macro_tmpl}".format(pip_utils.normalize_name(name), "whl")
def data_requirement(name):
return "{macro_tmpl}".format(pip_utils.normalize_name(name), "data")
def dist_info_requirement(name):
return "{macro_tmpl}".format(pip_utils.normalize_name(name), "dist_info")
def install_deps(**whl_library_kwargs):
{install_deps_calls}
for wheel_name in _wheel_names:
whl_library_alias(
name = "{name}_" + wheel_name,
wheel_name = wheel_name,
default_version = "{default_version}",
minor_mapping = {minor_mapping},
version_map = _version_map[wheel_name],
)
""".format(
name = rctx.attr.name,
install_deps_calls = "\n".join(install_deps_calls),
load_statements = "\n".join(load_statements),
macro_tmpl = macro_tmpl,
process_requirements_calls = "\n".join(process_requirements_calls),
rules_python = rules_python,
default_version = rctx.attr.default_version,
minor_mapping = render.indent(render.dict(rctx.attr.minor_mapping)).lstrip(),
)
rctx.file("requirements.bzl", requirements_bzl)
rctx.file("BUILD.bazel", "exports_files(['requirements.bzl'])")
_multi_pip_parse = repository_rule(
_multi_pip_parse_impl,
attrs = {
"default_version": attr.string(),
"minor_mapping": attr.string_dict(),
"pip_parses": attr.string_dict(),
"_rules_python_workspace": attr.label(default = Label("//:WORKSPACE")),
},
)
def multi_pip_parse(name, default_version, python_versions, python_interpreter_target, requirements_lock, minor_mapping, **kwargs):
"""NOT INTENDED FOR DIRECT USE!
This is intended to be used by the multi_pip_parse implementation in the template of the
multi_toolchain_aliases repository rule.
Args:
name: the name of the multi_pip_parse repository.
default_version: {type}`str` the default Python version.
python_versions: {type}`list[str]` all Python toolchain versions currently registered.
python_interpreter_target: {type}`dict[str, Label]` a dictionary which keys are Python versions and values are resolved host interpreters.
requirements_lock: {type}`dict[str, Label]` a dictionary which keys are Python versions and values are locked requirements files.
minor_mapping: {type}`dict[str, str]` mapping between `X.Y` to `X.Y.Z` format.
**kwargs: extra arguments passed to all wrapped pip_parse.
Returns:
The internal implementation of multi_pip_parse repository rule.
"""
pip_parses = {}
for python_version in python_versions:
if not python_version in python_interpreter_target:
fail("Missing python_interpreter_target for Python version %s in '%s'" % (python_version, name))
if not python_version in requirements_lock:
fail("Missing requirements_lock for Python version %s in '%s'" % (python_version, name))
pip_parse_name = name + "_" + python_version.replace(".", "_")
pip_parse(
name = pip_parse_name,
python_interpreter_target = python_interpreter_target[python_version],
requirements_lock = requirements_lock[python_version],
**kwargs
)
pip_parses[python_version] = pip_parse_name
return _multi_pip_parse(
name = name,
default_version = default_version,
pip_parses = pip_parses,
minor_mapping = minor_mapping,
)