DiagonalDisplay#
- class mlquantify.visualization.DiagonalDisplay(true_prevalences, predicted_prevalences, *, class_names=None)[source]#
True vs. predicted prevalence diagonal plot.
The signature diagnostic of quantification evaluation: across the many test samples generated by an evaluation protocol (e.g. APP/UPP), each sample’s predicted prevalence is plotted against its true prevalence, together with the \(y = x\) reference line. Points above the diagonal are over-estimates; points below are under-estimates. Tight clustering around the diagonal across the whole prevalence range indicates a good quantifier.
This is a multiple-sample display: it summarises a protocol run rather than a single prediction.
- Parameters:
- true_prevalencesndarray of shape (n_samples, n_classes)
True class prevalence of each evaluation sample.
- predicted_prevalencesndarray of shape (n_samples, n_classes)
Predicted class prevalence of each evaluation sample.
- class_nameslist of str, default=None
Class labels, in the column order of the prevalence arrays.
- Attributes:
- scatter_matplotlib PathCollection or list of PathCollection
The scatter artist(s). A list when several classes are drawn.
- line_matplotlib Line2D
The \(y = x\) reference line.
- ax_matplotlib Axes
The axes with the plot.
- figure_matplotlib Figure
The figure containing the axes.
See also
BiasDisplayBoxplots of signed estimation error.
ErrorByShiftDisplayError as a function of prior-probability shift.
Examples
>>> from mlquantify.visualization import DiagonalDisplay >>> from mlquantify.counting import CC >>> from sklearn.linear_model import LogisticRegression >>> from sklearn.datasets import make_classification >>> X, y = make_classification(n_samples=400, random_state=0) >>> disp = DiagonalDisplay.from_protocol( ... CC(LogisticRegression()), X, y, protocol="app", n_prevalences=11) >>> disp.ax_.set_title("CC")
- classmethod from_predictions(true_prevalences, predicted_prevalences, *, class_names=None, ax=None, **kwargs)[source]#
Build a
DiagonalDisplayfrom precomputed prevalence arrays.- Parameters:
- true_prevalences, predicted_prevalencesndarray of shape (n_samples, n_classes)
True and predicted prevalence of each evaluation sample, e.g. the
'true_prevalences'/'predicted_prevalences'arrays returned bymlquantify.model_selection.apply_protocol.- class_nameslist of str, default=None
Class labels in column order.
- axmatplotlib Axes, default=None
Axes to draw on.
- **kwargs
Passed to
plot.
- Returns:
- displayDiagonalDisplay
- classmethod from_protocol(quantifier, X, y, *, protocol='app', ax=None, name=None, class_index=None, plot_diagonal=True, diagonal_kw=None, **protocol_kwargs)[source]#
Run an evaluation protocol and plot the resulting diagonal.
Thin wrapper around
mlquantify.model_selection.apply_protocol.- Parameters:
- quantifierBaseQuantifier
Quantifier to evaluate.
- X, yarray-like
Data the protocol samples from.
- protocol{‘app’, ‘npp’, ‘upp’, ‘ppp’} or BaseProtocol, default=’app’
Sampling protocol.
- axmatplotlib Axes, default=None
Axes to draw on.
- namestr, default=None
Legend label (binary case).
- class_indexint, default=None
Class column to plot (see
plot).- plot_diagonalbool, default=True
Whether to draw the reference line.
- diagonal_kwdict, default=None
Reference-line styling.
- **protocol_kwargs
Forwarded to
apply_protocol(e.g.n_prevalences,batch_size,random_state). Styling kwargs for the scatter are not accepted here; callplotdirectly for that.
- Returns:
- displayDiagonalDisplay
- plot(ax=None, *, class_index=None, name=None, plot_diagonal=True, diagonal_kw=None, **kwargs)[source]#
Plot the diagonal scatter.
- Parameters:
- axmatplotlib Axes, default=None
Axes to draw on. A new figure/axes is created when None.
- class_indexint, default=None
Which class column to plot. Defaults to the last class for binary problems (the conventional “positive” class) and to all classes (color-coded) for multiclass problems.
- namestr, default=None
Label used in the legend (binary / single-class case).
- plot_diagonalbool, default=True
Whether to draw the \(y = x\) reference line.
- diagonal_kwdict, default=None
Keyword arguments forwarded to the reference-line
ax.plotcall.- **kwargs
Forwarded to
ax.scatter(the primary artist).
- Returns:
- displayDiagonalDisplay
Object that stores the computed artists.