Tutorial 3: Run HRCHY-CytoCommunity on 18 Cell MouseSpleen CODEX

Creator: Runzhi xie (rzxie@stu.xidian.edu.cn).

Affiliation: xidian University, Gao Lab

Date of Creation: 10.10.2025

Date of Last Modification: 10.10.2025

import warnings
warnings.filterwarnings("ignore")
import scanpy as sc
import numpy as np
import pandas as pd
import os
from sklearn.neighbors import kneighbors_graph
import datetime
from typing import Optional
from hrchy_cytocommunity.models.dataset import SpatialOmicsImageDataset
from hrchy_cytocommunity.models import HRCHYCytoCommunity, HRCHYCytoCommunityGrand
from hrchy_cytocommunity.visualization.visualization import load_base_data, vis_heatmap
from hrchy_cytocommunity.models.auto_k import HRCHYClusterAutoK, _dd_list,_dd_float

prepare input data

construct k-nn graph

def compute_knn(coords, K, sample_id, save_folder: Optional[str] = None):
    """
    construct KNN graph and save it into file
    
    参数:
    coords: (n, 2) ndarray, the coordinates of cells
    K: the number of nearest neighbors
    sample_id: sample id
    
    save_folder: the path of HRCHY-CytoCommunity input data
    """
    print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    print(f'Constructing KNN graph for {len(coords)} points...')
    A = kneighbors_graph(coords, K, mode='connectivity', include_self=False, n_jobs=-1)  # CSR

    A = A.maximum(A.T).tocsr()
    A.eliminate_zeros()
    A.sort_indices()

    src, dst = A.nonzero()
    edge_index = np.vstack((src, dst)).astype(np.int64) 
    edge_index = edge_index.T # or int32
    if save_folder is not None:
        filename = os.path.join(save_folder, f"{sample_id}_EdgeIndex.txt")
        np.savetxt(filename, edge_index, delimiter='\t', fmt='%d')
        print(f"Saved {len(edge_index)} edges to {filename}")
        print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    return edge_index
K = 50 # number of nearest neighbors
input_dir0 = './example_data/CODEX/raw'
sample_list = ['BALBc-2']
setting = f'KNN_{K}'
for i,sample_id in enumerate(sample_list):
    coords = np.loadtxt(f"{input_dir0}/{sample_id}_Coordinates.txt")
    print(f"{sample_id} is processing!")
    compute_knn(coords,K=K,sample_id=sample_id,save_folder=input_dir0)
    #print(len(set(adata.obs['region'])))
BALBc-2 is processing!
2025-10-10 21:44:18
Constructing KNN graph for 80832 points...
Saved 4272730 edges to ./example_data/CODEX/raw/BALBc-2_EdgeIndex.txt
2025-10-10 21:44:24

run HRCHY-CytoCommunity pipeline

finish training within 2mins (RTX4090)

data_input_dir = './example_data/CODEX'
save_dir = './results/CODEX/'
dataset = SpatialOmicsImageDataset(data_input_dir)
graph_dict = {
    'BALBc-2':0,
}
model_params = {
        'mode' : 'full',       # full HRCHYCytoCommunity model
        's' : 5,                # number of perturbations
        'num_tcn1' : 7,        # number of fine-grained TC
        'num_tcn2' : 2,        # number of coarse-grained TC
        'num_epoch' : 1500,
        'lambda1':1,           # Coefficient of consistency regularization
        'lambda_balance':1,     # Coefficient of cluster balance regularization
        'num_hidden' : 128,     # the dimension of hidden layer
        'lr' : 1e-4,            # learning rate
        'drop_rate' : 0.5,      # rate of drop node
        'gt_fine':False,        # whether input data contain fine-grained CN ground truth
        'gt_coarse':True,       # whether input data contain coarse-grained TC ground truth
        'device':'cuda:0'       # training device, if no gpu, set 'cpu'
        }
HyperPara_df = pd.DataFrame(model_params.items(), columns=['Parameter', 'Value'])
if not os.path.exists(os.path.join(save_dir)):
    os.makedirs(os.path.join(save_dir))
