#!/usr/bin/env python3

################################################################################
#
# MIT License
#
# Copyright 2024-2025 AMD ROCm(TM) Software
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
# ies of the Software, and to permit persons to whom the Software is furnished
# to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
# PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
################################################################################


import argparse
import pathlib
import shutil
import stat
import subprocess
import sys

try:
    from tqdm.auto import tqdm

    def write(*args, **kwargs):
        tqdm.write(*args, **kwargs)

except ImportError:
    tqdm = None

    def write(*args, **kwargs):
        print(*args, **kwargs)


repo_dir = pathlib.Path(__file__).resolve().parent.parent
sys.path.append(str(repo_dir / "scripts" / "lib"))

from rrperf import git  # noqa: E402


def format_and_patch(command: str, path: pathlib.Path):
    p = subprocess.run(command, shell=True, capture_output=True, text=True)

    output = p.stdout

    if p.returncode == 0:
        assert output.strip() == ""
        return

    output += "\n\n"

    p2 = subprocess.run(
        ["patch", str(path)],
        check=True,
        input=output,
        encoding="UTF-8",
        stdout=subprocess.PIPE,
    )
    write(p2.stdout)


def is_tool_available(tool: str) -> bool:
    if shutil.which(tool) is None:
        return False

    try:
        subprocess.run(
            [tool, "--version"],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
            check=True,
        )
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        return False


def run_clang_format(path: pathlib.Path):
    format_and_patch(f"clang-format '{path}' | diff {path} -", path)


def run_black(path: pathlib.Path):
    format_and_patch(f"black --quiet --check --diff {path}", path)


def run_isort(path: pathlib.Path):
    format_and_patch(f"isort --profile black --quiet --check --diff {path}", path)


def get_cpp_files(dir: pathlib.Path):
    cpp_suffixes = {".h", ".hpp", ".cpp", ".h.in", ".hpp.in", ".cpp.in"}
    return [
        p for x in git.ls_tree(dir) if (p := pathlib.Path(x)).suffix in cpp_suffixes
    ]


def get_py_files(dir: pathlib.Path):
    paths = []
    for x in git.ls_tree(dir):
        p = pathlib.Path(x)
        if p.suffix in {".py"}:
            paths.append(p)
        elif p.is_file() and (stat.S_IXUSR & p.stat()[stat.ST_MODE]):
            if "python" in p.read_text().splitlines()[0]:
                paths.append(p)
    return paths


def fix_cpp(dir: pathlib.Path):
    if not is_tool_available("clang-format"):
        print("clang-format is not available, skipping C++ formatting.")
        return

    paths = get_cpp_files(dir)
    write(f"Fixing {len(paths)} C++ files:")

    if tqdm is not None:
        paths = tqdm(paths, unit=" files")

    for path in paths:
        run_clang_format(path)
    write("Done.")


def check_python(dir: pathlib.Path):
    command = ["flake8"]
    write(" ".join(command))

    subprocess.run(command, cwd=dir, check=False)


def fix_python(dir: pathlib.Path):
    isort_avail = is_tool_available("isort")
    black_avail = is_tool_available("black")
    if not isort_avail and not black_avail:
        print("isort and black are not available, skipping Python formatting.")
        return

    if not isort_avail:
        print("isort is not available, skipping.")
    if not black_avail:
        print("black is not available, skipping.")

    paths = get_py_files(dir)
    write(f"Fixing {len(paths)} Python files:")

    if tqdm is not None:
        paths = tqdm(paths, unit=" files")

    for path in paths:
        if isort_avail:
            run_isort(path)
        if black_avail:
            run_black(path)
    write("Done.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Fix C++ formatting and check Python formatting."
    )

    parser.add_argument(
        "--skip-cpp",
        dest="cpp",
        default=True,
        action="store_false",
        help="Don't format C++ files.",
    )
    parser.add_argument(
        "--skip-python",
        dest="python",
        default=True,
        action="store_false",
        help="Don't check Python files.",
    )
    parser.add_argument("dir", type=pathlib.Path, default=pathlib.Path.cwd(), nargs="?")

    args = parser.parse_args()

    if args.cpp:
        fix_cpp(args.dir)

    if args.python:
        check_python(args.dir)
        fix_python(args.dir)
