Send patches - preferably formatted by git format-patch - to patches at archlinux32 dot org.
summaryrefslogtreecommitdiff
path: root/archinstall/lib/general.py
diff options
context:
space:
mode:
Diffstat (limited to 'archinstall/lib/general.py')
-rw-r--r--archinstall/lib/general.py184
1 files changed, 82 insertions, 102 deletions
diff --git a/archinstall/lib/general.py b/archinstall/lib/general.py
index f43d4f57..c85208ec 100644
--- a/archinstall/lib/general.py
+++ b/archinstall/lib/general.py
@@ -11,13 +11,14 @@ import sys
import time
import re
import urllib.parse
-import urllib.request
+from urllib.request import Request, urlopen
import urllib.error
import pathlib
from datetime import datetime, date
from enum import Enum
from typing import Callable, Optional, Dict, Any, List, Union, Iterator, TYPE_CHECKING
from select import epoll, EPOLLIN, EPOLLHUP
+from shutil import which
from .exceptions import RequirementError, SysCallError
from .output import debug, error, info
@@ -34,28 +35,17 @@ def generate_password(length :int = 64) -> str:
def locate_binary(name :str) -> str:
- for PATH in os.environ['PATH'].split(':'):
- for root, folders, files in os.walk(PATH):
- for file in files:
- if file == name:
- return os.path.join(root, file)
- break # Don't recurse
-
+ if path := which(name):
+ return path
raise RequirementError(f"Binary {name} does not exist.")
def clear_vt100_escape_codes(data :Union[bytes, str]) -> Union[bytes, str]:
# https://stackoverflow.com/a/43627833/929999
- if type(data) == bytes:
- byte_vt100_escape_regex = bytes(r'\x1B\[[?0-9;]*[a-zA-Z]', 'UTF-8')
- data = re.sub(byte_vt100_escape_regex, b'', data)
- elif type(data) == str:
- vt100_escape_regex = r'\x1B\[[?0-9;]*[a-zA-Z]'
- data = re.sub(vt100_escape_regex, '', data)
- else:
- raise ValueError(f'Unsupported data type: {type(data)}')
-
- return data
+ vt100_escape_regex = r'\x1B\[[?0-9;]*[a-zA-Z]'
+ if isinstance(data, bytes):
+ return re.sub(vt100_escape_regex.encode(), b'', data)
+ return re.sub(vt100_escape_regex, '', data)
def jsonify(obj: Any, safe: bool = True) -> Any:
@@ -120,21 +110,15 @@ class SysCommandWorker:
working_directory :Optional[str] = './',
remove_vt100_escape_codes_from_lines :bool = True
):
- if not callbacks:
- callbacks = {}
+ callbacks = callbacks or {}
+ environment_vars = environment_vars or {}
- if not environment_vars:
- environment_vars = {}
-
- if type(cmd) is str:
+ if isinstance(cmd, str):
cmd = shlex.split(cmd)
- cmd = list(cmd) # This is to please mypy
- if cmd[0][0] != '/' and cmd[0][:2] != './':
- # "which" doesn't work as it's a builtin to bash.
- # It used to work, but for whatever reason it doesn't anymore.
- # We there for fall back on manual lookup in os.PATH
- cmd[0] = locate_binary(cmd[0])
+ if cmd:
+ if cmd[0][0] != '/' and cmd[0][:2] != './': # pathlib.Path does not work well
+ cmd[0] = locate_binary(cmd[0])
self.cmd = cmd
self.callbacks = callbacks
@@ -158,29 +142,36 @@ class SysCommandWorker:
Contains will also move the current buffert position forward.
This is to avoid re-checking the same data when looking for output.
"""
- assert type(key) == bytes
+ assert isinstance(key, bytes)
- if (contains := key in self._trace_log[self._trace_log_pos:]):
- self._trace_log_pos += self._trace_log[self._trace_log_pos:].find(key) + len(key)
+ index = self._trace_log.find(key, self._trace_log_pos)
+ if index >= 0:
+ self._trace_log_pos += index + len(key)
+ return True
- return contains
+ return False
def __iter__(self, *args :str, **kwargs :Dict[str, Any]) -> Iterator[bytes]:
- for line in self._trace_log[self._trace_log_pos:self._trace_log.rfind(b'\n')].split(b'\n'):
- if line:
- escaped_line: bytes = line
-
- if self.remove_vt100_escape_codes_from_lines:
- escaped_line = clear_vt100_escape_codes(line) # type: ignore
+ last_line = self._trace_log.rfind(b'\n')
+ lines = filter(None, self._trace_log[self._trace_log_pos:last_line].splitlines())
+ for line in lines:
+ if self.remove_vt100_escape_codes_from_lines:
+ line = clear_vt100_escape_codes(line) # type: ignore
- yield escaped_line + b'\n'
+ yield line + b'\n'
- self._trace_log_pos = self._trace_log.rfind(b'\n')
+ self._trace_log_pos = last_line
def __repr__(self) -> str:
self.make_sure_we_are_executing()
return str(self._trace_log)
+ def __str__(self) -> str:
+ try:
+ return self._trace_log.decode('utf-8')
+ except UnicodeDecodeError:
+ return str(self._trace_log)
+
def __enter__(self) -> 'SysCommandWorker':
return self
@@ -205,7 +196,7 @@ class SysCommandWorker:
if self.exit_code != 0:
raise SysCallError(
- f"{self.cmd} exited with abnormal exit code [{self.exit_code}]: {str(self._trace_log[-500:])}",
+ f"{self.cmd} exited with abnormal exit code [{self.exit_code}]: {str(self)[-500:]}",
self.exit_code,
worker=self
)
@@ -244,7 +235,7 @@ class SysCommandWorker:
def peak(self, output: Union[str, bytes]) -> bool:
if self.peek_output:
- if type(output) == bytes:
+ if isinstance(output, bytes):
try:
output = output.decode('UTF-8')
except UnicodeDecodeError:
@@ -282,7 +273,7 @@ class SysCommandWorker:
self.ended = time.time()
break
- if self.ended or (got_output is False and _pid_exists(self.pid) is False):
+ if self.ended or (not got_output and not _pid_exists(self.pid)):
self.ended = time.time()
try:
wait_status = os.waitpid(self.pid, 0)[1]
@@ -321,10 +312,8 @@ class SysCommandWorker:
if change_perm:
os.chmod(str(history_logfile), stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP)
- except PermissionError:
- pass
+ except (PermissionError, FileNotFoundError):
# If history_logfile does not exist, ignore the error
- except FileNotFoundError:
pass
except Exception as e:
exception_type = type(e).__name__
@@ -355,22 +344,18 @@ class SysCommandWorker:
class SysCommand:
def __init__(self,
cmd :Union[str, List[str]],
- callbacks :Optional[Dict[str, Callable[[Any], Any]]] = None,
+ callbacks :Dict[str, Callable[[Any], Any]] = {},
start_callback :Optional[Callable[[Any], Any]] = None,
peek_output :Optional[bool] = False,
environment_vars :Optional[Dict[str, Any]] = None,
working_directory :Optional[str] = './',
remove_vt100_escape_codes_from_lines :bool = True):
- _callbacks = {}
- if callbacks:
- for hook, func in callbacks.items():
- _callbacks[hook] = func
+ self._callbacks = callbacks.copy()
if start_callback:
- _callbacks['on_start'] = start_callback
+ self._callbacks['on_start'] = start_callback
self.cmd = cmd
- self._callbacks = _callbacks
self.peek_output = peek_output
self.environment_vars = environment_vars
self.working_directory = working_directory
@@ -398,17 +383,15 @@ class SysCommand:
if not self.session:
raise KeyError(f"SysCommand() does not have an active session.")
elif type(key) is slice:
- start = key.start if key.start else 0
- end = key.stop if key.stop else len(self.session._trace_log)
+ start = key.start or 0
+ end = key.stop or len(self.session._trace_log)
return self.session._trace_log[start:end]
else:
raise ValueError("SysCommand() doesn't have key & value pairs, only slices, SysCommand('ls')[:10] as an example.")
def __repr__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> str:
- if self.session:
- return self.session._trace_log.decode('UTF-8', errors='backslashreplace')
- return ''
+ return self.decode('UTF-8', errors='backslashreplace') or ''
def __json__(self) -> Dict[str, Union[str, bool, List[str], Dict[str, Any], Optional[bool], Optional[Dict[str, Any]]]]:
return {
@@ -416,7 +399,7 @@ class SysCommand:
'callbacks': self._callbacks,
'peak': self.peek_output,
'environment_vars': self.environment_vars,
- 'session': True if self.session else False
+ 'session': self.session is not None
}
def create_session(self) -> bool:
@@ -436,10 +419,9 @@ class SysCommand:
remove_vt100_escape_codes_from_lines=self.remove_vt100_escape_codes_from_lines,
working_directory=self.working_directory) as session:
- if not self.session:
- self.session = session
+ self.session = session
- while self.session.ended is None:
+ while not self.session.ended:
self.session.poll()
if self.peek_output:
@@ -448,9 +430,9 @@ class SysCommand:
return True
- def decode(self, fmt :str = 'UTF-8') -> Optional[str]:
+ def decode(self, *args, **kwargs) -> Optional[str]:
if self.session:
- return self.session._trace_log.decode(fmt)
+ return self.session._trace_log.decode(*args, **kwargs)
return None
@property
@@ -476,54 +458,52 @@ def _pid_exists(pid: int) -> bool:
def run_custom_user_commands(commands :List[str], installation :Installer) -> None:
for index, command in enumerate(commands):
+ script_path = f"/var/tmp/user-command.{index}.sh"
+ chroot_path = installation.target / script_path
+
info(f'Executing custom command "{command}" ...')
-
- with open(f"{installation.target}/var/tmp/user-command.{index}.sh", "w") as temp_script:
- temp_script.write(command)
-
- SysCommand(f"arch-chroot {installation.target} bash /var/tmp/user-command.{index}.sh")
-
- os.unlink(f"{installation.target}/var/tmp/user-command.{index}.sh")
+ chroot_path.write_text(command)
+ SysCommand(f"arch-chroot {installation.target} bash {script_path}")
+
+ os.unlink(chroot_path)
def json_stream_to_structure(configuration_identifier : str, stream :str, target :dict) -> bool :
"""
- Function to load a stream (file (as name) or valid JSON string into an existing dictionary
- Returns true if it could be done
- Return false if operation could not be executed
+ Load a JSON encoded dictionary from a stream and merge it into an existing dictionary.
+ A stream can be a filepath, a URL or a raw JSON string.
+ Returns True if the operation succeeded, False otherwise.
+configuration_identifier is just a parameter to get meaningful, but not so long messages
"""
- parsed_url = urllib.parse.urlparse(stream)
-
- if parsed_url.scheme: # The stream is in fact a URL that should be grabbed
+ raw: Optional[str] = None
+ # Try using the stream as a URL that should be grabbed
+ if urllib.parse.urlparse(stream).scheme:
try:
- with urllib.request.urlopen(urllib.request.Request(stream, headers={'User-Agent': 'ArchInstall'})) as response:
- target.update(json.loads(response.read()))
+ with urlopen(Request(stream, headers={'User-Agent': 'ArchInstall'})) as response:
+ raw = response.read()
except urllib.error.HTTPError as err:
- error(f"Could not load {configuration_identifier} via {parsed_url} due to: {err}")
+ error(f"Could not fetch JSON from {stream} as {configuration_identifier}: {err}")
return False
- else:
- if pathlib.Path(stream).exists():
- try:
- with pathlib.Path(stream).open() as fh:
- target.update(json.load(fh))
- except Exception as err:
- error(f"{configuration_identifier} = {stream} does not contain a valid JSON format: {err}")
- return False
- else:
- # NOTE: This is a rudimentary check if what we're trying parse is a dict structure.
- # Which is the only structure we tolerate anyway.
- if stream.strip().startswith('{') and stream.strip().endswith('}'):
- try:
- target.update(json.loads(stream))
- except Exception as e:
- error(f"{configuration_identifier} Contains an invalid JSON format: {e}")
- return False
- else:
- error(f"{configuration_identifier} is neither a file nor is a JSON string")
- return False
+ # Try using the stream as a filepath that should be read
+ if raw is None and (path := pathlib.Path(stream)).exists():
+ try:
+ raw = path.read_text()
+ except Exception as err:
+ error(f"Could not read file {stream} as {configuration_identifier}: {err}")
+ return False
+
+ try:
+ # We use `or` to try the stream as raw JSON to be parsed
+ structure = json.loads(raw or stream)
+ except Exception as err:
+ error(f"{configuration_identifier} contains an invalid JSON format: {err}")
+ return False
+ if not isinstance(structure, dict):
+ error(f"{stream} passed as {configuration_identifier} is not a JSON encoded dictionary")
+ return False
+ target.update(structure)
return True