From b1b820f4cbf4c6a02a9c4638a1f1887d66dbde9c Mon Sep 17 00:00:00 2001 From: Anton Hvornum Date: Thu, 2 Dec 2021 20:20:31 +0000 Subject: Fixing some mypy complaints (#780) * Fixed some mypy issues regarding SysCommand* and logging * Fixed imports and undefined variable --- archinstall/lib/exceptions.py | 4 +- archinstall/lib/general.py | 179 +++++++++++++++++++++++++----------------- archinstall/lib/output.py | 86 +++----------------- 3 files changed, 121 insertions(+), 148 deletions(-) diff --git a/archinstall/lib/exceptions.py b/archinstall/lib/exceptions.py index 147b239b..aa86124b 100644 --- a/archinstall/lib/exceptions.py +++ b/archinstall/lib/exceptions.py @@ -1,3 +1,5 @@ +from typing import Optional + class RequirementError(BaseException): pass @@ -15,7 +17,7 @@ class ProfileError(BaseException): class SysCallError(BaseException): - def __init__(self, message, exit_code): + def __init__(self, message :str, exit_code :Optional[int]) -> None: super(SysCallError, self).__init__(message) self.message = message self.exit_code = exit_code diff --git a/archinstall/lib/general.py b/archinstall/lib/general.py index 48de4cbe..ea0bafc9 100644 --- a/archinstall/lib/general.py +++ b/archinstall/lib/general.py @@ -9,10 +9,11 @@ import string import sys import time from datetime import datetime, date -from typing import Union -try: +from typing import Callable, Optional, Dict, Any, List, Union, Iterator + +if sys.platform == 'linux': from select import epoll, EPOLLIN, EPOLLHUP -except: +else: import select EPOLLIN = 0 EPOLLHUP = 0 @@ -22,20 +23,20 @@ except: Create a epoll() implementation that simulates the epoll() behavior. This so that the rest of the code doesn't need to worry weither we're using select() or epoll(). """ - def __init__(self): - self.sockets = {} - self.monitoring = {} + def __init__(self) -> None: + self.sockets: Dict[str, Any] = {} + self.monitoring: Dict[int, Any] = {} - def unregister(self, fileno, *args, **kwargs): + def unregister(self, fileno :int, *args :List[Any], **kwargs :Dict[str, Any]) -> None: try: del(self.monitoring[fileno]) except: pass - def register(self, fileno, *args, **kwargs): + def register(self, fileno :int, *args :int, **kwargs :Dict[str, Any]) -> None: self.monitoring[fileno] = True - def poll(self, timeout=0.05, *args, **kwargs): + def poll(self, timeout: float = 0.05, *args :str, **kwargs :Dict[str, Any]) -> List[Any]: try: return [[fileno, 1] for fileno in select.select(list(self.monitoring.keys()), [], [], timeout)[0]] except OSError: @@ -66,13 +67,13 @@ def multisplit(s, splitters): s = ns return s -def locate_binary(name): +def locate_binary(name :str) -> str: for PATH in os.environ['PATH'].split(':'): for root, folders, files in os.walk(PATH): for file in files: if file == name: return os.path.join(root, file) - break # Don't recurse + break # Don't recurse raise RequirementError(f"Binary {name} does not exist.") @@ -157,7 +158,14 @@ class UNSAFE_JSON(json.JSONEncoder, json.JSONDecoder): return super(UNSAFE_JSON, self).encode(self._encode(obj)) class SysCommandWorker: - def __init__(self, cmd, callbacks=None, peak_output=False, environment_vars=None, logfile=None, working_directory='./'): + def __init__(self, + cmd :Union[str, List[str]], + callbacks :Optional[Dict[str, Any]] = None, + peak_output :Optional[bool] = False, + environment_vars :Optional[Dict[str, Any]] = None, + logfile :Optional[None] = None, + working_directory :Optional[str] = './'): + if not callbacks: callbacks = {} if not environment_vars: @@ -166,6 +174,7 @@ class SysCommandWorker: if type(cmd) is str: cmd = shlex.split(cmd) + cmd = list(cmd) # This is to please mypy if cmd[0][0] != '/' and cmd[0][:2] != './': # "which" doesn't work as it's a builtin to bash. # It used to work, but for whatever reason it doesn't anymore. @@ -179,15 +188,15 @@ class SysCommandWorker: self.logfile = logfile self.working_directory = working_directory - self.exit_code = None + self.exit_code :Optional[int] = None self._trace_log = b'' self._trace_log_pos = 0 self.poll_object = epoll() - self.child_fd = None - self.started = None - self.ended = None + self.child_fd :Optional[int] = None + self.started :Optional[float] = None + self.ended :Optional[float] = None - def __contains__(self, key: bytes): + def __contains__(self, key: bytes) -> bool: """ Contains will also move the current buffert position forward. This is to avoid re-checking the same data when looking for output. @@ -199,21 +208,21 @@ class SysCommandWorker: return contains - def __iter__(self, *args, **kwargs): + def __iter__(self, *args :str, **kwargs :Dict[str, Any]) -> Iterator[bytes]: for line in self._trace_log[self._trace_log_pos:self._trace_log.rfind(b'\n')].split(b'\n'): if line: yield line + b'\n' self._trace_log_pos = self._trace_log.rfind(b'\n') - def __repr__(self): + def __repr__(self) -> str: self.make_sure_we_are_executing() return str(self._trace_log) - def __enter__(self): + def __enter__(self) -> 'SysCommandWorker': return self - def __exit__(self, *args): + def __exit__(self, *args :str) -> None: # b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync. # TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager @@ -233,9 +242,9 @@ class SysCommandWorker: log(args[1], level=logging.ERROR, fg='red') if self.exit_code != 0: - raise SysCallError(f"{self.cmd} exited with abnormal exit code: {self.exit_code}") + raise SysCallError(f"{self.cmd} exited with abnormal exit code: {self.exit_code}", self.exit_code) - def is_alive(self): + def is_alive(self) -> bool: self.poll() if self.started and self.ended is None: @@ -243,22 +252,26 @@ class SysCommandWorker: return False - def write(self, data: bytes, line_ending=True): + def write(self, data: bytes, line_ending :bool = True) -> int: assert type(data) == bytes # TODO: Maybe we can support str as well and encode it self.make_sure_we_are_executing() - os.write(self.child_fd, data + (b'\n' if line_ending else b'')) + if self.child_fd: + return os.write(self.child_fd, data + (b'\n' if line_ending else b'')) - def make_sure_we_are_executing(self): + return 0 + + def make_sure_we_are_executing(self) -> bool: if not self.started: return self.execute() + return True def tell(self) -> int: self.make_sure_we_are_executing() return self._trace_log_pos - def seek(self, pos): + def seek(self, pos :int) -> None: self.make_sure_we_are_executing() # Safety check to ensure 0 < pos < len(tracelog) self._trace_log_pos = min(max(0, pos), len(self._trace_log)) @@ -271,39 +284,41 @@ class SysCommandWorker: except UnicodeDecodeError: return False - sys.stdout.write(output) + sys.stdout.write(str(output)) sys.stdout.flush() + return True - def poll(self): + def poll(self) -> None: self.make_sure_we_are_executing() - got_output = False - for fileno, event in self.poll_object.poll(0.1): - try: - output = os.read(self.child_fd, 8192) - got_output = True - self.peak(output) - self._trace_log += output - except OSError: + if self.child_fd: + got_output = False + for fileno, event in self.poll_object.poll(0.1): + try: + output = os.read(self.child_fd, 8192) + got_output = True + self.peak(output) + self._trace_log += output + except OSError: + self.ended = time.time() + break + + if self.ended or (got_output is False and pid_exists(self.pid) is False): self.ended = time.time() - break - - if self.ended or (got_output is False and pid_exists(self.pid) is False): - self.ended = time.time() - try: - self.exit_code = os.waitpid(self.pid, 0)[1] - except ChildProcessError: try: - self.exit_code = os.waitpid(self.child_fd, 0)[1] + self.exit_code = os.waitpid(self.pid, 0)[1] except ChildProcessError: - self.exit_code = 1 + try: + self.exit_code = os.waitpid(self.child_fd, 0)[1] + except ChildProcessError: + self.exit_code = 1 def execute(self) -> bool: import pty if (old_dir := os.getcwd()) != self.working_directory: - os.chdir(self.working_directory) + os.chdir(str(self.working_directory)) # Note: If for any reason, we get a Python exception between here # and until os.close(), the traceback will get locked inside @@ -320,7 +335,7 @@ class SysCommandWorker: except PermissionError: pass - os.execve(self.cmd[0], self.cmd, {**os.environ, **self.environment_vars}) + os.execve(self.cmd[0], list(self.cmd), {**os.environ, **self.environment_vars}) if storage['arguments'].get('debug'): log(f"Executing: {self.cmd}", level=logging.DEBUG) @@ -334,15 +349,23 @@ class SysCommandWorker: return True - def decode(self, encoding='UTF-8'): + def decode(self, encoding :str = 'UTF-8') -> str: return self._trace_log.decode(encoding) class SysCommand: - def __init__(self, cmd, callback=None, start_callback=None, peak_output=False, environment_vars=None, working_directory='./'): + def __init__(self, + cmd :Union[str, List[str]], + callbacks :Optional[Dict[str, Callable[[Any], Any]]] = None, + start_callback :Optional[Callable[[Any], Any]] = None, + peak_output :Optional[bool] = False, + environment_vars :Optional[Dict[str, Any]] = None, + working_directory :Optional[str] = './'): + _callbacks = {} - if callback: - _callbacks['on_end'] = callback + if callbacks: + for hook, func in callbacks.items(): + _callbacks[hook] = func if start_callback: _callbacks['on_start'] = start_callback @@ -352,26 +375,28 @@ class SysCommand: self.environment_vars = environment_vars self.working_directory = working_directory - self.session = None + self.session :Optional[SysCommandWorker] = None self.create_session() - def __enter__(self): + def __enter__(self) -> Optional[SysCommandWorker]: return self.session - def __exit__(self, *args, **kwargs): + def __exit__(self, *args :str, **kwargs :Dict[str, Any]) -> None: # b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync. # TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager if len(args) >= 2 and args[1]: log(args[1], level=logging.ERROR, fg='red') - def __iter__(self, *args, **kwargs): - - for line in self.session: - yield line + def __iter__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> Iterator[bytes]: + if self.session: + for line in self.session: + yield line - def __getitem__(self, key): - if type(key) is slice: + def __getitem__(self, key :slice) -> Optional[bytes]: + if not self.session: + raise KeyError(f"SysCommand() does not have an active session.") + elif type(key) is slice: start = key.start if key.start else 0 end = key.stop if key.stop else len(self.session._trace_log) @@ -379,10 +404,12 @@ class SysCommand: else: raise ValueError("SysCommand() doesn't have key & value pairs, only slices, SysCommand('ls')[:10] as an example.") - def __repr__(self, *args, **kwargs): - return self.session._trace_log.decode('UTF-8') + def __repr__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> str: + if self.session: + return self.session._trace_log.decode('UTF-8') + return '' - def __json__(self): + def __json__(self) -> Dict[str, Union[str, bool, List[str], Dict[str, Any], Optional[bool], Optional[Dict[str, Any]]]]: return { 'cmd': self.cmd, 'callbacks': self._callbacks, @@ -391,7 +418,7 @@ class SysCommand: 'session': True if self.session else False } - def create_session(self): + def create_session(self) -> bool: if self.session: return True @@ -406,16 +433,23 @@ class SysCommand: return True - def decode(self, fmt='UTF-8'): - return self.session._trace_log.decode(fmt) + def decode(self, fmt :str = 'UTF-8') -> Optional[str]: + if self.session: + return self.session._trace_log.decode(fmt) + return None @property - def exit_code(self): - return self.session.exit_code + def exit_code(self) -> Optional[int]: + if self.session: + return self.session.exit_code + else: + return None @property - def trace_log(self): - return self.session._trace_log + def trace_log(self) -> Optional[bytes]: + if self.session: + return self.session._trace_log + return None def prerequisite_check(): @@ -428,7 +462,8 @@ def prerequisite_check(): def reboot(): SysCommand("/usr/bin/reboot") -def pid_exists(pid: int): + +def pid_exists(pid: int) -> bool: try: return any(subprocess.check_output(['/usr/bin/ps', '--no-headers', '-o', 'pid', '-p', str(pid)]).strip()) except subprocess.CalledProcessError: diff --git a/archinstall/lib/output.py b/archinstall/lib/output.py index b81c91a9..3ef1e234 100644 --- a/archinstall/lib/output.py +++ b/archinstall/lib/output.py @@ -1,51 +1,19 @@ -import abc import logging import os import sys from pathlib import Path +from typing import Dict, Union from .storage import storage -# TODO: use logging's built in levels instead. -# Although logging is threaded and I wish to avoid that. -# It's more Pythonistic or w/e you want to call it. -class LogLevels: - Critical = 0b001 - Error = 0b010 - Warning = 0b011 - Info = 0b101 - Debug = 0b111 - - -class Journald(dict): +class Journald: @staticmethod - @abc.abstractmethod - def log(message, level=logging.DEBUG): + def log(message :str, level :int = logging.DEBUG) -> None: try: import systemd.journal # type: ignore except ModuleNotFoundError: - return False - - # For backwards compatibility, convert old style log-levels - # to logging levels (and warn about deprecated usage) - # There's some code re-usage here but that should be fine. - # TODO: Remove these in a few versions: - if level == LogLevels.Critical: - log("Deprecated level detected in log message, please use new logging. instead for the following log message:", fg="red", level=logging.ERROR, force=True) - level = logging.CRITICAL - elif level == LogLevels.Error: - log("Deprecated level detected in log message, please use new logging. instead for the following log message:", fg="red", level=logging.ERROR, force=True) - level = logging.ERROR - elif level == LogLevels.Warning: - log("Deprecated level detected in log message, please use new logging. instead for the following log message:", fg="red", level=logging.ERROR, force=True) - level = logging.WARNING - elif level == LogLevels.Info: - log("Deprecated level detected in log message, please use new logging. instead for the following log message:", fg="red", level=logging.ERROR, force=True) - level = logging.INFO - elif level == LogLevels.Debug: - log("Deprecated level detected in log message, please use new logging. instead for the following log message:", fg="red", level=logging.ERROR, force=True) - level = logging.DEBUG + return None log_adapter = logging.getLogger('archinstall') log_fmt = logging.Formatter("[%(levelname)s]: %(message)s") @@ -65,7 +33,7 @@ class SessionLogging: # Found first reference here: https://stackoverflow.com/questions/7445658/how-to-detect-if-the-console-does-support-ansi-escape-codes-in-python # And re-used this: https://github.com/django/django/blob/master/django/core/management/color.py#L12 -def supports_color(): +def supports_color() -> bool: """ Return True if the running system's terminal supports color, and False otherwise. @@ -79,7 +47,7 @@ def supports_color(): # Heavily influenced by: https://github.com/django/django/blob/ae8338daf34fd746771e0678081999b656177bae/django/utils/termcolors.py#L13 # Color options here: https://askubuntu.com/questions/528928/how-to-do-underline-bold-italic-strikethrough-color-background-and-size-i -def stylize_output(text: str, *opts, **kwargs): +def stylize_output(text: str, *opts :str, **kwargs :Union[str, int, Dict[str, Union[str, int]]]) -> str: opt_dict = {'bold': '1', 'italic': '3', 'underscore': '4', 'blink': '5', 'reverse': '7', 'conceal': '8'} color_names = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white') foreground = {color_names[x]: '3%s' % x for x in range(8)} @@ -91,9 +59,9 @@ def stylize_output(text: str, *opts, **kwargs): return '\x1b[%sm' % reset for k, v in kwargs.items(): if k == 'fg': - code_list.append(foreground[v]) + code_list.append(foreground[str(v)]) elif k == 'bg': - code_list.append(background[v]) + code_list.append(background[str(v)]) for o in opts: if o in opt_dict: code_list.append(opt_dict[o]) @@ -102,7 +70,7 @@ def stylize_output(text: str, *opts, **kwargs): return '%s%s' % (('\x1b[%sm' % ';'.join(code_list)), text or '') -def log(*args, **kwargs): +def log(*args :str, **kwargs :Union[str, int, Dict[str, Union[str, int]]]) -> None: string = orig_string = ' '.join([str(x) for x in args]) # Attempt to colorize the output if supported @@ -132,42 +100,10 @@ def log(*args, **kwargs): with open(absolute_logfile, 'a') as log_file: log_file.write(f"{orig_string}\n") - # If we assigned a level, try to log it to systemd's journald. - # Unless the level is higher than we've decided to output interactively. - # (Remember, log files still get *ALL* the output despite level restrictions) - if 'level' in kwargs: - # For backwards compatibility, convert old style log-levels - # to logging levels (and warn about deprecated usage) - # There's some code re-usage here but that should be fine. - # TODO: Remove these in a few versions: - if kwargs['level'] == LogLevels.Critical: - log("Deprecated level detected in log message, please use new logging. instead for the following log message:", fg="red", level=logging.ERROR, force=True) - kwargs['level'] = logging.CRITICAL - elif kwargs['level'] == LogLevels.Error: - log("Deprecated level detected in log message, please use new logging. instead for the following log message:", fg="red", level=logging.ERROR, force=True) - kwargs['level'] = logging.ERROR - elif kwargs['level'] == LogLevels.Warning: - log("Deprecated level detected in log message, please use new logging. instead for the following log message:", fg="red", level=logging.ERROR, force=True) - kwargs['level'] = logging.WARNING - elif kwargs['level'] == LogLevels.Info: - log("Deprecated level detected in log message, please use new logging. instead for the following log message:", fg="red", level=logging.ERROR, force=True) - kwargs['level'] = logging.INFO - elif kwargs['level'] == LogLevels.Debug: - log("Deprecated level detected in log message, please use new logging. instead for the following log message:", fg="red", level=logging.ERROR, force=True) - kwargs['level'] = logging.DEBUG - - if kwargs['level'] < storage.get('LOG_LEVEL', logging.INFO) and 'force' not in kwargs: - # Level on log message was Debug, but output level is set to Info. - # In that case, we'll drop it. - return None - - try: - Journald.log(string, level=kwargs.get('level', logging.INFO)) - except ModuleNotFoundError: - pass # Ignore writing to journald + Journald.log(string, level=int(str(kwargs.get('level', logging.INFO)))) # Finally, print the log unless we skipped it based on level. # We use sys.stdout.write()+flush() instead of print() to try and # fix issue #94 sys.stdout.write(f"{string}\n") - sys.stdout.flush() + sys.stdout.flush() \ No newline at end of file -- cgit v1.2.3-54-g00ecf