Tutorial 1 Run HRCHY-CytoCommunity on 18_Cell_TNBC_MIBI-TOF

Creator: Runzhi xie (rzxie@stu.xidian.edu.cn).

Affiliation: xidian University, Gao Lab

Date of Creation: 10.10.2025

Date of Last Modification: 10.10.2025

import warnings
warnings.filterwarnings("ignore")
import scanpy as sc
import numpy as np
import pandas as pd
import os
from sklearn.neighbors import kneighbors_graph
import datetime
from typing import Optional
from hrchy_cytocommunity.models.dataset import SpatialOmicsImageDataset
from hrchy_cytocommunity.models import HRCHYCytoCommunity, HRCHYCytoCommunityGrand
from hrchy_cytocommunity.visualization.visualization import load_base_data, vis_heatmap
from hrchy_cytocommunity.models.auto_k import HRCHYClusterAutoK, _dd_list,_dd_float

prepare input data

construct k-nn graph

def compute_knn(coords, K, sample_id, save_folder: Optional[str] = None):
    """
    construct KNN graph and save it into file
    
    参数:
    coords: (n, 2) ndarray, the coordinates of cells
    K: the number of nearest neighbors
    sample_id: sample id
    
    save_folder: the path of HRCHY-CytoCommunity input data
    """
    print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    print(f'Constructing KNN graph for {len(coords)} points...')
    A = kneighbors_graph(coords, K, mode='connectivity', include_self=False, n_jobs=-1)  # CSR

    A = A.maximum(A.T).tocsr()
    A.eliminate_zeros()
    A.sort_indices()


    src, dst = A.nonzero()
    edge_index = np.vstack((src, dst)).astype(np.int64) 
    edge_index = edge_index.T # or int32
    if save_folder is not None:
        filename = os.path.join(save_folder, f"{sample_id}_EdgeIndex.txt")
        np.savetxt(filename, edge_index, delimiter='\t', fmt='%d')
        print(f"Saved {len(edge_index)} edges to {filename}")
        print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    return edge_index
K = 50 # number of nearest neighbors
input_dir0 = './example_data/MIBI-TOF/raw'
sample_list = ['patient4']
setting = f'KNN_{K}'
for i,sample_id in enumerate(sample_list):
    coords = np.loadtxt(f"{input_dir0}/{sample_id}_Coordinates.txt")
    print(f"{sample_id} is processing!")
    compute_knn(coords,K=K,sample_id=sample_id,save_folder=input_dir0)
    #print(len(set(adata.obs['region'])))
patient4 is processing!
2025-10-10 23:08:29
Constructing KNN graph for 6667 points...
Saved 357786 edges to ./example_data/MIBI-TOF/raw/patient4_EdgeIndex.txt
2025-10-10 23:08:30

run HRCHY-CytoCommunity pipeline

finish training within 2mins (RTX4090)

data_input_dir = './example_data/MIBI-TOF'
save_dir = './results/MIBI-TOF/'
dataset = SpatialOmicsImageDataset(data_input_dir)
graph_dict = {
    'patient4':0,
}
model_params = {
        'mode' : 'full',       # full HRCHYCytoCommunity model
        's' : 5,                # number of perturbations
        'num_tcn1' : 14,        # number of fine-grained TC
        'num_tcn2' : 2,        # number of coarse-grained TC
        'num_epoch' : 1500,
        'lambda1':1,           # Coefficient of consistency regularization
        'lambda_balance':1,     # Coefficient of cluster balance regularization
        'num_hidden' : 128,     # the dimension of hidden layer
        'lr' : 1e-4,            # learning rate
        'drop_rate' : 0.5,      # rate of drop node
        'gt_fine':False,        # whether input data contain fine-grained CN ground truth
        'gt_coarse':False,       # whether input data contain coarse-grained TC ground truth
        'device':'cuda:0'       # training device, if no gpu, set 'cpu'
        }
HyperPara_df = pd.DataFrame(model_params.items(), columns=['Parameter', 'Value'])
if not os.path.exists(os.path.join(save_dir)):
    os.makedirs(os.path.join(save_dir))
