diff --git a/sklbench/benchmarks/sklearn_estimator.py b/sklbench/benchmarks/sklearn_estimator.py index 5b3032ea..ae0c0400 100644 --- a/sklbench/benchmarks/sklearn_estimator.py +++ b/sklbench/benchmarks/sklearn_estimator.py @@ -134,6 +134,16 @@ def get_subset_metrics_of_estimator( and isinstance(iterations[0], Union[Numeric, NumpyNumeric].__args__) ): metrics.update({"iterations": int(iterations[0])}) + if hasattr(estimator_instance, "estimators_"): + estimators_with_trees = [ + t + for t in estimator_instance.estimators_ + if hasattr(t, "tree_") and hasattr(t.tree_, "node_count") + ] + if estimators_with_trees: + metrics["n_nodes"] = sum( + t.tree_.node_count for t in estimators_with_trees + ) if task == "classification": y_pred = convert_to_numpy(estimator_instance.predict(x)) metrics.update( diff --git a/sklbench/report/implementation.py b/sklbench/report/implementation.py index 7861e3b5..5275b63b 100644 --- a/sklbench/report/implementation.py +++ b/sklbench/report/implementation.py @@ -71,6 +71,8 @@ # NB: 'n_clusters' is parameter of KMeans while # 'clusters' is number of computer clusters by DBSCAN "clusters", + # tree ensembles + "n_nodes", ], "incomparable": [ "1st-mean run ratio",