diff options
Diffstat (limited to 'linuxnamespaces/atlocation.py')
-rw-r--r-- | linuxnamespaces/atlocation.py | 49 |
1 files changed, 39 insertions, 10 deletions
diff --git a/linuxnamespaces/atlocation.py b/linuxnamespaces/atlocation.py index 20d402a..46ac541 100644 --- a/linuxnamespaces/atlocation.py +++ b/linuxnamespaces/atlocation.py @@ -9,13 +9,14 @@ code for doing so. import enum import errno +import locale import os import os.path import pathlib import stat import typing -from .filedescriptor import FileDescriptor +from .filedescriptor import FileDescriptor, FileDescriptorLike, HasFileno AT_FDCWD = FileDescriptor(-100) @@ -58,7 +59,7 @@ class AtLocation: def __new__( cls, - thing: typing.Union["AtLocation", int, PathConvertible], + thing: typing.Union["AtLocation", FileDescriptorLike, PathConvertible], location: PathConvertible | None = None, flags: AtFlags = AtFlags.NONE, ) -> "AtLocation": @@ -76,13 +77,14 @@ class AtLocation: ) return thing # Don't copy. obj = super(AtLocation, cls).__new__(cls) - if isinstance(thing, int): + if not isinstance(thing, FileDescriptor): + if isinstance(thing, (int, HasFileno)): + thing = FileDescriptor(thing) + if isinstance(thing, FileDescriptor): if thing < 0 and thing != AT_FDCWD: raise ValueError("fd cannot be negative") if isinstance(thing, FileDescriptor): obj.fd = thing - else: - obj.fd = FileDescriptor(thing) if location is None: obj.location = "" obj.flags = flags | AtFlags.AT_EMPTY_PATH @@ -148,7 +150,7 @@ class AtLocation: them with a slash as separator. The returned AtLocation borrows its fd if any. """ - if isinstance(other, int): + if isinstance(other, (int, HasFileno)): # A an fd is considered an absolute AT_EMPTY_PATH path. return AtLocation(other) non_empty_flags = self.flags & ~AtFlags.AT_EMPTY_PATH @@ -218,7 +220,12 @@ class AtLocation: "chdir on AtLocation only supports flag AT_EMPTY_PATH" ) assert self.location - return os.chdir(self.location) + if self.fd == AT_FDCWD: + return os.chdir(self.location) + with FileDescriptor( + self.open(flags=os.O_PATH | os.O_CLOEXEC) + ) as dirfd: + return os.fchdir(dirfd) def chmod(self, mode: int) -> None: """Wrapper for os.chmod or os.fchmod.""" @@ -417,7 +424,7 @@ class AtLocation: assert self.location os.mknod(self.location, mode, device, dir_fd=self.fd_or_none) - def open(self, flags: int, mode: int = 0o777) -> int: + def open(self, flags: int, mode: int = 0o777) -> FileDescriptor: """Wrapper for os.open supplying path and dir_fd.""" if self.flags == AtFlags.AT_SYMLINK_NOFOLLOW: flags |= os.O_NOFOLLOW @@ -426,7 +433,9 @@ class AtLocation: "opening an AtLocation only supports flag AT_SYMLINK_NOFOLLOW" ) assert self.location - return os.open(self.location, flags, mode, dir_fd=self.fd_or_none) + return FileDescriptor( + os.open(self.location, flags, mode, dir_fd=self.fd_or_none) + ) def readlink(self) -> str: """Wrapper for os.readlink supplying path and dir_fd.""" @@ -543,6 +552,26 @@ class AtLocation: AtLocation(dirfd), ) + def write_bytes(self, data: bytes) -> None: + """Overwrite the file with the given data bytes.""" + dataview = memoryview(data) + with self.open(os.O_CREAT | os.O_WRONLY) as fd: + while dataview: + written = os.write(fd, dataview) + dataview = dataview[written:] + + def write_text( + self, data: str, encoding: str | None = None, errors: str | None = None + ) -> None: + """Overwrite the file with the given data string.""" + if encoding is None: + encoding = locale.getencoding() + if errors is None: + databytes = data.encode(encoding=encoding) + else: + databytes = data.encode(encoding=encoding, errors=errors) + self.write_bytes(databytes) + def __enter__(self) -> "AtLocation": """When used as a context manager, the associated fd will be closed on scope exit. @@ -590,4 +619,4 @@ class AtLocation: return f"{cn}({self.fd}, flags={self.flags!r})" -AtLocationLike = typing.Union[AtLocation, int, PathConvertible] +AtLocationLike = typing.Union[AtLocation, FileDescriptorLike, PathConvertible] |