GridSearchQ#

class mlquantify.model_selection.GridSearchQ(quantifier, param_grid, protocol='app', samples_sizes=100, n_repetitions=10, scoring=<function MAE>, refit=True, val_split=0.4, n_jobs=1, random_seed=42, verbose=False)[source]#

Grid search over quantifier hyperparameters with evaluation protocols.

Evaluates all combinations in param_grid using a held-out validation split and a sampling protocol (APP, NPP, or UPP). Selects the combination with the lowest score on the chosen metric, then optionally refits on the full training data.

Parameters:
quantifierBaseQuantifier

Quantifier instance whose hyperparameters are searched.

param_griddict

Mapping of parameter names to lists of values to try.

protocol{‘app’, ‘npp’, ‘upp’}, default=’app’

Evaluation protocol used to generate validation batches.

samples_sizesint or list of int, default=100

Batch size(s) for protocol evaluation.

n_repetitionsint, default=10

Number of repetitions per evaluation batch.

scoringcallable, default=MAE

Scoring function (true_prev, predicted_prev) -> float.

refitbool, default=True

If True, refit the quantifier on the full data after search.

val_splitfloat, default=0.4

Fraction of data held out for validation.

n_jobsint or None, default=1

Number of parallel evaluation jobs.

random_seedint or None, default=42

Random seed for reproducibility.

verbosebool, default=False

Print progress messages.

Attributes:
best_score_float

Lowest score found during the search.

best_params_dict

Hyperparameter combination that achieved best_score_.

best_model_BaseQuantifier

Quantifier refitted with best_params_ on the full training data.

Examples

>>> from mlquantify.model_selection import GridSearchQ
>>> from mlquantify.counting import CC
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.datasets import make_classification
>>> X, y = make_classification(n_samples=300, random_state=42)
>>> param_grid = {'threshold': [0.3, 0.5, 0.7]}
>>> gs = GridSearchQ(
...     CC(LogisticRegression()),
...     param_grid=param_grid,
...     protocol='npp',
...     n_repetitions=3,
... ).fit(X, y)
>>> gs.best_params_
{'threshold': 0.5}
>>> gs.predict(X)
{0: 0.49, 1: 0.51}
best_model()[source]#

Return the best model after fitting.

Returns:
Quantifier

The best fitted model.

Raises:
ValueError

If called before fitting.

best_params()[source]#

Return the best parameters found during fitting.

Returns:
dict

The best parameters.

Raises:
ValueError

If called before fitting.

fit(X, y)[source]#

Fit quantifiers over grid parameter combinations with evaluation protocol.

Splits data into training and validation by val_split, and evaluates each parameter combination multiple times with protocol-generated batches.

Parameters:
Xarray-like

Feature matrix for training.

yarray-like

Target labels for training.

Returns:
selfobject

Returns self for chaining.

get_metadata_routing()[source]#

Get metadata routing of this object.

Please check User Guide on how the routing mechanism works.

Returns:
routingMetadataRequest

A MetadataRequest encapsulating routing information.

get_params(deep=True)[source]#

Get parameters for this estimator.

Parameters:
deepbool, default=True

If True, will return the parameters for this estimator and contained subobjects that are estimators.

Returns:
paramsdict

Parameter names mapped to their values.

predict(X)[source]#

Predict using the best found model.

Parameters:
Xarray-like

Data for prediction.

Returns:
predictionsarray-like

Prevalence predictions.

Raises:
RuntimeError

If called before fitting.

save_quantifier(path: str | None = None) None[source]#

Save the quantifier instance to a file.

set_params(**params)[source]#

Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects (such as Pipeline). The latter have parameters of the form <component>__<parameter> so that it’s possible to update each component of a nested object.

Parameters:
**paramsdict

Estimator parameters.

Returns:
selfestimator instance

Estimator instance.

sout(msg)[source]#

Prints messages if verbose is True.