Send patches - preferably formatted by git format-patch - to patches at archlinux32 dot org.
summaryrefslogtreecommitdiff
path: root/archinstall/lib/general.py
diff options
context:
space:
mode:
authorAnton Hvornum <anton@hvornum.se>2021-12-02 20:20:31 +0000
committerGitHub <noreply@github.com>2021-12-02 21:20:31 +0100
commitb1b820f4cbf4c6a02a9c4638a1f1887d66dbde9c (patch)
tree0b9fb050f16eeacdf7881c0697158e8244f1e43c /archinstall/lib/general.py
parent908c7b8cc0a804e9522d93fcf0dc71034c53ccdb (diff)
Fixing some mypy complaints (#780)
* Fixed some mypy issues regarding SysCommand* and logging * Fixed imports and undefined variable
Diffstat (limited to 'archinstall/lib/general.py')
-rw-r--r--archinstall/lib/general.py179
1 files changed, 107 insertions, 72 deletions
diff --git a/archinstall/lib/general.py b/archinstall/lib/general.py
index 48de4cbe..ea0bafc9 100644
--- a/archinstall/lib/general.py
+++ b/archinstall/lib/general.py
@@ -9,10 +9,11 @@ import string
import sys
import time
from datetime import datetime, date
-from typing import Union
-try:
+from typing import Callable, Optional, Dict, Any, List, Union, Iterator
+
+if sys.platform == 'linux':
from select import epoll, EPOLLIN, EPOLLHUP
-except:
+else:
import select
EPOLLIN = 0
EPOLLHUP = 0
@@ -22,20 +23,20 @@ except:
Create a epoll() implementation that simulates the epoll() behavior.
This so that the rest of the code doesn't need to worry weither we're using select() or epoll().
"""
- def __init__(self):
- self.sockets = {}
- self.monitoring = {}
+ def __init__(self) -> None:
+ self.sockets: Dict[str, Any] = {}
+ self.monitoring: Dict[int, Any] = {}
- def unregister(self, fileno, *args, **kwargs):
+ def unregister(self, fileno :int, *args :List[Any], **kwargs :Dict[str, Any]) -> None:
try:
del(self.monitoring[fileno])
except:
pass
- def register(self, fileno, *args, **kwargs):
+ def register(self, fileno :int, *args :int, **kwargs :Dict[str, Any]) -> None:
self.monitoring[fileno] = True
- def poll(self, timeout=0.05, *args, **kwargs):
+ def poll(self, timeout: float = 0.05, *args :str, **kwargs :Dict[str, Any]) -> List[Any]:
try:
return [[fileno, 1] for fileno in select.select(list(self.monitoring.keys()), [], [], timeout)[0]]
except OSError:
@@ -66,13 +67,13 @@ def multisplit(s, splitters):
s = ns
return s
-def locate_binary(name):
+def locate_binary(name :str) -> str:
for PATH in os.environ['PATH'].split(':'):
for root, folders, files in os.walk(PATH):
for file in files:
if file == name:
return os.path.join(root, file)
- break # Don't recurse
+ break # Don't recurse
raise RequirementError(f"Binary {name} does not exist.")
@@ -157,7 +158,14 @@ class UNSAFE_JSON(json.JSONEncoder, json.JSONDecoder):
return super(UNSAFE_JSON, self).encode(self._encode(obj))
class SysCommandWorker:
- def __init__(self, cmd, callbacks=None, peak_output=False, environment_vars=None, logfile=None, working_directory='./'):
+ def __init__(self,
+ cmd :Union[str, List[str]],
+ callbacks :Optional[Dict[str, Any]] = None,
+ peak_output :Optional[bool] = False,
+ environment_vars :Optional[Dict[str, Any]] = None,
+ logfile :Optional[None] = None,
+ working_directory :Optional[str] = './'):
+
if not callbacks:
callbacks = {}
if not environment_vars:
@@ -166,6 +174,7 @@ class SysCommandWorker:
if type(cmd) is str:
cmd = shlex.split(cmd)
+ cmd = list(cmd) # This is to please mypy
if cmd[0][0] != '/' and cmd[0][:2] != './':
# "which" doesn't work as it's a builtin to bash.
# It used to work, but for whatever reason it doesn't anymore.
@@ -179,15 +188,15 @@ class SysCommandWorker:
self.logfile = logfile
self.working_directory = working_directory
- self.exit_code = None
+ self.exit_code :Optional[int] = None
self._trace_log = b''
self._trace_log_pos = 0
self.poll_object = epoll()
- self.child_fd = None
- self.started = None
- self.ended = None
+ self.child_fd :Optional[int] = None
+ self.started :Optional[float] = None
+ self.ended :Optional[float] = None
- def __contains__(self, key: bytes):
+ def __contains__(self, key: bytes) -> bool:
"""
Contains will also move the current buffert position forward.
This is to avoid re-checking the same data when looking for output.
@@ -199,21 +208,21 @@ class SysCommandWorker:
return contains
- def __iter__(self, *args, **kwargs):
+ def __iter__(self, *args :str, **kwargs :Dict[str, Any]) -> Iterator[bytes]:
for line in self._trace_log[self._trace_log_pos:self._trace_log.rfind(b'\n')].split(b'\n'):
if line:
yield line + b'\n'
self._trace_log_pos = self._trace_log.rfind(b'\n')
- def __repr__(self):
+ def __repr__(self) -> str:
self.make_sure_we_are_executing()
return str(self._trace_log)
- def __enter__(self):
+ def __enter__(self) -> 'SysCommandWorker':
return self
- def __exit__(self, *args):
+ def __exit__(self, *args :str) -> None:
# b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync.
# TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager
@@ -233,9 +242,9 @@ class SysCommandWorker:
log(args[1], level=logging.ERROR, fg='red')
if self.exit_code != 0:
- raise SysCallError(f"{self.cmd} exited with abnormal exit code: {self.exit_code}")
+ raise SysCallError(f"{self.cmd} exited with abnormal exit code: {self.exit_code}", self.exit_code)
- def is_alive(self):
+ def is_alive(self) -> bool:
self.poll()
if self.started and self.ended is None:
@@ -243,22 +252,26 @@ class SysCommandWorker:
return False
- def write(self, data: bytes, line_ending=True):
+ def write(self, data: bytes, line_ending :bool = True) -> int:
assert type(data) == bytes # TODO: Maybe we can support str as well and encode it
self.make_sure_we_are_executing()
- os.write(self.child_fd, data + (b'\n' if line_ending else b''))
+ if self.child_fd:
+ return os.write(self.child_fd, data + (b'\n' if line_ending else b''))
- def make_sure_we_are_executing(self):
+ return 0
+
+ def make_sure_we_are_executing(self) -> bool:
if not self.started:
return self.execute()
+ return True
def tell(self) -> int:
self.make_sure_we_are_executing()
return self._trace_log_pos
- def seek(self, pos):
+ def seek(self, pos :int) -> None:
self.make_sure_we_are_executing()
# Safety check to ensure 0 < pos < len(tracelog)
self._trace_log_pos = min(max(0, pos), len(self._trace_log))
@@ -271,39 +284,41 @@ class SysCommandWorker:
except UnicodeDecodeError:
return False
- sys.stdout.write(output)
+ sys.stdout.write(str(output))
sys.stdout.flush()
+
return True
- def poll(self):
+ def poll(self) -> None:
self.make_sure_we_are_executing()
- got_output = False
- for fileno, event in self.poll_object.poll(0.1):
- try:
- output = os.read(self.child_fd, 8192)
- got_output = True
- self.peak(output)
- self._trace_log += output
- except OSError:
+ if self.child_fd:
+ got_output = False
+ for fileno, event in self.poll_object.poll(0.1):
+ try:
+ output = os.read(self.child_fd, 8192)
+ got_output = True
+ self.peak(output)
+ self._trace_log += output
+ except OSError:
+ self.ended = time.time()
+ break
+
+ if self.ended or (got_output is False and pid_exists(self.pid) is False):
self.ended = time.time()
- break
-
- if self.ended or (got_output is False and pid_exists(self.pid) is False):
- self.ended = time.time()
- try:
- self.exit_code = os.waitpid(self.pid, 0)[1]
- except ChildProcessError:
try:
- self.exit_code = os.waitpid(self.child_fd, 0)[1]
+ self.exit_code = os.waitpid(self.pid, 0)[1]
except ChildProcessError:
- self.exit_code = 1
+ try:
+ self.exit_code = os.waitpid(self.child_fd, 0)[1]
+ except ChildProcessError:
+ self.exit_code = 1
def execute(self) -> bool:
import pty
if (old_dir := os.getcwd()) != self.working_directory:
- os.chdir(self.working_directory)
+ os.chdir(str(self.working_directory))
# Note: If for any reason, we get a Python exception between here
# and until os.close(), the traceback will get locked inside
@@ -320,7 +335,7 @@ class SysCommandWorker:
except PermissionError:
pass
- os.execve(self.cmd[0], self.cmd, {**os.environ, **self.environment_vars})
+ os.execve(self.cmd[0], list(self.cmd), {**os.environ, **self.environment_vars})
if storage['arguments'].get('debug'):
log(f"Executing: {self.cmd}", level=logging.DEBUG)
@@ -334,15 +349,23 @@ class SysCommandWorker:
return True
- def decode(self, encoding='UTF-8'):
+ def decode(self, encoding :str = 'UTF-8') -> str:
return self._trace_log.decode(encoding)
class SysCommand:
- def __init__(self, cmd, callback=None, start_callback=None, peak_output=False, environment_vars=None, working_directory='./'):
+ def __init__(self,
+ cmd :Union[str, List[str]],
+ callbacks :Optional[Dict[str, Callable[[Any], Any]]] = None,
+ start_callback :Optional[Callable[[Any], Any]] = None,
+ peak_output :Optional[bool] = False,
+ environment_vars :Optional[Dict[str, Any]] = None,
+ working_directory :Optional[str] = './'):
+
_callbacks = {}
- if callback:
- _callbacks['on_end'] = callback
+ if callbacks:
+ for hook, func in callbacks.items():
+ _callbacks[hook] = func
if start_callback:
_callbacks['on_start'] = start_callback
@@ -352,26 +375,28 @@ class SysCommand:
self.environment_vars = environment_vars
self.working_directory = working_directory
- self.session = None
+ self.session :Optional[SysCommandWorker] = None
self.create_session()
- def __enter__(self):
+ def __enter__(self) -> Optional[SysCommandWorker]:
return self.session
- def __exit__(self, *args, **kwargs):
+ def __exit__(self, *args :str, **kwargs :Dict[str, Any]) -> None:
# b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync.
# TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager
if len(args) >= 2 and args[1]:
log(args[1], level=logging.ERROR, fg='red')
- def __iter__(self, *args, **kwargs):
-
- for line in self.session:
- yield line
+ def __iter__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> Iterator[bytes]:
+ if self.session:
+ for line in self.session:
+ yield line
- def __getitem__(self, key):
- if type(key) is slice:
+ def __getitem__(self, key :slice) -> Optional[bytes]:
+ if not self.session:
+ raise KeyError(f"SysCommand() does not have an active session.")
+ elif type(key) is slice:
start = key.start if key.start else 0
end = key.stop if key.stop else len(self.session._trace_log)
@@ -379,10 +404,12 @@ class SysCommand:
else:
raise ValueError("SysCommand() doesn't have key & value pairs, only slices, SysCommand('ls')[:10] as an example.")
- def __repr__(self, *args, **kwargs):
- return self.session._trace_log.decode('UTF-8')
+ def __repr__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> str:
+ if self.session:
+ return self.session._trace_log.decode('UTF-8')
+ return ''
- def __json__(self):
+ def __json__(self) -> Dict[str, Union[str, bool, List[str], Dict[str, Any], Optional[bool], Optional[Dict[str, Any]]]]:
return {
'cmd': self.cmd,
'callbacks': self._callbacks,
@@ -391,7 +418,7 @@ class SysCommand:
'session': True if self.session else False
}
- def create_session(self):
+ def create_session(self) -> bool:
if self.session:
return True
@@ -406,16 +433,23 @@ class SysCommand:
return True
- def decode(self, fmt='UTF-8'):
- return self.session._trace_log.decode(fmt)
+ def decode(self, fmt :str = 'UTF-8') -> Optional[str]:
+ if self.session:
+ return self.session._trace_log.decode(fmt)
+ return None
@property
- def exit_code(self):
- return self.session.exit_code
+ def exit_code(self) -> Optional[int]:
+ if self.session:
+ return self.session.exit_code
+ else:
+ return None
@property
- def trace_log(self):
- return self.session._trace_log
+ def trace_log(self) -> Optional[bytes]:
+ if self.session:
+ return self.session._trace_log
+ return None
def prerequisite_check():
@@ -428,7 +462,8 @@ def prerequisite_check():
def reboot():
SysCommand("/usr/bin/reboot")
-def pid_exists(pid: int):
+
+def pid_exists(pid: int) -> bool:
try:
return any(subprocess.check_output(['/usr/bin/ps', '--no-headers', '-o', 'pid', '-p', str(pid)]).strip())
except subprocess.CalledProcessError: