summaryrefslogtreecommitdiff
path: root/linuxnamespaces/atlocation.py
diff options
context:
space:
mode:
Diffstat (limited to 'linuxnamespaces/atlocation.py')
-rw-r--r--linuxnamespaces/atlocation.py49
1 files changed, 39 insertions, 10 deletions
diff --git a/linuxnamespaces/atlocation.py b/linuxnamespaces/atlocation.py
index 20d402a..46ac541 100644
--- a/linuxnamespaces/atlocation.py
+++ b/linuxnamespaces/atlocation.py
@@ -9,13 +9,14 @@ code for doing so.
import enum
import errno
+import locale
import os
import os.path
import pathlib
import stat
import typing
-from .filedescriptor import FileDescriptor
+from .filedescriptor import FileDescriptor, FileDescriptorLike, HasFileno
AT_FDCWD = FileDescriptor(-100)
@@ -58,7 +59,7 @@ class AtLocation:
def __new__(
cls,
- thing: typing.Union["AtLocation", int, PathConvertible],
+ thing: typing.Union["AtLocation", FileDescriptorLike, PathConvertible],
location: PathConvertible | None = None,
flags: AtFlags = AtFlags.NONE,
) -> "AtLocation":
@@ -76,13 +77,14 @@ class AtLocation:
)
return thing # Don't copy.
obj = super(AtLocation, cls).__new__(cls)
- if isinstance(thing, int):
+ if not isinstance(thing, FileDescriptor):
+ if isinstance(thing, (int, HasFileno)):
+ thing = FileDescriptor(thing)
+ if isinstance(thing, FileDescriptor):
if thing < 0 and thing != AT_FDCWD:
raise ValueError("fd cannot be negative")
if isinstance(thing, FileDescriptor):
obj.fd = thing
- else:
- obj.fd = FileDescriptor(thing)
if location is None:
obj.location = ""
obj.flags = flags | AtFlags.AT_EMPTY_PATH
@@ -148,7 +150,7 @@ class AtLocation:
them with a slash as separator. The returned AtLocation borrows its fd
if any.
"""
- if isinstance(other, int):
+ if isinstance(other, (int, HasFileno)):
# A an fd is considered an absolute AT_EMPTY_PATH path.
return AtLocation(other)
non_empty_flags = self.flags & ~AtFlags.AT_EMPTY_PATH
@@ -218,7 +220,12 @@ class AtLocation:
"chdir on AtLocation only supports flag AT_EMPTY_PATH"
)
assert self.location
- return os.chdir(self.location)
+ if self.fd == AT_FDCWD:
+ return os.chdir(self.location)
+ with FileDescriptor(
+ self.open(flags=os.O_PATH | os.O_CLOEXEC)
+ ) as dirfd:
+ return os.fchdir(dirfd)
def chmod(self, mode: int) -> None:
"""Wrapper for os.chmod or os.fchmod."""
@@ -417,7 +424,7 @@ class AtLocation:
assert self.location
os.mknod(self.location, mode, device, dir_fd=self.fd_or_none)
- def open(self, flags: int, mode: int = 0o777) -> int:
+ def open(self, flags: int, mode: int = 0o777) -> FileDescriptor:
"""Wrapper for os.open supplying path and dir_fd."""
if self.flags == AtFlags.AT_SYMLINK_NOFOLLOW:
flags |= os.O_NOFOLLOW
@@ -426,7 +433,9 @@ class AtLocation:
"opening an AtLocation only supports flag AT_SYMLINK_NOFOLLOW"
)
assert self.location
- return os.open(self.location, flags, mode, dir_fd=self.fd_or_none)
+ return FileDescriptor(
+ os.open(self.location, flags, mode, dir_fd=self.fd_or_none)
+ )
def readlink(self) -> str:
"""Wrapper for os.readlink supplying path and dir_fd."""
@@ -543,6 +552,26 @@ class AtLocation:
AtLocation(dirfd),
)
+ def write_bytes(self, data: bytes) -> None:
+ """Overwrite the file with the given data bytes."""
+ dataview = memoryview(data)
+ with self.open(os.O_CREAT | os.O_WRONLY) as fd:
+ while dataview:
+ written = os.write(fd, dataview)
+ dataview = dataview[written:]
+
+ def write_text(
+ self, data: str, encoding: str | None = None, errors: str | None = None
+ ) -> None:
+ """Overwrite the file with the given data string."""
+ if encoding is None:
+ encoding = locale.getencoding()
+ if errors is None:
+ databytes = data.encode(encoding=encoding)
+ else:
+ databytes = data.encode(encoding=encoding, errors=errors)
+ self.write_bytes(databytes)
+
def __enter__(self) -> "AtLocation":
"""When used as a context manager, the associated fd will be closed on
scope exit.
@@ -590,4 +619,4 @@ class AtLocation:
return f"{cn}({self.fd}, flags={self.flags!r})"
-AtLocationLike = typing.Union[AtLocation, int, PathConvertible]
+AtLocationLike = typing.Union[AtLocation, FileDescriptorLike, PathConvertible]