Predictions using different embedding modes¶
In this Notebook, we want to show the different embedding modes that are available for the different single cell RNA models, available in the package.
from helical.models.scgpt import scGPT, scGPTConfig
import torch
import anndata
from pathlib import Path
from helical.utils.downloader import Downloader
import os
from helical.constants.paths import CACHE_DIR_HELICAL
WARNING:py.warnings:/home/benoit/miniconda3/envs/helical-package/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm INFO:datasets:PyTorch version 2.5.1 available. INFO:datasets:Polars version 0.20.31 available. WARNING:helical.models.scgpt.model_dir.multiomic_model:flash_attn is not installed.
We show the working principle using the scGPT model. Get the data if you don't have it already:
scgpt = scGPT()
path = Path.joinpath(CACHE_DIR_HELICAL, "17_04_24_YolkSacRaw_F158_WE_annots.h5ad")
if not os.path.exists(path):
downloader = Downloader()
downloader.download_via_name("17_04_24_YolkSacRaw_F158_WE_annots.h5ad")
data = anndata.read_h5ad(path)
INFO:helical.models.scgpt.model:Model finished initializing. INFO:helical.models.scgpt.model:'scGPT' model is in 'eval' mode, on device 'cpu' with embedding mode 'cls'.
To explain the working principle of the different embedding modes, it is easier to simulate returned embeddings from the model. We can do this in the following cell:
- we define a torch tensor, simulating the embeddings
- overwrite the
scgpt.model._encode
function to return those embeddings - skip the
scgpt._normalize_embeddings
function by returning the input without modifying it
# Mock the method directly on the instance
mocked_embeddings = torch.tensor([
[[1.0, 1.0, 1.0, 1.0, 1.0],
[5.0, 5.0, 5.0, 5.0, 5.0],
[1.0, 2.0, 3.0, 2.0, 1.0],
[6.0, 6.0, 6.0, 6.0, 6.0]],
])
scgpt.model._encode = lambda *args, **kwargs: mocked_embeddings
scgpt._normalize_embeddings = lambda x: x
With this, we can run scGPT in the 3 different modes: gene
, cell
and cls
.
- The
gene
mode returns embeddings for every gene. - The
cell
mode returns the average of the gene embeddings. - The
cls
mode returns thecls
specific row, returned by the model. It can be thought of as a summary of the observation.
We run scGPT on a single observation / cell to explain the process.
dataset = scgpt.process_data(data[0])
scgpt.config["emb_mode"] = "gene"
gene_embeddings = scgpt.get_embeddings(dataset)
scgpt.config["emb_mode"] = "cell"
cell_embeddings = scgpt.get_embeddings(dataset)
scgpt.config["emb_mode"] = "cls"
cls_embeddings = scgpt.get_embeddings(dataset)
INFO:helical.models.scgpt.model:Processing data for scGPT.
INFO:helical.models.scgpt.model:Filtering out 10801 genes to a total of 26517 genes with an ID in the scGPT vocabulary. INFO:helical.models.scgpt.model:Successfully processed the data for scGPT. INFO:helical.models.scgpt.model:Started getting embeddings: Embedding cells: 100%|██████████| 1/1 [00:00<00:00, 8.50it/s] INFO:helical.models.scgpt.model:Finished getting embeddings. INFO:helical.models.scgpt.model:Started getting embeddings: Embedding cells: 100%|██████████| 1/1 [00:00<00:00, 1123.88it/s] INFO:helical.models.scgpt.model:Finished getting embeddings. INFO:helical.models.scgpt.model:Started getting embeddings: Embedding cells: 100%|██████████| 1/1 [00:00<00:00, 1723.92it/s] INFO:helical.models.scgpt.model:Finished getting embeddings.
The gene embeddings return embeddings for every gene:
gene_embeddings[0]
SLC39A14 [5.0, 5.0, 5.0, 5.0, 5.0] MPDU1 [1.0, 2.0, 3.0, 2.0, 1.0] GPHN [6.0, 6.0, 6.0, 6.0, 6.0] dtype: object
The cell embeddings hold the averages of the gene embeddings:
cell_embeddings[0]
array([4. , 4.3333335, 4.6666665, 4.3333335, 4. ], dtype=float32)
The cls embeddings correspond to the first row returned by the model.
This means that scGPT in cls
mode ignores the remaining 3 rows.
cls_embeddings[0]
array([1., 1., 1., 1., 1.], dtype=float32)
We can run this on real data too but the interpreation of this is harder to visualise:
First, we remove our modified scGPT model and instantiate a new one.
del scgpt
device = "cuda" if torch.cuda.is_available() else "cpu"
scgpt = scGPT(configurer=scGPTConfig(device=device))
INFO:helical.models.scgpt.model:Model finished initializing. INFO:helical.models.scgpt.model:'scGPT' model is in 'eval' mode, on device 'cuda' with embedding mode 'cls'.
scgpt.config["emb_mode"] = "gene"
gene_embeddings = scgpt.get_embeddings(dataset)
gene_embeddings[0]
INFO:helical.models.scgpt.model:Started getting embeddings: Embedding cells: 100%|██████████| 1/1 [00:00<00:00, 5.18it/s] INFO:helical.models.scgpt.model:Finished getting embeddings.
SLC39A14 [-0.0011799249, 0.0031951678, -0.0037296554, 0... MPDU1 [-0.0013756858, 0.017062135, -0.007849643, 0.0... GPHN [0.007110456, 0.025636358, 0.0028697518, 0.005... AGFG2 [-0.008909806, 0.01073429, 0.006347002, 0.0071... POLR3B [-0.012140153, 0.04901718, 0.02245722, 0.00043... ... TMEM258 [-0.0077039357, 0.017461302, 0.002785733, 0.01... BNIP3L [-0.0103421565, 0.035706572, 0.011275602, 0.00... KPNB1 [0.0004736521, 0.032073762, 0.0024564175, 0.00... ZSWIM5 [-0.012645806, 0.048165236, 0.02488112, -0.006... REPIN1 [0.00678998, 0.019529147, -0.0017630243, 0.001... Length: 1199, dtype: object
With real data, it is easier to analyse the output sizes:
print(f"Number of genes with embeddings: {gene_embeddings[0].shape}")
print(f"Embedding size per gene: {gene_embeddings[0][0].shape}")
Number of genes with embeddings: (1199,) Embedding size per gene: (512,)
scgpt.config["emb_mode"] = "cell"
cell_embeddings = scgpt.get_embeddings(dataset)
cell_embeddings[0]
INFO:helical.models.scgpt.model:Started getting embeddings: Embedding cells: 100%|██████████| 1/1 [00:00<00:00, 116.03it/s] INFO:helical.models.scgpt.model:Finished getting embeddings.
array([-8.15041643e-03, 2.63333023e-02, 7.81406183e-03, 9.81337484e-03, 1.40524013e-02, -2.95165903e-03, -1.69922467e-02, -4.81104245e-03, -7.60848820e-03, 4.59150635e-02, 5.87423565e-03, -5.45159075e-03, 2.13732291e-02, 9.06534586e-03, 1.08101517e-02, -4.91651380e-03, -1.33220525e-02, -2.16271300e-02, 1.46969380e-02, -2.12924127e-02, -1.29918456e-02, -8.90347362e-03, -3.24050188e-02, 1.69960652e-02, 6.72291825e-03, 3.31430845e-02, -3.28126512e-02, -2.09503025e-02, -2.97727287e-02, 7.01213209e-03, 2.35989746e-02, -2.13149730e-02, -2.99774809e-03, -1.64881814e-02, -1.13256490e-02, -3.82535718e-03, -1.26695279e-02, 5.00416942e-03, -6.52590021e-02, -4.58240882e-03, -2.73203477e-02, -5.88836940e-03, -2.24745702e-02, -5.89525886e-03, -3.59893106e-02, -2.78945770e-02, 8.49726028e-04, 1.33508444e-02, 1.13855442e-02, 2.42668390e-02, -1.94337759e-02, 1.32291475e-02, -4.44257283e-04, -1.83102898e-02, -2.08448507e-02, 1.82826556e-02, 1.19242212e-02, -3.09342373e-04, 4.26567607e-02, 1.78650152e-02, -1.22072417e-02, 1.22858435e-02, 3.05658821e-02, 9.43962764e-03, 4.51750346e-02, -2.96661500e-02, -5.75119108e-02, -5.81376860e-03, -1.80639122e-02, -1.69001874e-02, 1.67023279e-02, -2.39931233e-03, -1.83809933e-03, 3.20161544e-02, -3.58780213e-02, 2.30181478e-02, -1.01301484e-02, -8.31659790e-03, 1.36515126e-02, -1.00036536e-03, -7.31237512e-03, -2.01087706e-02, 6.31964812e-03, 5.51546691e-03, 2.41016336e-02, 4.74413810e-03, -3.58009264e-02, 1.88941117e-02, 3.33672091e-02, 2.22256090e-02, 1.24969874e-02, 3.52528156e-03, 1.23718660e-02, 2.18623634e-02, -2.01625726e-03, 1.86573213e-03, -3.18950601e-02, -9.83154459e-05, 2.58983634e-02, -1.17942384e-02, -9.07449238e-03, 1.70814723e-03, 1.31330080e-02, 4.50423360e-02, -5.97934751e-03, 8.15686863e-03, -7.17658550e-03, -3.37880524e-03, -1.05222268e-02, 7.52747199e-03, -1.90593023e-02, 2.03801859e-02, -2.10038084e-03, -2.45623812e-02, 3.24325897e-02, -4.20615524e-02, -3.38341546e-04, 4.45089638e-02, -7.43741263e-03, 1.41986310e-02, 2.59784493e-03, 1.08500132e-02, 3.10978964e-02, -7.94472266e-03, 2.52118427e-03, -1.47453938e-02, -9.28190630e-03, 7.39593059e-03, -1.47715509e-02, -9.86053608e-03, 3.98509763e-02, 3.56436428e-03, 8.54810979e-03, 5.99423714e-04, -1.20271575e-02, -8.72934796e-03, -1.60508286e-02, 2.57443152e-02, 9.43687651e-03, 1.95260067e-02, 2.37606536e-03, -1.94778237e-02, 2.16944255e-02, 2.05736980e-02, -6.41872967e-03, -1.56051095e-03, -1.59282275e-02, 2.52879895e-02, -1.93675328e-02, -4.43797708e-02, 1.20970234e-02, 4.58547957e-02, 7.58533133e-03, 1.02704652e-02, -1.28064835e-02, 2.98559740e-02, -2.03474611e-02, -5.76473475e-01, -5.13960654e-03, 2.43202727e-02, 4.93220724e-02, 1.12590883e-02, -1.69760119e-02, -2.47351043e-02, 2.91592497e-02, 2.48088595e-02, -3.54069695e-02, -3.00732683e-02, -1.46070169e-02, 1.72611121e-02, 2.25844923e-02, -2.45836191e-02, 1.40006086e-02, -4.39725034e-02, -2.84583420e-02, -2.03797631e-02, 1.95964333e-02, 6.25951355e-03, -2.55475808e-02, 5.02686277e-02, 2.04409156e-02, -3.24284844e-02, -3.34916124e-03, -3.73260230e-02, 1.40493819e-02, -1.19160721e-02, -7.24646961e-03, 4.72128317e-02, -6.35270076e-03, 3.85226980e-02, -8.69447645e-03, 1.94680113e-02, 1.27684288e-02, 2.93238671e-03, 1.31395962e-02, -1.21938772e-02, -8.36459640e-03, -5.38320187e-03, 3.49017006e-04, -2.18278822e-02, -9.98797244e-04, 6.98351813e-03, 2.31984966e-02, 1.62911341e-02, -1.99025515e-02, 8.43606051e-03, -1.17899561e-02, 2.57441169e-03, -4.76897210e-02, 5.17554879e-02, -1.16848052e-02, 4.94676270e-03, 1.86417084e-02, -2.96746138e-02, -4.41117473e-02, 2.91055930e-03, 4.77770017e-03, -8.67339503e-03, -1.43343639e-02, -6.08797045e-03, 3.59102711e-02, 5.66549180e-03, 1.50430361e-02, 4.81424406e-02, -1.86657626e-02, 9.20644403e-03, 1.14493174e-02, 1.21861871e-03, 1.30169988e-02, -3.25866719e-03, -1.95563585e-02, -1.34446248e-02, -5.34193264e-03, 1.15789995e-02, -3.59822251e-03, -6.05917443e-03, -5.88430315e-02, 6.70885667e-03, -2.40413994e-02, 1.33898249e-02, -1.53990593e-02, -1.54478103e-02, 1.97891481e-02, 1.68796023e-03, 4.15964276e-02, -7.24982703e-03, 5.10265715e-02, -6.14837324e-03, -8.88798106e-03, -2.81753968e-02, -4.49333660e-04, -5.43319527e-03, 2.33879238e-02, -5.35442606e-02, -4.07010689e-03, -1.10896630e-02, 2.38413922e-02, -2.05681357e-03, -7.62468611e-04, -1.29195498e-02, 3.25112402e-01, -2.09854450e-03, -1.52156102e-02, 2.79599242e-03, -3.00400592e-02, 2.00533817e-04, 4.48551625e-02, -7.97613803e-03, -1.24780852e-02, 2.27570944e-02, 3.98986116e-02, -3.02351173e-02, 4.93414933e-03, -7.46077276e-04, -4.29602712e-03, 8.70624091e-04, 2.66363956e-02, -8.91784951e-03, -5.09983748e-02, -3.93554047e-02, -8.97834636e-03, 5.34497380e-01, 4.15710807e-02, -2.93631069e-02, 1.20375818e-02, -1.28079718e-02, -1.42050358e-02, 3.97833362e-02, 1.40107870e-02, -2.39920523e-03, -4.70105046e-03, 4.68210801e-02, 1.54475784e-02, 5.61998133e-03, 2.43147742e-02, -2.24164352e-02, 1.46648940e-02, 8.99495464e-03, -6.53423090e-03, 1.79887563e-02, 1.85154285e-02, -3.26716118e-02, -1.27531597e-02, 5.07580070e-03, -1.71143860e-02, 1.28918644e-02, -1.35411078e-03, 2.26482712e-02, 1.18187880e-02, 1.74715593e-02, 6.35542581e-03, -7.73004442e-03, 6.38830243e-03, 2.84507312e-02, 1.76105853e-02, -2.28062086e-02, -1.62149465e-03, 6.47854060e-02, 1.48503073e-02, 1.37997037e-02, 2.97018103e-02, 8.74831621e-03, -2.74825264e-02, -1.43635236e-02, -2.59066490e-03, -8.15996062e-03, -1.75764989e-02, 2.93594282e-02, -5.96561050e-03, 5.27633261e-03, -2.72250082e-02, 3.07980534e-02, 8.68623145e-03, 1.27530685e-02, 9.36667155e-03, 1.43885938e-02, -2.89556384e-02, 2.75708195e-02, -1.02682346e-02, 1.21141179e-02, -3.35273594e-02, 2.95504229e-03, 1.72882657e-02, -6.86635673e-02, -3.61121632e-02, -2.03203037e-02, -1.13087250e-02, -1.38313007e-02, -2.70826034e-02, 2.37206947e-02, -2.15230137e-02, 4.15991666e-03, 8.86991248e-03, 2.63496563e-02, 2.41908673e-02, -3.67345810e-02, 1.83664281e-02, -4.39537130e-02, 5.98312495e-03, -3.49636190e-03, 1.00252330e-02, 1.33393332e-02, 7.60935480e-03, 5.20363031e-03, -1.65041406e-02, -4.52627009e-03, 3.55193093e-02, 8.90749320e-03, 1.27909388e-02, 8.33811518e-03, 1.94471348e-02, 1.68389231e-02, -1.93751212e-02, -1.10085038e-02, 1.83981564e-02, -2.40139961e-02, -2.94874515e-02, 1.01125650e-02, -2.05271598e-02, -2.76341336e-03, -1.32646821e-02, 3.76331359e-02, 8.15028697e-03, -3.10920514e-02, -3.44311609e-03, -6.40069786e-03, -5.01741320e-02, -1.87176559e-02, 8.99801496e-03, -2.24871212e-03, -3.84792278e-04, -3.24686430e-02, 1.87653247e-02, -1.03269462e-02, -2.04273276e-02, 5.72628248e-03, 3.20323631e-02, -5.93421690e-04, -3.81626301e-02, 1.03982529e-02, 9.31399036e-03, -9.05160414e-05, 3.46498378e-02, -7.75304995e-03, -9.04363301e-03, 2.21235268e-02, 6.51048403e-03, 2.03500427e-02, 1.03350561e-02, 7.78120058e-03, 1.49649344e-02, 2.23369263e-02, 2.59858426e-02, -5.30170510e-03, 2.96027455e-02, -2.79687159e-03, -2.20495407e-02, 1.88293215e-02, 7.58425985e-03, -5.03022000e-02, -8.56653601e-03, 2.05458421e-03, -2.22983491e-03, -1.31371897e-02, 3.72169213e-03, -1.90856550e-02, -2.25026943e-02, 3.84237472e-04, -8.41809437e-03, 3.35728265e-02, -2.40656547e-02, 1.65674407e-02, 4.37575877e-02, -2.22611297e-02, 1.56481992e-02, -4.64896252e-03, 1.13456706e-02, -3.59299663e-03, -2.44456939e-02, -5.62684610e-02, -6.08444796e-04, -7.87207112e-03, -1.70890689e-02, -1.05043147e-02, 2.19502654e-02, 1.09337000e-02, -2.74148956e-02, -3.31539623e-02, -1.42244538e-02, 2.58499496e-02, -1.75755024e-02, -2.55133826e-02, 7.51739135e-04, 2.48481100e-03, -3.04994378e-02, -1.40746264e-02, 2.61744000e-02, 5.87598886e-03, 2.33898405e-02, -6.32737800e-02, 1.78641230e-02, -1.75519601e-01, -8.24389141e-03, 1.98861826e-02, 3.37476358e-02, 7.28844898e-03, 3.08266692e-02, 4.87708626e-03, 2.09365971e-02, 1.73568614e-02, 3.69173177e-02, 2.79097166e-02, 2.77034808e-02, -2.53460445e-02, -2.13546418e-02, -5.77083193e-02, -1.14561366e-02, 2.31049191e-02, 3.03287655e-02, -3.81544698e-04, -1.71796624e-02, -4.23317999e-02, 1.18717095e-02, 2.22866610e-02, -1.19446928e-03, 1.10752694e-02, -5.63540356e-03, 2.86641587e-02, 1.95889026e-02, 1.96824670e-02, -2.92297900e-02, 2.18246505e-02, -2.48097051e-02, 1.63886491e-02, 3.13611962e-02, -1.68342353e-03, 1.88188329e-02, 2.00219527e-02, -5.88387949e-03, 2.13814694e-02, 7.82733504e-03, 7.68757053e-03, 3.46655548e-02, -1.79147162e-02, 1.58434100e-02, -3.25256586e-02, 9.75492597e-03, 8.89630523e-03, 3.20187919e-02, -2.86053754e-02, -4.22061840e-03, -1.45295085e-02], dtype=float32)
print(f"Embedding size per cell: {cell_embeddings[0].shape}")
Embedding size per cell: (512,)
scgpt.config["emb_mode"] = "cls"
cls_embeddings = scgpt.get_embeddings(dataset)
cls_embeddings[0]
INFO:helical.models.scgpt.model:Started getting embeddings: Embedding cells: 100%|██████████| 1/1 [00:00<00:00, 240.78it/s] INFO:helical.models.scgpt.model:Finished getting embeddings.
array([-2.45153867e-02, 6.76829070e-02, 9.26359184e-03, -2.09791050e-03, 2.24513449e-02, 2.08408223e-03, -2.53852215e-02, 3.11240065e-03, -6.62306789e-03, 3.31380554e-02, 2.76659112e-02, 4.46426682e-03, 3.08492407e-02, 1.56556666e-02, 1.73311401e-02, -1.20286504e-02, -3.41191003e-03, -2.74797548e-02, -8.76120233e-04, -1.59723293e-02, -1.63279437e-02, -1.06245605e-02, -1.66277196e-02, -2.04682187e-03, 1.51277408e-02, 5.14600612e-02, -5.13950512e-02, -3.17132138e-02, -2.95655672e-02, -2.13937908e-02, 1.59325860e-02, -2.51241624e-02, 4.61029354e-03, -2.76504010e-02, -1.92681160e-02, -3.63738127e-02, -6.18889090e-03, -4.07493208e-03, -7.18622655e-02, 4.02772194e-03, -3.14314142e-02, -7.31843291e-03, -5.05978167e-02, -9.09360871e-03, -2.41975486e-02, -1.72051415e-02, 7.02964189e-03, 3.86377797e-02, 9.64887347e-03, 5.07224277e-02, -2.60307938e-02, 2.90858541e-02, 7.36009097e-03, 1.74028252e-03, -2.98949555e-02, 2.33976301e-02, 1.02939699e-02, 2.12073866e-02, 4.89693619e-02, 2.88790110e-02, -2.77336799e-02, -4.46378300e-03, 4.08341996e-02, -3.96675081e-04, 5.43113872e-02, -4.32825685e-02, -7.32045248e-02, 9.76623502e-03, -1.32888434e-02, -3.96071039e-02, 7.86911510e-03, -1.11062010e-03, 1.97083247e-03, 1.95169840e-02, -3.42093743e-02, 2.74168756e-02, -3.95796215e-03, -7.52964709e-03, 2.74726804e-02, 2.59213755e-03, -6.16539083e-03, -1.58587899e-02, 2.82052979e-02, -1.15956198e-02, 3.54868509e-02, 2.10501440e-02, -2.49753688e-02, 3.49047594e-02, 3.15196700e-02, 2.41361558e-02, -3.76607967e-03, 3.49559938e-03, 1.93185955e-02, 3.04949172e-02, -1.16525106e-02, -3.90736852e-03, -1.58282351e-02, 3.65261957e-02, 3.23544145e-02, -2.08612122e-02, -1.35824289e-02, 1.10562518e-02, 2.50375569e-02, 7.57342279e-02, -1.28078144e-02, 7.01878732e-03, -5.82913216e-03, -4.30742698e-03, 4.81213778e-02, -6.27691951e-03, -2.15034243e-02, 2.16278862e-02, -1.80496816e-02, -4.38741744e-02, 3.37660015e-02, -4.03557792e-02, 2.39017084e-02, 4.97239716e-02, 8.25660117e-03, -1.15766805e-02, 6.64573535e-03, 8.37536808e-03, 1.50651345e-02, -3.06401737e-02, 1.06578423e-02, -2.15192046e-02, 1.51927005e-02, 7.00843893e-03, -1.39160594e-02, -1.12494053e-02, 5.43792397e-02, 4.49348055e-02, -8.21045507e-03, 6.14139019e-03, -1.22415740e-02, 4.16773022e-04, -2.25790031e-02, 2.32806262e-02, -2.98567116e-03, -1.85264312e-02, 1.34663479e-02, -2.40827929e-02, -4.38230578e-04, 2.45382376e-02, -1.47257214e-02, -1.46871654e-03, -3.57116312e-02, 2.32473407e-02, -2.97865532e-02, -5.13614155e-02, 1.58212744e-02, 8.06239173e-02, -2.04761140e-02, -1.62395532e-04, -1.97993778e-02, 6.55988678e-02, -3.79635133e-02, -4.76373076e-01, 1.76495910e-02, 1.76523570e-02, 4.72312011e-02, 3.60119641e-02, -4.50604688e-03, -9.84678417e-03, -1.48310815e-03, 6.49683643e-04, -5.05734533e-02, -2.23090481e-02, -2.09893622e-02, 2.83944681e-02, 3.98665443e-02, -3.97819020e-02, 3.02289277e-02, -6.41960055e-02, -3.12850736e-02, -2.14051344e-02, 9.48422309e-03, 1.04411719e-02, -2.30573844e-02, 5.35256118e-02, 1.43965846e-02, -4.82759178e-02, -4.62439191e-03, -6.90604970e-02, 1.99532807e-02, -2.63340026e-03, -7.42220599e-03, 4.60068770e-02, -2.90969908e-02, 2.18398310e-02, -2.34585330e-02, 2.26608873e-03, -2.22239364e-03, -2.21185423e-02, 3.63707938e-03, -2.51304284e-02, -1.48233669e-02, -1.08530521e-02, 1.45721203e-02, -2.17926800e-02, -4.89135645e-03, 8.29203892e-03, 2.51028836e-02, 1.03409085e-02, -1.78557765e-02, -6.04140945e-03, 2.05238699e-03, 2.82709692e-02, -3.24503630e-02, 3.54559459e-02, -3.53872031e-02, 3.11379209e-02, 4.96259928e-02, -8.66587460e-03, -7.14815855e-02, 1.15210470e-02, 1.80784240e-02, -5.19741587e-02, -6.71254983e-03, -1.26365563e-02, 4.98214737e-02, 8.07952322e-03, 2.27515530e-02, 6.07486628e-02, -1.38171967e-02, 3.36158723e-02, 5.27171185e-03, -2.48884223e-02, 2.67648492e-02, 7.27484655e-03, -7.95399770e-03, -3.81333083e-02, 2.81804637e-03, 1.15901353e-02, -2.18091030e-02, -2.17718966e-02, -4.97795902e-02, 4.76862397e-03, -2.01723278e-02, 2.65548769e-02, -1.98825561e-02, -3.65659930e-02, 3.32224146e-02, -1.67506132e-02, 4.91061732e-02, -1.20102270e-02, 2.07320806e-02, 6.19586110e-02, 3.37009295e-03, -1.01545528e-02, -9.37982090e-03, 3.72043811e-03, 3.52773704e-02, -5.10503836e-02, 1.51191065e-02, 1.35486033e-02, 2.85783764e-02, -6.08509965e-03, 8.12581927e-03, -9.47200111e-04, 4.36377168e-01, -1.17061613e-02, -3.76500525e-02, -1.73743330e-02, -3.62554304e-02, -7.12229521e-04, 7.17211589e-02, -1.11871678e-03, -2.81942077e-03, 1.81365814e-02, 4.89433147e-02, -5.35011478e-02, -2.05936991e-02, -2.75716581e-03, -1.34341754e-02, 1.71495937e-02, 3.41272503e-02, -9.98707488e-04, -4.40332443e-02, -2.61033233e-02, -5.76740783e-03, 3.51253748e-01, 3.57882269e-02, 2.57680687e-04, 8.88473634e-03, -4.53057215e-02, -3.88377905e-02, 4.76049073e-02, -1.86286587e-02, -6.37506600e-03, 1.03429775e-03, 4.43534069e-02, -3.41886049e-03, 1.73910931e-02, 4.96236533e-02, -3.81948426e-02, 1.57124866e-02, 2.50780489e-03, -3.65860411e-03, 2.53409501e-02, -7.24622048e-04, -1.44517347e-02, -1.94737352e-02, 3.76115702e-02, -3.52155380e-02, 3.56680248e-03, -2.10319012e-02, 2.98389420e-02, -5.44308778e-03, 1.40576540e-02, -2.62061512e-04, -1.53648760e-02, 1.98135469e-02, 3.25694233e-02, 4.54310477e-02, -1.63885728e-02, 1.20220520e-02, 6.35347366e-02, 1.08912708e-02, 2.02855766e-02, 3.82428616e-02, 1.10052777e-02, -2.17193719e-02, -1.58026423e-02, -2.14802809e-02, -5.12225088e-03, -2.51318403e-02, 6.36755824e-02, -3.40835005e-02, 2.07607169e-03, -2.16116831e-02, 7.36416802e-02, 1.74865592e-02, 3.75458077e-02, 4.12650825e-03, 8.52579810e-03, -3.39522772e-02, 2.49971002e-02, -2.51762550e-02, 8.62706732e-03, -4.32539880e-02, -8.96353833e-03, 6.73645409e-03, -7.29783550e-02, -6.26313537e-02, -2.44746829e-04, -9.67285596e-03, -3.47111858e-02, -1.16014006e-02, 2.85756849e-02, -2.09196750e-02, -1.40493680e-02, 8.25099554e-03, 5.86548038e-02, 1.85851846e-02, -5.52713126e-02, 3.80891152e-02, -6.55806288e-02, 3.01559316e-03, -1.53750554e-02, -1.32344523e-02, -1.39216371e-02, 2.48366762e-02, 1.19781457e-02, -3.60681303e-03, -8.88350792e-03, 2.99509112e-02, 1.41091915e-02, 3.02721411e-02, 2.74510216e-02, 3.79114896e-02, 6.18577981e-03, -2.02162806e-02, 8.86006933e-03, 4.37244959e-03, -1.69865005e-02, -3.95388640e-02, -6.01971615e-03, -4.53112926e-03, -3.33280605e-03, -7.93357007e-03, 6.15930259e-02, 7.47404993e-03, -5.24884239e-02, -1.02607217e-02, -4.16327454e-02, -6.04979992e-02, -4.74545322e-02, 5.26728295e-03, -1.57921184e-02, -4.90473490e-03, -1.73121970e-02, -3.25186062e-03, -1.78076476e-02, -1.31681720e-02, -1.48400199e-02, 3.45820636e-02, -3.79318222e-02, -3.98465209e-02, 5.30735124e-03, 2.45902408e-02, 3.14300656e-02, 6.30108267e-02, -1.24083590e-02, -1.89693309e-02, 2.83043850e-02, -1.92273594e-02, -4.13932558e-03, 2.84970496e-02, 2.58654524e-02, 1.38802072e-02, 1.09579228e-03, 4.01955470e-02, -1.92459077e-02, 6.25998676e-02, -3.43498192e-03, -1.17322234e-02, 4.65216972e-02, -1.45426631e-04, -6.76989853e-02, 3.45930792e-02, 1.78965426e-03, -4.68827598e-03, -2.11978052e-02, 1.68629754e-02, -2.95592397e-02, -2.81799436e-02, 2.55590193e-02, 1.55170374e-02, 4.00625654e-02, -1.83338411e-02, 5.29153924e-03, 2.13614181e-02, -2.45864578e-02, 2.74816044e-02, -5.72612137e-03, 2.09344216e-02, -7.11009139e-03, -1.07827922e-02, -5.92843071e-02, -2.96214614e-02, -6.93373266e-04, -3.61338742e-02, -2.20520645e-02, 3.14995237e-02, -2.18518469e-02, -5.05348034e-02, -5.09402566e-02, -2.13348819e-03, 1.54039720e-02, 3.19538489e-02, -4.11356539e-02, 3.83283990e-03, -4.44084872e-03, -3.83663289e-02, -2.78404041e-04, 4.50295173e-02, -1.53042013e-02, 2.65689809e-02, -5.76632135e-02, 2.92779431e-02, -1.99816048e-01, -2.14710343e-03, 8.64214171e-03, 1.44068627e-02, 2.82023028e-02, 3.13479416e-02, -2.64063105e-02, 3.91461104e-02, 4.71042767e-02, 4.99161296e-02, 1.69883831e-03, 4.15363535e-02, -2.07511596e-02, -2.40179505e-02, -7.79969022e-02, 2.89676187e-04, 3.72988544e-02, 8.40481650e-03, 5.54896705e-03, -3.84214185e-02, -1.18556013e-02, -1.60694565e-03, 3.53423283e-02, -3.45797054e-02, 2.27767508e-02, -4.68104240e-03, 8.65800027e-03, 3.38197835e-02, 1.60816684e-02, -3.27193998e-02, 1.10832201e-02, -1.63125228e-02, 2.71654911e-02, 7.86161283e-04, -1.17324255e-02, 3.07963714e-02, 2.13367324e-02, 3.76224564e-03, -2.13393662e-03, 1.07512320e-03, 3.83881922e-03, 3.68292853e-02, -1.65097248e-02, 2.31586695e-02, -5.19343726e-02, 4.10866551e-02, 5.31221693e-03, 3.09123825e-02, -1.99824646e-02, -4.52331454e-02, -4.32595611e-03], dtype=float32)
print(f"Embedding size per cls: {cls_embeddings[0].shape}")
Embedding size per cls: (512,)