Cell type annotation prediction¶
Run this notebook on google colab to use a free GPU!
In this notebook, an scGPT model is used to predict a cell type annotation with a given gene expression profile.
This follows the tutorial from scGPT here, but instead of fine-tuning the entire model, a smaller neural network is trained, using the embeddings of the gene expressions as inputs, to make a prediction.
The same approach is made with the Geneformer model and the results are compared against each other.
This approach greatly reduces time and complexity.
# !pip install helical
# !pip install datasets --upgrade
from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score
from sklearn.preprocessing import LabelEncoder
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import anndata as ad
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch import nn
from scipy.sparse import lil_matrix
import torch.optim as optim
from helical.models.scgpt import scGPT, scGPTConfig
from helical.models.geneformer import Geneformer, GeneformerConfig
from copy import deepcopy
from torch.nn.functional import one_hot
INFO:datasets:PyTorch version 2.3.0 available. INFO:datasets:Polars version 0.20.31 available. INFO:datasets:JAX version 0.4.31 available.
We load the our dataset from the EMBL website. To download the dataset, we are leveraging Huggingface's optimized library!
from datasets import load_dataset
ds = load_dataset("helical-ai/yolksac_human",trust_remote_code=True, split="train[:65%]",download_mode="reuse_cache_if_exists")
Generating train split: 100%|██████████| 25344/25344 [00:14<00:00, 1693.62 examples/s] Generating test split: 100%|██████████| 6336/6336 [00:02<00:00, 2115.23 examples/s]
observation_columns = [obs for obs in list(ds.features.keys()) if not obs == 'raw_counts']
obs_data = pd.DataFrame(ds.select_columns(observation_columns).data.to_pandas(),columns=observation_columns)
lil = lil_matrix((len(ds),ds[0]['size']))
lil.data = np.array(ds['raw_counts'],dtype="object")
lil.rows = np.array(ds['rows'],dtype="object")
adata = ad.AnnData(lil.tocsr(),obs=obs_data)
adata.var_names = ds.features['raw_counts'].id.split(",")
adata.var['gene_name'] = adata.var_names.str.upper()
Let's familiarise ourselves with the data!
We are interested in the names of the cells we want to predict. They are saved in adata.obs["LVL1"]
.
Additionally, we need to know how many distinct cell types/classes we have.
# get labels: the celltype
num_types = adata.obs["LVL1"].unique().shape[0]
id2type = dict(enumerate(adata.obs["LVL1"].astype("category").cat.categories))
celltypes_labels = np.array(adata.obs["LVL1"].tolist())
This is all summarized in this dictionary:
id2type
{0: 'ERYTHROID', 1: 'LYMPHOID', 2: 'MK', 3: 'MYELOID', 4: 'PROGENITOR', 5: 'STROMA'}
Use the Helical package to get the embeddings of the gene expression profile.
The only thing we need to specify is the column containing the names of the genes. (gene_name
in this case)
The resulting embeddings are the input features x
for our smaller NN model.
scGPT¶
device = "cuda" if torch.cuda.is_available() else "cpu"
scgpt_config = scGPTConfig(batch_size=50, device=device)
scgpt = scGPT(configurer = scgpt_config)
data = scgpt.process_data(adata, gene_names = "gene_name")
x_scgpt = scgpt.get_embeddings(data)
x_scgpt.shape
INFO:helical.utils.downloader:File: '/home/benoit/.cache/helical/models/scgpt/scGPT_CP/vocab.json' exists already. File is not overwritten and nothing is downloaded. INFO:helical.utils.downloader:File saved to: '/home/benoit/.cache/helical/models/scgpt/scGPT_CP/vocab.json' INFO:helical.utils.downloader:File: '/home/benoit/.cache/helical/models/scgpt/scGPT_CP/best_model.pt' exists already. File is not overwritten and nothing is downloaded. INFO:helical.utils.downloader:File saved to: '/home/benoit/.cache/helical/models/scgpt/scGPT_CP/best_model.pt' INFO:helical.models.scgpt.model:Model finished initializing. INFO:helical.models.scgpt.model:Filtering out 11163 genes to a total of 26155 genes with an id in the scGPT vocabulary. INFO:helical.models.scgpt.model:Inference started: Embedding cells: 100%|██████████| 330/330 [00:33<00:00, 9.84it/s]
(16474, 512)
With the input features, we also need the corresponding labels y
.
They correspond to the cell type labels.
As we have a categorical prediction, we transform the cell type labels to integer labels to work with CrossEntropyLoss later.
y = celltypes_labels
num_classes = num_types
encoder = LabelEncoder()
y_encoded = encoder.fit_transform(y)
y_encoded = one_hot(torch.tensor(y_encoded),num_classes).float()
Define and train the model¶
input_shape = 512
# Define the model architecture
head_model = nn.Sequential(
nn.Linear(input_shape, 128),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(128, 32),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(32, num_classes)
)
print(head_model)
Sequential( (0): Linear(in_features=512, out_features=128, bias=True) (1): ReLU() (2): Dropout(p=0.4, inplace=False) (3): Linear(in_features=128, out_features=32, bias=True) (4): ReLU() (5): Dropout(p=0.4, inplace=False) (6): Linear(in_features=32, out_features=6, bias=True) )
def train_model(model: nn.Sequential,
X_train: torch.Tensor,
y_train: torch.Tensor,
X_val: torch.Tensor,
y_val: torch.Tensor,
optimizer = optim.Adam,
loss_fn = nn.CrossEntropyLoss(),
num_epochs = 50,
batch = 64):
# Create DataLoader for batching
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch, shuffle=True)
# Validation dataset
val_dataset = TensorDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=batch, shuffle=False)
# Ensure model is in training mode
model.train()
for epoch in range(num_epochs):
for batch_X, batch_y in train_loader:
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = model(batch_X)
# Compute loss
loss = loss_fn(outputs, batch_y)
# Backward pass and optimize
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Validation phase (optional)
model.eval()
with torch.no_grad():
val_losses = []
for val_X, val_y in val_loader:
val_outputs = model(val_X)
val_loss = loss_fn(val_outputs, val_y)
val_losses.append(val_loss.item())
print(f"Epoch {epoch+1}, Validation Loss: {sum(val_losses)/len(val_losses)}")
# Set back to training mode for next epoch
model.train()
model.eval()
return model
X_train, X_test, y_train, y_test = train_test_split(x_scgpt, y_encoded, test_size=0.1, random_state=42)
head_model_scgpt = deepcopy(head_model)
head_model_scgpt = train_model(head_model_scgpt,
torch.from_numpy(X_train),
y_train,
torch.from_numpy(X_test),
y_test,
optim.Adam(head_model_scgpt.parameters(), lr=0.001),
nn.CrossEntropyLoss())
Epoch 1, Validation Loss: 0.10509900515899062 Epoch 2, Validation Loss: 0.09210213305106243 Epoch 3, Validation Loss: 0.08708009730952863 Epoch 4, Validation Loss: 0.08486809238988477 Epoch 5, Validation Loss: 0.07630292710830243 Epoch 6, Validation Loss: 0.07355103734880686 Epoch 7, Validation Loss: 0.06611681672690722 Epoch 8, Validation Loss: 0.06312987544627574 Epoch 9, Validation Loss: 0.056430156445668005 Epoch 10, Validation Loss: 0.05318047078505445 Epoch 11, Validation Loss: 0.05370905570005281 Epoch 12, Validation Loss: 0.050026227198451616 Epoch 13, Validation Loss: 0.04953212696441019 Epoch 14, Validation Loss: 0.04679244866397662 Epoch 15, Validation Loss: 0.05049864224341805 Epoch 16, Validation Loss: 0.04329267257260373 Epoch 17, Validation Loss: 0.05164720463262011 Epoch 18, Validation Loss: 0.044551525596314326 Epoch 19, Validation Loss: 0.04502302069494572 Epoch 20, Validation Loss: 0.040517984640945755 Epoch 21, Validation Loss: 0.04486050424971976 Epoch 22, Validation Loss: 0.044199613513998114 Epoch 23, Validation Loss: 0.04122369681135751 Epoch 24, Validation Loss: 0.03933001797346291 Epoch 25, Validation Loss: 0.042462229007819235 Epoch 26, Validation Loss: 0.04435293143167375 Epoch 27, Validation Loss: 0.04156974930755006 Epoch 28, Validation Loss: 0.03967876623657783 Epoch 29, Validation Loss: 0.03859212029671583 Epoch 30, Validation Loss: 0.041997895730533995 Epoch 31, Validation Loss: 0.03875408053639918 Epoch 32, Validation Loss: 0.041402942252506576 Epoch 33, Validation Loss: 0.04364953876714795 Epoch 34, Validation Loss: 0.03740719119937589 Epoch 35, Validation Loss: 0.03878503382908933 Epoch 36, Validation Loss: 0.04403487707839723 Epoch 37, Validation Loss: 0.045056324614471614 Epoch 38, Validation Loss: 0.04669451509499385 Epoch 39, Validation Loss: 0.04211538765230216 Epoch 40, Validation Loss: 0.046166509291372046 Epoch 41, Validation Loss: 0.0411657914261853 Epoch 42, Validation Loss: 0.038813597457752064 Epoch 43, Validation Loss: 0.03979039859125176 Epoch 44, Validation Loss: 0.038490781341142095 Epoch 45, Validation Loss: 0.03988477771837587 Epoch 46, Validation Loss: 0.04218148411578463 Epoch 47, Validation Loss: 0.04319032455588548 Epoch 48, Validation Loss: 0.043836868794572256 Epoch 49, Validation Loss: 0.046494297688835874 Epoch 50, Validation Loss: 0.040836334969543926
predictions_nn = head_model_scgpt(torch.Tensor(X_test))
y_pred = np.array(torch.argmax(predictions_nn, dim=1))
y_true = np.array(y_test.argmax(axis=1))
Present the results¶
- on the test set and,
- a separate, unseen evaluation set
def get_evaluations(name_data_set, y_true, y_pred) -> dict:
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average='macro')
f1 = f1_score(y_true, y_pred, average='macro')
recall = recall_score(y_true, y_pred, average='macro')
print(f"{name_data_set} accuracy: {(accuracy*100):.1f}%")
print(f"{name_data_set} precision: {(precision*100):.1f}%")
print(f"{name_data_set} f1: {(f1*100):.1f}%")
print(f"{name_data_set} recall: {(recall*100):.1f}%")
return {
"accuracy": accuracy,
"precision": precision,
"f1": f1,
"recall": recall,
}
get_evaluations("Test set", y_true, y_pred)
Test set accuracy: 99.3% Test set precision: 97.8% Test set f1: 97.4% Test set recall: 97.2%
{'accuracy': 0.9927184466019418, 'precision': 0.9775860337484131, 'f1': 0.9743593343813646, 'recall': 0.9720213390843576}
Load the unseen evaluation set:
ds = load_dataset("helical-ai/yolksac_human",trust_remote_code=True, split="train[70%:]",download_mode="reuse_cache_if_exists")
Generating train split: 100%|██████████| 25344/25344 [00:14<00:00, 1777.43 examples/s] Generating test split: 100%|██████████| 6336/6336 [00:02<00:00, 2123.87 examples/s]
observation_columns = [obs for obs in list(ds.features.keys()) if not obs == 'raw_counts']
obs_data = pd.DataFrame(ds.select_columns(observation_columns).data.to_pandas(),columns=observation_columns)
lil = lil_matrix((len(ds),ds[0]['size']))
lil.data = np.array(ds['raw_counts'],dtype="object")
lil.rows = np.array(ds['rows'],dtype="object")
adata_unseen = ad.AnnData(lil.tocsr(),obs=obs_data)
adata_unseen.var_names = ds.features['raw_counts'].id.split(",")
adata_unseen.var['gene_name'] = adata_unseen.var_names.str.upper()
data_unseen = scgpt.process_data(adata_unseen, gene_names="gene_name")
x_unseen = scgpt.get_embeddings(data_unseen)
predictions_nn_unseen = head_model_scgpt(torch.Tensor(x_unseen))
INFO:helical.models.scgpt.model:Filtering out 11163 genes to a total of 26155 genes with an id in the scGPT vocabulary.
INFO:helical.models.scgpt.model:Inference started: Embedding cells: 100%|██████████| 153/153 [00:15<00:00, 9.72it/s]
We should double check that the cell types are mapped to the correct id numbers for both the training data and this new data set.
num_types = adata_unseen.obs["LVL1"].unique().shape[0]
id2type_unseen = dict(enumerate(adata_unseen.obs["LVL1"].astype("category").cat.categories))
id2type_unseen == id2type
True
y_true_unseen = np.array(adata_unseen.obs["LVL1"].tolist())
y_pred_unseen = [id2type[prediction] for prediction in np.array(torch.argmax(predictions_nn_unseen, dim=1))]
scgpt_results = get_evaluations("Evaluation set", y_true_unseen, y_pred_unseen)
Evaluation set accuracy: 99.2% Evaluation set precision: 90.8% Evaluation set f1: 79.1% Evaluation set recall: 80.7%
Plot a confusion matrix to visualise the classification performance for each the cell type. This is done for the evalation set.
from sklearn.metrics import confusion_matrix
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
set_predicted_cell_types = list(adata_unseen.obs["LVL1"].unique())
for i in set(y_pred_unseen):
if i not in set_predicted_cell_types:
set_predicted_cell_types.remove(i)
cm = confusion_matrix(y_true_unseen, y_pred_unseen)
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
cm = pd.DataFrame(cm, index=set_predicted_cell_types[:cm.shape[0]], columns=set_predicted_cell_types[:cm.shape[1]])
plt.figure(figsize=(12, 12))
sns.heatmap(cm, annot=True, fmt=".1f", cmap="Blues")
<Axes: >
Geneformer¶
Let's do the same with the Geneformer.
device = "cuda" if torch.cuda.is_available() else "cpu"
if 'rows' in adata.obs:
adata.obs['rows'] = adata.obs['rows'].astype(str)
geneformer_config = GeneformerConfig(batch_size=50, device=device)
geneformer = Geneformer(configurer = geneformer_config)
data_geneformer = geneformer.process_data(adata, gene_names = "gene_name")
x_geneformer = geneformer.get_embeddings(data_geneformer)
x_geneformer.shape
INFO:helical.utils.downloader:File: '/home/benoit/.cache/helical/models/geneformer/v1/gene_median_dictionary.pkl' exists already. File is not overwritten and nothing is downloaded. INFO:helical.utils.downloader:File saved to: '/home/benoit/.cache/helical/models/geneformer/v1/gene_median_dictionary.pkl' INFO:helical.utils.downloader:File: '/home/benoit/.cache/helical/models/geneformer/v1/token_dictionary.pkl' exists already. File is not overwritten and nothing is downloaded. INFO:helical.utils.downloader:File saved to: '/home/benoit/.cache/helical/models/geneformer/v1/token_dictionary.pkl' INFO:helical.utils.downloader:File: '/home/benoit/.cache/helical/models/geneformer/v1/ensembl_mapping_dict.pkl' exists already. File is not overwritten and nothing is downloaded. INFO:helical.utils.downloader:File saved to: '/home/benoit/.cache/helical/models/geneformer/v1/ensembl_mapping_dict.pkl' INFO:helical.utils.downloader:File: '/home/benoit/.cache/helical/models/geneformer/v1/gf-12L-30M-i2048/config.json' exists already. File is not overwritten and nothing is downloaded. INFO:helical.utils.downloader:File saved to: '/home/benoit/.cache/helical/models/geneformer/v1/gf-12L-30M-i2048/config.json' INFO:helical.utils.downloader:File: '/home/benoit/.cache/helical/models/geneformer/v1/gf-12L-30M-i2048/training_args.bin' exists already. File is not overwritten and nothing is downloaded. INFO:helical.utils.downloader:File saved to: '/home/benoit/.cache/helical/models/geneformer/v1/gf-12L-30M-i2048/training_args.bin' INFO:helical.utils.downloader:File: '/home/benoit/.cache/helical/models/geneformer/v1/gf-12L-30M-i2048/pytorch_model.bin' exists already. File is not overwritten and nothing is downloaded. INFO:helical.utils.downloader:File saved to: '/home/benoit/.cache/helical/models/geneformer/v1/gf-12L-30M-i2048/pytorch_model.bin' INFO:helical.models.geneformer.model:Model finished initializing. INFO:pyensembl.sequence_data:Loaded sequence dictionary from /home/benoit/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.cdna.all.fa.gz.pickle INFO:pyensembl.sequence_data:Loaded sequence dictionary from /home/benoit/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.ncrna.fa.gz.pickle INFO:pyensembl.sequence_data:Loaded sequence dictionary from /home/benoit/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.pep.all.fa.gz.pickle INFO:helical.utils.mapping:Mapped 21111 genes to Ensembl IDs from a total of 37318 genes. INFO:helical.models.geneformer.geneformer_tokenizer:AnnData object with n_obs × n_vars = 16474 × 37318 obs: 'rows', 'size', 'LVL1', 'LVL2', 'LVL3', 'total_counts' var: 'gene_name', 'id_in_vocab', 'ensembl_id', 'gene_ids_collapsed' has no column attribute 'filter_pass'; tokenizing all cells. INFO:helical.models.geneformer.geneformer_tokenizer:Creating dataset. Map: 100%|██████████| 16474/16474 [00:05<00:00, 3025.41 examples/s] INFO:helical.models.geneformer.model:Inference started: 100%|██████████| 330/330 [05:07<00:00, 1.07it/s]
(16474, 512)
X_train, X_test, y_train, y_test = train_test_split(x_geneformer, y_encoded, test_size=0.1, random_state=42)
head_model_geneformer = deepcopy(head_model)
head_model_geneformer = train_model(head_model_geneformer,
torch.tensor(X_train),
y_train,
torch.tensor(X_test),
y_test,
optim.Adam(head_model_geneformer.parameters(), lr=0.001),
nn.CrossEntropyLoss())
Epoch 1, Validation Loss: 0.21696129383949134 Epoch 2, Validation Loss: 0.10857223622644177 Epoch 3, Validation Loss: 0.09659991732153755 Epoch 4, Validation Loss: 0.09191516814126562 Epoch 5, Validation Loss: 0.08651317040829991 Epoch 6, Validation Loss: 0.08288597050033367 Epoch 7, Validation Loss: 0.07626076801142727 Epoch 8, Validation Loss: 0.07917614357295231 Epoch 9, Validation Loss: 0.07136841879065077 Epoch 10, Validation Loss: 0.07527488532426875 Epoch 11, Validation Loss: 0.06889283100966938 Epoch 12, Validation Loss: 0.07603136524546873 Epoch 13, Validation Loss: 0.07237893132080969 Epoch 14, Validation Loss: 0.0707903919716885 Epoch 15, Validation Loss: 0.06682811810214144 Epoch 16, Validation Loss: 0.06787337761488743 Epoch 17, Validation Loss: 0.06565282110554668 Epoch 18, Validation Loss: 0.07381324346072059 Epoch 19, Validation Loss: 0.0704241120183724 Epoch 20, Validation Loss: 0.06702452171773005 Epoch 21, Validation Loss: 0.06794494667925531 Epoch 22, Validation Loss: 0.06038204437041154 Epoch 23, Validation Loss: 0.060252198385289654 Epoch 24, Validation Loss: 0.06358248780060631 Epoch 25, Validation Loss: 0.06292168345176972 Epoch 26, Validation Loss: 0.06266387048078916 Epoch 27, Validation Loss: 0.061681864904060676 Epoch 28, Validation Loss: 0.058782341822874375 Epoch 29, Validation Loss: 0.05916480725416197 Epoch 30, Validation Loss: 0.05904467829924005 Epoch 31, Validation Loss: 0.05701122948084958 Epoch 32, Validation Loss: 0.05678210777785772 Epoch 33, Validation Loss: 0.06069203816206517 Epoch 34, Validation Loss: 0.05833638691043374 Epoch 35, Validation Loss: 0.060886409481799304 Epoch 36, Validation Loss: 0.05711668227852967 Epoch 37, Validation Loss: 0.06182696464496145 Epoch 38, Validation Loss: 0.05757013989093069 Epoch 39, Validation Loss: 0.056393348726067834 Epoch 40, Validation Loss: 0.0600513292447431 Epoch 41, Validation Loss: 0.06829129664290051 Epoch 42, Validation Loss: 0.062308046387512986 Epoch 43, Validation Loss: 0.055768700710569434 Epoch 44, Validation Loss: 0.05563147523431466 Epoch 45, Validation Loss: 0.06237108943661532 Epoch 46, Validation Loss: 0.05813543734928736 Epoch 47, Validation Loss: 0.05794153505569109 Epoch 48, Validation Loss: 0.06292974418521716 Epoch 49, Validation Loss: 0.06411586219977695 Epoch 50, Validation Loss: 0.05854638515917871
data_unseen_geneformer = geneformer.process_data(adata_unseen, gene_names = "gene_name")
x_unseen_geneformer = geneformer.get_embeddings(data_unseen_geneformer)
predictions_nn_unseen_geneformer = head_model_geneformer(torch.Tensor(x_unseen_geneformer))
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /home/benoit/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.cdna.all.fa.gz.pickle INFO:pyensembl.sequence_data:Loaded sequence dictionary from /home/benoit/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.ncrna.fa.gz.pickle INFO:pyensembl.sequence_data:Loaded sequence dictionary from /home/benoit/.cache/pyensembl/GRCh38/ensembl110/Homo_sapiens.GRCh38.pep.all.fa.gz.pickle INFO:helical.utils.mapping:Mapped 21111 genes to Ensembl IDs from a total of 37318 genes. INFO:helical.models.geneformer.geneformer_tokenizer:AnnData object with n_obs × n_vars = 7603 × 37318 obs: 'rows', 'size', 'LVL1', 'LVL2', 'LVL3', 'total_counts' var: 'gene_name', 'id_in_vocab', 'ensembl_id', 'gene_ids_collapsed' has no column attribute 'filter_pass'; tokenizing all cells. INFO:helical.models.geneformer.geneformer_tokenizer:Creating dataset. Map: 100%|██████████| 7603/7603 [00:03<00:00, 2071.15 examples/s] INFO:helical.models.geneformer.model:Inference started: 100%|██████████| 153/153 [02:21<00:00, 1.08it/s]
y_true_unseen = np.array(adata_unseen.obs["LVL1"].tolist())
y_pred_unseen = [id2type[prediction] for prediction in np.array(torch.argmax(predictions_nn_unseen_geneformer, dim=1))]
geneformer_results = get_evaluations("Evaluation set", y_true_unseen, y_pred_unseen)
Evaluation set accuracy: 98.9% Evaluation set precision: 71.6% Evaluation set f1: 73.6% Evaluation set recall: 77.3%
import matplotlib.pyplot as plt
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
values_1 = [scgpt_results["accuracy"], geneformer_results["accuracy"]]
x = ["scGPT", "Geneformer"]
axs[0, 0].bar(x, values_1, width=0.4)
axs[0, 0].set_title("Accuracy")
axs[0, 0].set_ylim([0, 1])
values_2 = [scgpt_results["precision"], geneformer_results["precision"]]
axs[0, 1].bar(x, values_2, width=0.4)
axs[0, 1].set_title("Precision")
axs[0, 1].set_ylim([0, 1])
values_3 = [scgpt_results["f1"], geneformer_results["f1"]]
axs[1, 0].bar(x, values_3, width=0.4)
axs[1, 0].set_title("F1")
axs[1, 0].set_ylim([0, 1])
values_4 = [scgpt_results["recall"], geneformer_results["recall"]]
axs[1, 1].bar(x, values_4, width=0.4)
axs[1, 1].set_title("Recall")
axs[1, 1].set_ylim([0, 1])
fig.suptitle("scGPT vs. Geneformer \n Probing Comparison")
fig.tight_layout()
plt.show()
(c) Helical 2024 - Developed by the Helical Team