blob: 9b128670112049b63818741c12ae474d1cd36a79 [file] [log] [blame]
# SPDX-License-Identifier: MIT
"""
This module contains common utility functions and classes for various amd-debug-tools.
"""
import asyncio
import importlib.metadata
import logging
import os
import platform
import time
import struct
import subprocess
import sys
from ast import literal_eval
from datetime import date, timedelta
class Colors:
"""Colors for the terminal"""
DEBUG = "\033[90m"
HEADER = "\033[95m"
OK = "\033[94m"
WARNING = "\033[32m"
FAIL = "\033[91m"
ENDC = "\033[0m"
UNDERLINE = "\033[4m"
def read_file(fn) -> str:
"""Read a file and return the contents"""
with open(fn, "r", encoding="utf-8") as r:
return r.read().strip()
def compare_file(fn, expect) -> bool:
"""Compare a file to an expected string"""
return read_file(fn) == expect
def get_group_color(group) -> str:
"""Get the color for a group"""
if group == "🚦":
color = Colors.WARNING
elif group == "🗣️":
color = Colors.HEADER
elif any(mk in group for mk in ["💯", "🚫"]):
color = Colors.UNDERLINE
elif any(mk in group for mk in ["🦟", "🖴"]):
color = Colors.DEBUG
elif any(mk in group for mk in ["❌", "👀"]):
color = Colors.FAIL
elif any(mk in group for mk in ["✅", "🔋", "🐧", "💻", "○", "💤", "🥱"]):
color = Colors.OK
else:
color = group
return color
def print_color(message, group) -> None:
"""Print a message with a color"""
prefix = f"{group} "
suffix = Colors.ENDC
color = get_group_color(group)
if color == group:
prefix = ""
log_txt = f"{prefix}{message}".strip()
if any(c in color for c in [Colors.OK, Colors.HEADER, Colors.UNDERLINE]):
logging.info(log_txt)
elif color == Colors.WARNING:
logging.warning(log_txt)
elif color == Colors.FAIL:
logging.error(log_txt)
else:
logging.debug(log_txt)
if "TERM" in os.environ and os.environ["TERM"] == "dumb":
suffix = ""
color = ""
print(f"{prefix}{color}{message}{suffix}")
def colorize_choices(choices, default) -> str:
"""Output a list of choices with colors, where the default is highlighted"""
if default not in choices:
raise ValueError(f"Default choice '{default}' not in choices")
choices = [c for c in choices if c != default]
choices = [f"{Colors.OK}{default}{Colors.ENDC}"] + choices
return ", ".join(choices)
def fatal_error(message):
"""Prints a fatal error message and exits"""
_configure_log(None)
print_color(message, "👀")
sys.exit(1)
def apply_prefix_wrapper(header, message):
"""Apply a prefix to wrap a newline delimitted message"""
s = f"{header.strip()}\n"
lines = message.strip().split("\n")
for i, line in enumerate(lines):
line = line.strip()
if not line:
continue
if i == len(lines) - 1:
s += f"└─ {line}\n"
continue
s += f"│ {line}\n"
return s
def show_log_info():
"""Show log information"""
logger = logging.getLogger()
for handler in logger.handlers:
if isinstance(handler, logging.FileHandler):
filename = handler.baseFilename
if filename != "/dev/null":
print(f"Debug logs are saved to: {filename}")
def _configure_log(prefix) -> str:
"""Configure logging for the tool"""
if len(logging.root.handlers) > 0:
return
if prefix:
user = os.environ.get("SUDO_USER")
home = os.path.expanduser(f"~{user if user else ''}")
path = os.environ.get("XDG_DATA_HOME") or os.path.join(
home, ".local", "share", "amd-debug-tools"
)
os.makedirs(path, exist_ok=True)
log = os.path.join(
path,
f"{prefix}-{date.today()}.txt",
)
if not os.path.exists(log):
with open(log, "w", encoding="utf-8") as f:
f.write("")
if "SUDO_UID" in os.environ:
os.chown(path, int(os.environ["SUDO_UID"]), int(os.environ["SUDO_GID"]))
os.chown(log, int(os.environ["SUDO_UID"]), int(os.environ["SUDO_GID"]))
level = logging.DEBUG
else:
log = "/dev/null"
level = logging.WARNING
# for saving a log file for analysis
logging.basicConfig(
format="%(asctime)s %(levelname)s:\t%(message)s",
filename=log,
level=level,
)
return log
def check_lockdown():
"""Check if the system is in lockdown"""
fn = os.path.join("/", "sys", "kernel", "security", "lockdown")
if not os.path.exists(fn):
return False
lockdown = read_file(fn)
if lockdown.split()[0] != "[none]":
return lockdown
return False
def print_temporary_message(msg) -> int:
"""Print a temporary message to the console"""
print(msg, end="\r", flush=True)
return len(msg)
def clear_temporary_message(msg_len) -> None:
"""Clear a temporary message from the console"""
print(" " * msg_len, end="\r")
def run_countdown(action, t) -> bool:
"""Run a countdown timer"""
msg = ""
if t < 0:
return False
if t == 0:
return True
while t > 0:
msg = f"{action} in {timedelta(seconds=t)}"
print_temporary_message(msg)
time.sleep(1)
t -= 1
clear_temporary_message(len(msg))
return True
def get_distro() -> str:
"""Get the distribution name"""
distro = "unknown"
if os.path.exists("/etc/os-release"):
with open("/etc/os-release", "r", encoding="utf-8") as f:
for line in f:
if line.startswith("ID="):
return line.split("=")[1].strip().strip('"')
if os.path.exists("/etc/arch-release"):
return "arch"
elif os.path.exists("/etc/fedora-release"):
return "fedora"
elif os.path.exists("/etc/debian_version"):
return "debian"
return distro
def get_pretty_distro() -> str:
"""Get the pretty distribution name"""
distro = "Unknown"
if os.path.exists("/etc/os-release"):
with open("/etc/os-release", "r", encoding="utf-8") as f:
for line in f:
if line.startswith("PRETTY_NAME="):
distro = line.split("=")[1].strip().strip('"')
break
return distro
def bytes_to_gb(bytes_value):
"""Convert bytes to GB"""
return bytes_value * 4096 / (1024 * 1024 * 1024)
def gb_to_pages(gb_value):
"""Convert GB into bytes"""
return int(gb_value * (1024 * 1024 * 1024) / 4096)
def reboot():
"""Reboot the system"""
async def reboot_dbus_fast():
"""Reboot using dbus-fast"""
try:
from dbus_fast.aio import ( # pylint: disable=import-outside-toplevel
MessageBus,
)
from dbus_fast import BusType # pylint: disable=import-outside-toplevel
bus = await MessageBus(bus_type=BusType.SYSTEM).connect()
introspection = await bus.introspect(
"org.freedesktop.login1", "/org/freedesktop/login1"
)
proxy_obj = bus.get_proxy_object(
"org.freedesktop.login1", "/org/freedesktop/login1", introspection
)
interface = proxy_obj.get_interface("org.freedesktop.login1.Manager")
await interface.call_reboot(True)
except ImportError:
return False
return True
def reboot_dbus():
"""Reboot using python-dbus"""
try:
import dbus # pylint: disable=import-outside-toplevel
bus = dbus.SystemBus()
obj = bus.get_object("org.freedesktop.login1", "/org/freedesktop/login1")
intf = dbus.Interface(obj, "org.freedesktop.login1.Manager")
intf.Reboot(True)
except ImportError:
return False
return True
loop = asyncio.get_event_loop()
result = loop.run_until_complete(reboot_dbus_fast())
if not result:
return reboot_dbus()
return True
def get_system_mem():
"""Get the total system memory in GB using /proc/meminfo"""
with open(os.path.join("/", "proc", "meminfo"), "r", encoding="utf-8") as f:
for line in f:
if line.startswith("MemTotal:"):
# MemTotal line format: "MemTotal: 16384516 kB"
# Extract the number and convert from kB to GB
mem_kb = int(line.split()[1])
return mem_kb / (1024 * 1024)
raise ValueError("Could not find MemTotal in /proc/meminfo")
def is_root() -> bool:
"""Check if the user is root"""
return os.geteuid() == 0
def BIT(num): # pylint: disable=invalid-name
"""Return a bit shifted value"""
return 1 << num
def get_log_priority(num):
"""Maps an integer debug level to a priority type"""
if num:
try:
num = int(num)
except ValueError:
return num
if num == 7:
return "🦟"
elif num == 4:
return "🚦"
elif num <= 3:
return "❌"
return "○"
def minimum_kernel(major, minor) -> bool:
"""Checks if the kernel version is at least major.minor"""
ver = platform.uname().release.split(".")
kmajor = int(ver[0])
kminor = int(ver[1])
if kmajor > int(major):
return True
if kmajor < int(major):
return False
return kminor >= int(minor)
def systemd_in_use() -> bool:
"""Check if systemd is in use"""
# Check if /proc/1/comm exists and read its contents
init_daemon = read_file("/proc/1/comm")
return init_daemon == "systemd"
def get_property_pyudev(properties, key, fallback=""):
"""Get a property from a udev device"""
try:
return properties.get(key, fallback)
except UnicodeDecodeError:
return ""
def find_ip_version(base_path, kind, hw_ver) -> bool:
"""Determine if an IP version is present on the system"""
b = os.path.join(base_path, "ip_discovery", "die", "0", kind, "0")
for key, expected_value in hw_ver.items():
p = os.path.join(b, key)
if not os.path.exists(p):
return False
v = int(read_file(p))
if v != expected_value:
return False
return True
def read_msr(msr, cpu):
"""Read a Model-Specific Register (MSR) value from the CPU."""
p = f"/dev/cpu/{cpu}/msr"
if not os.path.exists(p) and is_root():
os.system("modprobe msr")
try:
f = os.open(p, os.O_RDONLY)
except OSError as exc:
raise PermissionError from exc
try:
os.lseek(f, msr, os.SEEK_SET)
val = struct.unpack("Q", os.read(f, 8))[0]
except OSError as exc:
raise PermissionError from exc
finally:
os.close(f)
return val
def relaunch_sudo() -> None:
"""Relaunch the script with sudo if not already running as root"""
if not is_root():
logging.debug("Relaunching with sudo")
os.execvp("sudo", ["sudo", "-E"] + sys.argv)
def running_ssh():
return "SSH_CLIENT" in os.environ or "SSH_TTY" in os.environ
def convert_string_to_bool(str_value) -> bool:
"""convert a string to a boolean value"""
try:
value = literal_eval(str_value)
except (SyntaxError, ValueError):
value = None
sys.exit(f"Invalid entry: {str_value}")
return bool(value)
def _git_describe() -> str:
"""Get the git description of the current commit"""
try:
result = subprocess.check_output(
["git", "log", "-1", '--format=commit %h ("%s")'],
cwd=os.path.dirname(__file__),
text=True,
stderr=subprocess.DEVNULL,
)
return result.strip()
except subprocess.CalledProcessError:
return None
except FileNotFoundError:
return None
def version() -> str:
"""Get version of the tool"""
ver = "unknown"
try:
ver = importlib.metadata.version("amd-debug-tools")
except importlib.metadata.PackageNotFoundError:
pass
describe = _git_describe()
if describe:
ver = f"{ver} [{describe}]"
return ver
class AmdTool:
"""Base class for AMD tools"""
def __init__(self, prefix):
self.log = _configure_log(prefix)
logging.debug("command: %s (module: %s)", sys.argv, type(self).__name__)
logging.debug("Version: %s", version())
if os.uname().sysname != "Linux":
raise RuntimeError("This tool only runs on Linux")