#!/usr/bin/python3
# Copyright (C) 2018 Jelmer Vernooij <jelmer@debian.org>
# This file is a part of debmutate.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA

__all__ = [
    "check_preserve_formatting",
    "check_generated_file",
    "edit_formatted_file",
    "Editor",
]


import logging
import os
import sys
from types import TracebackType
from typing import (
    Generic,
    Iterator,
    List,
    Literal,
    Optional,
    Type,
    TypeVar,
    Union,
)

if sys.version_info < (3, 11):
    from typing_extensions import Self
else:
    from typing import Self

DEFAULT_ENCODING = "utf-8"


class GeneratedFile(Exception):
    """The specified file is generated."""

    def __init__(
        self,
        path: str,
        template_path: Optional[str] = None,
        template_type: Optional[str] = None,
    ) -> None:
        self.path = path
        self.template_path = template_path
        self.template_type = template_type


class FormattingUnpreservable(Exception):
    """The file is unpreservable."""

    def __init__(
        self,
        path: str,
        original_contents: Union[str, bytes],
        rewritten_contents: Union[str, bytes],
    ) -> None:
        super().__init__(path)
        self.path = path
        self.original_contents = original_contents
        self.rewritten_contents = rewritten_contents

    def diff(self) -> Iterator[str]:
        from difflib import unified_diff

        if isinstance(self.original_contents, bytes) and isinstance(
            self.rewritten_contents, bytes
        ):
            return unified_diff(
                self.original_contents.decode("utf-8", errors="replace").splitlines(
                    True
                ),
                self.rewritten_contents.decode("utf-8", errors="replace").splitlines(
                    True
                ),
                fromfile="original",
                tofile="rewritten",
            )
        else:
            # Convert to strings if they aren't already
            orig_str = (
                self.original_contents
                if isinstance(self.original_contents, str)
                else self.original_contents.decode("utf-8", errors="replace")
            )
            rewritten_str = (
                self.rewritten_contents
                if isinstance(self.rewritten_contents, str)
                else self.rewritten_contents.decode("utf-8", errors="replace")
            )
            return unified_diff(
                orig_str.splitlines(True),
                rewritten_str.splitlines(True),
                fromfile="original",
                tofile="rewritten",
            )


def check_preserve_formatting(
    rewritten_text: Union[str, bytes],
    text: Union[str, bytes],
    path: str,
    allow_reformatting: bool = False,
) -> None:
    """Check that formatting can be preserved.

    Args:
      rewritten_text: The rewritten file contents
      text: The original file contents
      path: Path to the file (unused, just passed to the exception)
      allow_reformatting: Whether to allow reformatting
    Raises:
      FormattingUnpreservable: Raised when formatting could not be preserved
    """
    if rewritten_text == text:
        return
    if allow_reformatting:
        return
    raise FormattingUnpreservable(path, text, rewritten_text)


def check_generated_file(path: str) -> None:
    """Check if a file is generated from another file.

    Args:
      path: Path to the file to check
    Raises:
      GeneratedFile: when a generated file is found
    """
    for ext in [".in", ".m4", ".stub"]:
        if os.path.exists(path + ext):
            raise GeneratedFile(path, path + ext)
    DO_NOT_EDIT_SCAN_LINES = 20
    try:
        with open(path, "rb") as f:
            for i, line in enumerate(f):
                if i > DO_NOT_EDIT_SCAN_LINES:
                    break
                if b"DO NOT EDIT" in line:
                    raise GeneratedFile(path)
                if b"Do not edit!" in line:
                    raise GeneratedFile(path)
                if b"This file is autogenerated" in line:
                    raise GeneratedFile(path)
    except FileNotFoundError:
        return


