###############################################################################
# Top contributors (to current version):
#   José Neto, Aina Niemetz
#
# This file is part of the cvc5 project.
#
# Copyright (c) 2009-2025 by the authors listed in the file AUTHORS
# in the top-level source directory and their institutional affiliations.
# All rights reserved.  See the file COPYING in the top-level source
# directory for licensing information.
# #############################################################################
#
# Generate theory_traits.h and type_enumerator.cpp implementations.
#
##
import argparse
import re
import sys
import os
from datetime import date

# Add the parent directory to the system path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from expr.theory_validator import TheoryValidator

try:
    import tomllib
except ImportError:
    import tomli as tomllib


class CodeGenerator:

    def __init__(self, theory_traits_template, theory_traits_template_output,
                 input_command):
        self.template_data = ""
        self.input_command = input_command
        self.theory_includes = "\n"
        self.theory_traits = ""
        self.theory_constructors = ""
        self.type_enumerator_includes = "\n"
        self.type_kinds = ""
        self.type_constants = ""
        self.mk_type_enumerator_type_constant_cases = ""
        self.mk_type_enumerator_cases = ""

        current_year = date.today().year
        self.copyright = f"2010-{current_year}"

        self.copyright_replacement_pattern = b'${copyright}'
        self.generation_command_replacement_pattern = b'${generation_command}'
        self.template_file_path_replacement_pattern = b'${template_file_path}'
        self.theory_includes_replacement_pattern = b'${theory_includes}'
        self.theory_traits_replacement_pattern = b'${theory_traits}'
        self.theory_constructors_replacement_pattern = b'${theory_constructors}'
        self.type_enumerator_includes_replacement_pattern = b'${type_enumerator_includes}'
        self.mk_type_enumerator_type_constant_cases_replacement_pattern = b'${mk_type_enumerator_type_constant_cases}'
        self.mk_type_enumerator_cases_replacement_pattern = b'${mk_type_enumerator_cases}'

        self.file_header = f"""/******************************************************************************
 * This file is part of the cvc5 project.
 *
 * Copyright (c) {self.copyright} by the authors listed in the file AUTHORS
 * in the top-level source directory and their institutional affiliations.
 * All rights reserved.  See the file COPYING in the top-level source
 * directory for licensing information.
 * ****************************************************************************
 *
 * This file was automatically generated by:
 *
 *     {self.input_command}
 *
 * for the cvc5 project.
 */
 
/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT ! */
/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT ! */
/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT ! */
/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT ! */
/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT ! */
/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT ! */

/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT ! */
/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT ! */
/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT ! */
/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT ! */
/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT ! */
/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT ! */

/* Edit the template file instead:                     */
/* {theory_traits_template} */\n
""".encode('ascii')
        self.theory_traits_template = theory_traits_template
        self.theory_traits_template_output = theory_traits_template_output

    def read_template_data(self):
        with open(self.theory_traits_template, "rb") as f:
            self.template_data = f.read()

    def generate_file_header(self):
        self.fill_template(self.template_file_path_replacement_pattern,
                           self.theory_traits_template)

    def register_enumerator(self, enumerator, kind_name, theory_id, filename):
        header = enumerator["header"]
        enumerator_class = enumerator["class"]

        self.type_enumerator_includes += f"#include \"{header}\"\n"

        if re.search(rf"\b{re.escape(kind_name)}\b", self.type_constants):
            self.mk_type_enumerator_type_constant_cases += f"        case {kind_name}:\n          return new {enumerator_class}(type, tep);\n\n"
        elif re.search(rf"\b{re.escape(kind_name)}\b", self.type_kinds):
            self.mk_type_enumerator_cases += f"      case Kind::{kind_name}:\n        return new {enumerator_class}(type, tep);\n\n"
        else:
            print(
                f"{filename}:{theory_id}: error: don't know anything about {kind_name}; enumerator must appear after definition"
            )
            print(f"type_constants: {self.type_constants}")
            print(f"type_kinds : {self.type_kinds}")
            sys.exit(1)

    def register_kinds(self, kinds, theory_id, filename):
        if not kinds:
            return

        target_kinds_types = [
            "variable", "operator", "parameterized", "constant",
            "nullaryoperator"
        ]
        target_constant_types = ["sort"]

        for kind in kinds:
            kind_type = kind["type"]
            name_property = "K1" if kind_type == "parameterized" else "name"
            kind_name = kind[name_property]

            if kind_type in target_kinds_types:
                self.type_kinds += f"{kind_name} "
            elif kind_type in target_constant_types:
                self.type_constants += f"{kind_name} "

            if "enumerator" in kind:
                self.register_enumerator(kind["enumerator"], kind_name,
                                         theory_id, kind)

    def generate_code_for_theory(self, theory, rewriter):
        self.generate_code_for_theory_includes(theory["base_class_header"])
        self.generate_code_for_rewriter(rewriter)

        rewriter_class = rewriter["class"]

        theory_id = theory["id"]
        properties = theory["properties"]
        theory_class = theory["base_class"]

        self.theory_constructors += f"""
      case {theory_id}:\n
        engine->addTheory< {theory_class} >({theory_id});
        return;
        """

        theory_stable_infinite = 'true' if 'stable-infinite' in properties else 'false'
        theory_finite = 'true' if 'finite' in properties else 'false'
        theory_polite = 'true' if 'polite' in properties else 'false'
        theory_parametric = 'true' if 'parametric' in properties else 'false'
        theory_has_check = 'true' if 'check' in properties else 'false'
        theory_has_propagate = 'true' if 'propagate' in properties else 'false'
        theory_has_ppStaticLearn = 'true' if 'ppStaticLearn' in properties else 'false'
        theory_has_notifyRestart = 'true' if 'notifyRestart' in properties else 'false'
        theory_has_presolve = 'true' if 'presolve' in properties else 'false'

        self.theory_traits += f"""
template<>
struct TheoryTraits<{theory_id}> {{
    // typedef {theory_class} theory_class;
    typedef {rewriter_class} rewriter_class;

    static const bool isStableInfinite = {theory_stable_infinite};
    static const bool isFinite = {theory_finite};
    static const bool isPolite = {theory_polite};
    static const bool isParametric = {theory_parametric};

    static const bool hasCheck = {theory_has_check};
    static const bool hasPropagate = {theory_has_propagate};
    static const bool hasPpStaticLearn = {theory_has_ppStaticLearn};
    static const bool hasNotifyRestart = {theory_has_notifyRestart};
    static const bool hasPresolve = {theory_has_presolve};
}};/* struct TheoryTraits<{theory_id}> */
"""

    def generate_code_for_rewriter(self, rewriter):
        self.generate_code_for_theory_includes(rewriter["header"])

    def generate_code_for_theory_includes(self, theory_include):
        self.theory_includes += f"#include \"{theory_include}\"\n"

    def fill_template_data(self):
        self.fill_template(self.theory_includes_replacement_pattern,
                           self.theory_includes)
        self.fill_template(self.theory_traits_replacement_pattern,
                           self.theory_traits)
        self.fill_template(self.theory_constructors_replacement_pattern,
                           self.theory_constructors)
        self.fill_template(self.type_enumerator_includes_replacement_pattern,
                           self.type_enumerator_includes)
        self.fill_template(
            self.mk_type_enumerator_type_constant_cases_replacement_pattern,
            self.mk_type_enumerator_type_constant_cases)
        self.fill_template(self.mk_type_enumerator_cases_replacement_pattern,
                           self.mk_type_enumerator_cases)

    def fill_template(self, target_pattern, replacement_string):
        self.template_data = self.template_data.replace(
            target_pattern, str.encode(replacement_string))

    def write_output_data(self):
        with open(self.theory_traits_template_output, 'wb') as f:
            f.write(self.file_header)
            f.write(self.template_data)


