From a502c1b5ff94c4b6b88b8f7150351d06038220c5 Mon Sep 17 00:00:00 2001 From: Patrick Kunzmann Date: Mon, 5 May 2025 16:37:35 +0200 Subject: [PATCH] Selectors do not raise a warning, if input values are NaN --- src/peppr/selector.py | 10 ++++++++++ tests/test_selectors.py | 26 ++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/src/peppr/selector.py b/src/peppr/selector.py index fcd4728..c9016d0 100644 --- a/src/peppr/selector.py +++ b/src/peppr/selector.py @@ -74,6 +74,8 @@ def name(self) -> str: return "mean" def select(self, values: np.ndarray, smaller_is_better: bool) -> float: + if np.isnan(values).all(): + return np.nan return np.nanmean(values) @@ -87,6 +89,8 @@ def name(self) -> str: return "median" def select(self, values: np.ndarray, smaller_is_better: bool) -> float: + if np.isnan(values).all(): + return np.nan return np.nanmedian(values) @@ -100,6 +104,8 @@ def name(self) -> str: return "Oracle" def select(self, values: np.ndarray, smaller_is_better: bool) -> float: + if np.isnan(values).all(): + return np.nan if smaller_is_better: return np.nanmin(values) else: @@ -126,6 +132,8 @@ def name(self) -> str: def select(self, values: np.ndarray, smaller_is_better: bool) -> float: top_values = values[: self._k] + if np.isnan(top_values).all(): + return np.nan if smaller_is_better: return np.nanmin(top_values) else: @@ -160,6 +168,8 @@ def select(self, values: np.ndarray, smaller_is_better: bool) -> float: range(len(values)), size=self._k, replace=False ) top_values = values[random_indices] + if np.isnan(top_values).all(): + return np.nan if smaller_is_better: return np.nanmin(top_values) else: diff --git a/tests/test_selectors.py b/tests/test_selectors.py index 6ad7d8b..bfa32c1 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -28,6 +28,32 @@ def test_selectors(selector, expected_value): assert selected_value == expected_value +@pytest.mark.filterwarnings("error") +@pytest.mark.parametrize("smaller_is_better", [False, True]) +@pytest.mark.parametrize( + "selector", + [ + peppr.MeanSelector(), + peppr.OracleSelector(), + peppr.TopSelector(5), + peppr.TopSelector(1), + peppr.RandomSelector(5, seed=0), + ], + ids=lambda selector: selector.name, +) +def test_nan_values(selector, smaller_is_better): + """ + Check that :meth:`Selector.select()` returns NaN if all values are NaN, without + raising a warnings. + If any value is not NaN, the selector should return an actual value. + """ + values = np.full(10, np.nan) + assert np.isnan(selector.select(values, smaller_is_better)) + # Expect a non-NaN value if the input contains any non-NaN value + values = np.concatenate([np.arange(9), [np.nan]]) + assert not np.isnan(selector.select(values, smaller_is_better)) + + def test_random_selector(): """ Test the RandomSelector's statistical behavior.