Use Helix-mRNA¶
In this notebook we will dive into using our latest mRNA Bio Foundation Model, Helix-mRNA.
We will get and plot embeddings for our data.
We will fine-tune the model both using the Helical package
If running on a CUDA device compatible with mamba-ssm and causal-conv1d install the package below, otherwise remove the [mamba-ssm] optional dependency
- If running on colab, remove the [mamba-ssm] dependency
In [ ]:
Copied!
!pip install --upgrade helical[mamba-ssm]
!pip install --upgrade helical[mamba-ssm]
Imports¶
In [1]:
Copied!
from helical.models.helix_mrna import HelixmRNAConfig, HelixmRNA, HelixmRNAFineTuningModel
import subprocess
import torch
import pandas as pd
device = "cuda" if torch.cuda.is_available() else "cpu"
from helical.models.helix_mrna import HelixmRNAConfig, HelixmRNA, HelixmRNAFineTuningModel
import subprocess
import torch
import pandas as pd
device = "cuda" if torch.cuda.is_available() else "cpu"
INFO:datasets:PyTorch version 2.3.0 available. INFO:datasets:Polars version 0.20.31 available. INFO:datasets:JAX version 0.4.31 available.
Download one of CodonBERT's fine-tuning benchmarks¶
In [ ]:
Copied!
url = "https://raw.githubusercontent.com/Sanofi-Public/CodonBERT/refs/heads/master/benchmarks/CodonBERT/data/fine-tune/mRFP_Expression.csv"
output_filename = "mRFP_Expression.csv"
wget_command = ["wget", "-O", output_filename, url]
try:
subprocess.run(wget_command, check=True)
print(f"File downloaded successfully as {output_filename}")
except subprocess.CalledProcessError as e:
print(f"Error occurred: {e}")
url = "https://raw.githubusercontent.com/Sanofi-Public/CodonBERT/refs/heads/master/benchmarks/CodonBERT/data/fine-tune/mRFP_Expression.csv"
output_filename = "mRFP_Expression.csv"
wget_command = ["wget", "-O", output_filename, url]
try:
subprocess.run(wget_command, check=True)
print(f"File downloaded successfully as {output_filename}")
except subprocess.CalledProcessError as e:
print(f"Error occurred: {e}")
Load the dataset as a pandas dataframe and get the splits¶
- For this example we take a subset of the splits, feel free to run it on the entire dataset!
In [ ]:
Copied!
dataset = pd.read_csv(output_filename)
train_data = dataset[dataset["Split"] == "train"][:10]
eval_data = dataset[dataset["Split"] == "val"][:5]
test_data = dataset[dataset["Split"] == "test"][:5]
dataset = pd.read_csv(output_filename)
train_data = dataset[dataset["Split"] == "train"][:10]
eval_data = dataset[dataset["Split"] == "val"][:5]
test_data = dataset[dataset["Split"] == "test"][:5]
Define our Helix-mRNA model and desired configs¶
In [4]:
Copied!
# We set the max length to the maximum length of the sequences in the training data + 10 to include space for special tokens
helix_mrna_config = HelixmRNAConfig(device=device, batch_size=1, max_length=max(len(s) for s in train_data["Sequence"])+10)
helix_mrna = HelixmRNA(helix_mrna_config)
# We set the max length to the maximum length of the sequences in the training data + 10 to include space for special tokens
helix_mrna_config = HelixmRNAConfig(device=device, batch_size=1, max_length=max(len(s) for s in train_data["Sequence"])+10)
helix_mrna = HelixmRNA(helix_mrna_config)
INFO:helical.models.helix_mrna.model:Helix-mRNA initialized successfully.
Process our training sequences to tokenize them and prepare them for the model¶
In [5]:
Copied!
processed_train_data = helix_mrna.process_data(train_data["Sequence"].to_list())
processed_train_data = helix_mrna.process_data(train_data["Sequence"].to_list())
Generate embeddings for the train data¶
- We get an embeddings for each letter/token in the sequence, in this case 100 embeddings for each of the 688 tokens and our embedding dimension is 256
- Because the model has a recurrent nature, our final non-special token embedding at the second last position encapsulates everything that came before it
In [19]:
Copied!
embeddings = helix_mrna.get_embeddings(processed_train_data)
embeddings = embeddings[:, -2, :]
print(embeddings.shape)
print(embeddings[:1])
embeddings = helix_mrna.get_embeddings(processed_train_data)
embeddings = embeddings[:, -2, :]
print(embeddings.shape)
print(embeddings[:1])
Getting embeddings: 100%|██████████| 20/20 [00:00<00:00, 71.50it/s]
(100, 256) [[-9.70379915e-04 5.94667019e-03 1.07590854e-02 5.22067677e-03 -9.54915071e-04 -6.74154516e-03 -2.91207526e-03 -1.49831397e-03 1.78750437e-02 5.13957115e-03 8.79890576e-04 1.21943112e-02 -1.92209042e-03 3.27306171e-03 -2.27077748e-03 4.50014602e-04 7.30314665e-03 -8.66744318e-04 -8.81821662e-03 -7.57190645e-01 1.89280566e-02 -4.05776373e-04 6.08320069e-03 -1.78794132e-03 -8.79776548e-04 -8.19147026e-05 9.60175938e-04 -8.30806512e-03 5.66601008e-03 -5.93393855e-03 -5.19109843e-03 6.86887605e-03 -7.94085041e-02 -5.38914884e-03 -1.55241350e-02 -2.42359545e-02 2.57678051e-03 -9.53892432e-03 -7.16619950e-04 1.50164040e-02 -9.01486576e-01 -4.68801707e-03 3.71015654e-03 -1.07593695e-02 9.67101427e-04 5.75249782e-03 2.86138593e-03 -6.41007500e-04 -3.93231586e-03 -5.53809397e-04 1.72096007e-02 -8.10448000e-06 1.20042302e-02 -7.83413649e-03 2.40328256e-03 1.44813021e-04 6.37711585e-03 -2.75100190e-02 -9.19151399e-03 2.25025918e-02 -2.71240231e-02 3.18764849e-03 -8.12906027e-03 -9.14498232e-03 3.32334498e-03 8.77279043e-03 2.19076360e-03 -4.82588913e-03 7.93280269e-05 -4.37264703e-03 -1.03688613e-02 -1.15277218e-02 -4.73860680e-04 1.04218733e-03 4.24548006e-03 3.50359431e-03 -1.29856959e-01 1.34938598e-01 -3.02679911e-02 -2.39217840e-02 -2.36590131e-04 -4.36108152e-04 4.19709226e-03 -8.65293201e-03 1.99613022e-03 1.79194030e-03 2.84497248e-04 1.43997103e-03 -1.53440237e-01 -2.64659454e-03 2.34050487e-04 -2.68558436e-03 -2.99103018e-02 -4.54764161e-03 7.06061395e-03 -2.53863144e-03 8.90936144e-03 -7.19320118e-01 7.57683277e-01 -1.85208302e-02 -1.54839545e-01 -2.54138529e-01 8.30988749e-04 -8.85983836e-03 -9.01097711e-03 -7.00991787e-03 -1.82072073e-01 -5.14741063e-01 -2.93075689e-03 2.55425880e-03 1.37590896e-03 5.38261468e-03 -5.74133135e-02 -4.50886175e-04 1.39132570e-02 -9.26930178e-03 3.89014231e-03 3.58247235e-02 1.38020525e-02 4.48753638e-03 4.69827838e-03 5.32380529e-02 7.67468300e-04 -2.27806643e-02 9.79826669e-04 3.29421629e-04 2.56255385e-03 -3.15385172e-04 1.13730943e-02 5.02255885e-03 7.63128162e-04 -4.30183439e-03 -1.41088907e-02 -7.07946122e-02 2.18413552e-04 -4.30437940e-04 5.93306264e-03 3.88289336e-03 -6.69274572e-03 -1.05123809e-02 7.17154052e-03 9.30194370e-03 -2.66307388e-02 -2.35042372e-03 -3.61418119e-03 -1.88636947e-02 4.10996377e-03 1.86230207e-03 -7.77591905e-03 1.07999649e-02 -2.15348396e-02 -1.56054425e-03 -4.75367473e-04 -2.42964807e-03 1.37075689e-03 -1.18554395e-03 1.96172502e-02 8.72136280e-03 -2.54987436e-03 -1.78763457e-03 1.48834437e-01 4.15487972e-04 -8.82838969e-04 -4.85490542e-04 9.73013118e-02 1.01735163e-02 9.76046920e-03 7.66289607e-03 3.93118672e-02 5.41610224e-03 -7.19898380e-03 -4.61950190e-02 6.28079474e-03 2.30385065e-02 -1.32811114e-01 2.61072395e-03 2.72905454e-03 -8.26253928e-03 2.76575685e-02 -1.16535993e-02 7.09296510e-05 -5.02431765e-03 -2.00841855e-02 -9.82477888e-03 1.99634713e-04 -2.33941106e-03 -1.01937279e-02 -6.17030673e-02 5.41278534e-03 9.48928879e-04 9.36821289e-03 -7.82263931e-03 -1.20594129e-02 -6.56401785e-03 -8.18305537e-02 8.73102434e-03 -2.41522095e-03 6.06243312e-03 -2.66978621e-01 8.72417178e-04 8.10213108e-03 -1.89128786e-01 8.86955822e-04 1.45062711e-02 -4.65695048e-03 -3.56003083e-03 -1.77745167e-02 -3.33940163e-02 1.01557758e-04 -8.14760383e-03 8.52145813e-03 -9.25995596e-03 4.04966250e-03 3.44780415e-01 -4.55286279e-02 1.27975168e-02 5.34113357e-03 -1.58847857e-03 4.65576863e-03 -3.07517336e-03 2.26003435e-02 -3.24756862e-03 7.61093199e-02 -8.04630481e-03 1.46187656e-02 -4.15891828e-03 -7.69484183e-03 -6.56060642e-03 6.32394385e-03 1.89167308e-03 -1.65223163e-02 -7.15268105e-02 -5.13067655e-02 8.95772595e-03 3.47553840e-04 3.91185429e-04 1.50305599e-01 5.39071718e-03 -3.56106623e-03 -1.07512353e-02 -1.70928031e-01 2.04306114e-02 2.14800192e-03 -1.82585061e-01 -4.20546830e-02 1.20053962e-02 -6.52526272e-04 1.29553266e-02 2.05104008e-01 -3.85842402e-03 8.15556012e-03 -6.55666053e-01 -6.22088835e-03 1.99010246e-03 -1.32145118e-02 1.12704304e-03]]
Fine-tuning the model on our data¶
- This is a regression task and so our output is 1 continuous value
In [6]:
Copied!
helix_mrna_fine_tuning_model = HelixmRNAFineTuningModel(helix_mrna_config=helix_mrna_config, fine_tuning_head="regression", output_size=1)
helix_mrna_fine_tuning_model = HelixmRNAFineTuningModel(helix_mrna_config=helix_mrna_config, fine_tuning_head="regression", output_size=1)
INFO:helical.models.helix_mrna.model:Helix-mRNA initialized successfully.
Our training data is already processed since the standard Helix-mRNA model and fine-tuning model take the same input!¶
- We process our eval and test data
In [7]:
Copied!
processed_eval_data = helix_mrna_fine_tuning_model.process_data(eval_data["Sequence"].to_list())
processed_test_data = helix_mrna_fine_tuning_model.process_data(test_data["Sequence"].to_list())
processed_eval_data = helix_mrna_fine_tuning_model.process_data(eval_data["Sequence"].to_list())
processed_test_data = helix_mrna_fine_tuning_model.process_data(test_data["Sequence"].to_list())
Run fine-tuning on the model for this small sample of data¶
In [13]:
Copied!
helix_mrna_fine_tuning_model.train(train_dataset=processed_train_data,
train_labels=train_data["Value"].to_numpy().reshape(-1, 1),
validation_dataset=processed_eval_data,
validation_labels= eval_data["Value"].to_numpy().reshape(-1, 1),
epochs=5,
loss_function=torch.nn.MSELoss(),
trainable_layers=2)
helix_mrna_fine_tuning_model.train(train_dataset=processed_train_data,
train_labels=train_data["Value"].to_numpy().reshape(-1, 1),
validation_dataset=processed_eval_data,
validation_labels= eval_data["Value"].to_numpy().reshape(-1, 1),
epochs=5,
loss_function=torch.nn.MSELoss(),
trainable_layers=2)
INFO:helical.models.helix_mrna.fine_tuning_model:Unfreezing the last 2 layers of the Helix_mRNA model. INFO:helical.models.helix_mrna.fine_tuning_model:Starting Fine-Tuning Fine-Tuning: epoch 1/5: 100%|██████████| 20/20 [00:00<00:00, 49.31it/s, loss=86.9] Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 56.86it/s, val_loss=88.8] Fine-Tuning: epoch 2/5: 100%|██████████| 20/20 [00:00<00:00, 59.90it/s, loss=80.6] Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 53.01it/s, val_loss=83.3] Fine-Tuning: epoch 3/5: 100%|██████████| 20/20 [00:00<00:00, 51.32it/s, loss=74.9] Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 167.73it/s, val_loss=76.6] Fine-Tuning: epoch 4/5: 100%|██████████| 20/20 [00:00<00:00, 49.26it/s, loss=67.3] Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 54.98it/s, val_loss=67.4] Fine-Tuning: epoch 5/5: 100%|██████████| 20/20 [00:00<00:00, 60.40it/s, loss=59.8] Fine-Tuning Validation: 100%|██████████| 4/4 [00:00<00:00, 50.23it/s, val_loss=61.1] INFO:helical.models.helix_mrna.fine_tuning_model:Fine-Tuning Complete. Epochs: 5
Get outputs from our model on the test data¶
In [14]:
Copied!
outputs = helix_mrna_fine_tuning_model.get_outputs(processed_test_data)
print(outputs)
outputs = helix_mrna_fine_tuning_model.get_outputs(processed_test_data)
print(outputs)
Generating outputs: 100%|██████████| 2/2 [00:00<00:00, 46.17it/s]
[[2.682281 ] [2.4182768] [2.4362845] [2.6120207] [2.6543183] [2.6988027] [2.671821 ] [2.144202 ] [2.6866376] [2.6734226]]