# Copyright 2024 Helmut Grohne <helmut@subdivi.de>
# SPDX-License-Identifier: GPL-3

import asyncio
import errno
import os
import pathlib
import signal
import socket
import unittest

import pytest

import linuxnamespaces


class MountFlagsTest(unittest.TestCase):
    def test_tostrfromstr(self) -> None:
        for bit1 in range(32):
            for bit2 in range(bit1, 32):
                flag = (
                    linuxnamespaces.MountFlags(1 << bit1)
                    | linuxnamespaces.MountFlags(1 << bit2)
                )
                try:
                    text = flag.tostr()
                except ValueError:
                    continue
                self.assertEqual(
                    linuxnamespaces.MountFlags.fromstr(text), flag
                )


class IDAllocationTest(unittest.TestCase):
    def test_idalloc(self) -> None:
        alloc = linuxnamespaces.IDAllocation()
        alloc.add_range(1, 2)
        alloc.add_range(5, 4)
        self.assertIn(alloc.find(3), (5, 6))
        self.assertIn(alloc.allocate(3), (5, 6))
        self.assertRaises(ValueError, alloc.find, 3)
        self.assertRaises(ValueError, alloc.allocate, 3)
        self.assertEqual(alloc.find(2), 1)

    def test_merge(self) -> None:
        alloc = linuxnamespaces.IDAllocation()
        alloc.add_range(1, 2)
        alloc.add_range(3, 2)
        self.assertIn(alloc.allocate(3), (1, 2))

    def test_reserve(self) -> None:
        alloc = linuxnamespaces.IDAllocation()
        alloc.add_range(0, 10)
        # Split a range
        alloc.reserve(3, 3)
        self.assertEqual(alloc.ranges, [(0, 3), (6, 4)])
        self.assertRaises(ValueError, alloc.reserve, 0, 4)
        self.assertRaises(ValueError, alloc.reserve, 5, 4)
        # Head of range
        alloc.reserve(0, 2)
        self.assertEqual(alloc.ranges, [(2, 1), (6, 4)])
        # Tail of range
        alloc.reserve(7, 3)
        self.assertEqual(alloc.ranges, [(2, 1), (6, 1)])
        # Exact range
        alloc.reserve(6, 1)
        self.assertEqual(alloc.ranges, [(2, 1)])


