QuaNet#
- class mlquantify.neural.QuaNet(estimator, fit_estimator: bool = True, sample_size: int = 100, n_epochs: int = 100, tr_iter: int = 500, va_iter: int = 100, lr: float = 0.001, lstm_hidden_size: int = 64, lstm_nlayers: int = 1, ff_layers: Sequence[int] = (1024, 512), bidirectional: bool = True, random_state: int | None = None, qdrop_p: float = 0.5, patience: int = 10, checkpointdir: str = './checkpoint_quanet', checkpointname: str | None = None, device: str = 'cuda')[source]#
QuaNet: deep neural quantification with an LSTM architecture.
Learns a mapping from bags of instances to class-prevalence vectors using an LSTM network. During training, artificial bags are generated via the APP protocol; for each bag the network receives document embeddings, posterior probabilities, and simple quantification statistics (CC, PCC, EMQ …) and is trained to minimise the MSE against the true bag prevalences.
Requires a base estimator that implements
fit,predict_proba, andtransform(the last to produce document embeddings). PyTorch must be installed.- Parameters:
- estimatorestimator
Base probabilistic classifier with
fit,predict_proba, andtransformmethods.- fit_estimatorbool, default=True
If
True, fit the estimator insidefit.- sample_sizeint, default=100
Bag size used by the APP protocol during training.
- n_epochsint, default=100
Maximum number of training epochs.
- tr_iterint, default=500
Training APP samplings per epoch.
- va_iterint, default=100
Validation APP samplings per epoch.
- lrfloat, default=1e-3
Learning rate for the Adam optimiser.
- lstm_hidden_sizeint, default=64
Hidden size of the LSTM.
- lstm_nlayersint, default=1
Number of LSTM layers.
- ff_layerssequence of int, default=(1024, 512)
Sizes of the fully connected layers above the LSTM embedding.
- bidirectionalbool, default=True
Whether to use a bidirectional LSTM.
- qdrop_pfloat, default=0.5
Dropout probability in the network.
- patienceint, default=10
Early-stopping patience (epochs without validation improvement).
- checkpointdirstr, default=’./checkpoint_quanet’
Directory for saving intermediate model weights.
- checkpointnamestr or None, default=None
Checkpoint filename.
Nonegenerates a random name.- device{‘cpu’, ‘cuda’}, default=’cuda’
Device used for PyTorch computations.
- Attributes:
- classes_ndarray of shape (n_classes,)
Class labels seen during
fit.
References
References
[1]Esuli, A., Moreo, A., & Sebastiani, F. (2018). A Recurrent Neural Network for Sentiment Quantification. CIKM, pp. 1775–1778.
Examples
# Requires PyTorch and an estimator with a transform() method from mlquantify.neural import QuaNet q = QuaNet(estimator=my_embedding_classifier, device='cpu') q.fit(X_train, y_train) q.predict(X_test)
- fit(X, y)[source]#
Fit QuaNet to the training data.
Optionally fits the base estimator, then trains the LSTM network on artificially generated bags sampled with the UPP protocol. Uses early stopping based on the validation loss.
- Parameters:
- Xarray-like of shape (n_samples, n_features)
Training feature matrix. Must be compatible with both
estimator.fitandestimator.transform.- yarray-like of shape (n_samples,)
Training class labels.
- Returns:
- selfQuaNet
The fitted quantifier.
Notes
When
fit_estimator=Truethe data is internally split into a classifier-training set (60 %), a network-training set (32 %), and a validation set (8 %). Whenfit_estimator=Falseonly the train/validation split (80 %/20 %) is performed.
- get_metadata_routing()[source]#
Get metadata routing of this object.
Please check User Guide on how the routing mechanism works.
- Returns:
- routingMetadataRequest
A
MetadataRequestencapsulating routing information.
- get_params(deep=True)[source]#
Get parameters for this estimator.
- Parameters:
- deepbool, default=True
If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Returns:
- paramsdict
Parameter names mapped to their values.
- predict(X)[source]#
Predict class prevalences for a test bag.
Computes posterior probabilities and document embeddings with the base estimator, collects simple quantification statistics (CC, GACC, PCC, GPACC, EMQ) as auxiliary inputs, and forwards everything through the trained
QuaNetModule.- Parameters:
- Xarray-like of shape (n_samples, n_features)
Test feature matrix.
- Returns:
- prevalencesndarray of shape (n_classes,)
Estimated class prevalence vector for the test bag, normalised to sum to 1.
- set_params(**params)[source]#
Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as
Pipeline). The latter have parameters of the form<component>__<parameter>so that it’s possible to update each component of a nested object.- Parameters:
- **paramsdict
Estimator parameters.
- Returns:
- selfestimator instance
Estimator instance.