DiagonalDisplay#

class mlquantify.visualization.DiagonalDisplay(true_prevalences, predicted_prevalences, *, class_names=None)[source]#

True vs. predicted prevalence diagonal plot.

The signature diagnostic of quantification evaluation: across the many test samples generated by an evaluation protocol (e.g. APP/UPP), each sample’s predicted prevalence is plotted against its true prevalence, together with the \(y = x\) reference line. Points above the diagonal are over-estimates; points below are under-estimates. Tight clustering around the diagonal across the whole prevalence range indicates a good quantifier.

This is a multiple-sample display: it summarises a protocol run rather than a single prediction.

Parameters:
true_prevalencesndarray of shape (n_samples, n_classes)

True class prevalence of each evaluation sample.

predicted_prevalencesndarray of shape (n_samples, n_classes)

Predicted class prevalence of each evaluation sample.

class_nameslist of str, default=None

Class labels, in the column order of the prevalence arrays.

Attributes:
scatter_matplotlib PathCollection or list of PathCollection

The scatter artist(s). A list when several classes are drawn.

line_matplotlib Line2D

The \(y = x\) reference line.

ax_matplotlib Axes

The axes with the plot.

figure_matplotlib Figure

The figure containing the axes.

See also

BiasDisplay

Boxplots of signed estimation error.

ErrorByShiftDisplay

Error as a function of prior-probability shift.

Examples

>>> from mlquantify.visualization import DiagonalDisplay
>>> from mlquantify.counting import CC
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.datasets import make_classification
>>> X, y = make_classification(n_samples=400, random_state=0)
>>> disp = DiagonalDisplay.from_protocol(   
...     CC(LogisticRegression()), X, y, protocol="app", n_prevalences=11)
>>> disp.ax_.set_title("CC")   
classmethod from_predictions(true_prevalences, predicted_prevalences, *, class_names=None, ax=None, **kwargs)[source]#

Build a DiagonalDisplay from precomputed prevalence arrays.

Parameters:
true_prevalences, predicted_prevalencesndarray of shape (n_samples, n_classes)

True and predicted prevalence of each evaluation sample, e.g. the 'true_prevalences' / 'predicted_prevalences' arrays returned by mlquantify.model_selection.apply_protocol.

class_nameslist of str, default=None

Class labels in column order.

axmatplotlib Axes, default=None

Axes to draw on.

**kwargs

Passed to plot.

Returns:
displayDiagonalDisplay
classmethod from_protocol(quantifier, X, y, *, protocol='app', ax=None, name=None, class_index=None, plot_diagonal=True, diagonal_kw=None, **protocol_kwargs)[source]#

Run an evaluation protocol and plot the resulting diagonal.

Thin wrapper around mlquantify.model_selection.apply_protocol.

Parameters:
quantifierBaseQuantifier

Quantifier to evaluate.

X, yarray-like

Data the protocol samples from.

protocol{‘app’, ‘npp’, ‘upp’, ‘ppp’} or BaseProtocol, default=’app’

Sampling protocol.

axmatplotlib Axes, default=None

Axes to draw on.

namestr, default=None

Legend label (binary case).

class_indexint, default=None

Class column to plot (see plot).

plot_diagonalbool, default=True

Whether to draw the reference line.

diagonal_kwdict, default=None

Reference-line styling.

**protocol_kwargs

Forwarded to apply_protocol (e.g. n_prevalences, batch_size, random_state). Styling kwargs for the scatter are not accepted here; call plot directly for that.

Returns:
displayDiagonalDisplay
plot(ax=None, *, class_index=None, name=None, plot_diagonal=True, diagonal_kw=None, **kwargs)[source]#

Plot the diagonal scatter.

Parameters:
axmatplotlib Axes, default=None

Axes to draw on. A new figure/axes is created when None.

class_indexint, default=None

Which class column to plot. Defaults to the last class for binary problems (the conventional “positive” class) and to all classes (color-coded) for multiclass problems.

namestr, default=None

Label used in the legend (binary / single-class case).

plot_diagonalbool, default=True

Whether to draw the \(y = x\) reference line.

diagonal_kwdict, default=None

Keyword arguments forwarded to the reference-line ax.plot call.

**kwargs

Forwarded to ax.scatter (the primary artist).

Returns:
displayDiagonalDisplay

Object that stores the computed artists.