hrchy_cytocommunity.models.HRCHYCytoCommunity
- class hrchy_cytocommunity.models.HRCHYCytoCommunity(dataset, num_tcn1, num_tcn2, cell_meta, lr=0.0001, alpha=0.9, num_epoch=1500, lambda1=1.0, lambda2=1.0, lambda_balance=1.0, edge_pruning_cutoff=0.2, device=None, num_hidden=128, gt_coarse=False, gt_fine=False)
Hierarchical Community Detection Model (base model)
This class implements a hierarchical tissue structure identification base model that integrates a differentiable graph pooling mechanism (MinCut-based) with spatial feature learning on single-cell spatial omics data.
- Parameters:
dataset (object) – Input dataset object, which must contain at least num_features attribute and provide node features and adjacency information for graph construction.
num_tcn1 (int) – Number of fine-grained Cellular neighborhood (TCN1 level).
num_tcn2 (int) – Number of coarse-grained tissue compartments (TCN2 level).
cell_meta (pandas.DataFrame or dict) – Metadata for each cell, typically including cell type, position, or annotations(optimal).
lr (float, default=1e-4) – Learning rate for the optimizer.
alpha (float, default=0.9) – Initial balance coefficient between fine-grained and coarse-grained objectives, not recommend to change.
num_epoch (int, default=1500) – Number of training epochs.
lambda2 (float, default=1.0) – Weight of orthogonality loss term. Not recommend to change.
edge_pruning_cutoff (float or None, optional) – Threshold for edge pruning. If None, defaults to 1 / (num_tcn1 - 1).
device (str or torch.device or None, optional) – Device to run the model on. If None, automatically selects ‘cuda’ if available, otherwise ‘cpu’.
num_hidden (int, default=128) – Number of hidden channels in the graph neural network.
gt_coarse (bool, default=False) – Whether to use ground truth coarse-level annotations (for benchmarking or supervision).
gt_fine (bool, default=False) – Whether to use ground truth fine-level annotations.
- model
Underlying graph neural network model instance.
- Type:
SparseNet
- device
Device string used by PyTorch (‘cuda’ or ‘cpu’).
- Type:
str
- edge_pruning_cutoff
Final threshold used for edge pruning.
- Type:
float
- lr, alpha, epochs, lambda1
Model hyperparameters stored after initialization.
- Type:
float or int
Notes
The model automatically constructs an internal SparseNet instance based on dataset feature dimensions and user-specified clustering parameters.
This class supports GPU acceleration via CUDA when available.
- The naming convention follows:
TCN1 — fine-grained tissue community nodes
TCN2 — coarse-grained tissue community nodes
Examples
>>> model = HRCHYCytoCommunityGrand( ... dataset=my_dataset, ... num_tcn1=10, ... num_tcn2=2, ... cell_meta=meta_df, ... lr=1e-3, ... device='cuda' ... ) >>> print(model.device) 'cuda' >>> print(model.model) SparseNet(...)
- __init__(dataset, num_tcn1, num_tcn2, cell_meta, lr=0.0001, alpha=0.9, num_epoch=1500, lambda1=1.0, lambda2=1.0, lambda_balance=1.0, edge_pruning_cutoff=0.2, device=None, num_hidden=128, gt_coarse=False, gt_fine=False)
Methods
__init__(dataset, num_tcn1, num_tcn2, cell_meta)predict([save, save_dir])predict the hierarchical tissue structure assignment
train(save_dir[, output, vis_while_training])