def edit_formatted_file(
    path: str,
    original_contents: Union[str, bytes],
    rewritten_contents: Optional[Union[str, bytes]],
    updated_contents: Union[str, bytes],
    allow_generated: bool = False,
    allow_reformatting: bool = False,
    encoding: str = DEFAULT_ENCODING,
) -> bool:
    """Edit a formatted file.

    Args:
      path: path to the file
      original_contents: The original contents of the file
      rewritten_contents: The contents rewritten with our parser/serializer
      updated_contents: Updated contents rewritten with our parser/serializer
        after changes were made.
      allow_generated: Do not raise GeneratedFile when encountering a generated
        file
      allow_reformatting: Whether to allow reformatting of the file
    """
    if (
        updated_contents is not None
        and rewritten_contents is not None
        and type(updated_contents) is not type(rewritten_contents)
    ):
        raise TypeError(
            f"inconsistent types: {type(updated_contents)!r}, {type(rewritten_contents)!r}"
        )
    if updated_contents in (rewritten_contents, original_contents):
        return False
    if not allow_generated:
        check_generated_file(path)
    try:
        check_preserve_formatting(
            rewritten_contents.strip()  # type: ignore
            if rewritten_contents is not None
            else None,
            original_contents.strip() if original_contents is not None else None,
            path,
            allow_reformatting=allow_reformatting,
        )
    except FormattingUnpreservable as e:
        if (
            rewritten_contents is None
            or original_contents is None
            or updated_contents is None
        ):
            raise
        # Run three way merge
        logging.debug("Unable to preserve formatting; falling back to merge3")
        try:
            import merge3
        except ModuleNotFoundError:
            raise e
        if isinstance(rewritten_contents, bytes) and merge3.__version__ < (0, 0, 7):
            raise e
        if isinstance(updated_contents, bytes):
            assert isinstance(rewritten_contents, bytes)
            assert isinstance(original_contents, bytes)
            m3: merge3.Merge3[bytes] = merge3.Merge3(
                rewritten_contents.splitlines(True),
                original_contents.splitlines(True),
                updated_contents.splitlines(True),
            )
            if any([y[0] == "conflict" for y in m3.merge_regions()]):
                raise
            with open(path, "wb") as f:
                f.writelines(m3.merge_lines())
        else:
            assert isinstance(rewritten_contents, str)
            assert isinstance(original_contents, str)
            m3_str: merge3.Merge3[str] = merge3.Merge3(
                rewritten_contents.splitlines(True),
                original_contents.splitlines(True),
                updated_contents.splitlines(True),
            )
            if any([y[0] == "conflict" for y in m3_str.merge_regions()]):
                raise
            with open(path, "w", encoding=encoding) as f:
                f.writelines(m3_str.merge_lines())
    else:
        # Formatting can be preserved or is allowed to change - write the updated content
        if isinstance(updated_contents, bytes):
            with open(path, "wb") as f:
                f.write(updated_contents)
        else:
            with open(path, "w", encoding=encoding) as f:
                f.write(updated_contents)
    return True


T = TypeVar("T")
P = TypeVar("P", str, bytes)


class Editor(Generic[T, P]):
    """Context object for editing a file, preserving formatting."""

    changed_files: List[str]
    _rewritten_content: Optional[P]

    def __init__(
        self,
        path: str,
        mode: str = "",
        allow_generated: bool = False,
        allow_reformatting: Optional[bool] = None,
        encoding: str = DEFAULT_ENCODING,
    ) -> None:
        self.path = path
        self.mode = mode
        self.allow_generated = allow_generated
        # TODO(jelmer): Don't make this class check the environment
        if allow_reformatting is None:
            allow_reformatting = os.environ.get("REFORMATTING", "disallow") == "allow"
        self.allow_reformatting = allow_reformatting
        self.encoding = encoding

    def _nonexistent(self) -> Optional[T]:
        raise

    def _parse(self, content: P) -> T:
        """Parse the specified bytestring and returned parsed object."""
        raise NotImplementedError(self._parse)

    def _format(self, parsed: T) -> Optional[P]:
        """Serialize the parsed object."""
        raise NotImplementedError(self._format)

    def __enter__(self) -> Self:
        kwargs = {}
        if "b" not in self.mode:
            kwargs["encoding"] = self.encoding
        try:
            with open(self.path, "r" + self.mode, **kwargs) as f:  # type: ignore
                self._orig_content = f.read()
        except FileNotFoundError:
            self._orig_content = None
            self._parsed = self._nonexistent()
        else:
            self._parsed = self._parse(self._orig_content)
        if self._parsed is not None:
            self._rewritten_content = self._format(self._parsed)
        else:
            self._rewritten_content = None
        return self

    def _updated_content(self) -> Optional[P]:
        if self._parsed is not None:
            return self._format(self._parsed)
        else:
            return None

    def has_changed(self) -> bool:
        """Check if any changes have been made so far."""
        return self._updated_content() not in (
            self._rewritten_content,
            self._orig_content,
        )

    def __exit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_val: Optional[BaseException],
        exc_tb: Optional[TracebackType],
    ) -> Literal[False]:
        updated_content = self._updated_content()

        if updated_content is None:
            if os.path.exists(self.path):
                os.unlink(self.path)
                self.changed_files = [self.path]
        else:
            self.changed = edit_formatted_file(
                self.path,
                self._orig_content,
                self._rewritten_content,
                updated_content,
                allow_generated=self.allow_generated,
                allow_reformatting=self.allow_reformatting,
                encoding=self.encoding,
            )
            if self.changed:
                self.changed_files = [self.path]
            else:
                self.changed_files = []
        return False
