Model

Model Creation

coral.model.create_model(lowres_dim, hires_dim, lowres_size, hires_size, cell_type_dim, latent_dim=50, hidden_channels=16, v_dim=10, high_res_data_dist='Gamma', low_res_data_dist='NB')

CORAL Model

class coral.model.model_core.CORAL_model(*args: Any, **kwargs: Any)
property px_r
property px_r_sc
property py_r
reparameterize(mu, logvar)
encode(x_combined, cell_type, edge_index)
infer_v(x_combined, z, cell_type)
forward(batch, device)

batch: subgraph for each cell

efficient_contrastive_loss(outputs, labels, margin=50)