ttm: Use relaunch_sudo() Running in a wheel is more awkward than necessary when using is_root() Signed-off-by: Mario Limonciello <superm1@kernel.org>
diff --git a/src/amd_debug/ttm.py b/src/amd_debug/ttm.py index 585cb5c..bf4ec23 100644 --- a/src/amd_debug/ttm.py +++ b/src/amd_debug/ttm.py
@@ -9,7 +9,7 @@ bytes_to_gb, gb_to_pages, get_system_mem, - is_root, + relaunch_sudo, print_color, reboot, version, @@ -57,9 +57,7 @@ def set(self, gb_value) -> bool: """Set a new page limit""" - if not is_root(): - print_color("Root privileges required", "❌") - return False + relaunch_sudo() # Check against system memory total = get_system_mem() @@ -108,9 +106,7 @@ print_color(f"{MODPROBE_CONF_PATH} doesn't exist", "❌") return False - if not is_root(): - print_color("Root privileges required", "❌") - return False + relaunch_sudo() os.remove(MODPROBE_CONF_PATH) print_color(f"Configuration {MODPROBE_CONF_PATH} removed", "🐧")
diff --git a/src/amd_debug/test_ttm.py b/src/test_ttm.py similarity index 88% rename from src/amd_debug/test_ttm.py rename to src/test_ttm.py index bde7420..c3d5288 100644 --- a/src/amd_debug/test_ttm.py +++ b/src/test_ttm.py
@@ -199,18 +199,12 @@ ) self.assertFalse(result) - @mock.patch("amd_debug.ttm.is_root", return_value=False) - @mock.patch("amd_debug.ttm.print_color") - def test_set_not_root(self, mock_print, _mock_is_root): - """Test set() when not root""" - result = self.tool.set(2) - mock_print.assert_called_with("Root privileges required", "❌") - self.assertFalse(result) - - @mock.patch("amd_debug.ttm.is_root", return_value=True) + @mock.patch("amd_debug.ttm.relaunch_sudo", return_value=True) @mock.patch("amd_debug.ttm.get_system_mem", return_value=8.0) @mock.patch("amd_debug.ttm.print_color") - def test_set_gb_greater_than_total(self, mock_print, _mock_mem, _mock_is_root): + def test_set_gb_greater_than_total( + self, mock_print, _mock_mem, _mock_relaunch_sudo + ): """Test set() when gb_value > total system memory""" result = self.tool.set(16) mock_print.assert_any_call( @@ -218,19 +212,19 @@ ) self.assertFalse(result) - @mock.patch("amd_debug.ttm.is_root", return_value=True) + @mock.patch("amd_debug.ttm.relaunch_sudo", return_value=True) @mock.patch("amd_debug.ttm.get_system_mem", return_value=10.0) @mock.patch("amd_debug.ttm.print_color") @mock.patch("builtins.input", return_value="n") def test_set_gb_exceeds_max_percentage_cancel( - self, _mock_input, mock_print, _mock_mem, mock_is_root + self, _mock_input, mock_print, _mock_mem, mock_relaunch_sudo ): """Test set() when gb_value exceeds max percentage and user cancels""" result = self.tool.set(9.5) self.assertFalse(result) mock_print.assert_any_call("Operation cancelled.", "🚦") - @mock.patch("amd_debug.ttm.is_root", return_value=True) + @mock.patch("amd_debug.ttm.relaunch_sudo", return_value=True) @mock.patch("amd_debug.ttm.get_system_mem", return_value=10.0) @mock.patch("amd_debug.ttm.gb_to_pages", return_value=20480) @mock.patch("amd_debug.ttm.print_color") @@ -244,8 +238,8 @@ mock_open, mock_print, _mock_gb_to_pages, - mock_mem, - mock_is_root, + _mock_mem, + _relaunch_sudo, ): """Test set() success path""" result = self.tool.set(5) @@ -266,21 +260,12 @@ self.assertFalse(result) @mock.patch("os.path.exists", return_value=True) - @mock.patch("amd_debug.ttm.is_root", return_value=False) - @mock.patch("amd_debug.ttm.print_color") - def test_clear_not_root(self, mock_print, _mock_is_root, _mock_exists): - """Test clear() when not root""" - result = self.tool.clear() - mock_print.assert_called_with("Root privileges required", "❌") - self.assertFalse(result) - - @mock.patch("os.path.exists", return_value=True) - @mock.patch("amd_debug.ttm.is_root", return_value=True) + @mock.patch("amd_debug.ttm.relaunch_sudo", return_value=True) @mock.patch("os.remove") @mock.patch("amd_debug.ttm.print_color") @mock.patch("amd_debug.ttm.maybe_reboot", return_value=True) def test_clear_success( - self, _mock_reboot, mock_print, mock_remove, _mock_is_root, mock_exists + self, _mock_reboot, mock_print, mock_remove, _mock_relaunch_sudo, _mock_exists ): """Test clear() success path""" result = self.tool.clear()