cellcharter.tl.Cluster#
- class cellcharter.tl.Cluster(n_clusters=1, *, covariance_type='full', init_strategy='kmeans', init_means=None, convergence_tolerance=0.001, covariance_regularization=1e-06, batch_size=None, trainer_params=None, random_state=0)#
Cluster cells or spots based on the neighborhood aggregated features from CellCharter.
- Parameters:
n_clusters (
int
(default:1
)) – The number of components in the GMM. The dimensionality of each component is automatically inferred from the data.covariance_type (
str
(default:'full'
)) – The type of covariance to assume for all Gaussian components.init_strategy (
str
(default:'kmeans'
)) – The strategy for initializing component means and covariances.init_means (
Optional
[Tensor
] (default:None
)) – An optional initial guess for the means of the components. If provided, must be a tensor of shape[num_components, num_features]
. If this is given, theinit_strategy
is ignored and the means are handled as if K-means initialization has been run.convergence_tolerance (
float
(default:0.001
)) – The change in the per-datapoint negative log-likelihood which implies that training has converged.covariance_regularization (
float
(default:1e-06
)) – A small value which is added to the diagonal of the covariance matrix to ensure that it is positive semi-definite.batch_size (The batch size to use when fitting the model. If not provided, the full) – data will be used as a single batch. Set this if the full data does not fit into memory.
trainer_params (
Optional
[dict
] (default:None
)) – Initialization parameters to use when initializing a PyTorch Lightning trainer. By default, it disables various stdout logs unless TorchGMM is configured to do verbose logging. Checkpointing and logging are disabled regardless of the log level. This estimator further sets the following overridable defaults: -max_epochs=100
.random_state (
Union
[int
,RandomState
,None
] (default:0
)) – Initialization seed.
Examples
>>> adata = anndata.read_h5ad(path_to_anndata) >>> sq.gr.spatial_neighbors(adata, coord_type='generic', delaunay=True) >>> cc.gr.remove_long_links(adata) >>> cc.gr.aggregate_neighbors(adata, n_layers=3) >>> model = cc.tl.Cluster(n_clusters=11) >>> model.fit(adata, use_rep='X_cellcharter')
Attributes table#
Returns the list of fitted attributes that ought to be saved and loaded. |
|
The fitted PyTorch module with all estimated parameters. |
|
A boolean indicating whether the model converged during training. |
|
The number of iterations the model was fitted for, excluding initialization. |
|
The average per-datapoint negative log-likelihood at the last training step. |
Methods table#
|
Clones the estimator without copying any fitted attributes. |
|
Fit data into |
|
Fits the estimator using the provided data and subsequently predicts the labels for the data using the fitted estimator. |
|
Returns the estimator's parameters as passed to the initializer. |
|
Loads the estimator and (if available) the fitted model. |
|
Loads the fitted attributes that are stored at the fitted path. |
|
Initializes this estimator by loading its parameters. |
|
Predict the labels for the data in |
|
Computes a distribution over the components for each of the provided datapoints. |
|
Samples datapoints from the fitted Gaussian mixture. |
|
Saves the estimator to the provided directory. |
|
Saves the fitted attributes of this estimator. |
|
Saves the parameters of this estimator. |
|
Computes the average negative log-likelihood (NLL) of the provided datapoints. |
|
Computes the negative log-likelihood (NLL) of each of the provided datapoints. |
|
Sets the provided values on the estimator. |
|
Returns the trainer as configured by the estimator. |
Attributes#
- Cluster.persistent_attributes#
Returns the list of fitted attributes that ought to be saved and loaded.
By default, this encompasses all annotations.
-
Cluster.model_:
GaussianMixtureModel
# The fitted PyTorch module with all estimated parameters.
Methods#
- Cluster.clone()#
Clones the estimator without copying any fitted attributes. All parameters of this estimator are copied via
copy.deepcopy()
.- Return type:
TypeVar
(E
, bound= BaseEstimator)- Returns:
The cloned estimator with the same parameters.
- Cluster.fit(adata, use_rep='X_cellcharter')#
Fit data into
n_clusters
clusters.- Parameters:
adata (
AnnData
) – Annotated data object.use_rep (
str
(default:'X_cellcharter'
)) – Key inanndata.AnnData.obsm
to use as data to fit the clustering model.
- Cluster.fit_predict(data)#
Fits the estimator using the provided data and subsequently predicts the labels for the data using the fitted estimator. It simply chains calls to
fit()
andpredict()
.- Args:
- data: The data to use for fitting and to predict labels for. The data must have the
same type as for the
fit()
method.
- Cluster.get_params(deep=True)#
Returns the estimator’s parameters as passed to the initializer.
- Args:
deep: Ignored. For Scikit-learn compatibility.
- classmethod Cluster.load(path)#
Loads the estimator and (if available) the fitted model. This method should only be expected to work to load an estimator that has previously been saved via
save()
.- Args:
path: The directory from which to load the estimator.
- Return type:
TypeVar
(E
, bound= BaseEstimator)- Returns:
The loaded estimator, either fitted or not.
- Cluster.load_attributes(path)#
Loads the fitted attributes that are stored at the fitted path. If subclasses overwrite
save_attributes()
, this method should also be overwritten.Typically, this method should not be called directly. It is called as part of
load()
.- Return type:
- Args:
path: The directory from which the parameters should be loaded.
- Raises:
FileNotFoundError – If the no fitted attributes have been stored.:
- classmethod Cluster.load_parameters(path)#
Initializes this estimator by loading its parameters. If subclasses overwrite
save_parameters()
, this method should also be overwritten.Typically, this method should not be called directly. It is called as part of
load()
.- Return type:
TypeVar
(E
, bound= BaseEstimator)
- Args:
path: The directory from which the parameters should be loaded.
- Cluster.predict(adata, use_rep='X_cellcharter')#
Predict the labels for the data in
use_rep
using the fitted model.- Parameters:
adata (
AnnData
) – Annotated data object.use_rep (
str
(default:'X_cellcharter'
)) – Key inanndata.AnnData.obsm
used as data to fit the clustering model. IfNone
, usesanndata.AnnData.X
.k – Number of clusters to predict using the fitted model. If
None
, the number of clusters with the highest stability will be selected. Ifmax_runs > 1
, the model with the largest marginal likelihood will be used among the ones fitted onk
.
- Return type:
Categorical
- Cluster.predict_proba(data)#
Computes a distribution over the components for each of the provided datapoints.
- Parameters:
data (
Union
[ndarray
[Any
,dtype
[float32
]],Tensor
]) – The datapoints for which to compute the component assignment probabilities.- Return type:
Tensor
- Returns:
A tensor of shape
[num_datapoints, num_components]
with the assignment probabilities for each component and datapoint. Note that each row of the vector sums to 1, i.e. the returned tensor provides a proper distribution over the components for each datapoint.
Attention
When calling this function in a multi-process environment, each process receives only a subset of the predictions. If you want to aggregate predictions, make sure to gather the values returned from this method.
- Cluster.sample(num_datapoints)#
Samples datapoints from the fitted Gaussian mixture.
- Args:
num_datapoints: The number of datapoints to sample.
- Return type:
Tensor
- Returns:
A tensor of shape
[num_datapoints, dim]
providing the samples.- Note:
This method does not parallelize across multiple processes, i.e. performs no synchronization.
- Cluster.save(path)#
Saves the estimator to the provided directory. It saves a file named
estimator.pickle
for the configuration of the estimator and additional files for the fitted model (if applicable). For more information on the files saved for the fitted model or for more customization, look atget_params()
andtorchgmm.base.nn.Configurable.save()
.- Return type:
- Args:
path: The directory to which all files should be saved.
- Note:
This method may be called regardless of whether the estimator has already been fitted.
- Attention:
If the dictionary returned by
get_params()
is not JSON-serializable, this method usespickle
which is not necessarily backwards-compatible.
- Cluster.save_attributes(path)#
Saves the fitted attributes of this estimator. By default, it uses JSON and falls back to
pickle
. Subclasses should overwrite this method if non-primitive attributes are fitted.Typically, this method should not be called directly. It is called as part of
save()
.- Return type:
- Args:
path: The directory to which the fitted attributed should be saved.
- Raises:
NotFittedError – If the estimator has not been fitted.:
- Cluster.save_parameters(path)#
Saves the parameters of this estimator. By default, it uses JSON and falls back to
pickle
. It subclasses use non-primitive types as parameters, they should overwrite this method.Typically, this method should not be called directly. It is called as part of
save()
.- Return type:
- Args:
path: The directory to which the parameters should be saved.
- Cluster.score(data)#
Computes the average negative log-likelihood (NLL) of the provided datapoints.
- Args:
data: The datapoints for which to evaluate the NLL.
- Return type:
- Returns:
The average NLL of all datapoints.
- Note:
See
score_samples()
to obtain NLL values for individual datapoints.
- Cluster.score_samples(data)#
Computes the negative log-likelihood (NLL) of each of the provided datapoints.
- Parameters:
data (
Union
[ndarray
[Any
,dtype
[float32
]],Tensor
]) – The datapoints for which to compute the NLL.- Return type:
Tensor
- Returns:
A tensor of shape
[num_datapoints]
with the NLL for each datapoint.
Attention
When calling this function in a multi-process environment, each process receives only a subset of the predictions. If you want to aggregate predictions, make sure to gather the values returned from this method.
- Cluster.set_params(values)#
Sets the provided values on the estimator. The estimator is returned as well, but the estimator on which this function is called is also modified.
- Args:
values: The values to set.
- Return type:
TypeVar
(E
, bound= BaseEstimator)- Returns:
The estimator where the values have been set.
- Cluster.trainer(**kwargs)#
Returns the trainer as configured by the estimator. Typically, this method is only called by functions in the estimator.
- Args:
- kwargs: Additional arguments that override the trainer arguments registered in the
initializer of the estimator.
- Return type:
Trainer
- Returns:
A fully initialized PyTorch Lightning trainer.
- Note:
This function should be preferred over initializing the trainer directly. It ensures that the returned trainer correctly deals with TorchGMM components that may be introduced in the future.