#!/usr/bin/env python3

################################################################################
#
# MIT License
#
# Copyright 2024-2025 AMD ROCm(TM) Software
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
# ies of the Software, and to permit persons to whom the Software is furnished
# to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
# PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
################################################################################


import argparse
import csv
import glob
import pathlib
import subprocess
import sys
import tempfile
from dataclasses import asdict
from pathlib import Path

import yaml

repo_dir = pathlib.Path(__file__).resolve().parent.parent
sys.path.append(str(repo_dir / "scripts" / "lib"))

import rrperf.git as git  # noqa: E402
import rrperf.problems as problems  # noqa: E402


class keyvalue(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, dict())

        for value in values:
            key, value = value.split("=")
            getattr(namespace, self.dest)[key] = value


def process_cli():
    parser = argparse.ArgumentParser(description="Run Tensile")
    parser.add_argument("config", type=str, help="Path of Tensile Config.")
    parser.add_argument(
        "--tensile_repo_url",
        type=str,
        default="https://github.com/ROCmSoftwarePlatform/Tensile.git",
        help="URL of Tensile Repo.",
    )
    parser.add_argument(
        "--tensile_commit", type=str, default=None, help="Commit of Tensile to use."
    )
    parser.add_argument(
        "--yaml", type=str, default="out.yaml", help="Yaml output file."
    )
    parser.add_argument("--kwargs", nargs="*", action=keyvalue, default=dict())
    return parser.parse_args()


def parse_type(type_str):
    if type_str == "H":
        return "half"
    elif type_str == "S":
        return "float"
    else:
        raise ValueError(f"{type_str} is not a valid type string.")


def parse_tensile_result(result_row):
    retval = dict()
    retval["client"] = "GEMMv00"
    retval["path"] = result_row["WinnerName"]
    retval["kernelGenerate"] = 0
    retval["kernelAssemble"] = 0
    retval["kernelExecute"] = [float(result_row["WinnerTimeUS"]) * 1000]
    retval["K"] = int(result_row["SizeL"])
    retval["M"] = int(result_row["SizeI"])
    retval["N"] = int(result_row["SizeJ"])
    retval["scheduler"] = "Tensile"

    # Information on Tensile kernel names:
    # https://github.com/ROCmSoftwarePlatform/Tensile/blob/36b5a45961477433b1da8518df3b79295e76fec0/Tensile/SolutionStructs.py#L1115
    # This functionality may need to be expanded if we add more configs.
    kernel_name_fields = result_row["WinnerName"].split("_")

    for i in range(len(kernel_name_fields)):
        if kernel_name_fields[i].startswith("MT"):
            macro_tile_pos = i
            break

    if macro_tile_pos > 4 and len(kernel_name_fields[3]) == 3:
        retval["type_A"] = parse_type(kernel_name_fields[3][0])
        retval["type_B"] = parse_type(kernel_name_fields[3][0])
        retval["type_C"] = parse_type(kernel_name_fields[3][1])
        retval["type_D"] = parse_type(kernel_name_fields[3][1])
        retval["type_acc"] = parse_type(kernel_name_fields[3][2])
    else:
        retval["type_A"] = parse_type(kernel_name_fields[3][0])
        retval["type_B"] = retval["type_A"]
        retval["type_C"] = retval["type_A"]
        retval["type_D"] = retval["type_A"]
        retval["type_acc"] = retval["type_A"]

    if kernel_name_fields[1] == "Ailk":
        retval["trans_A"] = "N"
    else:
        retval["trans_A"] = "T"
    if kernel_name_fields[2] == "Bljk":
        retval["trans_B"] = "N"
    else:
        retval["trans_B"] = "T"
    return retval


if __name__ == "__main__":
    args = process_cli()

    with tempfile.TemporaryDirectory(
        prefix="rocRoller_", suffix="_profileTensile"
    ) as tensile_repo:
        print("Created temporary directory", tensile_repo)
        git.clone(args.tensile_repo_url, tensile_repo)

        if args.tensile_commit is not None:
            git.checkout(tensile_repo, args.tensile_commit)

        tensile_path: Path = Path(tensile_repo) / "Tensile" / "bin" / "Tensile"
        if not tensile_path.exists():
            raise FileNotFoundError(f"Tensile not found at {tensile_path}")

        with tempfile.TemporaryDirectory(
            prefix="rocRoller_", suffix="_profileTensile"
        ) as working_dir_name:
            print("Created temporary directory", working_dir_name)
            subprocess.run(
                [str(tensile_path), args.config, working_dir_name],
                check=True,
            )
            results = []
            for csv_file_name in list(
                glob.glob(working_dir_name + "/2_BenchmarkData/*.csv")
            ):
                with open(csv_file_name, mode="r") as csv_file:
                    csv_reader = csv.DictReader(
                        csv_file, skipinitialspace=True, delimiter=","
                    )
                    for row in csv_reader:
                        gemm_args = parse_tensile_result(row)
                        for key, value in args.kwargs.items():
                            if key not in gemm_args:
                                gemm_args[key] = value
                        results += [problems.GEMMResult(**gemm_args)]

        yamlfile = pathlib.Path(args.yaml)
        with yamlfile.open("w") as yfile:
            yaml_text = yaml.dump_all(
                [asdict(x) for x in results],
                yfile,
                default_flow_style=False,
                default_style=None,
            )