class AsnycioTest(unittest.IsolatedAsyncioTestCase):
    async def asyncSetUp(self) -> None:
        self.loop = asyncio.get_running_loop()

    def wait_readable(self, rfd: int) -> asyncio.Future[None]:
        fut = self.loop.create_future()
        def callback() -> None:
            self.loop.remove_reader(rfd)
            fut.set_result(None)
        self.loop.add_reader(rfd, callback)
        return fut

    async def test_eventfd(self) -> None:
        with linuxnamespaces.EventFD(
            1, linuxnamespaces.EventFDFlags.NONBLOCK
        ) as efd:
            fut = asyncio.ensure_future(efd.aread())
            await asyncio.sleep(0.000001)  # Let the loop run
            self.assertTrue(fut.done())
            self.assertEqual(await fut, 1)
            fut = asyncio.ensure_future(efd.aread())
            await asyncio.sleep(0.000001)  # Let the loop run
            self.assertFalse(fut.done())
            efd.write()
            self.assertEqual(await fut, 1)

    async def test_signalfd(self) -> None:
        testsig = signal.SIGUSR1
        sfd = linuxnamespaces.SignalFD(
            [testsig], linuxnamespaces.SignalFDFlags.NONBLOCK
        )
        self.addCleanup(sfd.close)
        oldmask = signal.pthread_sigmask(signal.SIG_SETMASK, [testsig])
        self.addCleanup(signal.pthread_sigmask, signal.SIG_SETMASK, oldmask)
        fut = asyncio.ensure_future(sfd.aread())
        await asyncio.sleep(0.000001)  # Let the loop run
        self.assertFalse(fut.done())
        sigval = 123456789
        mypid = os.getpid()
        linuxnamespaces.sigqueue(mypid, testsig, sigval)
        siginfo = await fut
        self.assertEqual(siginfo.ssi_signo, testsig)
        self.assertEqual(siginfo.ssi_pid, mypid)
        self.assertEqual(siginfo.ssi_uid, os.getuid())
        self.assertEqual(siginfo.ssi_int, sigval)

    async def test_run_in_fork(self) -> None:
        with linuxnamespaces.EventFD(
            0, linuxnamespaces.EventFDFlags.NONBLOCK
        ) as efd:
            fut = asyncio.ensure_future(efd.aread())
            set_ready = linuxnamespaces.async_run_in_fork(efd.write)
            await asyncio.sleep(0.000001)  # Let the loop run
            self.assertFalse(fut.done())
            await set_ready()
            await asyncio.wait_for(fut, 10)

    async def test_copyfd_file_sock(self) -> None:
        sock1, sock2 = socket.socketpair()
        with (
            sock1,
            sock2,
            linuxnamespaces.FileDescriptor(
                os.open("/etc/passwd", os.O_RDONLY)
            ) as rfd,
        ):
            fut = asyncio.ensure_future(
                linuxnamespaces.async_copyfd(rfd, sock1.fileno(), 999)
            )
            await self.wait_readable(sock2.fileno())
            self.assertGreater(len(sock2.recv(999)), 0)
            self.assertTrue(fut.done())
            self.assertGreater(await fut, 0)

    async def test_copyfd_file_pipe(self) -> None:
        rfdp, wfdp = linuxnamespaces.FileDescriptor.pipe(blocking=False)
        with (
            rfdp,
            wfdp,
            linuxnamespaces.FileDescriptor(
                os.open("/etc/passwd", os.O_RDONLY)
            ) as rfd,
        ):
            fut = asyncio.ensure_future(
                linuxnamespaces.async_copyfd(rfd, wfdp, 999)
            )
            await self.wait_readable(rfdp)
            self.assertGreater(len(os.read(rfdp, 999)), 0)
            self.assertTrue(fut.done())
            self.assertGreater(await fut, 0)

    async def test_copyfd_pipe_pipe(self) -> None:
        rfd1, wfd1 = linuxnamespaces.FileDescriptor.pipe(blocking=False)
        rfd2, wfd2 = linuxnamespaces.FileDescriptor.pipe(blocking=False)
        with wfd2, rfd2, rfd1:
            with wfd1:
                fut = asyncio.ensure_future(
                    linuxnamespaces.async_copyfd(rfd1, wfd2)
                )
                os.write(wfd1, b"hello")
                await asyncio.sleep(0.000001)  # Let the loop run
                os.write(wfd1, b"world")
                await self.wait_readable(rfd2)
                self.assertEqual(os.read(rfd2, 11), b"helloworld")
                self.assertFalse(fut.done())
            await asyncio.sleep(0.000001)  # Let the loop run
            self.assertTrue(fut.done())
        self.assertEqual(await fut, 10)

    async def test_copyfd_epipe(self) -> None:
        rfd1, wfd1 = linuxnamespaces.FileDescriptor.pipe(blocking=False)
        rfd2, wfd2 = linuxnamespaces.FileDescriptor.pipe(blocking=False)
        with wfd2, wfd1, rfd1:
            with rfd2:
                fut = asyncio.ensure_future(
                    linuxnamespaces.async_copyfd(rfd1, wfd2)
                )
            os.write(wfd1, b"hello")
            await asyncio.sleep(0.000001)  # Let the loop run
            self.assertTrue(fut.done())
        exc = fut.exception()
        self.assertIsInstance(exc, OSError)
        assert isinstance(exc, OSError)  # also tell mypy
        self.assertEqual(exc.errno, errno.EPIPE)


