MMD_RKHS#
- class mlquantify.mixture.MMD_RKHS(kernel='rbf', gamma=None, degree=3, coef0=0.0)[source]#
Maximum Mean Discrepancy in RKHS (MMD-RKHS) quantification method.
This method estimates class prevalences in an unlabeled test set by matching the kernel mean embedding of the test distribution to a convex combination of the class-conditional training embeddings.
Let \(\mathcal{X} \subseteq \mathbb{R}^d\) be the input space and \(\mathcal{Y} = \{0, \dots, C-1\}\) the label set. Let \(K\) be a positive definite kernel with RKHS \(\mathcal{H}\) and feature map \(\phi\), so that \(K(x, x') = \langle \phi(x), \phi(x') \rangle_{\mathcal{H}}\).
For each class \(y\), the class-conditional kernel mean embedding is
\[\mu_y \;=\; \mathbb{E}_{x \sim P_{D}(x \mid y)}[\phi(x)] \in \mathcal{H},\]and the test mean embedding is
\[\mu_U \;=\; \mathbb{E}_{x \sim P_{U}(x)}[\phi(x)] \in \mathcal{H}.\]Under prior probability shift, the test distribution satisfies
\[P_U(x) = \sum_{y=0}^{C-1} \theta_y \, P_D(x \mid y),\]which implies
\[\mu_U = \sum_{y=0}^{C-1} \theta_y \, \mu_y,\]where \(\theta \in \Delta^{C-1}\) is the class prevalence vector. The MMD-RKHS estimator solves
\[\hat{\theta} \;=\; \arg\min_{\theta \in \Delta^{C-1}} \big\lVert \textstyle\sum_{y=0}^{C-1} \theta_y \mu_y - \mu_U \big\rVert_{\mathcal{H}}^2.\]In practice, embeddings are approximated by empirical means. Using the kernel trick, the objective can be written as a quadratic program
\[\hat{\theta} \;=\; \arg\min_{\theta \in \Delta^{C-1}} \big( \theta^\top G \, \theta - 2 \, h^\top \theta \big),\]with
\[G_{yy'} = \langle \hat{\mu}_y, \hat{\mu}_{y'} \rangle_{\mathcal{H}}, \qquad h_y = \langle \hat{\mu}_y, \hat{\mu}_U \rangle_{\mathcal{H}}.\]The solution \(\hat{\theta}\) is the estimated prevalence vector.
- Parameters:
- kernel{‘rbf’, ‘linear’, ‘poly’, ‘sigmoid’, ‘cosine’}, default=’rbf’
Kernel used to build the RKHS where MMD is computed.
- gammafloat or None, default=None
Kernel coefficient for ‘rbf’ and ‘sigmoid’.
- degreeint, default=3
Degree of the polynomial kernel.
- coef0float, default=0.0
Independent term in ‘poly’ and ‘sigmoid’ kernels.
- strategy{‘ovr’, ‘ovo’}, default=’ovr’
Multiclass quantification strategy flag (for consistency with other mixture-based quantifiers).
- Attributes:
- classes_ndarray of shape (n_classes,)
Class labels seen during fitting.
- X_train_ndarray of shape (n_train, n_features)
Training feature matrix.
- y_train_ndarray of shape (n_train,)
Training labels.
- class_means_ndarray of shape (n_classes, n_train)
Empirical class-wise kernel mean embeddings in the span of training samples.
- K_train_ndarray of shape (n_train, n_train)
Gram matrix of training samples under the chosen kernel.
References
[1]Iyer, A., Nath, S., & Sarawagi, S. (2014). Maximum Mean Discrepancy for Class Ratio Estimation: Convergence Bounds and Kernel Selection. ICML.
[2]Esuli, A., Moreo, A., & Sebastiani, F. (2023). Learning to Quantify. Springer.
- best_mixture(X_test, X_train, y_train)[source]#
Implements the MMD-based class ratio estimation:
\[\min_{\theta \in \Delta^{C-1}} \| \sum_{y=0}^{C-1} \theta_y \mu_y - \mu_U \|^2\]and returns (theta, objective_value).
- get_best_distance(*args, **kwargs)[source]#
Get the best distance value from the mixture fitting process.
Notes
If the quantifier has not been fitted yet, it will fit the model for getting the best distance.
- classmethod get_distance(dist_train, dist_test, measure='hellinger')[source]#
Compute distance between two distributions.
- get_metadata_routing()[source]#
Get metadata routing of this object.
Please check User Guide on how the routing mechanism works.
- Returns:
- routingMetadataRequest
A
MetadataRequestencapsulating routing information.
- get_params(deep=True)[source]#
Get parameters for this estimator.
- Parameters:
- deepbool, default=True
If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Returns:
- paramsdict
Parameter names mapped to their values.
- set_params(**params)[source]#
Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as
Pipeline). The latter have parameters of the form<component>__<parameter>so that it’s possible to update each component of a nested object.- Parameters:
- **paramsdict
Estimator parameters.
- Returns:
- selfestimator instance
Estimator instance.