From daa04ab1b98af08932a2e392f881f144b8133ad1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 20 May 2025 17:01:19 +0200 Subject: [PATCH 1/2] Update _rotation_forest.py --- sktime/classification/sklearn/_rotation_forest.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sktime/classification/sklearn/_rotation_forest.py b/sktime/classification/sklearn/_rotation_forest.py index f7b0cce5bef..2ef45781596 100644 --- a/sktime/classification/sklearn/_rotation_forest.py +++ b/sktime/classification/sklearn/_rotation_forest.py @@ -137,6 +137,11 @@ def __init__( super().__init__() + if self.base_estimator is None: + self._base_estimator = DecisionTreeClassifier(criterion="entropy") + else: + self._base_estimator = self.base_estimator + def fit(self, X, y): """Fit a forest of trees on cases (X,y), where y is the target variable. @@ -185,9 +190,6 @@ def fit(self, X, y): start_time = time.time() train_time = 0 - if self.base_estimator is None: - self._base_estimator = DecisionTreeClassifier(criterion="entropy") - # remove useless attributes self._useful_atts = ~np.all(X[1:] == X[:-1], axis=0) X = X[:, self._useful_atts] From 0a2e23573c4c447054fd1e55ce8cfcfeb7f1a122 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 20 May 2025 17:04:04 +0200 Subject: [PATCH 2/2] Update test_all_classifiers.py --- .../classification/sklearn/tests/test_all_classifiers.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sktime/classification/sklearn/tests/test_all_classifiers.py b/sktime/classification/sklearn/tests/test_all_classifiers.py index dc29a44c950..cb27689024a 100644 --- a/sktime/classification/sklearn/tests/test_all_classifiers.py +++ b/sktime/classification/sklearn/tests/test_all_classifiers.py @@ -3,6 +3,7 @@ __author__ = ["MatthewMiddlehurst"] import pytest +from sklearn.linear_model import LogisticRegression from sklearn.utils.estimator_checks import parametrize_with_checks from sktime.classification.sklearn import ContinuousIntervalTree, RotationForest @@ -10,12 +11,18 @@ ALL_SKLEARN_CLASSIFIERS = [RotationForest, ContinuousIntervalTree] +INSTANCES_TO_TEST = [ + RotationForest(n_estimators=3), + RotationForest(n_estimators=3, base_estimator=LogisticRegression()), + ContinuousIntervalTree(), +] + @pytest.mark.skipif( not run_test_for_class(ALL_SKLEARN_CLASSIFIERS), reason="run test only if softdeps are present and incrementally (if requested)", ) -@parametrize_with_checks([RotationForest(n_estimators=3), ContinuousIntervalTree()]) +@parametrize_with_checks(INSTANCES_TO_TEST) def test_sklearn_compatible_estimator(estimator, check): """Test that sklearn estimators adhere to sklearn conventions.""" try: