# Copyright 2024 Helmut Grohne <helmut@subdivi.de>
# SPDX-License-Identifier: GPL-3

import contextlib
import functools
import os
import pathlib
import stat
import tempfile
import typing
import unittest

import linuxnamespaces
from linuxnamespaces import AtLocation


class AtLocationTest(unittest.TestCase):
    def setUp(self) -> None:
        self.tempdir = pathlib.Path(
            self.enterContext(tempfile.TemporaryDirectory())
        )
        self.counter = 0

    @contextlib.contextmanager
    def create(
        self,
        loctype: typing.Literal["relative", "absolute", "emptypath", "withfd"],
        filetype: typing.Literal["absent", "file", "directory", "symlink"],
        linktarget: typing.Optional[str],
        follow_symlinks: bool = True,
    ) -> typing.Iterator[AtLocation]:
        """Create an AtLocation object for testing purposes. The created object
        can be "absent", a regular "file", a "directory" or a symbolic link
        ("symlink"). If it is a symlink, a linktarget must be given and
        follow_symlinks may be used to set AT_SYMLINK_NOFOLLOW. The location
        can refer to the object in multiple ways. If expressed as a "relative"
        path, the resulting context manager will temporarily change the working
        directory. It can also be "absolute" or use an O_PATH file descriptor
        with an "emptypath" or with a path relative to the file descriptor
        ("withfd")."""
        sym = str(self.counter)
        self.counter += 1
        abspath = self.tempdir / sym
        if filetype == "file":
            abspath.touch()
        elif filetype == "directory":
            abspath.mkdir()
        elif filetype == "symlink":
            assert linktarget is not None
            abspath.symlink_to(linktarget)
        else:
            self.assertEqual(filetype, "absent")
        if follow_symlinks:
            flags = linuxnamespaces.AtFlags.NONE
        else:
            flags = linuxnamespaces.AtFlags.AT_SYMLINK_NOFOLLOW
        if loctype == "relative":
            origcwd = os.getcwd()
            os.chdir(self.tempdir)
            try:
                yield AtLocation(sym, flags=flags)
            finally:
                os.chdir(origcwd)
        elif loctype == "absolute":
            yield AtLocation(abspath, flags=flags)
        elif loctype == "emptypath":
            if follow_symlinks:
                fd = os.open(abspath, os.O_PATH | os.O_NOFOLLOW)
            else:
                fd = os.open(abspath, os.O_PATH)
            with AtLocation(fd) as loc:
                yield loc
        else:
            self.assertEqual(loctype, "withfd")
            with AtLocation(
                os.open(self.tempdir, os.O_PATH), sym, flags
            ) as loc:
                yield loc

    def create_all(
        self, skip: typing.Container[str] = ()
    ) -> typing.Iterator[tuple[str, typing.ContextManager[AtLocation]]]:
        """Create various AtLocation objects referring to files, directories
        and other things in various ways.
        """
        for loctype in ("relative", "absolute", "emptypath", "withfd"):
            if loctype in skip:
                continue
            for filetype in ("absent", "file", "directory", "symlink"):
                if filetype in skip:
                    continue
                if filetype == "absent" and loctype in ("emptypath", "withfd"):
                    continue
                follow_symlinks_values = [True]
                if filetype == "symlink" and loctype != "emptypath":
                    follow_symlinks_values.append(False)
                for follow_symlinks in follow_symlinks_values:
                    atlocctx = self.create(
                        loctype, filetype, "X", follow_symlinks
                    )
                    yield (filetype, atlocctx)

    @staticmethod
    def atloc_subtest(
        skip: typing.Container[str] = (),
    ) -> typing.Callable[
        [typing.Callable[["AtLocationTest", str, AtLocation], None]],
        typing.Callable[["AtLocationTest"], None],
    ]:
        """Wrap a test function and invoke it with possible AtLocations in a
        subTest managed context.
        """

        def decorator(
            func: typing.Callable[["AtLocationTest", str, AtLocation], None]
        ) -> typing.Callable[["AtLocationTest"], None]:
            @functools.wraps(func)
            def decorated(self: "AtLocationTest") -> None:
                for filetype, atlocctx in self.create_all(skip):
                    with atlocctx as atloc, self.subTest(
                        atlocation=repr(atloc), filetype=filetype
                    ):
                        func(self, filetype, atloc)

            return decorated

        return decorator

    @atloc_subtest()
    def test_access(self, filetype: str, atloc: AtLocation) -> None:
        should_exist = filetype in ("file", "directory") or (
            filetype == "symlink"
            and atloc.flags
            & (
                linuxnamespaces.AtFlags.AT_SYMLINK_NOFOLLOW
                | linuxnamespaces.AtFlags.AT_EMPTY_PATH
            )
            != linuxnamespaces.AtFlags.NONE
        )
        if not atloc.flags & linuxnamespaces.AtFlags.AT_EMPTY_PATH:
            self.assertEqual(atloc.access(os.R_OK), should_exist)
        self.assertEqual(atloc.exists(), should_exist)

    def test_as_emptypath(self) -> None:
        atloc = AtLocation(self.tempdir)
        self.assertFalse(atloc.flags & linuxnamespaces.AtFlags.AT_EMPTY_PATH)
        statres = atloc.stat()
        atloc_ep = self.enterContext(atloc.as_emptypath())
        self.assertTrue(atloc_ep.flags & linuxnamespaces.AtFlags.AT_EMPTY_PATH)
        self.assertGreaterEqual(atloc_ep.fd, 0)
        self.assertEqual(atloc_ep.location, "")
        self.assertEqual(atloc_ep.stat().st_ino, statres.st_ino)
        atloc_dup = self.enterContext(atloc_ep.as_emptypath())
        self.assertTrue(
            atloc_dup.flags & linuxnamespaces.AtFlags.AT_EMPTY_PATH
        )
        self.assertGreaterEqual(atloc_ep.fd, 0)
        self.assertNotEqual(atloc_dup.fd, atloc_ep.fd)
        self.assertEqual(atloc_dup.location, "")
        self.assertEqual(atloc_dup.stat().st_ino, statres.st_ino)

    @atloc_subtest(skip=("absent", "file", "symlink"))
    def test_join_mkdir(self, _: str, atloc: AtLocation) -> None:
        subdir = atloc / "subdir"
        self.assertFalse(subdir.exists())
        subdir.mkdir()
        self.assertTrue(subdir.is_dir())
        subdir.rmdir()
        self.assertFalse(subdir.exists())

    def test_mknod(self) -> None:
        with AtLocation(os.open(self.tempdir, os.O_PATH)) as rootloc:
            for loc in [AtLocation(self.tempdir) / "a", rootloc / "b"]:
                self.assertFalse(loc.exists(), f"{loc} does not exist")
                loc.mknod(device=stat.S_IFREG)
                self.assertTrue(loc.is_file(), f"{loc} is a file")

    @atloc_subtest(skip=("absent", "file", "symlink"))
    def test_walk(self, _: str, atloc: AtLocation) -> None:
        (atloc / "emptydir").mkdir()
        (atloc / "dir").mkdir()
        (atloc / "dir" / "alibi").mknod(device=stat.S_IFREG)
        (atloc / "deadlink").symlink_to("doesnotexist")
        (atloc / "symlink").symlink_to("dir/alibi")
        for dirloc, dirnames, filenames, dirfd in atloc.walk():
            self.assertEqual(dirloc.fd, atloc.fd)
            self.assertTrue(
                dirfd.flags & linuxnamespaces.AtFlags.AT_EMPTY_PATH
            )
            for dentry in dirnames:
                self.assertEqual(dentry.fd, dirfd.fd)
                self.assertTrue(dentry.location)
            for fentry in filenames:
                self.assertEqual(fentry.fd, dirfd.fd)
                self.assertTrue(fentry.location)
                thing = str(fentry.location).rsplit("/", 1)[-1]
                if thing == "alibi":
                    self.assertTrue(fentry.is_file())
                if thing == "deadlink":
                    self.assertTrue(fentry.is_symlink())
                    self.assertEqual(fentry.readlink(), "doesnotexist")
                    self.assertFalse(fentry.symfollow().exists())
                if thing == "symlink":
                    self.assertTrue(fentry.is_symlink())
                    self.assertEqual(fentry.readlink(), "dir/alibi")
                    self.assertTrue(fentry.symfollow().exists())