#!/usr/bin/env python3

# RISC-V pipeline model checker.
# Copyright (C) 2025 Free Software Foundation, Inc.
#
# This file is part of GCC.
#
# GCC is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3, or (at your option)
# any later version.
#
# GCC is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GCC; see the file COPYING3.  If not see
# <http://www.gnu.org/licenses/>.

import re
import sys
import argparse
from pathlib import Path
from typing import List
import pprint

def remove_line_comments(text: str) -> str:
    # Remove ';;' and everything after it on each line
    cleaned_lines = []
    for line in text.splitlines():
        comment_index = line.find(';;')
        if comment_index != -1:
            line = line[:comment_index]
        cleaned_lines.append(line)
    return '\n'.join(cleaned_lines)


def tokenize_sexpr(s: str) -> List[str]:
    # Tokenize input string, including support for balanced {...} C blocks
    tokens = []
    i = 0
    while i < len(s):
        c = s[i]
        if c.isspace():
            i += 1
        elif c == '(' or c == ')':
            tokens.append(c)
            i += 1
        elif c == '"':
            # Parse quoted string
            j = i + 1
            while j < len(s) and s[j] != '"':
                if s[j] == '\\':
                    j += 1  # Skip escape
                j += 1
            tokens.append(s[i:j+1])
            i = j + 1
        elif c == '{':
            # Parse balanced C block
            depth = 1
            j = i + 1
            while j < len(s) and depth > 0:
                if s[j] == '{':
                    depth += 1
                elif s[j] == '}':
                    depth -= 1
                j += 1
            tokens.append(s[i:j])  # Include enclosing braces
            i = j
        else:
            # Parse atom
            j = i
            while j < len(s) and not s[j].isspace() and s[j] not in '()"{}':
                j += 1
            tokens.append(s[i:j])
            i = j
    return tokens


def parse_sexpr(tokens: List[str]) -> any:
    # Recursively parse tokenized S-expression
    token = tokens.pop(0)
    if token == '(':
        lst = []
        while tokens[0] != ')':
            lst.append(parse_sexpr(tokens))
        tokens.pop(0)  # Discard closing parenthesis
        return lst
    elif token.startswith('"') and token.endswith('"'):
        return token[1:-1]  # Remove surrounding quotes
    elif token.startswith('{') and token.endswith('}'):
        return token  # Keep C code block as-is
    else:
        return token


def find_define_attr_type(ast: any) -> List[List[str]]:
    # Traverse AST to find all (define_attr "type" ...) entries
    result = []
    if isinstance(ast, list):
        if ast and ast[0] == 'define_attr' and len(ast) >= 2 and ast[1] == 'type':
            result.append(ast)
        for elem in ast:
            result.extend(find_define_attr_type(elem))
    return result


def parse_md_file(path: Path):
    # Read file, remove comments, and parse all top-level S-expressions
    with open(path, encoding='utf-8') as f:
        raw_content = f.read()
    clean_content = remove_line_comments(raw_content)
    tokens = tokenize_sexpr(clean_content)
    items = []
    while tokens:
        items.append(parse_sexpr(tokens))
    return items

def parsing_str_set(s: str) -> set:
    s = s.replace('\\','').split(',')
    s = set(map(lambda x: x.strip(), s))
    return s

def get_avaliable_types(md_file_path: str):
    # Main logic: parse input file and print define_attr "type" expressions
    ast = parse_md_file(Path(md_file_path))

    # Get all type from define_attr type
    define_attr_types = find_define_attr_type(ast)
    types = parsing_str_set (define_attr_types[0][2])
    return types

def get_consumed_type(entry: List[str]) -> set:
    # Extract the consumed type from a define_insn_reservation entry
    current_type = entry[0]
    if current_type in ['and', 'or']:
        return get_consumed_type(entry[1]) | get_consumed_type(entry[2])
    elif current_type == 'eq_attr' and entry[1] == 'type':
        return parsing_str_set(entry[2])
    return set()

def check_pipemodel(md_file_path: str):
    # Load the RISCV MD file and check for pipemodel
    ast = parse_md_file(Path(md_file_path))

    consumed_type = set()

    for entry in ast:
        entry_type = entry[0]
        if entry_type not in ["define_insn_reservation"]:
            continue
        consumed_type |= get_consumed_type(entry[3])
    return consumed_type


def main():
    parser = argparse.ArgumentParser(description='Check GCC pipeline model for instruction type coverage')
    parser.add_argument('pipeline_model', help='Pipeline model file to check')
    parser.add_argument('--base-md',
                        help='Base machine description file (default: riscv.md in script directory)',
                        default=None)
    parser.add_argument('-v', '--verbose',
                        help='Show detailed type information',
                        action='store_true')
    args = parser.parse_args()

    # Set default base-md path if not provided
    if args.base_md is None:
        script_dir = Path(__file__).parent
        base_md_path = script_dir / "riscv.md"
    else:
        base_md_path = Path(args.base_md)
    avaliable_types = get_avaliable_types(str(base_md_path))
    consumed_type = check_pipemodel(args.pipeline_model)

    if args.verbose:
        print("Available types:\n", avaliable_types)
        print("Consumed types:\n", consumed_type)

    if not avaliable_types.issubset(consumed_type):
        print("Error: Some types are not consumed by the pipemodel")
        print("Missing types:\n", avaliable_types - consumed_type)
        sys.exit(1)
    else:
        print("All available types are consumed by the pipemodel.")


if __name__ == '__main__':
    main()
