From 31e7526937f5abc8900e1da6c0eecd246dcd5b34 Mon Sep 17 00:00:00 2001 From: Emma Chen Date: Wed, 7 Aug 2024 18:10:15 -0400 Subject: [PATCH] Replace COVET representation with NICER (NeighborIng Cells Expression Representation) --- scenvi/ENVI.py | 379 +++++++++++++++++++++++++++++++++++++--------- scenvi/_dists.py | 48 ++++++ scenvi/utils.py | 381 +++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 724 insertions(+), 84 deletions(-) diff --git a/scenvi/ENVI.py b/scenvi/ENVI.py index a56ffa9..926fb34 100644 --- a/scenvi/ENVI.py +++ b/scenvi/ENVI.py @@ -11,17 +11,21 @@ from flax import linen as nn from jax import jit, random from tqdm import trange +from math import sqrt -from scenvi._dists import ( +import sys +sys.path.insert(1, '/home/chene5/ENVI_new_copy/scenvi') +from _dists import ( KL, AOT_Distance, + S2, log_nb_pdf, log_normal_pdf, log_pos_pdf, log_zinb_pdf, ) -from scenvi.utils import CVAE, Metrics, TrainState, compute_covet, niche_cell_type +from utils import CVAE, Metrics, TrainState, compute_covet, compute_niche, niche_cell_type, DefaultConfig class ENVI: @@ -63,7 +67,7 @@ def __init__( num_neurons=1024, latent_dim=512, k_nearest=8, - num_cov_genes=64, + #num_cov_genes=64, cov_genes=[], num_HVG=2048, sc_genes=[], @@ -71,7 +75,8 @@ def __init__( sc_dist="nb", spatial_coeff=1, sc_coeff=1, - cov_coeff=1, + #cov_coeff=1, + niche_coeff=0.01, kl_coeff=0.3, log_input=0.1, stable_eps=1e-6, @@ -79,20 +84,18 @@ def __init__( self.spatial_data = spatial_data[:, np.intersect1d(spatial_data.var_names, sc_data.var_names)] self.sc_data = sc_data - - if "highly_variable" not in self.sc_data.var.columns: - if 'log' in self.sc_data.layers.keys(): - sc.pp.highly_variable_genes(self.sc_data, n_top_genes=num_HVG, layer="log") - elif('log1p' in self.sc_data.layers.keys()): - sc.pp.highly_variable_genes(self.sc_data, n_top_genes=num_HVG, layer="log1p") - elif(self.sc_data.X.min() < 0): - sc.pp.highly_variable_genes(self.sc_data, n_top_genes=num_HVG) - else: - sc_data.layers["log"] = np.log(self.sc_data.X + 1) - sc.pp.highly_variable_genes(self.sc_data, n_top_genes=num_HVG, layer="log") + if log_input > 0: + self.spatial_data.layers["log"] = np.log(self.spatial_data.X + log_input) + self.sc_data.layers["log"] = np.log(self.sc_data.X + log_input) + + if "highly_variable" not in sc_data.var.columns: + sc_data.layers["log"] = np.log(sc_data.X + 1) + sc.pp.highly_variable_genes( + sc_data, layer="log", n_top_genes=min(num_HVG, sc_data.shape[-1]) + ) sc_genes_keep = np.union1d( - self.sc_data.var_names[self.sc_data.var.highly_variable], self.spatial_data.var_names + sc_data.var_names[sc_data.var.highly_variable], self.spatial_data.var_names ) if len(sc_genes) > 0: sc_genes_keep = np.union1d(sc_genes_keep, sc_genes) @@ -121,27 +124,40 @@ def __init__( self.spatial_key = spatial_key self.batch_key = batch_key self.cov_genes = cov_genes - self.num_cov_genes = num_cov_genes + #self.num_cov_genes = num_cov_genes + self.overlap_num = self.overlap_genes.shape[0] + # self.cov_gene_num = self.spatial_data.obsm["COVET_SQRT"].shape[-1] + self.n_niche_genes = self.spatial_data.X.shape[-1] + self.full_trans_gene_num = self.sc_data.shape[-1] + - print("Computing Niche Covariance Matrices") + # print("Computing Niche Covariance Matrices") + # ( + # self.spatial_data.obsm["COVET"], + # self.spatial_data.obsm["COVET_SQRT"], + # self.CovGenes, + # ) = compute_covet( + # self.spatial_data, + # self.k_nearest, + # self.num_cov_genes, + # self.cov_genes, + # spatial_key=self.spatial_key, + # batch_key=self.batch_key, + # ) + + print("Computing Niche Matrices") ( - self.spatial_data.obsm["COVET"], - self.spatial_data.obsm["COVET_SQRT"], - self.CovGenes, - ) = compute_covet( + self.gene_mins, + self.gene_maxs + ) = compute_niche( self.spatial_data, + self.n_niche_genes, self.k_nearest, - self.num_cov_genes, - self.cov_genes, spatial_key=self.spatial_key, batch_key=self.batch_key, ) - self.overlap_num = self.overlap_genes.shape[0] - self.cov_gene_num = self.spatial_data.obsm["COVET_SQRT"].shape[-1] - self.full_trans_gene_num = self.sc_data.shape[-1] - self.num_layers = num_layers self.num_neurons = num_neurons self.latent_dim = latent_dim @@ -158,10 +174,11 @@ def __init__( self.spatial_coeff = spatial_coeff self.sc_coeff = sc_coeff - self.cov_coeff = cov_coeff + # self.cov_coeff = cov_coeff + self.niche_coeff = niche_coeff self.kl_coeff = kl_coeff - if self.sc_dist == "norm" or self.spatial_dist == "norm" or self.spatial_data.X.min()<0 or self.sc_data.X.min()<0: + if self.sc_dist == "norm" or self.spatial_dist == "norm": self.log_input = -1 else: self.log_input = log_input @@ -175,7 +192,9 @@ def __init__( n_neurons=self.num_neurons, n_latent=self.latent_dim, n_output_exp=self.exp_dec_size, - n_output_cov=int(self.cov_gene_num * (self.cov_gene_num + 1) / 2), + # n_output_cov=int(self.cov_gene_num * (self.cov_gene_num + 1) / 2), + k_nearest=self.k_nearest, + n_niche_genes = self.n_niche_genes ) print("Finished Initializing ENVI") @@ -320,6 +339,16 @@ def grammian_cov(self, dec_cov): dec_cov = jax_prob.math.fill_triangular(dec_cov) return jnp.matmul(dec_cov, dec_cov.transpose([0, 2, 1])) + def batched_S2(self, x, y, epsilon, lse_mode): + """ + Helper function to run S2 on x and y (batches of 2D matrices) + :return: mean_OT_dists, OT_dists + OT_dists is a (batchsize,) array + """ + + OT_dists = jax.vmap(lambda x, y: S2(x, y, epsilon, lse_mode), in_axes=[0,0])(x, y) + return jnp.mean(OT_dists), OT_dists + def create_train_state(self, key=random.key(0), init_lr=0.0001, decay_steps=4000): """ :meta private: @@ -341,15 +370,16 @@ def create_train_state(self, key=random.key(0), init_lr=0.0001, decay_steps=4000 ) @partial(jit, static_argnums=(0,)) - def train_step(self, state, spatial_inp, spatial_COVET, sc_inp, key=random.key(0)): + def train_step(self, state, spatial_inp, spatial_niche, sc_inp, key=random.key(0)): """ :meta private: """ key, subkey1, subkey2 = random.split(key, num=3) - def loss_fn(params): - spatial_enc_mu, spatial_enc_logstd, spatial_dec_exp, spatial_dec_cov = ( + def loss_fn(params): # redefine loss function + # spatial_enc_mu, spatial_enc_logstd, spatial_dec_exp, spatial_dec_cov = ( + spatial_enc_mu, spatial_enc_logstd, spatial_dec_exp, spatial_dec_niche = ( state.apply_fn( {"params": params}, x=self.inp_log_fn(spatial_inp), @@ -366,23 +396,32 @@ def loss_fn(params): spatial_exp_like = self.factor_spatial(spatial_inp, spatial_dec_exp) sc_exp_like = self.factor_sc(sc_inp, sc_dec_exp) - spatial_cov_like = jnp.mean( - AOT_Distance(spatial_COVET, self.grammian_cov(spatial_dec_cov)) - ) + + # spatial_cov_like = jnp.mean( + # AOT_Distance(spatial_COVET, self.grammian_cov(spatial_dec_cov)) + # ) + + spatial_niche_cost = self.batched_S2(spatial_niche, + spatial_dec_niche, + epsilon=1e-2, + lse_mode=True)[0] + kl_div = jnp.mean(KL(spatial_enc_mu, spatial_enc_logstd)) + jnp.mean( KL(sc_enc_mu, sc_enc_logstd) ) loss = ( - -self.spatial_coeff * spatial_exp_like + - self.spatial_coeff * spatial_exp_like - self.sc_coeff * sc_exp_like - - self.cov_coeff * spatial_cov_like + # - self.cov_coeff * spatial_cov_like + + self.niche_coeff * spatial_niche_cost + self.kl_coeff * kl_div ) return ( loss, - [sc_exp_like, spatial_exp_like, spatial_cov_like, kl_div * 0.5], + # [sc_exp_like, spatial_exp_like, spatial_cov_like, kl_div * 0.5], + [sc_exp_like, spatial_exp_like, spatial_niche_cost, kl_div * 0.5], ) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) @@ -421,10 +460,11 @@ def train( state = self.create_train_state( subkey, init_lr=init_lr, decay_steps=decay_steps ) + self.state = state # save state for later self.params = state.params tq = trange(training_steps, leave=True, desc="") - sc_loss_mean, spatial_loss_mean, cov_loss_mean, kl_loss_mean, count = ( + sc_loss_mean, spatial_loss_mean, niche_loss_mean, kl_loss_mean, count = ( 0, 0, 0, @@ -434,7 +474,8 @@ def train( sc_X = self.sc_data.X spatial_X = self.spatial_data.X - spatial_COVET = self.spatial_data.obsm["COVET_SQRT"] + # spatial_COVET = self.spatial_data.obsm["COVET_SQRT"] + spatial_niche = self.spatial_data.obsm["scaled_niche"] for training_step in tq: key, subkey1, subkey2 = random.split(key, num=3) @@ -449,24 +490,29 @@ def train( key=subkey2, a=self.sc_data.shape[0], shape=[batch_size], replace=False ) - batch_spatial_exp, batch_spatial_cov = ( + # batch_spatial_exp, batch_spatial_cov = ( + # spatial_X[batch_spatial_ind], + # spatial_COVET[batch_spatial_ind], + # ) + batch_spatial_exp, batch_spatial_niche = ( spatial_X[batch_spatial_ind], - spatial_COVET[batch_spatial_ind], + spatial_niche[batch_spatial_ind], ) batch_sc_exp = sc_X[batch_sc_ind] key, subkey = random.split(key) state, loss = self.train_step( - state, batch_spatial_exp, batch_spatial_cov, batch_sc_exp, key=subkey + # state, batch_spatial_exp, batch_spatial_cov, batch_sc_exp, key=subkey + state, batch_spatial_exp, batch_spatial_niche, batch_sc_exp, key=subkey ) self.params = state.params - sc_loss_mean, spatial_loss_mean, cov_loss_mean, kl_loss_mean, count = ( + sc_loss_mean, spatial_loss_mean, niche_loss_mean, kl_loss_mean, count = ( sc_loss_mean + loss[1][0], spatial_loss_mean + loss[1][1], - cov_loss_mean + loss[1][2], + niche_loss_mean + loss[1][2], kl_loss_mean + loss[1][3], count + 1, ) @@ -474,8 +520,8 @@ def train( if training_step % verbose == 0: print_statement = "" for metric, value in zip( - ["spatial", "sc", "cov", "kl"], - [spatial_loss_mean, sc_loss_mean, cov_loss_mean, kl_loss_mean], + ["spatial", "sc", "niche", "kl"], + [spatial_loss_mean, sc_loss_mean, niche_loss_mean, kl_loss_mean], ): print_statement = ( print_statement @@ -484,7 +530,7 @@ def train( + ": {:.3e}".format(value / count) ) - sc_loss_mean, spatial_loss_mean, cov_loss_mean, kl_loss_mean, count = ( + sc_loss_mean, spatial_loss_mean, niche_loss_mean, kl_loss_mean, count = ( 0, 0, 0, @@ -496,7 +542,117 @@ def train( self.latent_rep() - @partial(jit, static_argnums=(0,)) + def start_train( + self, + init_lr=0.0001, + decay_steps=4000, + key=random.key(0) + ): + key, subkey = random.split(key) + state = self.create_train_state( + subkey, init_lr=init_lr, decay_steps=decay_steps + ) + self.state = state # save state for later + self.params = state.params + + def continue_train( + self, + training_steps=16000, + batch_size=128, + verbose=16, + key=random.key(0), + ): + ''' + continue training with pre-existing train state + ''' + + batch_size = min( + self.sc_data.shape[0], min(self.spatial_data.shape[0], batch_size) + ) + + state = self.state + self.params = state.params + + tq = trange(training_steps, leave=True, desc="") + sc_loss_mean, spatial_loss_mean, niche_loss_mean, kl_loss_mean, count = ( + 0, + 0, + 0, + 0, + 0, + ) + + sc_X = self.sc_data.X + spatial_X = self.spatial_data.X + # spatial_COVET = self.spatial_data.obsm["COVET_SQRT"] + spatial_niche = self.spatial_data.obsm["scaled_niche"] + + for training_step in tq: + key, subkey1, subkey2 = random.split(key, num=3) + + batch_spatial_ind = random.choice( + key=subkey1, + a=self.spatial_data.shape[0], + shape=[batch_size], + replace=False, + ) + batch_sc_ind = random.choice( + key=subkey2, a=self.sc_data.shape[0], shape=[batch_size], replace=False + ) + + # batch_spatial_exp, batch_spatial_cov = ( + # spatial_X[batch_spatial_ind], + # spatial_COVET[batch_spatial_ind], + # ) + batch_spatial_exp, batch_spatial_niche = ( + spatial_X[batch_spatial_ind], + spatial_niche[batch_spatial_ind], + ) + batch_sc_exp = sc_X[batch_sc_ind] + + key, subkey = random.split(key) + + state, loss = self.train_step( + # state, batch_spatial_exp, batch_spatial_cov, batch_sc_exp, key=subkey + state, batch_spatial_exp, batch_spatial_niche, batch_sc_exp, key=subkey + ) + + self.params = state.params + + sc_loss_mean, spatial_loss_mean, niche_loss_mean, kl_loss_mean, count = ( + sc_loss_mean + loss[1][0], + spatial_loss_mean + loss[1][1], + niche_loss_mean + loss[1][2], + kl_loss_mean + loss[1][3], + count + 1, + ) + + if training_step % verbose == 0: + print_statement = "" + for metric, value in zip( + ["spatial", "sc", "niche", "kl"], + [spatial_loss_mean, sc_loss_mean, niche_loss_mean, kl_loss_mean], + ): + print_statement = ( + print_statement + + " " + + metric + + ": {:.3e}".format(value / count) + ) + + sc_loss_mean, spatial_loss_mean, niche_loss_mean, kl_loss_mean, count = ( + 0, + 0, + 0, + 0, + 0, + ) + tq.set_description(print_statement) + tq.refresh() # to show + + self.latent_rep() + + #@partial(jit, static_argnums=(0,)) def model_encoder(self, x): """ :meta private: @@ -504,7 +660,7 @@ def model_encoder(self, x): return self.model.bind({"params": self.params}).encoder(x) - @partial(jit, static_argnums=(0,)) + #@partial(jit, static_argnums=(0,)) def model_decoder_exp(self, x): """ :meta private: @@ -512,12 +668,20 @@ def model_decoder_exp(self, x): return self.model.bind({"params": self.params}).decoder_exp(x) - @partial(jit, static_argnums=(0,)) - def model_decoder_cov(self, x): + # @partial(jit, static_argnums=(0,)) + # def model_decoder_cov(self, x): + # """ + # :meta private: + # """ + # return self.model.bind({"params": self.params}).decoder_cov(x) + + #@partial(jit, static_argnums=(0,)) + def model_decoder_niche(self, x): """ :meta private: """ - return self.model.bind({"params": self.params}).decoder_cov(x) + return self.model.bind({"params": self.params}).decoder_niche(x) + def encode(self, x, mode="spatial", max_batch=256): """ @@ -591,7 +755,33 @@ def decode_exp(self, x, mode="spatial", max_batch=256): ) return dec - def decode_cov(self, x, max_batch=256): + # def decode_cov(self, x, max_batch=256): + # """ + # :meta private: + # """ + + # conf_const = 0 + # conf_neurons = jax.nn.one_hot( + # conf_const * jnp.ones(x.shape[0], dtype=jnp.int8), 2, dtype=jnp.float32 + # ) + + # x_conf = jnp.concatenate([x, conf_neurons], axis=-1) + + # if x_conf.shape[0] < max_batch: + # dec = self.grammian_cov(self.model_decoder_cov(x_conf)) + # else: # For when the GPU can't pass all point-clouds at once + # num_split = int(x_conf.shape[0] / max_batch) + 1 + # x_conf_split = np.array_split(x_conf, num_split) + # dec = np.concatenate( + # [ + # self.grammian_cov(self.model_decoder_cov(x_conf_split[split_ind])) + # for split_ind in range(num_split) + # ], + # axis=0, + # ) + # return dec + + def decode_niche(self, x, max_batch=256): """ :meta private: """ @@ -604,17 +794,18 @@ def decode_cov(self, x, max_batch=256): x_conf = jnp.concatenate([x, conf_neurons], axis=-1) if x_conf.shape[0] < max_batch: - dec = self.grammian_cov(self.model_decoder_cov(x_conf)) + dec = self.model_decoder_niche(x_conf) else: # For when the GPU can't pass all point-clouds at once num_split = int(x_conf.shape[0] / max_batch) + 1 x_conf_split = np.array_split(x_conf, num_split) dec = np.concatenate( [ - self.grammian_cov(self.model_decoder_cov(x_conf_split[split_ind])) + self.model_decoder_niche(x_conf_split[split_ind]) for split_ind in range(num_split) ], axis=0, ) + return dec def latent_rep(self): @@ -631,12 +822,28 @@ def latent_rep(self): self.sc_data[:, self.spatial_data.var_names].X, mode="sc" ) + def latent_rep_x(self, x, mode): + """ + Compute latent embeddings for spatial and single cell data, automatically performed after training + + :return: nothing, adds 'envi_latent' self.spatial_data.obsm and self.spatial_data.obsm + """ + if mode=="spatial": + input = x.X + elif mode=="sc": + input = x[:, self.spatial_data.var_names].X + + x.obsm["envi_latent"] = self.encode( + input, mode=mode + ) + def impute_genes(self): """ Impute full transcriptome for spatial data :return: nothing, adds 'imputation' to self.spatial_data.obsm """ + # does not re-call latent_rep method. make sure latent is up-to-date self.spatial_data.obsm["imputation"] = pd.DataFrame( self.decode_exp(self.spatial_data.obsm["envi_latent"], mode="sc"), @@ -648,19 +855,53 @@ def impute_genes(self): "Finished imputing missing gene for spatial data! See 'imputation' in obsm of ENVI.spatial_data" ) - def infer_niche_covet(self): + # def infer_niche_covet(self): + # """ + # Predict COVET representation for single-cell data + + # :return: nothing, adds 'COVET_SQRT' and 'COVET' to self.sc_data.obsm + # """ + + # self.sc_data.obsm["COVET_SQRT"] = self.decode_cov( + # self.sc_data.obsm["envi_latent"] + # ) + # self.sc_data.obsm["COVET"] = np.matmul( + # self.sc_data.obsm["COVET_SQRT"], self.sc_data.obsm["COVET_SQRT"] + # ) + + + def infer_niche_x(self, x, mode): """ - Predict COVET representation for single-cell data + Predict niche representation for single-cell data - :return: nothing, adds 'COVET_SQRT' and 'COVET' to self.sc_data.obsm + :return: nothing, adds 'niche' to self.sc_data.obsm """ - self.sc_data.obsm["COVET_SQRT"] = self.decode_cov( - self.sc_data.obsm["envi_latent"] - ) - self.sc_data.obsm["COVET"] = np.matmul( - self.sc_data.obsm["COVET_SQRT"], self.sc_data.obsm["COVET_SQRT"] + self.latent_rep_x(x, mode=mode) + + x.obsm["inferred_scaled_niche"] = self.decode_niche( + x.obsm["envi_latent"] ) + x.obsm["inferred_niche"] = (x.obsm["inferred_scaled_niche"] * sqrt(self.n_niche_genes) + 1) / 2 * (self.gene_maxs - self.gene_mins) + self.gene_mins + + + def infer_niche_sc(self): + """ + Predict niche representation for single-cell data + + :return: nothing, adds 'inferred_niche' to self.sc_data.obsm + """ + self.infer_niche_x(self.sc_data, mode="sc") + + + def infer_niche_st(self): + """ + For validation + + :return: nothing, adds 'inferred_scaled_niche' and 'inferred_niche' to self.sc_data.obsm + """ + self.infer_niche_x(self.spatial_data, mode="spatial") + def infer_niche_celltype(self, cell_type_key="cell_type"): """ @@ -680,14 +921,14 @@ def infer_niche_celltype(self, cell_type_key="cell_type"): ) regression_model = sklearn.neighbors.KNeighborsRegressor(n_neighbors=5).fit( - self.spatial_data.obsm["COVET_SQRT"].reshape( + self.spatial_data.obsm["scaled_niche"].reshape( [self.spatial_data.shape[0], -1] ), self.spatial_data.obsm["cell_type_niche"], ) sc_cell_type = regression_model.predict( - self.sc_data.obsm["COVET_SQRT"].reshape([self.sc_data.shape[0], -1]) + self.sc_data.obsm["inferred_scaled_niche"].reshape([self.sc_data.shape[0], -1]) ) self.sc_data.obsm["cell_type_niche"] = pd.DataFrame( @@ -695,3 +936,5 @@ def infer_niche_celltype(self, cell_type_key="cell_type"): index=self.sc_data.obs_names, columns=self.spatial_data.obsm["cell_type_niche"].columns, ) + + diff --git a/scenvi/_dists.py b/scenvi/_dists.py index 89488bf..2f78847 100644 --- a/scenvi/_dists.py +++ b/scenvi/_dists.py @@ -1,5 +1,7 @@ import jax.numpy as jnp import tensorflow_probability.substrates.jax.distributions as jnd +import ott +from ott.solvers import linear def KL(mean, log_std): @@ -57,3 +59,49 @@ def AOT_Distance(sample, mean): mean = jnp.reshape(mean, [mean.shape[0], -1]) log_prob = -jnp.square(sample - mean) return jnp.mean(log_prob, axis=-1) + + +# from Wasserstein Wormhole +def S2(x, y, eps, lse_mode): + + """ + Calculate Sinkhorn Divergnece (S2) between two weighted point clouds + + + :param x: (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight) + :param y: (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight) + :param eps: (float) coefficient of entropic regularization + :param lse_mode: (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False) + + :return S2: Sinkhorn Divergnece between x and y + """ + + # x,a = x[0], x[1] + # y,b = y[0], y[1] + # default to uniform without specifying a and b + + ot_solve_xy = linear.solve( + ott.geometry.pointcloud.PointCloud(x, y, cost_fn=None, epsilon = eps), + # a = a, + # b = b, + lse_mode=lse_mode, + min_iterations=0, + max_iterations=100) + + ot_solve_xx = linear.solve( + ott.geometry.pointcloud.PointCloud(x, x, cost_fn=None, epsilon = eps), + # a = a, + # b = a, + lse_mode=lse_mode, + min_iterations=0, + max_iterations=100) + + ot_solve_yy = linear.solve( + ott.geometry.pointcloud.PointCloud(y, y, cost_fn=None, epsilon = eps), + # a = b, + # b = b, + lse_mode=lse_mode, + min_iterations=0, + max_iterations=100) + + return(ot_solve_xy.reg_ot_cost - 0.5 * ot_solve_xx.reg_ot_cost - 0.5 * ot_solve_yy.reg_ot_cost) diff --git a/scenvi/utils.py b/scenvi/utils.py index 0d1fd93..fc2c6b7 100644 --- a/scenvi/utils.py +++ b/scenvi/utils.py @@ -9,7 +9,22 @@ from flax import struct from flax.training import train_state from jax import random -import scipy.sparse +from math import sqrt + + +# def MaxMinScale(arr): +# """ +# :meta private: +# """ + +# arr = ( +# 2 +# * (arr - arr.min(axis=0, keepdims=True)) +# / (arr.max(axis=0, keepdims=True) - arr.min(axis=0, keepdims=True)) +# - 1 +# ) +# return arr + class FeedForward(nn.Module): """ @@ -69,7 +84,9 @@ class CVAE(nn.Module): n_neurons: int n_latent: int n_output_exp: int - n_output_cov: int + # n_output_cov: int + k_nearest: int + n_niche_genes: int def setup(self): """ @@ -80,7 +97,9 @@ def setup(self): n_neurons = self.n_neurons n_latent = self.n_latent n_output_exp = self.n_output_exp - n_output_cov = self.n_output_cov + # n_output_cov = self.n_output_cov + k_nearest = self.k_nearest + n_niche_genes = self.n_niche_genes self.encoder = FeedForward( n_layers=n_layers, n_neurons=n_neurons, n_output=n_latent * 2 @@ -90,8 +109,17 @@ def setup(self): n_layers=n_layers, n_neurons=n_neurons, n_output=n_output_exp ) - self.decoder_cov = FeedForward( - n_layers=n_layers, n_neurons=n_neurons, n_output=n_output_cov + # self.decoder_cov = FeedForward( + # n_layers=n_layers, n_neurons=n_neurons, n_output=n_output_cov + # ) + # incorporated this into AttentionDecoderModel + + self.decoder_niche = AttentionDecoderModel( + n_layers=n_layers, + n_neurons=n_neurons, + config=DefaultConfig(), + out_seq_len=k_nearest, + inp_dim=n_niche_genes ) def __call__(self, x, mode="spatial", key=random.key(0)): @@ -116,8 +144,10 @@ def __call__(self, x, mode="spatial", key=random.key(0)): dec_exp = self.decoder_exp(z_conf) if mode == "spatial": - dec_cov = self.decoder_cov(z_conf) - return (enc_mu, enc_logstd, dec_exp, dec_cov) + # dec_cov = self.decoder_cov(z_conf) + # return (enc_mu, enc_logstd, dec_exp, dec_cov) + dec_niche = self.decoder_niche(z_conf) + return (enc_mu, enc_logstd, dec_exp, dec_niche) return (enc_mu, enc_logstd, dec_exp) @@ -176,6 +206,35 @@ def BatchKNN(data, batch, k): return kNNGraphIndex.astype("int") +def CalcNicheMats(spatial_data, kNN, spatial_key="spatial", batch_key=-1): + """ + :meta private: + """ + + ExpData = spatial_data.layers["scaled_log"] # constructed using scaled log data + + if batch_key == -1: + kNNGraph = sklearn.neighbors.kneighbors_graph( + spatial_data.obsm[spatial_key], + n_neighbors=kNN, + mode="connectivity", + n_jobs=-1, + ).tocoo() + kNNGraphIndex = np.reshape( + np.asarray(kNNGraph.col), [spatial_data.obsm[spatial_key].shape[0], kNN] + ) + else: + kNNGraphIndex = BatchKNN( + spatial_data.obsm[spatial_key], spatial_data.obs[batch_key], kNN + ) + + NicheMats = ( + ExpData[kNNGraphIndex[np.arange(ExpData.shape[0])]] + ) + + return NicheMats + + def CalcCovMats(spatial_data, kNN, genes, spatial_key="spatial", batch_key=-1): """ :meta private: @@ -276,15 +335,8 @@ def compute_covet( CovGenes = spatial_data.var_names else: if "highly_variable" not in spatial_data.var.columns: - if 'log' in spatial_data.layers.keys(): - sc.pp.highly_variable_genes(spatial_data, n_top_genes=g, layer="log") - elif('log1p' in spatial_data.layers.keys()): - sc.pp.highly_variable_genes(spatial_data, n_top_genes=g, layer="log1p") - elif(spatial_data.X.min() < 0): - sc.pp.highly_variable_genes(spatial_data, n_top_genes=g) - else: - spatial_data.layers["log"] = np.log(spatial_data.X + 1) - sc.pp.highly_variable_genes(spatial_data, n_top_genes=g, layer="log") + spatial_data.layers["log"] = np.log(spatial_data.X + 1) + sc.pp.highly_variable_genes(spatial_data, n_top_genes=g, layer="log") CovGenes = np.asarray(spatial_data.var_names[spatial_data.var.highly_variable]) if len(genes) > 0: @@ -303,3 +355,300 @@ def compute_covet( COVET_SQRT.astype("float32"), np.asarray(CovGenes).astype("str"), ) + +def compute_niche( + spatial_data, n_niche_genes, k=8, spatial_key="spatial", batch_key=-1 +): + """ + Compute niche matrices for spatial data in log space + + :param spatial_data: (anndata) spatial data, with an obsm indicating spatial location of spot/segmented cell + :param k: (int) number of nearest neighbours to define niche (default 8) + :param spatial_key: (str) obsm key name with physical location of spots/cells (default 'spatial') + :param batch_key: (str) obs key name of batch/sample of spatial data (default 'batch' if in spatial_data.obs, else -1) + + :return Niches: niche matrices + :return CovGenes: list of genes selected for COVET representation + """ + + # get min max of each gene in spatial_data and scale to -1 to 1 + gene_mins, gene_maxs = spatial_data.layers['log'].min(axis=0), spatial_data.layers['log'].max(axis=0) + spatial_data.layers["scaled_log"] = ((spatial_data.layers["log"] - gene_mins) / (gene_maxs - gene_mins) * 2 - 1) / sqrt(n_niche_genes) #divide by n_niche_genes + + if batch_key not in spatial_data.obs.columns: + batch_key = -1 + + Niches = CalcNicheMats( + spatial_data, k, spatial_key=spatial_key, batch_key=batch_key + ) + + spatial_data.obsm["scaled_niche"] = Niches.astype("float32") + spatial_data.obsm["niche"] = (spatial_data.obsm["scaled_niche"] * sqrt(n_niche_genes) + 1) / 2 * (gene_maxs - gene_mins) + gene_mins # still in log space + + + return ( + gene_mins, + gene_maxs + ) + + + + + + +# from Wasserstein Wormhole +from typing import Callable, Any, Optional +from jax.typing import ArrayLike + + +@struct.dataclass +class DefaultConfig: + + """ + Object with configuration parameters for Wormhole + + + :param dtype: (data type) float point precision for Wormhole model (default jnp.float32) + :param dist_func_enc: (str) OT metric used for embedding space (default 'S2', could be 'W1', 'S1', 'W2', 'S2', 'GW' and 'GS') + :param dist_func_dec: (str) OT metric used for Wormhole decoder loss (default 'S2', could be 'W1', 'S1', 'W2', 'S2', 'GW' and 'GS') + :param eps_enc: (float) entropic regularization for embedding OT (default 0.1) + :param eps_dec: (float) entropic regularization for Wormhole decoder loss (default 0.1) + :param lse_enc: (bool) whether to use log-sum-exp mode or kernel mode for embedding OT (default False) + :param lse_dec: (bool) whether to use log-sum-exp mode or kernel mode for decoder OT (default True) + :param coeff_dec: (float) coefficient for decoder loss (default 1) + :param scale: (str) how to scale input point clouds ('min_max_total' and scales all point clouds so values are between -1 and 1) + :param factor: (float) multiplicative factor applied on point cloud coordinates after scaling (default 1) + :param emb_dim: (int) Wormhole embedding dimention (defulat 128) + :param num_heads: (int) number of heads in multi-head attention (default 4) + :param num_layers: (int) number of layers of multi-head attention for Wormhole encoder and decoder (default 3) + :param qkv_dim: (int) dimention of query, key and value attributes in attention (default 512) + :param mlp_dim: (int) dimention of hidden layer for fully-connected network after every multi-head attention layer + :param attention_dropout_rate: (float) dropout rate for attention matrices during training (default 0.1) + :param kernel_init: (Callable) initializer of kernel weights (default nn.initializers.glorot_uniform()) + :param bias_init: ((Callable) initializer of bias weights (default nn.initializers.zeros_init()) + """ + + dtype: Any = jnp.float32 + dist_func_enc: str = 'S2' + dist_func_dec: str = 'S2' + eps_enc: float = 0.1 + eps_dec: float = 0.01 + lse_enc: bool = False + lse_dec: bool = True + coeff_dec: float = 1 + scale: str = 'min_max_total' + factor: float = 1.0 + emb_dim: int = 128 + num_heads: int = 4 + num_layers: int = 3 + qkv_dim: int = 512 + mlp_dim: int = 512 + attention_dropout_rate: float = 0.1 + kernel_init: Callable = nn.initializers.glorot_uniform() + bias_init: Callable = nn.initializers.zeros_init() + +def expand_weights(weights): + if weights.ndim == 2: + weights = weights[:, None, None, :] + if weights.ndim == 3: + weights = weights.unsqueeze(1) + while weights.ndim < 4: + weights = weights.unsqueeze(0) + return weights + +def scaled_dot_product(q, + k, + v, + weights: Optional[ArrayLike] = None, + scale_weights: float = 1, + deterministic: bool = False, + dropout_rng: Optional[ArrayLike] = random.key(0), + dropout_rate: float = 0.0, + ): + + dtype, d_k = q.dtype, q.shape[-1], + + attn_logits = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) + attn_logits = attn_logits / jnp.sqrt(d_k) + + + if weights is not None: + # attn_logits = attn_logits + jnp.tan(math.pi*(jnp.clip(weights, 1e-7, 1-1e-7)-1/2)) - jnp.tan(math.pi*(1/q.shape[-2]-1/2)) + attn_logits = attn_logits + jnp.log(weights/scale_weights + jnp.finfo(jnp.float32).tiny) + attn_logits = jnp.where(weights == 0, -9e15, attn_logits) + attn_logits = jnp.where(weights == 1, 9e15, attn_logits) + + attention = nn.softmax(attn_logits, axis=-1) + + # apply attention dropout + if not deterministic and dropout_rate > 0.0: + keep_prob = 1.0 - dropout_rate + keep = random.bernoulli(dropout_rng, keep_prob, attention.shape) # type: ignore + multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype) + attention = attention * multiplier + + values = jnp.matmul(attention, v) + return values, attention + +class WormholeFeedForward(nn.Module): + """Transformer MLP / feed-forward block. + + Attributes: + config: DefaultConfig dataclass containing hyperparameters. + out_dim: optionally specify out dimension. + """ + + config: DefaultConfig + + @nn.compact + def __call__(self, inputs): + config = self.config + x = nn.Dense( + config.mlp_dim, + dtype=config.dtype, + kernel_init=config.kernel_init, + bias_init=config.bias_init, + )(inputs) + x = nn.relu(x) + output = nn.Dense( + inputs.shape[-1], + dtype=config.dtype, + kernel_init=config.kernel_init, + bias_init=config.bias_init, + )(x) + inputs + return output + +class WeightedMultiheadAttention(nn.Module): + + config: DefaultConfig + scale_weights: Optional[float] = 1 + + def setup(self): + config = self.config + # Stack all weight matrices 1...h and W^Q, W^K, W^V together for efficiency + # Note that in many implementations you see "bias=False" which is optional + self.qkv_proj = nn.Dense(3 * config.emb_dim, + dtype=config.dtype, + kernel_init=config.kernel_init, + bias_init=config.bias_init) + + def __call__(self, + x, + weights: Optional[ArrayLike] = None, + deterministic: Optional[bool] = True, + dropout_rng: Optional[ArrayLike] = random.key(0)): + + config = self.config + scale_weights = self.scale_weights + + batch_size, seq_length, _ = x.shape + + # assert x.shape[-1] == config.emb_dim + + if weights is not None: + weights = expand_weights(weights) + + qkv = self.qkv_proj(x) + + # Separate Q, K, V from linear output + qkv = qkv.reshape(batch_size, seq_length, config.num_heads, -1) + qkv = qkv.transpose(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims] + q, k, v = jnp.array_split(qkv, 3, axis=-1) + + # Determine value outputs + values, attention = scaled_dot_product(q, k, v, weights = weights, + scale_weights = scale_weights, + deterministic = deterministic, + dropout_rng = dropout_rng, + dropout_rate = config.attention_dropout_rate) + values = values.transpose(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims] + values = values.reshape(batch_size, seq_length, config.emb_dim) + + return values + + +class DecoderBlock(nn.Module): + """Transformer decoder layer. + + Attributes: + config: DefaultConfig dataclass containing hyperparameters. + """ + + config: DefaultConfig + + @nn.compact + def __call__(self, inputs, deterministic, dropout_rng): + config = self.config + + # Attention block. + x = WeightedMultiheadAttention(config)(x = inputs, + deterministic = deterministic, + dropout_rng = dropout_rng) + inputs + + #x = nn.Dropout(rate=config.attention_dropout_rate)(x, deterministic=deterministic) + x = nn.LayerNorm(dtype=config.dtype)(x) + x = WormholeFeedForward(config=config)(x) + output = nn.LayerNorm(dtype=config.dtype)(x) + return output + + +class Unembedding(nn.Module): + """Transformer embedding block. + + Attributes: + config: DefaultConfig dataclass containing hyperparameters. + out_dim: optionally specify out dimension. + """ + + config: DefaultConfig + inp_dim: int + + @nn.compact + def __call__(self, inputs): + config = self.config + output = nn.Dense( + self.inp_dim, + dtype=config.dtype, + kernel_init=config.kernel_init, + bias_init=config.bias_init, + )(inputs) + return output + + +class AttentionDecoderModel(nn.Module): + """Transformer decoder network. + + Attributes: + config: DefaultConfig dataclass containing hyperparameters. + """ + n_layers: int + n_neurons: int + config: DefaultConfig + out_seq_len: int + inp_dim: int # this is number of niche genes + + @nn.compact + def __call__(self, inputs, deterministic = False, dropout_rng = random.key(0)): + + config = self.config + + x = inputs#.astype('int32') + + # x = Multiplyer(config, self.out_seq_len)(x) + x = FeedForward(n_layers=self.n_layers, + n_neurons=self.n_neurons, + n_output=self.out_seq_len * config.emb_dim)(x) # output dim according to emb_dim to get ready for attention layers + x = jnp.reshape(x, [x.shape[0], self.out_seq_len, config.emb_dim]) + for _ in range(config.num_layers): + x = DecoderBlock(config)(inputs = x, + deterministic = deterministic, + dropout_rng = dropout_rng) + x = WormholeFeedForward(config)(x) + x = Unembedding(config, self.inp_dim)(x) + + # do scaling + output = (nn.sigmoid(x) * 2 - 1) / sqrt(self.inp_dim) + + return output + +