blob: 65fa547123a4b076dd045150943fd75dff07a476 [file] [log] [blame]
"""
This script uses linear programming to analyze outputs of triton mm config tuning.
To generate output that can be fed into this script set the env varTORCHINDUCTOR_MM_LOGGING_FILE.
That file can be fed into this script to generate the minimizes total, weighted matmul time as a function of allowed templates.
"""
import json
import click
import pulp
def parse_log_file(file_path):
with open(file_path) as f:
logs = json.load(f)
occurrence_count = {}
benchmark_logs = {}
# Parse the logs
for entry in logs:
if "invoke" in entry:
shape = entry["invoke"]
if shape not in occurrence_count:
occurrence_count[shape] = 0
occurrence_count[shape] += 1
else:
for shape, timings in entry.items():
if shape not in benchmark_logs:
benchmark_logs[shape] = []
benchmark_logs[shape].extend(timings)
return occurrence_count, benchmark_logs
def optimize_templates(N, occurrence_count, benchmark_logs, verbose=False):
# Set of all possible Triton templates keyed by their attributes
triton_templates = set()
for timings in benchmark_logs.values():
for timing in timings:
if timing["type"] == "triton":
triton_templates.add(
(
timing["BLOCK_M"],
timing["BLOCK_N"],
timing["BLOCK_K"],
timing["num_stages"],
timing["num_warps"],
)
)
# Print the initial data
if verbose:
print("Occurrence Count:", occurrence_count)
print("Triton Templates:", triton_templates)
# Create a dictionary to store template selection variables
template_vars = {
template: pulp.LpVariable(f"Template_{template}", 0, 1, pulp.LpBinary)
for template in triton_templates
}
# Variables to select specific timing option for each shape
selection_vars = {
(shape, "cublas"): pulp.LpVariable(
f"Select_{shape}_cublas", 0, 1, pulp.LpBinary
)
for shape in occurrence_count
}
for shape in occurrence_count:
for template in triton_templates:
selection_vars[(shape, template)] = pulp.LpVariable(
f"Select_{shape}_{template}", 0, 1, pulp.LpBinary
)
# Variables for the total time for each shape
min_time_vars = pulp.LpVariable.dicts(
"MinTime", occurrence_count.keys(), 0, None, pulp.LpContinuous
)
# Define the problem
prob = pulp.LpProblem("MatrixMultiplicationOptimization", pulp.LpMinimize)
# Objective: Minimize the weighted total time
prob += pulp.lpSum(
[occurrence_count[shape] * min_time_vars[shape] for shape in occurrence_count]
)
# Constraints to select exactly N templates
prob += pulp.lpSum([template_vars[template] for template in triton_templates]) == N
# Store triton options per shape for debugging
triton_options_per_shape = {}
# Constraints for the total time for each shape
for shape in occurrence_count:
# Get cuBLAS time
cublas_times = [
timing["time"]
for timing in benchmark_logs[shape]
if timing["type"] == "cublas"
]
min_cublas_time = min(cublas_times)
# Collect Triton options
triton_options = []
for template in triton_templates:
triton_times = [
timing["time"]
for timing in benchmark_logs[shape]
if timing["type"] == "triton"
and (
timing["BLOCK_M"],
timing["BLOCK_N"],
timing["BLOCK_K"],
timing["num_stages"],
timing["num_warps"],
)
== template
]
if triton_times:
min_triton_time = min(triton_times)
triton_options.append((min_triton_time, template))
# Save triton options for debugging
triton_options_per_shape[shape] = triton_options
# Ensure exactly one timing option is selected for each shape
prob += (
pulp.lpSum(
[selection_vars[(shape, "cublas")]]
+ [
selection_vars[(shape, template)]
for triton_time, template in triton_options
]
)
== 1
)
# Ensure min_time_vars[shape] matches the selected timing option
prob += min_time_vars[shape] == (
selection_vars[(shape, "cublas")] * min_cublas_time
+ pulp.lpSum(
[
selection_vars[(shape, template)] * triton_time
for triton_time, template in triton_options
]
)
)
# Ensure Triton templates can only be selected if they are included in the N allowed templates
for triton_time, template in triton_options:
prob += selection_vars[(shape, template)] <= template_vars[template]
# Print the constraints
if verbose:
print("Constraints:")
for constraint in prob.constraints.values():
print(constraint)
# Solve the problem with suppressed output
prob.solve(pulp.PULP_CBC_CMD(msg=False))
# Output the selected templates and their configurations
selected_templates = [
template
for template in triton_templates
if pulp.value(template_vars[template]) == 1
]
total_time = sum(
pulp.value(min_time_vars[shape]) * occurrence_count[shape]
for shape in occurrence_count
)
# Print the values of the decision variables after solving
if verbose:
print("Decision Variable Values:")
for var in prob.variables():
print(f"{var.name} = {var.varValue}")
# # Debugging information
if verbose:
for shape in occurrence_count:
print(f"Shape: {shape}")
print(f" Min Time: {pulp.value(min_time_vars[shape])}")
print(f" Occurrences: {occurrence_count[shape]}")
print(
f" Min CuBLAS Time: {min_cublas_time} Selected: {pulp.value(selection_vars[(shape, 'cublas')])}"
)
for triton_time, template in triton_options_per_shape[shape]:
print(
f" Triton Template: {template} Time: {triton_time} Selected: {pulp.value(selection_vars[(shape, template)])}"
)
return selected_templates, total_time
# Main code to parse the log file and optimize templates
@click.command()
@click.argument("filename")
@click.option("--min-templates", default=0, help="Minimum number of templates.")
@click.option("--max-templates", default=10, help="Maximum number of templates.")
@click.option("--verbose", is_flag=True, help="Enable verbose output.")
def main(filename, min_templates, max_templates, verbose):
occurrence_count, benchmark_logs = parse_log_file(filename)
times = []
for N in range(min_templates, max_templates + 1):
selected_templates, total_time = optimize_templates(
N, occurrence_count, benchmark_logs, verbose
)
print(f"N = {N}")
print(f"Selected Templates: {selected_templates}")
print(f"Total Weighted Time: {total_time}")
times.append(total_time)
print(times)
if __name__ == "__main__":
main()