import numpy as np
import pandas as pd
from sklearn.feature_selection import VarianceThreshold
from poniard import PoniardClassifier
Base plugin
BasePlugin
.
BasePlugin
BasePlugin ()
Base plugin class. New plugins should inherit from this class.
BasePlugin
defines a set of common plugin hooks that allow plugin classes to execute actions during the life of a Poniard estimator. These could be thought of hooks in callbacks in other libraries like Keras, Transformers or fastai.
We have named them plugins as callbacks in other libraries are generally not expected to significantly alter what the main code does, and instead add funcionality like logging, model saving, etc. Poniard plugins have no such restriction.
Developing plugins
Plugins allow devs to extend Poniard funcionality beyond what the main module offers. Doing so is straightforward: subclass BasePlugin
and implement the desired methods.
Crucially, Poniard estimators inject themselves to all plugins during initialization, meaning that plugin instances have access to the estimator on the attribute _poniard
.
The following minimal example builds a plugin that adds a new (useless) feature and modifies the preprocessor.
class StringFeaturePlugin(BasePlugin):
"""A plugin that adds a feature comprised of a single string.
Parameters
----------
string :
The string to add as a feature.
"""
def __init__(self, string: str):
super().__init__()
self.string = string
def on_setup_data(self):
= self._poniard.X
data if hasattr(data, "iloc"):
self._poniard.X = data.assign(**{self.string: self.string})
else:
self._poniard.X = np.append(data, self.string, axis=1)
return
def on_setup_preprocessor(self):
= self._poniard.preprocessor
old_preprocessor if isinstance(old_preprocessor[-1], VarianceThreshold):
self._poniard.preprocessor = old_preprocessor[:-1]
self._poniard.pipelines = self._poniard._build_pipelines()
return
= pd.DataFrame(
features =(20, 2)), columns=[f"X_{i}" for i in range(2)]
np.random.normal(size
)= np.random.choice([0, 1], size=20)
target = PoniardClassifier(plugins=StringFeaturePlugin("foobar")).setup(features, target)
pnd pnd.preprocessor
Target info
-----------
Type: binary
Shape: (20,)
Unique values: 2
Main metric
-----------
roc_auc
Thresholds
----------
Minimum unique values to consider a feature numeric: 2
Minimum unique values to consider a categorical high cardinality: 20
Inferred feature types
----------------------
numeric | categorical_high | categorical_low | datetime | |
---|---|---|---|---|
0 | X_0 | foobar | ||
1 | X_1 |
Pipeline(steps=[('type_preprocessor', ColumnTransformer(transformers=[('numeric_preprocessor', Pipeline(steps=[('numeric_imputer', SimpleImputer()), ('scaler', StandardScaler())]), ['X_0', 'X_1']), ('categorical_low_preprocessor', Pipeline(steps=[('categorical_imputer', SimpleImputer(strategy='most_frequent')), ('one-hot_encoder', OneHotEncoder(drop='if_binary', handle_unknown='ignore', sparse=False))]), ['foobar'])]))], verbose=0)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('type_preprocessor', ColumnTransformer(transformers=[('numeric_preprocessor', Pipeline(steps=[('numeric_imputer', SimpleImputer()), ('scaler', StandardScaler())]), ['X_0', 'X_1']), ('categorical_low_preprocessor', Pipeline(steps=[('categorical_imputer', SimpleImputer(strategy='most_frequent')), ('one-hot_encoder', OneHotEncoder(drop='if_binary', handle_unknown='ignore', sparse=False))]), ['foobar'])]))], verbose=0)
ColumnTransformer(transformers=[('numeric_preprocessor', Pipeline(steps=[('numeric_imputer', SimpleImputer()), ('scaler', StandardScaler())]), ['X_0', 'X_1']), ('categorical_low_preprocessor', Pipeline(steps=[('categorical_imputer', SimpleImputer(strategy='most_frequent')), ('one-hot_encoder', OneHotEncoder(drop='if_binary', handle_unknown='ignore', sparse=False))]), ['foobar'])])
['X_0', 'X_1']
SimpleImputer()
StandardScaler()
['foobar']
SimpleImputer(strategy='most_frequent')
OneHotEncoder(drop='if_binary', handle_unknown='ignore', sparse=False)