#!/usr/bin/env python3

import argparse
import os
from pathlib import Path

parser = argparse.ArgumentParser()
parser.add_argument("extension_file")
parser.add_argument("-o", "--output")
args = parser.parse_args()

SINGLE_LETTER_EXTENSION_ORDER = "iemafdgqlcbkjtpvh"

single_letter_extensions = set()
unprivileged_extensions = set()  # Z*
supervisor_extensions = set()  # Ss*, Sv*
hypervisor_extensions = set()  # Sh*
machine_extensions = set()  # Sm*
non_standard_extensions = set()  # X*

with open(args.extension_file) as extension_file:
    for line in extension_file.readlines():
        extension = line.partition("#")[0].strip()

        if not extension:
            continue

        extension = extension.lower()

        if extension.startswith("z"):
            unprivileged_extensions.add(extension)
        elif extension.startswith("ss") or extension.startswith("sv"):
            supervisor_extensions.add(extension)
        elif extension.startswith("sh"):
            hypervisor_extensions.add(extension)
        elif extension.startswith("sm"):
            machine_extensions.add(extension)
        elif extension.startswith("x"):
            non_standard_extensions.add(extension)
        elif len(extension) == 1:
            single_letter_extensions.add(extension)
        else:
            raise Exception(f"Unknown extension category: {extension}")


def unprivileged_extension_sort_key(extension):
    # Z* extensions are sorted first by their related single-letter extension and then alphabetically.
    related_single_letter_extension = extension[1]
    return SINGLE_LETTER_EXTENSION_ORDER.index(related_single_letter_extension), extension[2:]


sorted_single_letter_extensions = sorted(single_letter_extensions, key=SINGLE_LETTER_EXTENSION_ORDER.index)
sorted_unprivileged_extensions = sorted(unprivileged_extensions, key=unprivileged_extension_sort_key)
sorted_supervisor_extensions = sorted(supervisor_extensions)
sorted_hypervisor_extensions = sorted(hypervisor_extensions)
sorted_machine_extensions = sorted(machine_extensions)
sorted_non_standard_extensions = sorted(non_standard_extensions)

sorted_extension_file = """\
# This file is used to generate the RISC-V Extensions.h header.
# It is automatically sorted by Meta/riscv_extensions_generator.py to keep the canonical extension order defined by the RISC-V ISA manual.
# To add an extension, simply insert it anywhere in this file and run the script to sort it correctly.

"""  # noqa: E501

extension_header = """\
// This file was automatically generated by Meta/riscv_extensions_generator.py.

#pragma once

namespace Kernel {

// The extensions are sorted in canonical order, see https://github.com/riscv/riscv-isa-manual/blob/main/src/naming.adoc.

// 1st argument: extension name (used as our CPUFeature flag name)
// 2nd argument: lowercase extension name (used in the devicetree, see https://www.kernel.org/doc/Documentation/devicetree/bindings/riscv/extensions.yaml)
// 3rd argument: CPUFeature bitmask index
#define ENUMERATE_RISCV_EXTENSIONS(E) \\
"""  # noqa: E501

i = 0

for extension_category, extension_list in (
    ("Single-letter extensions", sorted_single_letter_extensions),
    ("Unprivileged extensions", sorted_unprivileged_extensions),
    ("Supervisor extensions", sorted_supervisor_extensions),
    ("Hypervisor extensions", sorted_hypervisor_extensions),
    ("Machine extensions", sorted_machine_extensions),
    ("Non-standard extensions", sorted_non_standard_extensions),
):
    sorted_extension_file += f"# {extension_category}\n"

    for extension in extension_list:
        capitalized_extension = extension[0].upper() + extension[1:]
        extension_header += f"    E({capitalized_extension}, {extension}, {i}) \\\n"

        sorted_extension_file += capitalized_extension + "\n"

        i += 1

    sorted_extension_file += "\n"

# Remove the last trailing " \".
extension_header = extension_header.removesuffix(" \\\n")

extension_header += "\n\n}"


# FIXME: Share this function with TIFFGenerator.py.
def update_file(target: Path, new_content: str):
    should_update = True

    if target.exists():
        with target.open('r') as file:
            content = file.read()
            if content == new_content:
                should_update = False

    if should_update:
        with target.open('w') as file:
            file.write(new_content)


extension_file_path = Path(args.extension_file)
extension_header_path = Path(args.output)

update_file(extension_file_path, sorted_extension_file)

os.makedirs(extension_header_path.parent, exist_ok=True)
update_file(extension_header_path, extension_header)
