index : archinstall32 | |
Archlinux32 installer | gitolite user |
summaryrefslogtreecommitdiff |
-rw-r--r-- | archinstall/lib/general.py | 426 |
diff --git a/archinstall/lib/general.py b/archinstall/lib/general.py index 79ab024b..8dbf23ff 100644 --- a/archinstall/lib/general.py +++ b/archinstall/lib/general.py @@ -1,7 +1,6 @@ from __future__ import annotations -import hashlib + import json -import logging import os import secrets import shlex @@ -12,212 +11,117 @@ import sys import time import re import urllib.parse -import urllib.request +from urllib.request import Request, urlopen import urllib.error import pathlib from datetime import datetime, date +from enum import Enum from typing import Callable, Optional, Dict, Any, List, Union, Iterator, TYPE_CHECKING -# https://stackoverflow.com/a/39757388/929999 -if TYPE_CHECKING: - from .installer import Installer - -if sys.platform == 'linux': - from select import epoll, EPOLLIN, EPOLLHUP -else: - import select - EPOLLIN = 0 - EPOLLHUP = 0 - - class epoll(): - """ #!if windows - Create a epoll() implementation that simulates the epoll() behavior. - This so that the rest of the code doesn't need to worry weither we're using select() or epoll(). - """ - def __init__(self) -> None: - self.sockets: Dict[str, Any] = {} - self.monitoring: Dict[int, Any] = {} - - def unregister(self, fileno :int, *args :List[Any], **kwargs :Dict[str, Any]) -> None: - try: - del(self.monitoring[fileno]) # noqa: E275 - except: - pass - - def register(self, fileno :int, *args :int, **kwargs :Dict[str, Any]) -> None: - self.monitoring[fileno] = True - - def poll(self, timeout: float = 0.05, *args :str, **kwargs :Dict[str, Any]) -> List[Any]: - try: - return [[fileno, 1] for fileno in select.select(list(self.monitoring.keys()), [], [], timeout)[0]] - except OSError: - return [] +from select import epoll, EPOLLIN, EPOLLHUP +from shutil import which from .exceptions import RequirementError, SysCallError -from .output import log +from .output import debug, error, info from .storage import storage -def gen_uid(entropy_length :int = 256) -> str: - return hashlib.sha512(os.urandom(entropy_length)).hexdigest() + +if TYPE_CHECKING: + from .installer import Installer + def generate_password(length :int = 64) -> str: - haystack = string.printable # digits, ascii_letters, punctiation (!"#$[] etc) and whitespace + haystack = string.printable # digits, ascii_letters, punctuation (!"#$[] etc) and whitespace return ''.join(secrets.choice(haystack) for i in range(length)) -def multisplit(s :str, splitters :List[str]) -> str: - s = [s, ] - for key in splitters: - ns = [] - for obj in s: - x = obj.split(key) - for index, part in enumerate(x): - if len(part): - ns.append(part) - if index < len(x) - 1: - ns.append(key) - s = ns - return s def locate_binary(name :str) -> str: - for PATH in os.environ['PATH'].split(':'): - for root, folders, files in os.walk(PATH): - for file in files: - if file == name: - return os.path.join(root, file) - break # Don't recurse - + if path := which(name): + return path raise RequirementError(f"Binary {name} does not exist.") -def clear_vt100_escape_codes(data :Union[bytes, str]): - # https://stackoverflow.com/a/43627833/929999 - if type(data) == bytes: - vt100_escape_regex = bytes(r'\x1B\[[?0-9;]*[a-zA-Z]', 'UTF-8') - else: - vt100_escape_regex = r'\x1B\[[?0-9;]*[a-zA-Z]' - - for match in re.findall(vt100_escape_regex, data, re.IGNORECASE): - data = data.replace(match, '' if type(data) == str else b'') - return data - -def json_dumps(*args :str, **kwargs :str) -> str: - return json.dumps(*args, **{**kwargs, 'cls': JSON}) +def clear_vt100_escape_codes(data :Union[bytes, str]) -> Union[bytes, str]: + # https://stackoverflow.com/a/43627833/929999 + vt100_escape_regex = r'\x1B\[[?0-9;]*[a-zA-Z]' + if isinstance(data, bytes): + return re.sub(vt100_escape_regex.encode(), b'', data) + return re.sub(vt100_escape_regex, '', data) -class JsonEncoder: - @staticmethod - def _encode(obj :Any) -> Any: - """ - This JSON encoder function will try it's best to convert - any archinstall data structures, instances or variables into - something that's understandable by the json.parse()/json.loads() lib. - _encode() will skip any dictionary key starting with an exclamation mark (!) - """ - if isinstance(obj, dict): - # We'll need to iterate not just the value that default() usually gets passed - # But also iterate manually over each key: value pair in order to trap the keys. - - copy = {} - for key, val in list(obj.items()): - if isinstance(val, dict): - # This, is a EXTREMELY ugly hack.. but it's the only quick way I can think of to trigger a encoding of sub-dictionaries. - val = json.loads(json.dumps(val, cls=JSON)) - else: - val = JsonEncoder._encode(val) - - if type(key) == str and key[0] == '!': - pass - else: - copy[JsonEncoder._encode(key)] = val - return copy - elif hasattr(obj, 'json'): - # json() is a friendly name for json-helper, it should return - # a dictionary representation of the object so that it can be - # processed by the json library. - return json.loads(json.dumps(obj.json(), cls=JSON)) - elif hasattr(obj, '__dump__'): - return obj.__dump__() - elif isinstance(obj, (datetime, date)): - return obj.isoformat() - elif isinstance(obj, (list, set, tuple)): - return [json.loads(json.dumps(item, cls=JSON)) for item in obj] - elif isinstance(obj, (pathlib.Path)): - return str(obj) - else: - return obj +def jsonify(obj: Any, safe: bool = True) -> Any: + """ + Converts objects into json.dumps() compatible nested dictionaries. + Setting safe to True skips dictionary keys starting with a bang (!) + """ - @staticmethod - def _unsafe_encode(obj :Any) -> Any: - """ - Same as _encode() but it keeps dictionary keys starting with ! - """ - if isinstance(obj, dict): - copy = {} - for key, val in list(obj.items()): - if isinstance(val, dict): - # This, is a EXTREMELY ugly hack.. but it's the only quick way I can think of to trigger a encoding of sub-dictionaries. - val = json.loads(json.dumps(val, cls=UNSAFE_JSON)) - else: - val = JsonEncoder._unsafe_encode(val) - - copy[JsonEncoder._unsafe_encode(key)] = val - return copy - else: - return JsonEncoder._encode(obj) + compatible_types = str, int, float, bool + if isinstance(obj, dict): + return { + key: jsonify(value, safe) + for key, value in obj.items() + if isinstance(key, compatible_types) + and not (isinstance(key, str) and key.startswith("!") and safe) + } + if isinstance(obj, Enum): + return obj.value + if hasattr(obj, 'json'): + # json() is a friendly name for json-helper, it should return + # a dictionary representation of the object so that it can be + # processed by the json library. + return jsonify(obj.json(), safe) + if isinstance(obj, (datetime, date)): + return obj.isoformat() + if isinstance(obj, (list, set, tuple)): + return [jsonify(item, safe) for item in obj] + if isinstance(obj, pathlib.Path): + return str(obj) + if hasattr(obj, "__dict__"): + return vars(obj) + + return obj class JSON(json.JSONEncoder, json.JSONDecoder): """ A safe JSON encoder that will omit private information in dicts (starting with !) """ - def _encode(self, obj :Any) -> Any: - return JsonEncoder._encode(obj) - def encode(self, obj :Any) -> Any: - return super(JSON, self).encode(self._encode(obj)) + def encode(self, obj: Any) -> str: + return super().encode(jsonify(obj)) + class UNSAFE_JSON(json.JSONEncoder, json.JSONDecoder): """ UNSAFE_JSON will call/encode and keep private information in dicts (starting with !) """ - def _encode(self, obj :Any) -> Any: - return JsonEncoder._unsafe_encode(obj) - def encode(self, obj :Any) -> Any: - return super(UNSAFE_JSON, self).encode(self._encode(obj)) + def encode(self, obj: Any) -> str: + return super().encode(jsonify(obj, safe=False)) + class SysCommandWorker: - def __init__(self, + def __init__( + self, cmd :Union[str, List[str]], callbacks :Optional[Dict[str, Any]] = None, peek_output :Optional[bool] = False, - peak_output :Optional[bool] = False, environment_vars :Optional[Dict[str, Any]] = None, logfile :Optional[None] = None, working_directory :Optional[str] = './', - remove_vt100_escape_codes_from_lines :bool = True): - - if peak_output: - log("SysCommandWorker()'s peak_output is deprecated, use peek_output instead.", level=logging.WARNING, fg='red') - - if not callbacks: - callbacks = {} - if not environment_vars: - environment_vars = {} + remove_vt100_escape_codes_from_lines :bool = True + ): + callbacks = callbacks or {} + environment_vars = environment_vars or {} - if type(cmd) is str: + if isinstance(cmd, str): cmd = shlex.split(cmd) - cmd = list(cmd) # This is to please mypy - if cmd[0][0] != '/' and cmd[0][:2] != './': - # "which" doesn't work as it's a builtin to bash. - # It used to work, but for whatever reason it doesn't anymore. - # We there for fall back on manual lookup in os.PATH - cmd[0] = locate_binary(cmd[0]) + if cmd: + if cmd[0][0] != '/' and cmd[0][:2] != './': # pathlib.Path does not work well + cmd[0] = locate_binary(cmd[0]) self.cmd = cmd self.callbacks = callbacks self.peek_output = peek_output - if not self.peek_output and peak_output: - self.peek_output = peak_output # define the standard locale for command outputs. For now the C ascii one. Can be overridden self.environment_vars = {**storage.get('CMD_LOCALE',{}),**environment_vars} self.logfile = logfile @@ -237,27 +141,36 @@ class SysCommandWorker: Contains will also move the current buffert position forward. This is to avoid re-checking the same data when looking for output. """ - assert type(key) == bytes + assert isinstance(key, bytes) - if (contains := key in self._trace_log[self._trace_log_pos:]): - self._trace_log_pos += self._trace_log[self._trace_log_pos:].find(key) + len(key) + index = self._trace_log.find(key, self._trace_log_pos) + if index >= 0: + self._trace_log_pos += index + len(key) + return True - return contains + return False def __iter__(self, *args :str, **kwargs :Dict[str, Any]) -> Iterator[bytes]: - for line in self._trace_log[self._trace_log_pos:self._trace_log.rfind(b'\n')].split(b'\n'): - if line: - if self.remove_vt100_escape_codes_from_lines: - line = clear_vt100_escape_codes(line) + last_line = self._trace_log.rfind(b'\n') + lines = filter(None, self._trace_log[self._trace_log_pos:last_line].splitlines()) + for line in lines: + if self.remove_vt100_escape_codes_from_lines: + line = clear_vt100_escape_codes(line) # type: ignore - yield line + b'\n' + yield line + b'\n' - self._trace_log_pos = self._trace_log.rfind(b'\n') + self._trace_log_pos = last_line def __repr__(self) -> str: self.make_sure_we_are_executing() return str(self._trace_log) + def __str__(self) -> str: + try: + return self._trace_log.decode('utf-8') + except UnicodeDecodeError: + return str(self._trace_log) + def __enter__(self) -> 'SysCommandWorker': return self @@ -278,10 +191,14 @@ class SysCommandWorker: sys.stdout.flush() if len(args) >= 2 and args[1]: - log(args[1], level=logging.DEBUG, fg='red') + debug(args[1]) if self.exit_code != 0: - raise SysCallError(f"{self.cmd} exited with abnormal exit code [{self.exit_code}]: {self._trace_log[-500:]}", self.exit_code, worker=self) + raise SysCallError( + f"{self.cmd} exited with abnormal exit code [{self.exit_code}]: {str(self)[-500:]}", + self.exit_code, + worker=self + ) def is_alive(self) -> bool: self.poll() @@ -292,12 +209,13 @@ class SysCommandWorker: return False def write(self, data: bytes, line_ending :bool = True) -> int: - assert type(data) == bytes # TODO: Maybe we can support str as well and encode it + assert isinstance(data, bytes) # TODO: Maybe we can support str as well and encode it self.make_sure_we_are_executing() if self.child_fd: return os.write(self.child_fd, data + (b'\n' if line_ending else b'')) + os.fsync(self.child_fd) return 0 @@ -317,7 +235,7 @@ class SysCommandWorker: def peak(self, output: Union[str, bytes]) -> bool: if self.peek_output: - if type(output) == bytes: + if isinstance(output, bytes): try: output = output.decode('UTF-8') except UnicodeDecodeError: @@ -330,7 +248,7 @@ class SysCommandWorker: change_perm = True with peak_logfile.open("a") as peek_output_log: - peek_output_log.write(output) + peek_output_log.write(str(output)) if change_perm: os.chmod(str(peak_logfile), stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP) @@ -355,7 +273,7 @@ class SysCommandWorker: self.ended = time.time() break - if self.ended or (got_output is False and pid_exists(self.pid) is False): + if self.ended or (not got_output and not _pid_exists(self.pid)): self.ended = time.time() try: wait_status = os.waitpid(self.pid, 0)[1] @@ -394,22 +312,20 @@ class SysCommandWorker: if change_perm: os.chmod(str(history_logfile), stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP) - except PermissionError: - pass - # If history_logfile does not exist, ignore the error - except FileNotFoundError: + except (PermissionError, FileNotFoundError): + # If history_logfile does not exist, ignore the error pass except Exception as e: exception_type = type(e).__name__ - log(f"Unexpected {exception_type} occurred in {self.cmd}: {e}", level=logging.ERROR) + error(f"Unexpected {exception_type} occurred in {self.cmd}: {e}") raise e os.execve(self.cmd[0], list(self.cmd), {**os.environ, **self.environment_vars}) if storage['arguments'].get('debug'): - log(f"Executing: {self.cmd}", level=logging.DEBUG) + debug(f"Executing: {self.cmd}") except FileNotFoundError: - log(f"{self.cmd[0]} does not exist.", level=logging.ERROR, fg="red") + error(f"{self.cmd[0]} does not exist.") self.exit_code = 1 return False else: @@ -428,29 +344,19 @@ class SysCommandWorker: class SysCommand: def __init__(self, cmd :Union[str, List[str]], - callbacks :Optional[Dict[str, Callable[[Any], Any]]] = None, + callbacks :Dict[str, Callable[[Any], Any]] = {}, start_callback :Optional[Callable[[Any], Any]] = None, peek_output :Optional[bool] = False, - peak_output :Optional[bool] = False, environment_vars :Optional[Dict[str, Any]] = None, working_directory :Optional[str] = './', remove_vt100_escape_codes_from_lines :bool = True): - if peak_output: - log("SysCommandWorker()'s peak_output is deprecated, use peek_output instead.", level=logging.WARNING, fg='red') - - _callbacks = {} - if callbacks: - for hook, func in callbacks.items(): - _callbacks[hook] = func + self._callbacks = callbacks.copy() if start_callback: - _callbacks['on_start'] = start_callback + self._callbacks['on_start'] = start_callback self.cmd = cmd - self._callbacks = _callbacks self.peek_output = peek_output - if not self.peek_output and peak_output: - self.peek_output = peak_output self.environment_vars = environment_vars self.working_directory = working_directory self.remove_vt100_escape_codes_from_lines = remove_vt100_escape_codes_from_lines @@ -466,7 +372,7 @@ class SysCommand: # TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager if len(args) >= 2 and args[1]: - log(args[1], level=logging.ERROR, fg='red') + error(args[1]) def __iter__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> Iterator[bytes]: if self.session: @@ -477,17 +383,15 @@ class SysCommand: if not self.session: raise KeyError(f"SysCommand() does not have an active session.") elif type(key) is slice: - start = key.start if key.start else 0 - end = key.stop if key.stop else len(self.session._trace_log) + start = key.start or 0 + end = key.stop or len(self.session._trace_log) return self.session._trace_log[start:end] else: raise ValueError("SysCommand() doesn't have key & value pairs, only slices, SysCommand('ls')[:10] as an example.") def __repr__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> str: - if self.session: - return self.session._trace_log.decode('UTF-8', errors='backslashreplace') - return '' + return self.decode('UTF-8', errors='backslashreplace') or '' def __json__(self) -> Dict[str, Union[str, bool, List[str], Dict[str, Any], Optional[bool], Optional[Dict[str, Any]]]]: return { @@ -495,7 +399,7 @@ class SysCommand: 'callbacks': self._callbacks, 'peak': self.peek_output, 'environment_vars': self.environment_vars, - 'session': True if self.session else False + 'session': self.session is not None } def create_session(self) -> bool: @@ -505,7 +409,7 @@ class SysCommand: clears any printed output if ``.peek_output=True``. """ if self.session: - return self.session + return True with SysCommandWorker( self.cmd, @@ -515,10 +419,9 @@ class SysCommand: remove_vt100_escape_codes_from_lines=self.remove_vt100_escape_codes_from_lines, working_directory=self.working_directory) as session: - if not self.session: - self.session = session + self.session = session - while self.session.ended is None: + while not self.session.ended: self.session.poll() if self.peek_output: @@ -527,10 +430,21 @@ class SysCommand: return True - def decode(self, fmt :str = 'UTF-8') -> Optional[str]: - if self.session: - return self.session._trace_log.decode(fmt) - return None + def decode(self, encoding: str = 'utf-8', errors='backslashreplace', strip: bool = True) -> str: + if not self.session: + raise ValueError('No session available to decode') + + val = self.session._trace_log.decode(encoding, errors=errors) + + if strip: + return val.strip() + return val + + def output(self) -> bytes: + if not self.session: + raise ValueError('No session available') + + return self.session._trace_log.replace(b'\r\n', b'\n') @property def exit_code(self) -> Optional[int]: @@ -546,22 +460,7 @@ class SysCommand: return None -def prerequisite_check() -> bool: - """ - This function is used as a safety check before - continuing with an installation. - - Could be anything from checking that /boot is big enough - to check if nvidia hardware exists when nvidia driver was chosen. - """ - - return True - -def reboot(): - SysCommand("/usr/bin/reboot") - - -def pid_exists(pid: int) -> bool: +def _pid_exists(pid: int) -> bool: try: return any(subprocess.check_output(['/usr/bin/ps', '--no-headers', '-o', 'pid', '-p', str(pid)]).strip()) except subprocess.CalledProcessError: @@ -570,56 +469,57 @@ def pid_exists(pid: int) -> bool: def run_custom_user_commands(commands :List[str], installation :Installer) -> None: for index, command in enumerate(commands): - log(f'Executing custom command "{command}" ...', level=logging.INFO) + script_path = f"/var/tmp/user-command.{index}.sh" + chroot_path = f"{installation.target}/{script_path}" - with open(f"{installation.target}/var/tmp/user-command.{index}.sh", "w") as temp_script: - temp_script.write(command) + info(f'Executing custom command "{command}" ...') + with open(chroot_path, "w") as user_script: + user_script.write(command) - execution_output = SysCommand(f"arch-chroot {installation.target} bash /var/tmp/user-command.{index}.sh") + SysCommand(f"arch-chroot {installation.target} bash {script_path}") + + os.unlink(chroot_path) - log(execution_output) - os.unlink(f"{installation.target}/var/tmp/user-command.{index}.sh") def json_stream_to_structure(configuration_identifier : str, stream :str, target :dict) -> bool : """ - Function to load a stream (file (as name) or valid JSON string into an existing dictionary - Returns true if it could be done - Return false if operation could not be executed + Load a JSON encoded dictionary from a stream and merge it into an existing dictionary. + A stream can be a filepath, a URL or a raw JSON string. + Returns True if the operation succeeded, False otherwise. +configuration_identifier is just a parameter to get meaningful, but not so long messages """ - parsed_url = urllib.parse.urlparse(stream) + raw: Optional[str] = None + # Try using the stream as a URL that should be grabbed + if urllib.parse.urlparse(stream).scheme: + try: + with urlopen(Request(stream, headers={'User-Agent': 'ArchInstall'})) as response: + raw = response.read() + except urllib.error.HTTPError as err: + error(f"Could not fetch JSON from {stream} as {configuration_identifier}: {err}") + return False - if parsed_url.scheme: # The stream is in fact a URL that should be grabbed + # Try using the stream as a filepath that should be read + if raw is None and (path := pathlib.Path(stream)).exists(): try: - with urllib.request.urlopen(urllib.request.Request(stream, headers={'User-Agent': 'ArchInstall'})) as response: - target.update(json.loads(response.read())) - except urllib.error.HTTPError as error: - log(f"Could not load {configuration_identifier} via {parsed_url} due to: {error}", level=logging.ERROR, fg="red") + raw = path.read_text() + except Exception as err: + error(f"Could not read file {stream} as {configuration_identifier}: {err}") return False - else: - if pathlib.Path(stream).exists(): - try: - with pathlib.Path(stream).open() as fh: - target.update(json.load(fh)) - except Exception as error: - log(f"{configuration_identifier} = {stream} does not contain a valid JSON format: {error}", level=logging.ERROR, fg="red") - return False - else: - # NOTE: This is a rudimentary check if what we're trying parse is a dict structure. - # Which is the only structure we tolerate anyway. - if stream.strip().startswith('{') and stream.strip().endswith('}'): - try: - target.update(json.loads(stream)) - except Exception as e: - log(f" {configuration_identifier} Contains an invalid JSON format : {e}",level=logging.ERROR, fg="red") - return False - else: - log(f" {configuration_identifier} is neither a file nor is a JSON string:",level=logging.ERROR, fg="red") - return False + try: + # We use `or` to try the stream as raw JSON to be parsed + structure = json.loads(raw or stream) + except Exception as err: + error(f"{configuration_identifier} contains an invalid JSON format: {err}") + return False + if not isinstance(structure, dict): + error(f"{stream} passed as {configuration_identifier} is not a JSON encoded dictionary") + return False + target.update(structure) return True + def secret(x :str): """ return * with len equal to to the input string """ return '*' * len(x) |