HyperPara_df.to_csv(os.path.join(save_dir,'HyperPara.csv'))
for slice_name,graph_idx in graph_dict.items():
    print(f"{slice_name} is processing")
    train_dataset = dataset[graph_idx]
    cell_meta = load_base_data(os.path.join(data_input_dir,'raw'),
                            graph_idx,fine_GT=model_params['gt_fine'],
                            coarse_GT=model_params['gt_coarse'])
    cell_meta = cell_meta
    print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
    if model_params['mode'] == 'base':
        hrchycytocommunity = HRCHYCytoCommunity(
            dataset = train_dataset,
            num_tcn1 = model_params['num_tcn1'],
            num_tcn2 = model_params['num_tcn2'],
            cell_meta = cell_meta,
            lr = model_params['lr'],
            num_epoch = model_params['num_epoch'],
            lambda1 = model_params['lambda1'],
            lambda_balance = model_params['lambda_balance'],
            device = model_params['device'],
            gt_coarse=model_params['gt_coarse'],
            gt_fine=model_params['gt_fine'],
        )
    elif  model_params['mode'] == 'full':
        hrchycytocommunity = HRCHYCytoCommunityGrand(
            dataset = train_dataset,
            num_tcn1 = model_params['num_tcn1'],
            num_tcn2 = model_params['num_tcn2'],
            cell_meta = cell_meta,
            lr = model_params['lr'],
            num_epoch = model_params['num_epoch'],
            lambda1 = model_params['lambda1'],
            lambda_balance = model_params['lambda_balance'],
            s = model_params['s'],
            drop_rate = model_params['drop_rate'],
            device = model_params['device'],
            gt_coarse=model_params['gt_coarse'],
            gt_fine=model_params['gt_fine'],
        )
    ret_output_dir = os.path.join(save_dir,slice_name)
    hrchycytocommunity.train(save_dir=ret_output_dir, # path to save output results
                             output=False,    # whether print out training information during training
                             vis_while_training=True    # whether visualize clustering results during training
                             )
    hrchycytocommunity.predict(save = True,save_dir=ret_output_dir)
    print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
patient4 is processing
2025-10-10 23:09:33
edge_pruning_cutoff = 0.07692307692307693
  0%|          | 0/1500 [00:00<?, ?it/s]
100%|██████████| 1500/1500 [00:43<00:00, 34.47it/s]
2025-10-10 23:10:16

visualize result

x_label = 'x_coordinate'
y_label = 'y_coordinate'
dict_color_ct = {"CD4+ T": "#fee08b", 
                 "B cells": "Red", 
                 "Dendritic cells": "Black", 
                 "CD8+ T": "MediumBlue", 
                 "CD3+ T": "Purple", 
                 "Macrophages": "#00A087",
                 "Dendritic cells and monocytes mixed subpopulation": "#1F77B4", 
                 "NK cells": "#a50026", 
                 "Regulatory T": "#FF7F0E", 
                 "Other immune cells": "#9467BD",
                "Monocytes and neutrophils mixed subpopulation": "#2CA02C", 
                "Neutrophil": "#8C564B", 
                "Neoplastic cells": "#E377C2", 
                "Mesenchymal stromal cells": "#7F7F7F",
                "Keratin+ neoplastic cells": "#543005", 
                "Unknown subpopulation": "#BCBD22", 
                "Endothelial cells": "#17BECF"}

dict_color_TC = {
    0:"#7fc97f",
    1:"#beaed4"
}
dict_color_TCN = {0: "#5a3b1c", 1: "#939396", 2: "#2c663b", 3: "#d63189", 4: "#54a9dd",
                   5: "#813188", 6: "Orange", 7: "#231f1f", 8: "#912b61",  9: "#dc143c",
                   10: "#ba55d3", 11: "#7b68ee", 12: "#00008b", 13: "#cd853f", 
                   #"15": "#5f9eA0"
                   }

load hierarchical tissue structure assignment

cell_meta = load_base_data(os.path.join(data_input_dir,'raw'),
                        graph_idx,fine_GT=model_params['gt_fine'],
                        coarse_GT=model_params['gt_coarse'])
fine_cluster_id = pd.read_csv(f"{save_dir}/{slice_name}/fine_ClusterAssignMatrix_hard.csv",header=None)[0].tolist()
coarse_cluster_id = pd.read_csv(f"{save_dir}/{slice_name}/coarse_ClusterAssignMatrix_hard.csv",header=None)[0]
coarse_cluster_id = coarse_cluster_id[fine_cluster_id].tolist()
cell_meta['coarse_cluster'] = coarse_cluster_id
cell_meta['fine_cluster'] = fine_cluster_id
import seaborn as sns
import matplotlib.pyplot as plt
df = cell_meta[[x_label,y_label]].copy()
df['coarse_cluster'] = cell_meta['coarse_cluster'].to_list()
df['fine_cluster'] = cell_meta['fine_cluster'].to_list()

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# left:coarse cluster
sns.scatterplot(
    x=x_label, y=y_label,
    data=df,
    hue='coarse_cluster',
    legend=True,
    s=10,
    palette=dict_color_TC,
    alpha=1.0,
    ax=axes[0]
)
axes[0].set_title("Coarse-grained TC")
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_xlabel(None)
axes[0].set_ylabel(None)
axes[0].invert_yaxis()
sns.despine(ax=axes[0], left=True, bottom=True)

# right:fine cluster
sns.scatterplot(
    x=x_label, y=y_label,
    data=df,
    hue='fine_cluster',
    legend=True,
    s=10,
    palette=dict_color_TCN,
    alpha=1.0,
    ax=axes[1]
)
axes[1].set_title("Fine-grained CN")
axes[1].set_xticks([])
axes[1].set_yticks([])
axes[1].set_xlabel(None)
axes[1].set_ylabel(None)
axes[1].invert_yaxis()  # Invert y-axis to match image coordinate system
sns.despine(ax=axes[1], left=True, bottom=True)
plt.show()
../_images/0b0137d0a11213d5c9de6738ef2bbc63da70d17a5414ff406e7f388abc4fdbf4.png