#!/usr/bin/python3
# Copyright 2024 Helmut Grohne <helmut@subdivi.de>
# SPDX-License-Identifier: GPL-3

"""Extensions to the tarfile module.
 * ZstdTarFile extends TarFile to deal with zstd-compressed archives.
 * get_comptype guesses the compression used for an open TarFile.
 * XAttrTarFile extends TarFile to map extended attributes to PAX headers.
"""

import os
import tarfile
import typing


TarPath = str | bytes | os.PathLike[str] | os.PathLike[bytes]


class ZstdTarFile(tarfile.TarFile):
    """Subclass of tarfile.TarFile that can read zstd compressed archives."""

    # mypy types OPEN_METH as Mapping rather than dict while it really is a
    # dict. Hence, it complains that there is no __or__ for dict and Mapping.
    OPEN_METH = {
        "zst": "zstopen",
    } | tarfile.TarFile.OPEN_METH  # type: ignore[operator]

    @classmethod
    def zstopen(
        cls,
        name: str,
        mode: typing.Literal["r", "w", "x"] = "r",
        fileobj: typing.BinaryIO | None = None,
        **kwargs: typing.Any,
    ) -> tarfile.TarFile:
        if mode not in ("r", "w", "x"):
            raise ValueError("mode must be 'r', 'w' or 'x'")
        openobj: str | typing.BinaryIO = name if fileobj is None else fileobj
        try:
            import zstandard
        except ImportError as err:
            raise tarfile.CompressionError(
                "zstandard module not available"
            ) from err
        if mode == "r":
            zfobj = zstandard.open(openobj, "rb")
        else:
            zfobj = zstandard.open(
                openobj,
                mode + "b",
                cctx=zstandard.ZstdCompressor(write_checksum=True, threads=-1),
            )
        try:
            tarobj = cls.taropen(name, mode, zfobj, **kwargs)
        except (OSError, EOFError, zstandard.ZstdError) as exc:
            zfobj.close()
            if mode == "r":
                raise tarfile.ReadError("not a zst file") from exc
            raise
        except:
            zfobj.close()
            raise
        # Setting the _extfileobj attribute is important to signal a need to
        # close this object and thus flush the compressed stream.
        # Unfortunately, tarfile.pyi doesn't know about it.
        tarobj._extfileobj = False  # type: ignore
        return tarobj


def get_comptype(tarobj: tarfile.TarFile) -> str:
    """Return the compression type used to compress the given TarFile."""
    # The tarfile module does not expose the compression method selected
    # for open mode "r:*" in any way. We can guess it from the module that
    # implements the fileobj.
    compmodule = tarobj.fileobj.__class__.__module__
    try:
        return {
            "bz2": "bz2",
            "gzip": "gz",
            "lzma": "xz",
            "_io": "tar",
            "zstd": "zst",
        }[compmodule]
    except KeyError:
        # pylint: disable=raise-missing-from  # no value in chaining
        raise ValueError(f"cannot guess comptype for module {compmodule}")


class XAttrTarFile(tarfile.TarFile):
    """A subclass to tarfile.TarFile that adds support for extended attributes
    via SCHILY.xattr.* PAX headers to extraction and creation of archives. It
    can be used as a mixin class with others as it does not add any state.
    """

    def extract(
        self,
        member: tarfile.TarInfo | str,
        path: TarPath = "",
        set_attrs: bool = True,
        **kwargs: typing.Any,
    ) -> None:
        """Refer to tarfile.TarFile.extract. In addition, SCHILY.xattr.* PAX
        headers are examined and applied as extended attributes if set_attrs is
        true-ish.
        """
        if not set_attrs:
            super().extract(member, path, False, **kwargs)
            return

        # We also need the tarinfo, so mimic the start of the built-in extract.
        if isinstance(member, str):
            tarinfo = self.getmember(member)
        else:
            tarinfo = member

        super().extract(tarinfo, path, True, **kwargs)

        # mypy is unhappy about the next line, but we have the same code in
        # TarFile.extract and if it bails here, it also bails there.
        path = os.path.join(path, tarinfo.name)  # type: ignore

        for attr, value in tarinfo.pax_headers.items():
            if not attr.startswith("SCHILY.xattr."):
                continue
            attr = attr.removeprefix("SCHILY.xattr.")
            os.setxattr(
                path,
                attr,
                value.encode(self.encoding or "utf8", "surrogateescape"),
                follow_symlinks=False,
            )

    def gettarinfo(
        self,
        name: TarPath | None = None,
        arcname: str | None = None,
        fileobj: typing.IO[bytes] | None = None,
    ) -> tarfile.TarInfo:
        tarinfo = super().gettarinfo(name, arcname, fileobj)
        path: int | TarPath
        if fileobj is not None:
            path = fileobj.fileno()
        elif name is not None:
            path = name
        else:
            raise ValueError("gettarinfo requires a name or fileobj")
        dereference = True if self.dereference is None else self.dereference
        for attr in os.listxattr(path, follow_symlinks=dereference):
            key = "SCHILY.xattr." + attr
            value = os.getxattr(
                path, attr, follow_symlinks=dereference
            ).decode(self.encoding or "utf8", "surrogateescape")
            # TarInfo.pax_headers is designated as (read-only) Mapping, but it
            # really is a writable dict.
            tarinfo.pax_headers[key] = value  # type: ignore[index]
        return tarinfo