def mktheorytraits_main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--kinds',
                        nargs='+',
                        help='List of input TOML files',
                        required=True,
                        type=str)
    parser.add_argument('--template',
                        help='Path to the template',
                        required=True,
                        type=str)
    parser.add_argument('--output', help='Output path', required=True)

    args = parser.parse_args()
    theory_traits_template_path = args.template
    output_path = args.output
    kinds_files = args.kinds

    input_command = ' '.join(sys.argv)

    cg = CodeGenerator(theory_traits_template_path, output_path, input_command)
    cg.read_template_data()
    cg.generate_file_header()

    tv = TheoryValidator()

    # Check if given kinds files exist.
    for file in kinds_files:
        if not os.path.exists(file):
            exit(f"Kinds file '{file}' does not exist")

    # Parse and check toml files
    for filename in kinds_files:
        try:
            with open(filename, "rb") as f:
                kinds_data = tomllib.load(f)
                tv.validate_theory(filename, kinds_data)

                theory = kinds_data["theory"]
                rewriter = kinds_data["rewriter"]
                kinds = kinds_data["kinds"]

                cg.generate_code_for_theory(theory, rewriter)
                cg.register_kinds(kinds, theory["id"], filename)
        except Exception as e:
            print(f"Could not parse file {filename}")
            print(e)
            exit(1)

    cg.fill_template_data()
    cg.write_output_data()


if __name__ == "__main__":
    mktheorytraits_main()
    exit(0)
