| #!/usr/bin/python3 |
| # SPDX-License-Identifier: MIT |
| |
| """ |
| This module contains unit tests for the common functions in the amd-debug-tools package. |
| """ |
| from unittest.mock import patch, mock_open, call |
| |
| import logging |
| import tempfile |
| import unittest |
| import os |
| from platform import uname_result |
| |
| |
| from amd_debug.common import ( |
| apply_prefix_wrapper, |
| Colors, |
| convert_string_to_bool, |
| colorize_choices, |
| check_lockdown, |
| compare_file, |
| find_ip_version, |
| fatal_error, |
| get_distro, |
| get_log_priority, |
| get_pretty_distro, |
| is_root, |
| minimum_kernel, |
| print_color, |
| run_countdown, |
| systemd_in_use, |
| running_ssh, |
| ) |
| |
| color_dict = { |
| "š¦": Colors.WARNING, |
| "š¦": Colors.DEBUG, |
| "ā": Colors.FAIL, |
| "š": Colors.FAIL, |
| "ā
": Colors.OK, |
| "š": Colors.OK, |
| "š§": Colors.OK, |
| "š»": Colors.OK, |
| "ā": Colors.OK, |
| "š¤": Colors.OK, |
| "šÆ": Colors.UNDERLINE, |
| "š£ļø": Colors.HEADER, |
| } |
| |
| |
| class TestCommon(unittest.TestCase): |
| """Test common functions""" |
| |
| @classmethod |
| def setUpClass(cls): |
| logging.basicConfig(filename="/dev/null", level=logging.DEBUG) |
| |
| def test_read_compare_file(self): |
| """Test read_file and compare_file strip files correctly""" |
| |
| f = tempfile.NamedTemporaryFile() |
| f.write("foo bar baz\n ".encode("utf-8")) |
| f.seek(0) |
| self.assertTrue(compare_file(f.name, "foo bar baz")) |
| |
| def test_countdown(self): |
| """Test countdown function""" |
| |
| result = run_countdown("Full foo", 1) |
| self.assertTrue(result) |
| result = run_countdown("Half foo", 0.5) |
| self.assertTrue(result) |
| result = run_countdown("No foo", 0) |
| self.assertTrue(result) |
| result = run_countdown("Negative foo", -1) |
| self.assertFalse(result) |
| |
| @patch("os.path.exists", return_value=True) |
| @patch("builtins.open", new_callable=mock_open, read_data="ID=foo\nVERSION_ID=bar") |
| def test_get_distro_known(self, mock_exists, _mock_open): |
| """Test get_distro function""" |
| distro = get_distro() |
| mock_exists.assert_has_calls( |
| [ |
| call("/etc/os-release", "r", encoding="utf-8"), |
| call().__enter__(), |
| call().__iter__(), |
| call().__exit__(None, None, None), |
| ] |
| ) |
| self.assertEqual(distro, "foo") |
| |
| @patch("os.path.exists", return_value=False) |
| def test_get_distro_unknown(self, mock_exists): |
| """Test get_distro function""" |
| distro = get_distro() |
| mock_exists.assert_has_calls( |
| [ |
| call("/etc/os-release"), |
| call("/etc/arch-release"), |
| call("/etc/fedora-release"), |
| call("/etc/debian_version"), |
| ] |
| ) |
| self.assertEqual(distro, "unknown") |
| |
| @patch("os.path.exists", return_value=True) |
| @patch("builtins.open", new_callable=mock_open, read_data="PRETTY_NAME=Foo") |
| def test_get_pretty_distro_known(self, mock_exists, _mock_open): |
| """Test get_distro function""" |
| distro = get_pretty_distro() |
| self.assertEqual(distro, "Foo") |
| mock_exists.assert_has_calls( |
| [ |
| call("/etc/os-release", "r", encoding="utf-8"), |
| call().__enter__(), |
| call().__iter__(), |
| call().__exit__(None, None, None), |
| ] |
| ) |
| |
| @patch("os.path.exists", return_value=False) |
| def test_get_pretty_distro_unknown(self, mock_exists): |
| """Test get_distro function""" |
| distro = get_pretty_distro() |
| self.assertEqual(distro, "Unknown") |
| mock_exists.assert_has_calls( |
| [ |
| call("/etc/os-release"), |
| ] |
| ) |
| |
| @patch("os.path.exists", return_value=True) |
| @patch( |
| "builtins.open", |
| new_callable=mock_open, |
| read_data="[none] integrity confidentiality", |
| ) |
| def test_lockdown_pass(self, mock_exists, _mock_file): |
| """Test lockdown function""" |
| lockdown = check_lockdown() |
| mock_exists.assert_called_once_with( |
| "/sys/kernel/security/lockdown", "r", encoding="utf-8" |
| ) |
| self.assertFalse(lockdown) |
| |
| @patch("os.path.exists", return_value=True) |
| @patch( |
| "builtins.open", |
| new_callable=mock_open, |
| read_data="none [integrity] confidentiality", |
| ) |
| def test_lockdown_fail_integrity(self, mock_exists, _mock_file): |
| """Test lockdown function""" |
| lockdown = check_lockdown() |
| mock_exists.assert_called_once_with( |
| "/sys/kernel/security/lockdown", "r", encoding="utf-8" |
| ) |
| self.assertTrue(lockdown) |
| |
| @patch("os.path.exists", return_value=True) |
| @patch( |
| "builtins.open", |
| new_callable=mock_open, |
| read_data="none integrity [confidentiality]", |
| ) |
| def test_lockdown_fail_confidentiality(self, mock_exists, _mock_file): |
| """Test lockdown function""" |
| lockdown = check_lockdown() |
| mock_exists.assert_called_once_with( |
| "/sys/kernel/security/lockdown", "r", encoding="utf-8" |
| ) |
| self.assertTrue(lockdown) |
| |
| @patch("os.path.exists", return_value=False) |
| def test_lockdown_missing(self, mock_exists): |
| """Test lockdown function""" |
| lockdown = check_lockdown() |
| mock_exists.assert_called_once_with("/sys/kernel/security/lockdown") |
| self.assertFalse(lockdown) |
| |
| @patch("builtins.print") |
| def test_print_color(self, mocked_print): |
| """Test print_color function for all expected levels""" |
| message = "foo" |
| # test all color groups |
| for group, color in color_dict.items(): |
| prefix = f"{group} " |
| print_color(message, group) |
| mocked_print.assert_called_once_with( |
| f"{prefix}{color}{message}{Colors.ENDC}" |
| ) |
| mocked_print.reset_mock() |
| |
| # call without a group |
| print_color(message, Colors.WARNING) |
| mocked_print.assert_called_once_with(f"{Colors.WARNING}{message}{Colors.ENDC}") |
| mocked_print.reset_mock() |
| |
| # test dumb terminal |
| os.environ["TERM"] = "dumb" |
| print_color(message, Colors.WARNING) |
| mocked_print.assert_called_once_with(f"{message}") |
| |
| @patch("builtins.print") |
| def test_fatal_error(self, mocked_print): |
| """Test fatal_error function""" |
| with patch("sys.exit") as mock_exit: |
| fatal_error("foo") |
| mocked_print.assert_called_once_with(f"š {Colors.FAIL}foo{Colors.ENDC}") |
| mock_exit.assert_called_once_with(1) |
| |
| @patch("os.geteuid", return_value=0) |
| def test_is_root_true(self, mock_geteuid): |
| """Test is_root function when user is root""" |
| self.assertTrue(is_root()) |
| mock_geteuid.assert_called_once() |
| self.assertEqual(mock_geteuid.call_count, 1) |
| |
| @patch("os.geteuid", return_value=1000) |
| def test_is_root_false(self, mock_geteuid): |
| """Test is_root function when user is not root""" |
| self.assertFalse(is_root()) |
| mock_geteuid.assert_called_once() |
| self.assertEqual(mock_geteuid.call_count, 1) |
| |
| def test_get_log_priority(self): |
| """Test get_log_priority works for expected values""" |
| ret = get_log_priority(None) |
| self.assertEqual(ret, "ā") |
| ret = get_log_priority("foo") |
| self.assertEqual(ret, "foo") |
| ret = get_log_priority("3") |
| self.assertEqual(ret, "ā") |
| ret = get_log_priority(4) |
| self.assertEqual(ret, "š¦") |
| ret = get_log_priority(7) |
| self.assertEqual(ret, "š¦") |
| |
| def test_minimum_kernel(self): |
| """Test minimum_kernel function""" |
| with patch("platform.uname") as mock_uname: |
| mock_uname.return_value = uname_result( |
| system="Linux", |
| node="foo", |
| release="6.12.0-rc5", |
| version="baz", |
| machine="x86_64", |
| ) |
| self.assertTrue(minimum_kernel("6", "12")) |
| self.assertFalse(minimum_kernel("6", "13")) |
| self.assertTrue(minimum_kernel(5, 1)) |
| self.assertFalse(minimum_kernel(7, 1)) |
| with self.assertRaises(ValueError): |
| minimum_kernel("foo", "bar") |
| with self.assertRaises(TypeError): |
| minimum_kernel(None, None) |
| |
| def test_systemd_in_use(self): |
| """Test systemd_in_use function""" |
| with patch( |
| "builtins.open", new_callable=mock_open, read_data="systemd" |
| ) as mock_file: |
| self.assertTrue(systemd_in_use()) |
| mock_file.assert_called_once_with("/proc/1/comm", "r", encoding="utf-8") |
| with patch( |
| "builtins.open", new_callable=mock_open, read_data="upstart" |
| ) as mock_file: |
| self.assertFalse(systemd_in_use()) |
| mock_file.assert_called_once_with("/proc/1/comm", "r", encoding="utf-8") |
| |
| def test_running_in_ssh(self): |
| """Test running_in_ssh function""" |
| with patch("os.environ", {"SSH_TTY": "/dev/pts/0"}): |
| self.assertTrue(running_ssh()) |
| with patch("os.environ", {}): |
| self.assertFalse(running_ssh()) |
| |
| def test_apply_prefix_wrapper(self): |
| """Test apply_prefix_wrapper function""" |
| header = "Header:" |
| message = "Line 1\nLine 2\nLine 3" |
| expected_output = "Header:\n" "ā Line 1\n" "ā Line 2\n" "āā Line 3\n" |
| self.assertEqual(apply_prefix_wrapper(header, message), expected_output) |
| |
| # Test with a single line message |
| message = "Single Line" |
| expected_output = "Header:\nāā Single Line\n" |
| self.assertEqual(apply_prefix_wrapper(header, message), expected_output) |
| |
| # Test with an empty message |
| message = "" |
| expected_output = "Header:\n" |
| self.assertEqual(apply_prefix_wrapper(header, message), expected_output) |
| |
| # Test with leading/trailing whitespace in the message |
| message = " Line 1\nLine 2 \n Line 3 " |
| expected_output = "Header:\n" "ā Line 1\n" "ā Line 2\n" "āā Line 3\n" |
| self.assertEqual(apply_prefix_wrapper(header, message), expected_output) |
| |
| def test_colorize_choices_with_default(self): |
| """Test colorize_choices function with a default value""" |
| choices = ["option1", "option2", "option3"] |
| default = "option2" |
| expected_output = f"{Colors.OK}{default}{Colors.ENDC}, option1, option3" |
| self.assertEqual(colorize_choices(choices, default), expected_output) |
| |
| def test_colorize_choices_without_default(self): |
| """Test colorize_choices function when default is not in choices""" |
| choices = ["option1", "option2", "option3"] |
| default = "option4" |
| with self.assertRaises(ValueError) as context: |
| colorize_choices(choices, default) |
| self.assertEqual( |
| str(context.exception), "Default choice 'option4' not in choices" |
| ) |
| |
| def test_colorize_choices_empty_list(self): |
| """Test colorize_choices function with an empty list""" |
| choices = [] |
| default = "option1" |
| with self.assertRaises(ValueError) as context: |
| colorize_choices(choices, default) |
| self.assertEqual( |
| str(context.exception), "Default choice 'option1' not in choices" |
| ) |
| |
| def test_colorize_choices_single_choice(self): |
| """Test colorize_choices function with a single choice""" |
| choices = ["option1"] |
| default = "option1" |
| expected_output = f"{Colors.OK}{default}{Colors.ENDC}" |
| self.assertEqual(colorize_choices(choices, default), expected_output) |
| |
| @patch("amd_debug.common.read_file") |
| @patch("os.path.exists") |
| def test_find_ip_version_found(self, mock_exists, mock_read_file): |
| """Test find_ip_version returns True when expected value is found""" |
| base_path = "/foo" |
| kind = "bar" |
| hw_ver = {"baz": 42} |
| |
| # Simulate file exists and value matches |
| def exists_side_effect(path): |
| return True |
| |
| mock_exists.side_effect = exists_side_effect |
| mock_read_file.return_value = "42" |
| result = find_ip_version(base_path, kind, hw_ver) |
| self.assertTrue(result) |
| b = os.path.join(base_path, "ip_discovery", "die", "0", kind, "0") |
| expected_path = os.path.join(b, "baz") |
| mock_exists.assert_any_call(expected_path) |
| mock_read_file.assert_any_call(expected_path) |
| |
| @patch("amd_debug.common.read_file") |
| @patch("os.path.exists") |
| def test_find_ip_version_not_found_due_to_missing_file( |
| self, mock_exists, mock_read_file |
| ): |
| """Test find_ip_version returns False if file does not exist""" |
| base_path = "/foo" |
| kind = "bar" |
| hw_ver = {"baz": 42} |
| # Simulate file does not exist |
| mock_exists.return_value = False |
| result = find_ip_version(base_path, kind, hw_ver) |
| self.assertFalse(result) |
| b = os.path.join(base_path, "ip_discovery", "die", "0", kind, "0") |
| expected_path = os.path.join(b, "baz") |
| mock_exists.assert_any_call(expected_path) |
| mock_read_file.assert_not_called() |
| |
| @patch("amd_debug.common.read_file") |
| @patch("os.path.exists") |
| def test_find_ip_version_not_found_due_to_value_mismatch( |
| self, mock_exists, mock_read_file |
| ): |
| """Test find_ip_version returns False if value does not match""" |
| base_path = "/foo" |
| kind = "bar" |
| hw_ver = {"baz": 42} |
| # Simulate file exists but value does not match |
| mock_exists.return_value = True |
| mock_read_file.return_value = "99" |
| result = find_ip_version(base_path, kind, hw_ver) |
| self.assertFalse(result) |
| b = os.path.join(base_path, "ip_discovery", "die", "0", kind, "0") |
| expected_path = os.path.join(b, "baz") |
| mock_exists.assert_any_call(expected_path) |
| mock_read_file.assert_any_call(expected_path) |
| |
| @patch("amd_debug.common.read_file") |
| @patch("os.path.exists") |
| def test_find_ip_version_multiple_keys(self, mock_exists, mock_read_file): |
| """Test find_ip_version with multiple keys in hw_ver""" |
| base_path = "/foo" |
| kind = "bar" |
| hw_ver = {"baz": 42, "qux": 99} |
| |
| # First key: file exists, value does not match |
| # Second key: file exists, value matches |
| def exists_side_effect(path): |
| return True |
| |
| def read_file_side_effect(path): |
| if path.endswith("baz"): |
| return "0" |
| if path.endswith("qux"): |
| return "99" |
| return "0" |
| |
| mock_exists.side_effect = exists_side_effect |
| mock_read_file.side_effect = read_file_side_effect |
| result = find_ip_version(base_path, kind, hw_ver) |
| self.assertFalse(result) |
| |
| def test_convert_string_to_bool_true_values(self): |
| """Test convert_string_to_bool returns True for truthy string values""" |
| self.assertTrue(convert_string_to_bool("True")) |
| self.assertTrue(convert_string_to_bool("1")) |
| self.assertTrue(convert_string_to_bool("'nonempty'")) |
| self.assertTrue(convert_string_to_bool('"nonempty"')) |
| |
| def test_convert_string_to_bool_false_values(self): |
| """Test convert_string_to_bool returns False for falsy string values""" |
| self.assertFalse(convert_string_to_bool("False")) |
| self.assertFalse(convert_string_to_bool("0")) |
| self.assertFalse(convert_string_to_bool("''")) |
| self.assertFalse(convert_string_to_bool('""')) |
| self.assertFalse(convert_string_to_bool("None")) |
| |
| def test_convert_string_to_bool_invalid_syntax(self): |
| """Test convert_string_to_bool exits on invalid syntax""" |
| with patch("sys.exit") as mock_exit: |
| convert_string_to_bool("not_a_bool") |
| mock_exit.assert_called_once_with("Invalid entry: not_a_bool") |
| |
| def test_convert_string_to_bool_invalid_value(self): |
| """Test convert_string_to_bool exits on invalid value""" |
| with patch("sys.exit") as mock_exit: |
| convert_string_to_bool("[unclosed_list") |
| mock_exit.assert_called_once_with("Invalid entry: [unclosed_list") |