From 1ae1f2ff1144d502830834ba5a64262f7d195e91 Mon Sep 17 00:00:00 2001 From: Himadri Bhattacharjee <107522312+lavafroth@users.noreply.github.com> Date: Wed, 28 Jun 2023 11:42:53 +0000 Subject: Refactor installer and general design patterns (#1895) * fix: refactor clear_vt100_escape_codes * fix: check for structure being a dict after handling potential parsing errors * refactor: use short circuit logic than if-elif-else chains * fix: use or for nullish moutpoint attribute * fix: better error handling for JSON from urls and paths * chore: json_stream_to_structure documentation * refactor: dry up relative and chroot path for custom command scripts * refactor: use write_text for pathlib.Path object * refactor: use sets to find intersection instead of filter and list * refactor: replace loop with dictionary comprehension in preparing luks partition * refactor: use walrus operator to check if luks_handler exists * refactor: use read_text and splitlines for potential Path object * fix: use keepends in splitlines for compatibility * fix: use keepends in splitlines for compatibility * feat: set pacman_conf Path as an attribute of installer * fix: empty string is a part of any string, avoid tuples * refactor: use iterator patterns to uncomment multilib and testing blocks * fix: don't json.loads an already loaded structure * fix: use fstab_path uniformly in genfstab * fix: remove unused variable matched * refactor: create separate class to modify pacman.conf in a single pass * fix: remove unused attribute pacman_conf from installer * fix: remove unused attribute pacman_conf from installer * feat: add persist method for pacman.conf, rewrite only when needed * fix: use path.write_text for locale.conf * use `or` operator for nullish new_conf * refactor: Installer.target is always a pathlib.Path object, do not check for string type * fix: use Optional[str] in function type definition instead of sumtype of str and None * fix: mypy type annotation * fix: make flake8 happy * chore: move pacman config and repo into pacman module * refactor: use Pacman object instead of Installer's pacstrap method * fix: break after first sync * fix: keep old build script for now * use nullish operator for base_packages and disk_encryption of Installer * feat: use shutil.which instead of rolling our own implementation * fix: check for binary only if list is not empty * fix: import Enum and fix mypy errors * refactor: use nullish operator for default values * refactor: linear search for key in Installer._trace_log only once * fix: use logs instead of the entirety of self._trace_log when searching for key * refactor: do not copy slice of bytes for search * refactor: use rfind only once to iterate over logs, do not raise ValueError in clear_vt100_escape_codes since TYPE_CHECKING will take care of it. * refactor: try decoding trace log before falling back to strigification * refactor: use an empty dict as default for callbacks in SysCommand.__init__ * refactor: use nullish or operator for slice start and end when not specified * refactor: use nullish or operator for SysCommand session * refactor: use pre-existing decode method in __repr__ for SysCommand * fix: overindentation * fix: use shallow copy of callbacks to prevent mutating the key-value relationships of the argument dict * refactor: use truthy value of self.session is not None for json encoding SysCommand * refactor: directly assign to SysCommand.session in create_session since it short circuits to True if already present * refactor: use dict.items() instead of manually retrieving the value using the key * refactor: user_config_to_json method sounds pretty self explanatory * refactor: store path validity as boolean for return * refactor: use pathlib.Path.write_text to write configs to destinations * fix: cannot use assignment expressions with expression * fix: use config_output.save for saving both config and creds * refactor: switch dictionary keys and values for options to avoid redundancy * refactor: use itertools.takewhile to collect locale.gen entries until the empty line * refactor: use iterative approach for nvidia driver fix * refactor: install packages if not nvidia * refactor: return early if no profile is selected * refactor: use strip to remove commented lines * fix: install additional packages only when we have a driver * fix: path with one command is matched as relative to '.' * fix: remove translation for debug log --------- Co-authored-by: Anton Hvornum --- archinstall/lib/general.py | 184 ++++++++++++++++++++------------------------- 1 file changed, 82 insertions(+), 102 deletions(-) (limited to 'archinstall/lib/general.py') diff --git a/archinstall/lib/general.py b/archinstall/lib/general.py index f43d4f57..c85208ec 100644 --- a/archinstall/lib/general.py +++ b/archinstall/lib/general.py @@ -11,13 +11,14 @@ import sys import time import re import urllib.parse -import urllib.request +from urllib.request import Request, urlopen import urllib.error import pathlib from datetime import datetime, date from enum import Enum from typing import Callable, Optional, Dict, Any, List, Union, Iterator, TYPE_CHECKING from select import epoll, EPOLLIN, EPOLLHUP +from shutil import which from .exceptions import RequirementError, SysCallError from .output import debug, error, info @@ -34,28 +35,17 @@ def generate_password(length :int = 64) -> str: def locate_binary(name :str) -> str: - for PATH in os.environ['PATH'].split(':'): - for root, folders, files in os.walk(PATH): - for file in files: - if file == name: - return os.path.join(root, file) - break # Don't recurse - + if path := which(name): + return path raise RequirementError(f"Binary {name} does not exist.") def clear_vt100_escape_codes(data :Union[bytes, str]) -> Union[bytes, str]: # https://stackoverflow.com/a/43627833/929999 - if type(data) == bytes: - byte_vt100_escape_regex = bytes(r'\x1B\[[?0-9;]*[a-zA-Z]', 'UTF-8') - data = re.sub(byte_vt100_escape_regex, b'', data) - elif type(data) == str: - vt100_escape_regex = r'\x1B\[[?0-9;]*[a-zA-Z]' - data = re.sub(vt100_escape_regex, '', data) - else: - raise ValueError(f'Unsupported data type: {type(data)}') - - return data + vt100_escape_regex = r'\x1B\[[?0-9;]*[a-zA-Z]' + if isinstance(data, bytes): + return re.sub(vt100_escape_regex.encode(), b'', data) + return re.sub(vt100_escape_regex, '', data) def jsonify(obj: Any, safe: bool = True) -> Any: @@ -120,21 +110,15 @@ class SysCommandWorker: working_directory :Optional[str] = './', remove_vt100_escape_codes_from_lines :bool = True ): - if not callbacks: - callbacks = {} + callbacks = callbacks or {} + environment_vars = environment_vars or {} - if not environment_vars: - environment_vars = {} - - if type(cmd) is str: + if isinstance(cmd, str): cmd = shlex.split(cmd) - cmd = list(cmd) # This is to please mypy - if cmd[0][0] != '/' and cmd[0][:2] != './': - # "which" doesn't work as it's a builtin to bash. - # It used to work, but for whatever reason it doesn't anymore. - # We there for fall back on manual lookup in os.PATH - cmd[0] = locate_binary(cmd[0]) + if cmd: + if cmd[0][0] != '/' and cmd[0][:2] != './': # pathlib.Path does not work well + cmd[0] = locate_binary(cmd[0]) self.cmd = cmd self.callbacks = callbacks @@ -158,29 +142,36 @@ class SysCommandWorker: Contains will also move the current buffert position forward. This is to avoid re-checking the same data when looking for output. """ - assert type(key) == bytes + assert isinstance(key, bytes) - if (contains := key in self._trace_log[self._trace_log_pos:]): - self._trace_log_pos += self._trace_log[self._trace_log_pos:].find(key) + len(key) + index = self._trace_log.find(key, self._trace_log_pos) + if index >= 0: + self._trace_log_pos += index + len(key) + return True - return contains + return False def __iter__(self, *args :str, **kwargs :Dict[str, Any]) -> Iterator[bytes]: - for line in self._trace_log[self._trace_log_pos:self._trace_log.rfind(b'\n')].split(b'\n'): - if line: - escaped_line: bytes = line - - if self.remove_vt100_escape_codes_from_lines: - escaped_line = clear_vt100_escape_codes(line) # type: ignore + last_line = self._trace_log.rfind(b'\n') + lines = filter(None, self._trace_log[self._trace_log_pos:last_line].splitlines()) + for line in lines: + if self.remove_vt100_escape_codes_from_lines: + line = clear_vt100_escape_codes(line) # type: ignore - yield escaped_line + b'\n' + yield line + b'\n' - self._trace_log_pos = self._trace_log.rfind(b'\n') + self._trace_log_pos = last_line def __repr__(self) -> str: self.make_sure_we_are_executing() return str(self._trace_log) + def __str__(self) -> str: + try: + return self._trace_log.decode('utf-8') + except UnicodeDecodeError: + return str(self._trace_log) + def __enter__(self) -> 'SysCommandWorker': return self @@ -205,7 +196,7 @@ class SysCommandWorker: if self.exit_code != 0: raise SysCallError( - f"{self.cmd} exited with abnormal exit code [{self.exit_code}]: {str(self._trace_log[-500:])}", + f"{self.cmd} exited with abnormal exit code [{self.exit_code}]: {str(self)[-500:]}", self.exit_code, worker=self ) @@ -244,7 +235,7 @@ class SysCommandWorker: def peak(self, output: Union[str, bytes]) -> bool: if self.peek_output: - if type(output) == bytes: + if isinstance(output, bytes): try: output = output.decode('UTF-8') except UnicodeDecodeError: @@ -282,7 +273,7 @@ class SysCommandWorker: self.ended = time.time() break - if self.ended or (got_output is False and _pid_exists(self.pid) is False): + if self.ended or (not got_output and not _pid_exists(self.pid)): self.ended = time.time() try: wait_status = os.waitpid(self.pid, 0)[1] @@ -321,10 +312,8 @@ class SysCommandWorker: if change_perm: os.chmod(str(history_logfile), stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP) - except PermissionError: - pass + except (PermissionError, FileNotFoundError): # If history_logfile does not exist, ignore the error - except FileNotFoundError: pass except Exception as e: exception_type = type(e).__name__ @@ -355,22 +344,18 @@ class SysCommandWorker: class SysCommand: def __init__(self, cmd :Union[str, List[str]], - callbacks :Optional[Dict[str, Callable[[Any], Any]]] = None, + callbacks :Dict[str, Callable[[Any], Any]] = {}, start_callback :Optional[Callable[[Any], Any]] = None, peek_output :Optional[bool] = False, environment_vars :Optional[Dict[str, Any]] = None, working_directory :Optional[str] = './', remove_vt100_escape_codes_from_lines :bool = True): - _callbacks = {} - if callbacks: - for hook, func in callbacks.items(): - _callbacks[hook] = func + self._callbacks = callbacks.copy() if start_callback: - _callbacks['on_start'] = start_callback + self._callbacks['on_start'] = start_callback self.cmd = cmd - self._callbacks = _callbacks self.peek_output = peek_output self.environment_vars = environment_vars self.working_directory = working_directory @@ -398,17 +383,15 @@ class SysCommand: if not self.session: raise KeyError(f"SysCommand() does not have an active session.") elif type(key) is slice: - start = key.start if key.start else 0 - end = key.stop if key.stop else len(self.session._trace_log) + start = key.start or 0 + end = key.stop or len(self.session._trace_log) return self.session._trace_log[start:end] else: raise ValueError("SysCommand() doesn't have key & value pairs, only slices, SysCommand('ls')[:10] as an example.") def __repr__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> str: - if self.session: - return self.session._trace_log.decode('UTF-8', errors='backslashreplace') - return '' + return self.decode('UTF-8', errors='backslashreplace') or '' def __json__(self) -> Dict[str, Union[str, bool, List[str], Dict[str, Any], Optional[bool], Optional[Dict[str, Any]]]]: return { @@ -416,7 +399,7 @@ class SysCommand: 'callbacks': self._callbacks, 'peak': self.peek_output, 'environment_vars': self.environment_vars, - 'session': True if self.session else False + 'session': self.session is not None } def create_session(self) -> bool: @@ -436,10 +419,9 @@ class SysCommand: remove_vt100_escape_codes_from_lines=self.remove_vt100_escape_codes_from_lines, working_directory=self.working_directory) as session: - if not self.session: - self.session = session + self.session = session - while self.session.ended is None: + while not self.session.ended: self.session.poll() if self.peek_output: @@ -448,9 +430,9 @@ class SysCommand: return True - def decode(self, fmt :str = 'UTF-8') -> Optional[str]: + def decode(self, *args, **kwargs) -> Optional[str]: if self.session: - return self.session._trace_log.decode(fmt) + return self.session._trace_log.decode(*args, **kwargs) return None @property @@ -476,54 +458,52 @@ def _pid_exists(pid: int) -> bool: def run_custom_user_commands(commands :List[str], installation :Installer) -> None: for index, command in enumerate(commands): + script_path = f"/var/tmp/user-command.{index}.sh" + chroot_path = installation.target / script_path + info(f'Executing custom command "{command}" ...') - - with open(f"{installation.target}/var/tmp/user-command.{index}.sh", "w") as temp_script: - temp_script.write(command) - - SysCommand(f"arch-chroot {installation.target} bash /var/tmp/user-command.{index}.sh") - - os.unlink(f"{installation.target}/var/tmp/user-command.{index}.sh") + chroot_path.write_text(command) + SysCommand(f"arch-chroot {installation.target} bash {script_path}") + + os.unlink(chroot_path) def json_stream_to_structure(configuration_identifier : str, stream :str, target :dict) -> bool : """ - Function to load a stream (file (as name) or valid JSON string into an existing dictionary - Returns true if it could be done - Return false if operation could not be executed + Load a JSON encoded dictionary from a stream and merge it into an existing dictionary. + A stream can be a filepath, a URL or a raw JSON string. + Returns True if the operation succeeded, False otherwise. +configuration_identifier is just a parameter to get meaningful, but not so long messages """ - parsed_url = urllib.parse.urlparse(stream) - - if parsed_url.scheme: # The stream is in fact a URL that should be grabbed + raw: Optional[str] = None + # Try using the stream as a URL that should be grabbed + if urllib.parse.urlparse(stream).scheme: try: - with urllib.request.urlopen(urllib.request.Request(stream, headers={'User-Agent': 'ArchInstall'})) as response: - target.update(json.loads(response.read())) + with urlopen(Request(stream, headers={'User-Agent': 'ArchInstall'})) as response: + raw = response.read() except urllib.error.HTTPError as err: - error(f"Could not load {configuration_identifier} via {parsed_url} due to: {err}") + error(f"Could not fetch JSON from {stream} as {configuration_identifier}: {err}") return False - else: - if pathlib.Path(stream).exists(): - try: - with pathlib.Path(stream).open() as fh: - target.update(json.load(fh)) - except Exception as err: - error(f"{configuration_identifier} = {stream} does not contain a valid JSON format: {err}") - return False - else: - # NOTE: This is a rudimentary check if what we're trying parse is a dict structure. - # Which is the only structure we tolerate anyway. - if stream.strip().startswith('{') and stream.strip().endswith('}'): - try: - target.update(json.loads(stream)) - except Exception as e: - error(f"{configuration_identifier} Contains an invalid JSON format: {e}") - return False - else: - error(f"{configuration_identifier} is neither a file nor is a JSON string") - return False + # Try using the stream as a filepath that should be read + if raw is None and (path := pathlib.Path(stream)).exists(): + try: + raw = path.read_text() + except Exception as err: + error(f"Could not read file {stream} as {configuration_identifier}: {err}") + return False + + try: + # We use `or` to try the stream as raw JSON to be parsed + structure = json.loads(raw or stream) + except Exception as err: + error(f"{configuration_identifier} contains an invalid JSON format: {err}") + return False + if not isinstance(structure, dict): + error(f"{stream} passed as {configuration_identifier} is not a JSON encoded dictionary") + return False + target.update(structure) return True -- cgit v1.2.3-54-g00ecf