class UnshareTest(unittest.TestCase):
    @pytest.mark.forked
    def test_unshare_user(self) -> None:
        overflowuid = int(pathlib.Path("/proc/sys/fs/overflowuid").read_text())
        idmap = linuxnamespaces.IDMapping(0, os.getuid(), 1)
        linuxnamespaces.unshare(linuxnamespaces.CloneFlags.NEWUSER)
        self.assertEqual(os.getuid(), overflowuid)
        linuxnamespaces.newuidmap(-1, [idmap], False)
        self.assertEqual(os.getuid(), 0)
        # UID 1 is not mapped.
        self.assertRaises(OSError, os.setuid, 1)

    @pytest.mark.forked
    def test_mount_proc(self) -> None:
        idmap = linuxnamespaces.IDMapping(0, os.getuid(), 1)
        linuxnamespaces.unshare(
            linuxnamespaces.CloneFlags.NEWUSER
            | linuxnamespaces.CloneFlags.NEWNS
            | linuxnamespaces.CloneFlags.NEWPID
        )
        linuxnamespaces.newuidmap(-1, [idmap], False)
        @linuxnamespaces.run_in_fork
        def setup() -> None:
            self.assertEqual(os.getpid(), 1)
            linuxnamespaces.mount("proc", "/proc", "proc")
        setup()

    @pytest.mark.forked
    def test_sethostname(self) -> None:
        self.assertRaises(socket.error, socket.sethostname, "example")
        linuxnamespaces.unshare(
            linuxnamespaces.CloneFlags.NEWUSER
            | linuxnamespaces.CloneFlags.NEWUTS
        )
        socket.sethostname("example")

    @pytest.mark.forked
    def test_populate_dev(self) -> None:
        linuxnamespaces.unshare_user_idmap_nohelper(
            0,
            0,
            linuxnamespaces.CloneFlags.NEWUSER
            | linuxnamespaces.CloneFlags.NEWNS,
        )
        linuxnamespaces.mount("tmpfs", "/mnt", "tmpfs", data="mode=0755")
        os.mkdir("/mnt/dev")
        linuxnamespaces.populate_dev("/", "/mnt", pidns=False)
        self.assertTrue(os.access("/mnt/dev/null", os.W_OK))
        pathlib.Path("/mnt/dev/null").write_text("")


class UnshareIdmapTest(unittest.TestCase):
    def setUp(self) -> None:
        super().setUp()
        self.uidalloc = linuxnamespaces.IDAllocation.loadsubid("uid")
        self.gidalloc = linuxnamespaces.IDAllocation.loadsubid("gid")
        try:
            self.uidalloc.find(65536)
            self.gidalloc.find(65536)
        except ValueError:
            self.skipTest("insufficient /etc/sub?id allocation")

    @pytest.mark.forked
    def test_unshare_user_idmap(self) -> None:
        uidmaps = [
            linuxnamespaces.IDMapping(0, self.uidalloc.allocate(65536), 65536),
            linuxnamespaces.IDMapping(65536, os.getuid(), 1),
        ]
        self.assertNotEqual(os.getuid(), uidmaps[0].outerstart)
        gidmaps = [
            linuxnamespaces.IDMapping(0, self.gidalloc.allocate(65536), 65536),
            linuxnamespaces.IDMapping(65536, os.getgid(), 1),
        ]
        pid = os.getpid()
        @linuxnamespaces.run_in_fork
        def setup() -> None:
            linuxnamespaces.newidmaps(pid, uidmaps, gidmaps)
        linuxnamespaces.unshare(linuxnamespaces.CloneFlags.NEWUSER)
        setup()
        self.assertEqual(os.getuid(), 65536)
        os.setuid(0)
        self.assertEqual(os.getuid(), 0)
        # Keep root in saved-set for later setuid
        os.setresuid(1, 1, 0)
        self.assertEqual(os.getuid(), 1)
        # Regain root and a full set of capabilities to save test coverage
        os.setuid(0)

    @pytest.mark.forked
    def test_populate_dev(self) -> None:
        uidmaps = [
            linuxnamespaces.IDMapping(0, self.uidalloc.allocate(65536), 65536),
            # Also map our own uid to make coverage testing work
            linuxnamespaces.IDMapping(65536, os.getuid(), 1),
        ]
        self.assertNotEqual(os.getuid(), uidmaps[0].outerstart)
        gidmaps = [
            linuxnamespaces.IDMapping(0, self.gidalloc.allocate(65536), 65536),
            linuxnamespaces.IDMapping(65536, os.getgid(), 1),
        ]
        pid = os.getpid()
        @linuxnamespaces.run_in_fork
        def setup() -> None:
            linuxnamespaces.newidmaps(pid, uidmaps, gidmaps)
        linuxnamespaces.unshare(
            linuxnamespaces.CloneFlags.NEWUSER
            | linuxnamespaces.CloneFlags.NEWNS
            | linuxnamespaces.CloneFlags.NEWPID
        )
        setup()
        os.setreuid(0, 0)
        os.setregid(0, 0)
        linuxnamespaces.mount("tmpfs", "/mnt", "tmpfs")
        os.mkdir("/mnt/dev")
        @linuxnamespaces.run_in_fork
        def test() -> None:
            linuxnamespaces.populate_dev("/", "/mnt")
        test()