Send patches - preferably formatted by git format-patch - to patches at archlinux32 dot org.
summaryrefslogtreecommitdiff
path: root/archinstall/lib/general.py
diff options
context:
space:
mode:
Diffstat (limited to 'archinstall/lib/general.py')
-rw-r--r--archinstall/lib/general.py426
1 files changed, 163 insertions, 263 deletions
diff --git a/archinstall/lib/general.py b/archinstall/lib/general.py
index 79ab024b..8dbf23ff 100644
--- a/archinstall/lib/general.py
+++ b/archinstall/lib/general.py
@@ -1,7 +1,6 @@
from __future__ import annotations
-import hashlib
+
import json
-import logging
import os
import secrets
import shlex
@@ -12,212 +11,117 @@ import sys
import time
import re
import urllib.parse
-import urllib.request
+from urllib.request import Request, urlopen
import urllib.error
import pathlib
from datetime import datetime, date
+from enum import Enum
from typing import Callable, Optional, Dict, Any, List, Union, Iterator, TYPE_CHECKING
-# https://stackoverflow.com/a/39757388/929999
-if TYPE_CHECKING:
- from .installer import Installer
-
-if sys.platform == 'linux':
- from select import epoll, EPOLLIN, EPOLLHUP
-else:
- import select
- EPOLLIN = 0
- EPOLLHUP = 0
-
- class epoll():
- """ #!if windows
- Create a epoll() implementation that simulates the epoll() behavior.
- This so that the rest of the code doesn't need to worry weither we're using select() or epoll().
- """
- def __init__(self) -> None:
- self.sockets: Dict[str, Any] = {}
- self.monitoring: Dict[int, Any] = {}
-
- def unregister(self, fileno :int, *args :List[Any], **kwargs :Dict[str, Any]) -> None:
- try:
- del(self.monitoring[fileno]) # noqa: E275
- except:
- pass
-
- def register(self, fileno :int, *args :int, **kwargs :Dict[str, Any]) -> None:
- self.monitoring[fileno] = True
-
- def poll(self, timeout: float = 0.05, *args :str, **kwargs :Dict[str, Any]) -> List[Any]:
- try:
- return [[fileno, 1] for fileno in select.select(list(self.monitoring.keys()), [], [], timeout)[0]]
- except OSError:
- return []
+from select import epoll, EPOLLIN, EPOLLHUP
+from shutil import which
from .exceptions import RequirementError, SysCallError
-from .output import log
+from .output import debug, error, info
from .storage import storage
-def gen_uid(entropy_length :int = 256) -> str:
- return hashlib.sha512(os.urandom(entropy_length)).hexdigest()
+
+if TYPE_CHECKING:
+ from .installer import Installer
+
def generate_password(length :int = 64) -> str:
- haystack = string.printable # digits, ascii_letters, punctiation (!"#$[] etc) and whitespace
+ haystack = string.printable # digits, ascii_letters, punctuation (!"#$[] etc) and whitespace
return ''.join(secrets.choice(haystack) for i in range(length))
-def multisplit(s :str, splitters :List[str]) -> str:
- s = [s, ]
- for key in splitters:
- ns = []
- for obj in s:
- x = obj.split(key)
- for index, part in enumerate(x):
- if len(part):
- ns.append(part)
- if index < len(x) - 1:
- ns.append(key)
- s = ns
- return s
def locate_binary(name :str) -> str:
- for PATH in os.environ['PATH'].split(':'):
- for root, folders, files in os.walk(PATH):
- for file in files:
- if file == name:
- return os.path.join(root, file)
- break # Don't recurse
-
+ if path := which(name):
+ return path
raise RequirementError(f"Binary {name} does not exist.")
-def clear_vt100_escape_codes(data :Union[bytes, str]):
- # https://stackoverflow.com/a/43627833/929999
- if type(data) == bytes:
- vt100_escape_regex = bytes(r'\x1B\[[?0-9;]*[a-zA-Z]', 'UTF-8')
- else:
- vt100_escape_regex = r'\x1B\[[?0-9;]*[a-zA-Z]'
-
- for match in re.findall(vt100_escape_regex, data, re.IGNORECASE):
- data = data.replace(match, '' if type(data) == str else b'')
- return data
-
-def json_dumps(*args :str, **kwargs :str) -> str:
- return json.dumps(*args, **{**kwargs, 'cls': JSON})
+def clear_vt100_escape_codes(data :Union[bytes, str]) -> Union[bytes, str]:
+ # https://stackoverflow.com/a/43627833/929999
+ vt100_escape_regex = r'\x1B\[[?0-9;]*[a-zA-Z]'
+ if isinstance(data, bytes):
+ return re.sub(vt100_escape_regex.encode(), b'', data)
+ return re.sub(vt100_escape_regex, '', data)
-class JsonEncoder:
- @staticmethod
- def _encode(obj :Any) -> Any:
- """
- This JSON encoder function will try it's best to convert
- any archinstall data structures, instances or variables into
- something that's understandable by the json.parse()/json.loads() lib.
- _encode() will skip any dictionary key starting with an exclamation mark (!)
- """
- if isinstance(obj, dict):
- # We'll need to iterate not just the value that default() usually gets passed
- # But also iterate manually over each key: value pair in order to trap the keys.
-
- copy = {}
- for key, val in list(obj.items()):
- if isinstance(val, dict):
- # This, is a EXTREMELY ugly hack.. but it's the only quick way I can think of to trigger a encoding of sub-dictionaries.
- val = json.loads(json.dumps(val, cls=JSON))
- else:
- val = JsonEncoder._encode(val)
-
- if type(key) == str and key[0] == '!':
- pass
- else:
- copy[JsonEncoder._encode(key)] = val
- return copy
- elif hasattr(obj, 'json'):
- # json() is a friendly name for json-helper, it should return
- # a dictionary representation of the object so that it can be
- # processed by the json library.
- return json.loads(json.dumps(obj.json(), cls=JSON))
- elif hasattr(obj, '__dump__'):
- return obj.__dump__()
- elif isinstance(obj, (datetime, date)):
- return obj.isoformat()
- elif isinstance(obj, (list, set, tuple)):
- return [json.loads(json.dumps(item, cls=JSON)) for item in obj]
- elif isinstance(obj, (pathlib.Path)):
- return str(obj)
- else:
- return obj
+def jsonify(obj: Any, safe: bool = True) -> Any:
+ """
+ Converts objects into json.dumps() compatible nested dictionaries.
+ Setting safe to True skips dictionary keys starting with a bang (!)
+ """
- @staticmethod
- def _unsafe_encode(obj :Any) -> Any:
- """
- Same as _encode() but it keeps dictionary keys starting with !
- """
- if isinstance(obj, dict):
- copy = {}
- for key, val in list(obj.items()):
- if isinstance(val, dict):
- # This, is a EXTREMELY ugly hack.. but it's the only quick way I can think of to trigger a encoding of sub-dictionaries.
- val = json.loads(json.dumps(val, cls=UNSAFE_JSON))
- else:
- val = JsonEncoder._unsafe_encode(val)
-
- copy[JsonEncoder._unsafe_encode(key)] = val
- return copy
- else:
- return JsonEncoder._encode(obj)
+ compatible_types = str, int, float, bool
+ if isinstance(obj, dict):
+ return {
+ key: jsonify(value, safe)
+ for key, value in obj.items()
+ if isinstance(key, compatible_types)
+ and not (isinstance(key, str) and key.startswith("!") and safe)
+ }
+ if isinstance(obj, Enum):
+ return obj.value
+ if hasattr(obj, 'json'):
+ # json() is a friendly name for json-helper, it should return
+ # a dictionary representation of the object so that it can be
+ # processed by the json library.
+ return jsonify(obj.json(), safe)
+ if isinstance(obj, (datetime, date)):
+ return obj.isoformat()
+ if isinstance(obj, (list, set, tuple)):
+ return [jsonify(item, safe) for item in obj]
+ if isinstance(obj, pathlib.Path):
+ return str(obj)
+ if hasattr(obj, "__dict__"):
+ return vars(obj)
+
+ return obj
class JSON(json.JSONEncoder, json.JSONDecoder):
"""
A safe JSON encoder that will omit private information in dicts (starting with !)
"""
- def _encode(self, obj :Any) -> Any:
- return JsonEncoder._encode(obj)
- def encode(self, obj :Any) -> Any:
- return super(JSON, self).encode(self._encode(obj))
+ def encode(self, obj: Any) -> str:
+ return super().encode(jsonify(obj))
+
class UNSAFE_JSON(json.JSONEncoder, json.JSONDecoder):
"""
UNSAFE_JSON will call/encode and keep private information in dicts (starting with !)
"""
- def _encode(self, obj :Any) -> Any:
- return JsonEncoder._unsafe_encode(obj)
- def encode(self, obj :Any) -> Any:
- return super(UNSAFE_JSON, self).encode(self._encode(obj))
+ def encode(self, obj: Any) -> str:
+ return super().encode(jsonify(obj, safe=False))
+
class SysCommandWorker:
- def __init__(self,
+ def __init__(
+ self,
cmd :Union[str, List[str]],
callbacks :Optional[Dict[str, Any]] = None,
peek_output :Optional[bool] = False,
- peak_output :Optional[bool] = False,
environment_vars :Optional[Dict[str, Any]] = None,
logfile :Optional[None] = None,
working_directory :Optional[str] = './',
- remove_vt100_escape_codes_from_lines :bool = True):
-
- if peak_output:
- log("SysCommandWorker()'s peak_output is deprecated, use peek_output instead.", level=logging.WARNING, fg='red')
-
- if not callbacks:
- callbacks = {}
- if not environment_vars:
- environment_vars = {}
+ remove_vt100_escape_codes_from_lines :bool = True
+ ):
+ callbacks = callbacks or {}
+ environment_vars = environment_vars or {}
- if type(cmd) is str:
+ if isinstance(cmd, str):
cmd = shlex.split(cmd)
- cmd = list(cmd) # This is to please mypy
- if cmd[0][0] != '/' and cmd[0][:2] != './':
- # "which" doesn't work as it's a builtin to bash.
- # It used to work, but for whatever reason it doesn't anymore.
- # We there for fall back on manual lookup in os.PATH
- cmd[0] = locate_binary(cmd[0])
+ if cmd:
+ if cmd[0][0] != '/' and cmd[0][:2] != './': # pathlib.Path does not work well
+ cmd[0] = locate_binary(cmd[0])
self.cmd = cmd
self.callbacks = callbacks
self.peek_output = peek_output
- if not self.peek_output and peak_output:
- self.peek_output = peak_output
# define the standard locale for command outputs. For now the C ascii one. Can be overridden
self.environment_vars = {**storage.get('CMD_LOCALE',{}),**environment_vars}
self.logfile = logfile
@@ -237,27 +141,36 @@ class SysCommandWorker:
Contains will also move the current buffert position forward.
This is to avoid re-checking the same data when looking for output.
"""
- assert type(key) == bytes
+ assert isinstance(key, bytes)
- if (contains := key in self._trace_log[self._trace_log_pos:]):
- self._trace_log_pos += self._trace_log[self._trace_log_pos:].find(key) + len(key)
+ index = self._trace_log.find(key, self._trace_log_pos)
+ if index >= 0:
+ self._trace_log_pos += index + len(key)
+ return True
- return contains
+ return False
def __iter__(self, *args :str, **kwargs :Dict[str, Any]) -> Iterator[bytes]:
- for line in self._trace_log[self._trace_log_pos:self._trace_log.rfind(b'\n')].split(b'\n'):
- if line:
- if self.remove_vt100_escape_codes_from_lines:
- line = clear_vt100_escape_codes(line)
+ last_line = self._trace_log.rfind(b'\n')
+ lines = filter(None, self._trace_log[self._trace_log_pos:last_line].splitlines())
+ for line in lines:
+ if self.remove_vt100_escape_codes_from_lines:
+ line = clear_vt100_escape_codes(line) # type: ignore
- yield line + b'\n'
+ yield line + b'\n'
- self._trace_log_pos = self._trace_log.rfind(b'\n')
+ self._trace_log_pos = last_line
def __repr__(self) -> str:
self.make_sure_we_are_executing()
return str(self._trace_log)
+ def __str__(self) -> str:
+ try:
+ return self._trace_log.decode('utf-8')
+ except UnicodeDecodeError:
+ return str(self._trace_log)
+
def __enter__(self) -> 'SysCommandWorker':
return self
@@ -278,10 +191,14 @@ class SysCommandWorker:
sys.stdout.flush()
if len(args) >= 2 and args[1]:
- log(args[1], level=logging.DEBUG, fg='red')
+ debug(args[1])
if self.exit_code != 0:
- raise SysCallError(f"{self.cmd} exited with abnormal exit code [{self.exit_code}]: {self._trace_log[-500:]}", self.exit_code, worker=self)
+ raise SysCallError(
+ f"{self.cmd} exited with abnormal exit code [{self.exit_code}]: {str(self)[-500:]}",
+ self.exit_code,
+ worker=self
+ )
def is_alive(self) -> bool:
self.poll()
@@ -292,12 +209,13 @@ class SysCommandWorker:
return False
def write(self, data: bytes, line_ending :bool = True) -> int:
- assert type(data) == bytes # TODO: Maybe we can support str as well and encode it
+ assert isinstance(data, bytes) # TODO: Maybe we can support str as well and encode it
self.make_sure_we_are_executing()
if self.child_fd:
return os.write(self.child_fd, data + (b'\n' if line_ending else b''))
+ os.fsync(self.child_fd)
return 0
@@ -317,7 +235,7 @@ class SysCommandWorker:
def peak(self, output: Union[str, bytes]) -> bool:
if self.peek_output:
- if type(output) == bytes:
+ if isinstance(output, bytes):
try:
output = output.decode('UTF-8')
except UnicodeDecodeError:
@@ -330,7 +248,7 @@ class SysCommandWorker:
change_perm = True
with peak_logfile.open("a") as peek_output_log:
- peek_output_log.write(output)
+ peek_output_log.write(str(output))
if change_perm:
os.chmod(str(peak_logfile), stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP)
@@ -355,7 +273,7 @@ class SysCommandWorker:
self.ended = time.time()
break
- if self.ended or (got_output is False and pid_exists(self.pid) is False):
+ if self.ended or (not got_output and not _pid_exists(self.pid)):
self.ended = time.time()
try:
wait_status = os.waitpid(self.pid, 0)[1]
@@ -394,22 +312,20 @@ class SysCommandWorker:
if change_perm:
os.chmod(str(history_logfile), stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP)
- except PermissionError:
- pass
- # If history_logfile does not exist, ignore the error
- except FileNotFoundError:
+ except (PermissionError, FileNotFoundError):
+ # If history_logfile does not exist, ignore the error
pass
except Exception as e:
exception_type = type(e).__name__
- log(f"Unexpected {exception_type} occurred in {self.cmd}: {e}", level=logging.ERROR)
+ error(f"Unexpected {exception_type} occurred in {self.cmd}: {e}")
raise e
os.execve(self.cmd[0], list(self.cmd), {**os.environ, **self.environment_vars})
if storage['arguments'].get('debug'):
- log(f"Executing: {self.cmd}", level=logging.DEBUG)
+ debug(f"Executing: {self.cmd}")
except FileNotFoundError:
- log(f"{self.cmd[0]} does not exist.", level=logging.ERROR, fg="red")
+ error(f"{self.cmd[0]} does not exist.")
self.exit_code = 1
return False
else:
@@ -428,29 +344,19 @@ class SysCommandWorker:
class SysCommand:
def __init__(self,
cmd :Union[str, List[str]],
- callbacks :Optional[Dict[str, Callable[[Any], Any]]] = None,
+ callbacks :Dict[str, Callable[[Any], Any]] = {},
start_callback :Optional[Callable[[Any], Any]] = None,
peek_output :Optional[bool] = False,
- peak_output :Optional[bool] = False,
environment_vars :Optional[Dict[str, Any]] = None,
working_directory :Optional[str] = './',
remove_vt100_escape_codes_from_lines :bool = True):
- if peak_output:
- log("SysCommandWorker()'s peak_output is deprecated, use peek_output instead.", level=logging.WARNING, fg='red')
-
- _callbacks = {}
- if callbacks:
- for hook, func in callbacks.items():
- _callbacks[hook] = func
+ self._callbacks = callbacks.copy()
if start_callback:
- _callbacks['on_start'] = start_callback
+ self._callbacks['on_start'] = start_callback
self.cmd = cmd
- self._callbacks = _callbacks
self.peek_output = peek_output
- if not self.peek_output and peak_output:
- self.peek_output = peak_output
self.environment_vars = environment_vars
self.working_directory = working_directory
self.remove_vt100_escape_codes_from_lines = remove_vt100_escape_codes_from_lines
@@ -466,7 +372,7 @@ class SysCommand:
# TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager
if len(args) >= 2 and args[1]:
- log(args[1], level=logging.ERROR, fg='red')
+ error(args[1])
def __iter__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> Iterator[bytes]:
if self.session:
@@ -477,17 +383,15 @@ class SysCommand:
if not self.session:
raise KeyError(f"SysCommand() does not have an active session.")
elif type(key) is slice:
- start = key.start if key.start else 0
- end = key.stop if key.stop else len(self.session._trace_log)
+ start = key.start or 0
+ end = key.stop or len(self.session._trace_log)
return self.session._trace_log[start:end]
else:
raise ValueError("SysCommand() doesn't have key & value pairs, only slices, SysCommand('ls')[:10] as an example.")
def __repr__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> str:
- if self.session:
- return self.session._trace_log.decode('UTF-8', errors='backslashreplace')
- return ''
+ return self.decode('UTF-8', errors='backslashreplace') or ''
def __json__(self) -> Dict[str, Union[str, bool, List[str], Dict[str, Any], Optional[bool], Optional[Dict[str, Any]]]]:
return {
@@ -495,7 +399,7 @@ class SysCommand:
'callbacks': self._callbacks,
'peak': self.peek_output,
'environment_vars': self.environment_vars,
- 'session': True if self.session else False
+ 'session': self.session is not None
}
def create_session(self) -> bool:
@@ -505,7 +409,7 @@ class SysCommand:
clears any printed output if ``.peek_output=True``.
"""
if self.session:
- return self.session
+ return True
with SysCommandWorker(
self.cmd,
@@ -515,10 +419,9 @@ class SysCommand:
remove_vt100_escape_codes_from_lines=self.remove_vt100_escape_codes_from_lines,
working_directory=self.working_directory) as session:
- if not self.session:
- self.session = session
+ self.session = session
- while self.session.ended is None:
+ while not self.session.ended:
self.session.poll()
if self.peek_output:
@@ -527,10 +430,21 @@ class SysCommand:
return True
- def decode(self, fmt :str = 'UTF-8') -> Optional[str]:
- if self.session:
- return self.session._trace_log.decode(fmt)
- return None
+ def decode(self, encoding: str = 'utf-8', errors='backslashreplace', strip: bool = True) -> str:
+ if not self.session:
+ raise ValueError('No session available to decode')
+
+ val = self.session._trace_log.decode(encoding, errors=errors)
+
+ if strip:
+ return val.strip()
+ return val
+
+ def output(self) -> bytes:
+ if not self.session:
+ raise ValueError('No session available')
+
+ return self.session._trace_log.replace(b'\r\n', b'\n')
@property
def exit_code(self) -> Optional[int]:
@@ -546,22 +460,7 @@ class SysCommand:
return None
-def prerequisite_check() -> bool:
- """
- This function is used as a safety check before
- continuing with an installation.
-
- Could be anything from checking that /boot is big enough
- to check if nvidia hardware exists when nvidia driver was chosen.
- """
-
- return True
-
-def reboot():
- SysCommand("/usr/bin/reboot")
-
-
-def pid_exists(pid: int) -> bool:
+def _pid_exists(pid: int) -> bool:
try:
return any(subprocess.check_output(['/usr/bin/ps', '--no-headers', '-o', 'pid', '-p', str(pid)]).strip())
except subprocess.CalledProcessError:
@@ -570,56 +469,57 @@ def pid_exists(pid: int) -> bool:
def run_custom_user_commands(commands :List[str], installation :Installer) -> None:
for index, command in enumerate(commands):
- log(f'Executing custom command "{command}" ...', level=logging.INFO)
+ script_path = f"/var/tmp/user-command.{index}.sh"
+ chroot_path = f"{installation.target}/{script_path}"
- with open(f"{installation.target}/var/tmp/user-command.{index}.sh", "w") as temp_script:
- temp_script.write(command)
+ info(f'Executing custom command "{command}" ...')
+ with open(chroot_path, "w") as user_script:
+ user_script.write(command)
- execution_output = SysCommand(f"arch-chroot {installation.target} bash /var/tmp/user-command.{index}.sh")
+ SysCommand(f"arch-chroot {installation.target} bash {script_path}")
+
+ os.unlink(chroot_path)
- log(execution_output)
- os.unlink(f"{installation.target}/var/tmp/user-command.{index}.sh")
def json_stream_to_structure(configuration_identifier : str, stream :str, target :dict) -> bool :
"""
- Function to load a stream (file (as name) or valid JSON string into an existing dictionary
- Returns true if it could be done
- Return false if operation could not be executed
+ Load a JSON encoded dictionary from a stream and merge it into an existing dictionary.
+ A stream can be a filepath, a URL or a raw JSON string.
+ Returns True if the operation succeeded, False otherwise.
+configuration_identifier is just a parameter to get meaningful, but not so long messages
"""
- parsed_url = urllib.parse.urlparse(stream)
+ raw: Optional[str] = None
+ # Try using the stream as a URL that should be grabbed
+ if urllib.parse.urlparse(stream).scheme:
+ try:
+ with urlopen(Request(stream, headers={'User-Agent': 'ArchInstall'})) as response:
+ raw = response.read()
+ except urllib.error.HTTPError as err:
+ error(f"Could not fetch JSON from {stream} as {configuration_identifier}: {err}")
+ return False
- if parsed_url.scheme: # The stream is in fact a URL that should be grabbed
+ # Try using the stream as a filepath that should be read
+ if raw is None and (path := pathlib.Path(stream)).exists():
try:
- with urllib.request.urlopen(urllib.request.Request(stream, headers={'User-Agent': 'ArchInstall'})) as response:
- target.update(json.loads(response.read()))
- except urllib.error.HTTPError as error:
- log(f"Could not load {configuration_identifier} via {parsed_url} due to: {error}", level=logging.ERROR, fg="red")
+ raw = path.read_text()
+ except Exception as err:
+ error(f"Could not read file {stream} as {configuration_identifier}: {err}")
return False
- else:
- if pathlib.Path(stream).exists():
- try:
- with pathlib.Path(stream).open() as fh:
- target.update(json.load(fh))
- except Exception as error:
- log(f"{configuration_identifier} = {stream} does not contain a valid JSON format: {error}", level=logging.ERROR, fg="red")
- return False
- else:
- # NOTE: This is a rudimentary check if what we're trying parse is a dict structure.
- # Which is the only structure we tolerate anyway.
- if stream.strip().startswith('{') and stream.strip().endswith('}'):
- try:
- target.update(json.loads(stream))
- except Exception as e:
- log(f" {configuration_identifier} Contains an invalid JSON format : {e}",level=logging.ERROR, fg="red")
- return False
- else:
- log(f" {configuration_identifier} is neither a file nor is a JSON string:",level=logging.ERROR, fg="red")
- return False
+ try:
+ # We use `or` to try the stream as raw JSON to be parsed
+ structure = json.loads(raw or stream)
+ except Exception as err:
+ error(f"{configuration_identifier} contains an invalid JSON format: {err}")
+ return False
+ if not isinstance(structure, dict):
+ error(f"{stream} passed as {configuration_identifier} is not a JSON encoded dictionary")
+ return False
+ target.update(structure)
return True
+
def secret(x :str):
""" return * with len equal to to the input string """
return '*' * len(x)