Send patches - preferably formatted by git format-patch - to patches at archlinux32 dot org.
summaryrefslogtreecommitdiff
path: root/archinstall/lib
diff options
context:
space:
mode:
authorAnton Hvornum <anton@hvornum.se>2021-12-02 20:20:31 +0000
committerGitHub <noreply@github.com>2021-12-02 21:20:31 +0100
commitb1b820f4cbf4c6a02a9c4638a1f1887d66dbde9c (patch)
tree0b9fb050f16eeacdf7881c0697158e8244f1e43c /archinstall/lib
parent908c7b8cc0a804e9522d93fcf0dc71034c53ccdb (diff)
Fixing some mypy complaints (#780)
* Fixed some mypy issues regarding SysCommand* and logging * Fixed imports and undefined variable
Diffstat (limited to 'archinstall/lib')
-rw-r--r--archinstall/lib/exceptions.py4
-rw-r--r--archinstall/lib/general.py179
-rw-r--r--archinstall/lib/output.py86
3 files changed, 121 insertions, 148 deletions
diff --git a/archinstall/lib/exceptions.py b/archinstall/lib/exceptions.py
index 147b239b..aa86124b 100644
--- a/archinstall/lib/exceptions.py
+++ b/archinstall/lib/exceptions.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
class RequirementError(BaseException):
pass
@@ -15,7 +17,7 @@ class ProfileError(BaseException):
class SysCallError(BaseException):
- def __init__(self, message, exit_code):
+ def __init__(self, message :str, exit_code :Optional[int]) -> None:
super(SysCallError, self).__init__(message)
self.message = message
self.exit_code = exit_code
diff --git a/archinstall/lib/general.py b/archinstall/lib/general.py
index 48de4cbe..ea0bafc9 100644
--- a/archinstall/lib/general.py
+++ b/archinstall/lib/general.py
@@ -9,10 +9,11 @@ import string
import sys
import time
from datetime import datetime, date
-from typing import Union
-try:
+from typing import Callable, Optional, Dict, Any, List, Union, Iterator
+
+if sys.platform == 'linux':
from select import epoll, EPOLLIN, EPOLLHUP
-except:
+else:
import select
EPOLLIN = 0
EPOLLHUP = 0
@@ -22,20 +23,20 @@ except:
Create a epoll() implementation that simulates the epoll() behavior.
This so that the rest of the code doesn't need to worry weither we're using select() or epoll().
"""
- def __init__(self):
- self.sockets = {}
- self.monitoring = {}
+ def __init__(self) -> None:
+ self.sockets: Dict[str, Any] = {}
+ self.monitoring: Dict[int, Any] = {}
- def unregister(self, fileno, *args, **kwargs):
+ def unregister(self, fileno :int, *args :List[Any], **kwargs :Dict[str, Any]) -> None:
try:
del(self.monitoring[fileno])
except:
pass
- def register(self, fileno, *args, **kwargs):
+ def register(self, fileno :int, *args :int, **kwargs :Dict[str, Any]) -> None:
self.monitoring[fileno] = True
- def poll(self, timeout=0.05, *args, **kwargs):
+ def poll(self, timeout: float = 0.05, *args :str, **kwargs :Dict[str, Any]) -> List[Any]:
try:
return [[fileno, 1] for fileno in select.select(list(self.monitoring.keys()), [], [], timeout)[0]]
except OSError:
@@ -66,13 +67,13 @@ def multisplit(s, splitters):
s = ns
return s
-def locate_binary(name):
+def locate_binary(name :str) -> str:
for PATH in os.environ['PATH'].split(':'):
for root, folders, files in os.walk(PATH):
for file in files:
if file == name:
return os.path.join(root, file)
- break # Don't recurse
+ break # Don't recurse
raise RequirementError(f"Binary {name} does not exist.")
@@ -157,7 +158,14 @@ class UNSAFE_JSON(json.JSONEncoder, json.JSONDecoder):
return super(UNSAFE_JSON, self).encode(self._encode(obj))
class SysCommandWorker:
- def __init__(self, cmd, callbacks=None, peak_output=False, environment_vars=None, logfile=None, working_directory='./'):
+ def __init__(self,
+ cmd :Union[str, List[str]],
+ callbacks :Optional[Dict[str, Any]] = None,
+ peak_output :Optional[bool] = False,
+ environment_vars :Optional[Dict[str, Any]] = None,
+ logfile :Optional[None] = None,
+ working_directory :Optional[str] = './'):
+
if not callbacks:
callbacks = {}
if not environment_vars:
@@ -166,6 +174,7 @@ class SysCommandWorker:
if type(cmd) is str:
cmd = shlex.split(cmd)
+ cmd = list(cmd) # This is to please mypy
if cmd[0][0] != '/' and cmd[0][:2] != './':
# "which" doesn't work as it's a builtin to bash.
# It used to work, but for whatever reason it doesn't anymore.
@@ -179,15 +188,15 @@ class SysCommandWorker:
self.logfile = logfile
self.working_directory = working_directory
- self.exit_code = None
+ self.exit_code :Optional[int] = None
self._trace_log = b''
self._trace_log_pos = 0
self.poll_object = epoll()
- self.child_fd = None
- self.started = None
- self.ended = None
+ self.child_fd :Optional[int] = None
+ self.started :Optional[float] = None
+ self.ended :Optional[float] = None
- def __contains__(self, key: bytes):
+ def __contains__(self, key: bytes) -> bool:
"""
Contains will also move the current buffert position forward.
This is to avoid re-checking the same data when looking for output.
@@ -199,21 +208,21 @@ class SysCommandWorker:
return contains
- def __iter__(self, *args, **kwargs):
+ def __iter__(self, *args :str, **kwargs :Dict[str, Any]) -> Iterator[bytes]:
for line in self._trace_log[self._trace_log_pos:self._trace_log.rfind(b'\n')].split(b'\n'):
if line:
yield line + b'\n'
self._trace_log_pos = self._trace_log.rfind(b'\n')
- def __repr__(self):
+ def __repr__(self) -> str:
self.make_sure_we_are_executing()
return str(self._trace_log)
- def __enter__(self):
+ def __enter__(self) -> 'SysCommandWorker':
return self
- def __exit__(self, *args):
+ def __exit__(self, *args :str) -> None:
# b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync.
# TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager
@@ -233,9 +242,9 @@ class SysCommandWorker:
log(args[1], level=logging.ERROR, fg='red')
if self.exit_code != 0:
- raise SysCallError(f"{self.cmd} exited with abnormal exit code: {self.exit_code}")
+ raise SysCallError(f"{self.cmd} exited with abnormal exit code: {self.exit_code}", self.exit_code)
- def is_alive(self):
+ def is_alive(self) -> bool:
self.poll()
if self.started and self.ended is None:
@@ -243,22 +252,26 @@ class SysCommandWorker:
return False
- def write(self, data: bytes, line_ending=True):
+ def write(self, data: bytes, line_ending :bool = True) -> int:
assert type(data) == bytes # TODO: Maybe we can support str as well and encode it
self.make_sure_we_are_executing()
- os.write(self.child_fd, data + (b'\n' if line_ending else b''))
+ if self.child_fd:
+ return os.write(self.child_fd, data + (b'\n' if line_ending else b''))
- def make_sure_we_are_executing(self):
+ return 0
+
+ def make_sure_we_are_executing(self) -> bool:
if not self.started:
return self.execute()
+ return True
def tell(self) -> int:
self.make_sure_we_are_executing()
return self._trace_log_pos
- def seek(self, pos):
+ def seek(self, pos :int) -> None:
self.make_sure_we_are_executing()
# Safety check to ensure 0 < pos < len(tracelog)
self._trace_log_pos = min(max(0, pos), len(self._trace_log))
@@ -271,39 +284,41 @@ class SysCommandWorker:
except UnicodeDecodeError:
return False
- sys.stdout.write(output)
+ sys.stdout.write(str(output))
sys.stdout.flush()
+
return True
- def poll(self):
+ def poll(self) -> None:
self.make_sure_we_are_executing()
- got_output = False
- for fileno, event in self.poll_object.poll(0.1):
- try:
- output = os.read(self.child_fd, 8192)
- got_output = True
- self.peak(output)
- self._trace_log += output
- except OSError:
+ if self.child_fd:
+ got_output = False
+ for fileno, event in self.poll_object.poll(0.1):
+ try:
+ output = os.read(self.child_fd, 8192)
+ got_output = True
+ self.peak(output)
+ self._trace_log += output
+ except OSError:
+ self.ended = time.time()
+ break
+
+ if self.ended or (got_output is False and pid_exists(self.pid) is False):
self.ended = time.time()
- break
-
- if self.ended or (got_output is False and pid_exists(self.pid) is False):
- self.ended = time.time()
- try:
- self.exit_code = os.waitpid(self.pid, 0)[1]
- except ChildProcessError:
try:
- self.exit_code = os.waitpid(self.child_fd, 0)[1]
+ self.exit_code = os.waitpid(self.pid, 0)[1]
except ChildProcessError:
- self.exit_code = 1
+ try:
+ self.exit_code = os.waitpid(self.child_fd, 0)[1]
+ except ChildProcessError:
+ self.exit_code = 1
def execute(self) -> bool:
import pty
if (old_dir := os.getcwd()) != self.working_directory:
- os.chdir(self.working_directory)
+ os.chdir(str(self.working_directory))
# Note: If for any reason, we get a Python exception between here
# and until os.close(), the traceback will get locked inside
@@ -320,7 +335,7 @@ class SysCommandWorker:
except PermissionError:
pass
- os.execve(self.cmd[0], self.cmd, {**os.environ, **self.environment_vars})
+ os.execve(self.cmd[0], list(self.cmd), {**os.environ, **self.environment_vars})
if storage['arguments'].get('debug'):
log(f"Executing: {self.cmd}", level=logging.DEBUG)
@@ -334,15 +349,23 @@ class SysCommandWorker:
return True
- def decode(self, encoding='UTF-8'):
+ def decode(self, encoding :str = 'UTF-8') -> str:
return self._trace_log.decode(encoding)
class SysCommand:
- def __init__(self, cmd, callback=None, start_callback=None, peak_output=False, environment_vars=None, working_directory='./'):
+ def __init__(self,
+ cmd :Union[str, List[str]],
+ callbacks :Optional[Dict[str, Callable[[Any], Any]]] = None,
+ start_callback :Optional[Callable[[Any], Any]] = None,
+ peak_output :Optional[bool] = False,
+ environment_vars :Optional[Dict[str, Any]] = None,
+ working_directory :Optional[str] = './'):
+
_callbacks = {}
- if callback:
- _callbacks['on_end'] = callback
+ if callbacks:
+ for hook, func in callbacks.items():
+ _callbacks[hook] = func
if start_callback:
_callbacks['on_start'] = start_callback
@@ -352,26 +375,28 @@ class SysCommand:
self.environment_vars = environment_vars
self.working_directory = working_directory
- self.session = None
+ self.session :Optional[SysCommandWorker] = None
self.create_session()
- def __enter__(self):
+ def __enter__(self) -> Optional[SysCommandWorker]:
return self.session
- def __exit__(self, *args, **kwargs):
+ def __exit__(self, *args :str, **kwargs :Dict[str, Any]) -> None:
# b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync.
# TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager
if len(args) >= 2 and args[1]:
log(args[1], level=logging.ERROR, fg='red')
- def __iter__(self, *args, **kwargs):
-
- for line in self.session:
- yield line
+ def __iter__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> Iterator[bytes]:
+ if self.session:
+ for line in self.session:
+ yield line
- def __getitem__(self, key):
- if type(key) is slice:
+ def __getitem__(self, key :slice) -> Optional[bytes]:
+ if not self.session:
+ raise KeyError(f"SysCommand() does not have an active session.")
+ elif type(key) is slice:
start = key.start if key.start else 0
end = key.stop if key.stop else len(self.session._trace_log)
@@ -379,10 +404,12 @@ class SysCommand:
else:
raise ValueError("SysCommand() doesn't have key & value pairs, only slices, SysCommand('ls')[:10] as an example.")
- def __repr__(self, *args, **kwargs):
- return self.session._trace_log.decode('UTF-8')
+ def __repr__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> str:
+ if self.session:
+ return self.session._trace_log.decode('UTF-8')
+ return ''
- def __json__(self):
+ def __json__(self) -> Dict[str, Union[str, bool, List[str], Dict[str, Any], Optional[bool], Optional[Dict[str, Any]]]]:
return {
'cmd': self.cmd,
'callbacks': self._callbacks,
@@ -391,7 +418,7 @@ class SysCommand:
'session': True if self.session else False
}
- def create_session(self):
+ def create_session(self) -> bool:
if self.session:
return True
@@ -406,16 +433,23 @@ class SysCommand:
return True
- def decode(self, fmt='UTF-8'):
- return self.session._trace_log.decode(fmt)
+ def decode(self, fmt :str = 'UTF-8') -> Optional[str]:
+ if self.session:
+ return self.session._trace_log.decode(fmt)
+ return None
@property
- def exit_code(self):
- return self.session.exit_code
+ def exit_code(self) -> Optional[int]:
+ if self.session:
+ return self.session.exit_code
+ else:
+ return None
@property
- def trace_log(self):
- return self.session._trace_log
+ def trace_log(self) -> Optional[bytes]:
+ if self.session:
+ return self.session._trace_log
+ return None
def prerequisite_check():
@@ -428,7 +462,8 @@ def prerequisite_check():
def reboot():
SysCommand("/usr/bin/reboot")
-def pid_exists(pid: int):
+
+def pid_exists(pid: int) -> bool:
try:
return any(subprocess.check_output(['/usr/bin/ps', '--no-headers', '-o', 'pid', '-p', str(pid)]).strip())
except subprocess.CalledProcessError:
diff --git a/archinstall/lib/output.py b/archinstall/lib/output.py
index b81c91a9..3ef1e234 100644
--- a/archinstall/lib/output.py
+++ b/archinstall/lib/output.py
@@ -1,51 +1,19 @@
-import abc
import logging
import os
import sys
from pathlib import Path
+from typing import Dict, Union
from .storage import storage
-# TODO: use logging's built in levels instead.
-# Although logging is threaded and I wish to avoid that.
-# It's more Pythonistic or w/e you want to call it.
-class LogLevels:
- Critical = 0b001
- Error = 0b010
- Warning = 0b011
- Info = 0b101
- Debug = 0b111
-
-
-class Journald(dict):
+class Journald:
@staticmethod
- @abc.abstractmethod
- def log(message, level=logging.DEBUG):
+ def log(message :str, level :int = logging.DEBUG) -> None:
try:
import systemd.journal # type: ignore
except ModuleNotFoundError:
- return False
-
- # For backwards compatibility, convert old style log-levels
- # to logging levels (and warn about deprecated usage)
- # There's some code re-usage here but that should be fine.
- # TODO: Remove these in a few versions:
- if level == LogLevels.Critical:
- log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
- level = logging.CRITICAL
- elif level == LogLevels.Error:
- log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
- level = logging.ERROR
- elif level == LogLevels.Warning:
- log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
- level = logging.WARNING
- elif level == LogLevels.Info:
- log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
- level = logging.INFO
- elif level == LogLevels.Debug:
- log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
- level = logging.DEBUG
+ return None
log_adapter = logging.getLogger('archinstall')
log_fmt = logging.Formatter("[%(levelname)s]: %(message)s")
@@ -65,7 +33,7 @@ class SessionLogging:
# Found first reference here: https://stackoverflow.com/questions/7445658/how-to-detect-if-the-console-does-support-ansi-escape-codes-in-python
# And re-used this: https://github.com/django/django/blob/master/django/core/management/color.py#L12
-def supports_color():
+def supports_color() -> bool:
"""
Return True if the running system's terminal supports color,
and False otherwise.
@@ -79,7 +47,7 @@ def supports_color():
# Heavily influenced by: https://github.com/django/django/blob/ae8338daf34fd746771e0678081999b656177bae/django/utils/termcolors.py#L13
# Color options here: https://askubuntu.com/questions/528928/how-to-do-underline-bold-italic-strikethrough-color-background-and-size-i
-def stylize_output(text: str, *opts, **kwargs):
+def stylize_output(text: str, *opts :str, **kwargs :Union[str, int, Dict[str, Union[str, int]]]) -> str:
opt_dict = {'bold': '1', 'italic': '3', 'underscore': '4', 'blink': '5', 'reverse': '7', 'conceal': '8'}
color_names = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white')
foreground = {color_names[x]: '3%s' % x for x in range(8)}
@@ -91,9 +59,9 @@ def stylize_output(text: str, *opts, **kwargs):
return '\x1b[%sm' % reset
for k, v in kwargs.items():
if k == 'fg':
- code_list.append(foreground[v])
+ code_list.append(foreground[str(v)])
elif k == 'bg':
- code_list.append(background[v])
+ code_list.append(background[str(v)])
for o in opts:
if o in opt_dict:
code_list.append(opt_dict[o])
@@ -102,7 +70,7 @@ def stylize_output(text: str, *opts, **kwargs):
return '%s%s' % (('\x1b[%sm' % ';'.join(code_list)), text or '')
-def log(*args, **kwargs):
+def log(*args :str, **kwargs :Union[str, int, Dict[str, Union[str, int]]]) -> None:
string = orig_string = ' '.join([str(x) for x in args])
# Attempt to colorize the output if supported
@@ -132,42 +100,10 @@ def log(*args, **kwargs):
with open(absolute_logfile, 'a') as log_file:
log_file.write(f"{orig_string}\n")
- # If we assigned a level, try to log it to systemd's journald.
- # Unless the level is higher than we've decided to output interactively.
- # (Remember, log files still get *ALL* the output despite level restrictions)
- if 'level' in kwargs:
- # For backwards compatibility, convert old style log-levels
- # to logging levels (and warn about deprecated usage)
- # There's some code re-usage here but that should be fine.
- # TODO: Remove these in a few versions:
- if kwargs['level'] == LogLevels.Critical:
- log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
- kwargs['level'] = logging.CRITICAL
- elif kwargs['level'] == LogLevels.Error:
- log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
- kwargs['level'] = logging.ERROR
- elif kwargs['level'] == LogLevels.Warning:
- log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
- kwargs['level'] = logging.WARNING
- elif kwargs['level'] == LogLevels.Info:
- log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
- kwargs['level'] = logging.INFO
- elif kwargs['level'] == LogLevels.Debug:
- log("Deprecated level detected in log message, please use new logging.<level> instead for the following log message:", fg="red", level=logging.ERROR, force=True)
- kwargs['level'] = logging.DEBUG
-
- if kwargs['level'] < storage.get('LOG_LEVEL', logging.INFO) and 'force' not in kwargs:
- # Level on log message was Debug, but output level is set to Info.
- # In that case, we'll drop it.
- return None
-
- try:
- Journald.log(string, level=kwargs.get('level', logging.INFO))
- except ModuleNotFoundError:
- pass # Ignore writing to journald
+ Journald.log(string, level=int(str(kwargs.get('level', logging.INFO))))
# Finally, print the log unless we skipped it based on level.
# We use sys.stdout.write()+flush() instead of print() to try and
# fix issue #94
sys.stdout.write(f"{string}\n")
- sys.stdout.flush()
+ sys.stdout.flush() \ No newline at end of file