HyperPara_df.to_csv(os.path.join(save_dir,'HyperPara.csv'))
for slice_name,graph_idx in graph_dict.items():
    print(f"{slice_name} is processing")
    train_dataset = dataset[graph_idx]
    cell_meta = load_base_data(os.path.join(data_input_dir,'raw'),
                            graph_idx,fine_GT=model_params['gt_fine'],
                            coarse_GT=model_params['gt_coarse'])
    cell_meta = cell_meta
    print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    if model_params['mode'] == 'base':
        hrchycytocommunity = HRCHYCytoCommunity(
            dataset = train_dataset,
            num_tcn1 = model_params['num_tcn1'],
            num_tcn2 = model_params['num_tcn2'],
            cell_meta = cell_meta,
            lr = model_params['lr'],
            num_epoch = model_params['num_epoch'],
            lambda1 = model_params['lambda1'],
            lambda_balance = model_params['lambda_balance'],
            device = model_params['device'],
            gt_coarse=model_params['gt_coarse'],
            gt_fine=model_params['gt_fine'],
        )
    elif  model_params['mode'] == 'full':
        hrchycytocommunity = HRCHYCytoCommunityGrand(
            dataset = train_dataset,
            num_tcn1 = model_params['num_tcn1'],
            num_tcn2 = model_params['num_tcn2'],
            cell_meta = cell_meta,
            lr = model_params['lr'],
            num_epoch = model_params['num_epoch'],
            lambda1 = model_params['lambda1'],
            lambda_balance = model_params['lambda_balance'],
            s = model_params['s'],
            drop_rate = model_params['drop_rate'],
            device = model_params['device'],
            gt_coarse=model_params['gt_coarse'],
            gt_fine=model_params['gt_fine'],
        )
    ret_output_dir = os.path.join(save_dir,slice_name)
    hrchycytocommunity.train(save_dir=ret_output_dir, # path to save output results
                             output=False,    # whether print out training information during training
                             vis_while_training=True    # whether visualize clustering results during training
                             )
    hrchycytocommunity.predict(save = True,save_dir=ret_output_dir)
    print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
Processing...
BALBc-2 is processing !
Done!
BALBc-2 is processing
2025-10-10 21:47:43
edge_pruning_cutoff = 0.16666666666666666
100%|██████████| 1500/1500 [01:45<00:00, 14.27it/s]
2025-10-10 21:49:29

visualize result

x_label = 'x_coordinate'
y_label = 'y_coordinate'
dict_color_TC = {'1': "#818181", '2': "#E96D5C"}
dict_color_TCN = {
    '6':'#83CA83',
    # 7:'#FDCD9D',
    '4':'#CBBDDD',
    '3':'#FFFFAD',
    '1':'#F33599',
    '2':'#FF9DA3',
    '5':'#4072B2',
    '7':'#C05C19',
    # 2:'#9B55A7',
    # 3:'#8190C7'
}
import seaborn as sns
import matplotlib.pyplot as plt
df = cell_meta[[x_label,y_label]].copy()
df['coarse_cluster'] = cell_meta['coarse_cluster_id'].to_list()
df['fine_cluster'] = cell_meta['fine_cluster_id'].to_list()

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# left:coarse cluster
sns.scatterplot(
    x=x_label, y=y_label,
    data=df,
    hue='coarse_cluster',
    legend=True,
    s=1,
    palette=dict_color_TC,
    alpha=1.0,
    ax=axes[0]
)
axes[0].set_title("Coarse-grained TC")
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_xlabel(None)
axes[0].set_ylabel(None)
axes[0].invert_yaxis()
sns.despine(ax=axes[0], left=True, bottom=True)

# right:fine cluster
sns.scatterplot(
    x=x_label, y=y_label,
    data=df,
    hue='fine_cluster',
    legend=True,
    s=1,
    palette=dict_color_TCN,
    alpha=1.0,
    ax=axes[1]
)
axes[1].set_title("Fine-grained CN")
axes[1].set_xticks([])
axes[1].set_yticks([])
axes[1].set_xlabel(None)
axes[1].set_ylabel(None)
axes[1].invert_yaxis()  # Invert y-axis to match image coordinate system
sns.despine(ax=axes[1], left=True, bottom=True)
plt.show()
../_images/e8d1a5835c0b5ed0fd009bf280b876acd5d5484068bf1b3d80bea6d5a288ffcf.png