Skip to content

Model

helical.models.tahoe.Tahoe

Bases: HelicalRNAModel

Tahoe-1x Model.

The Tahoe-1x Model is a transformer-based foundation model designed for single-cell RNA-seq data. It can extract cell and gene embeddings from raw count data. The model is available in three sizes:

  • 70m: 12-layer transformer with 512 embedding dimensions
  • 1b: 24-layer transformer with 1024 embedding dimensions (coming soon)
  • 3b: 36-layer transformer with 1536 embedding dimensions (coming soon)
Example
from helical.models.tahoe import Tahoe, TahoeConfig
import anndata as ad

# Example configuration
tahoe_config = TahoeConfig(model_size="70m", batch_size=8)
tahoe = Tahoe(configurer=tahoe_config)

# Load and process data - returns a DataLoader
ann_data = ad.read_h5ad("anndata_file.h5ad")
dataloader = tahoe.process_data(ann_data)

# Get embeddings from the DataLoader
embeddings = tahoe.get_embeddings(dataloader)
print("Tahoe embeddings shape:", embeddings.shape)

# Get both cell and gene embeddings
cell_embeddings, gene_embeddings = tahoe.get_embeddings(dataloader, return_gene_embeddings=True)
print("Cell embeddings shape:", cell_embeddings.shape)
print("Gene embeddings:", len(gene_embeddings), "cells")  # List of pandas Series, one per cell
print("First cell genes:", len(gene_embeddings[0]), "genes")  # Number of genes in first cell
print("Gene names for first cell:", list(gene_embeddings[0].keys())[:5])  # First 5 gene names

# Get attention weights (requires attn_impl='torch')
tahoe_config_attn = TahoeConfig(model_size="70m", batch_size=8, attn_impl='torch')
tahoe_attn = Tahoe(configurer=tahoe_config_attn)
dataloader_attn = tahoe_attn.process_data(ann_data)
cell_embeddings, attentions = tahoe_attn.get_embeddings(dataloader_attn, output_attentions=True)
print(f"Attention shape: {attentions.shape}")  # (n_batches, batch_size, n_heads, seq_len, seq_len)

Parameters:

Name Type Description Default
configurer TahoeConfig

The model configuration. Defaults to TahoeConfig() with default parameters.

TahoeConfig()
Notes

The Tahoe-1x model uses Ensembl IDs to identify genes and currently supports only human genes. The model is published by Tahoe Therapeutics and available on Hugging Face at https://huggingface.co/tahoebio/Tahoe-x1.

By default, the model uses Flash Attention (attn_impl='flash') for efficient inference. To extract attention weights, use attn_impl='torch' when creating the TahoeConfig, though this will be slower and use more memory.

Source code in helical/models/tahoe/model.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
class Tahoe(HelicalRNAModel):
    """Tahoe-1x Model.

    The Tahoe-1x Model is a transformer-based foundation model designed for single-cell
    RNA-seq data. It can extract cell and gene embeddings from raw count data.
    The model is available in three sizes:

    - 70m: 12-layer transformer with 512 embedding dimensions
    - 1b: 24-layer transformer with 1024 embedding dimensions (coming soon)
    - 3b: 36-layer transformer with 1536 embedding dimensions (coming soon)

    Example
    -------
    ```python
    from helical.models.tahoe import Tahoe, TahoeConfig
    import anndata as ad

    # Example configuration
    tahoe_config = TahoeConfig(model_size="70m", batch_size=8)
    tahoe = Tahoe(configurer=tahoe_config)

    # Load and process data - returns a DataLoader
    ann_data = ad.read_h5ad("anndata_file.h5ad")
    dataloader = tahoe.process_data(ann_data)

    # Get embeddings from the DataLoader
    embeddings = tahoe.get_embeddings(dataloader)
    print("Tahoe embeddings shape:", embeddings.shape)

    # Get both cell and gene embeddings
    cell_embeddings, gene_embeddings = tahoe.get_embeddings(dataloader, return_gene_embeddings=True)
    print("Cell embeddings shape:", cell_embeddings.shape)
    print("Gene embeddings:", len(gene_embeddings), "cells")  # List of pandas Series, one per cell
    print("First cell genes:", len(gene_embeddings[0]), "genes")  # Number of genes in first cell
    print("Gene names for first cell:", list(gene_embeddings[0].keys())[:5])  # First 5 gene names

    # Get attention weights (requires attn_impl='torch')
    tahoe_config_attn = TahoeConfig(model_size="70m", batch_size=8, attn_impl='torch')
    tahoe_attn = Tahoe(configurer=tahoe_config_attn)
    dataloader_attn = tahoe_attn.process_data(ann_data)
    cell_embeddings, attentions = tahoe_attn.get_embeddings(dataloader_attn, output_attentions=True)
    print(f"Attention shape: {attentions.shape}")  # (n_batches, batch_size, n_heads, seq_len, seq_len)
    ```

    Parameters
    ----------
    configurer : TahoeConfig, optional
        The model configuration. Defaults to TahoeConfig() with default parameters.

    Notes
    -----
    The Tahoe-1x model uses Ensembl IDs to identify genes and currently supports only
    human genes. The model is published by Tahoe Therapeutics and available on
    Hugging Face at https://huggingface.co/tahoebio/Tahoe-x1.

    By default, the model uses Flash Attention (attn_impl='flash') for efficient inference.
    To extract attention weights, use attn_impl='torch' when creating the TahoeConfig,
    though this will be slower and use more memory.
    """

    def __init__(self, configurer: TahoeConfig = TahoeConfig()) -> None:

        super().__init__()

        self.configurer = configurer
        self.config = configurer.config
        self.device = torch.device(self.config["device"])

        LOGGER.info(
            f"Loading Tahoe model (size: {self.config['model_size']}) from Hugging Face..."
        )

        # Load model from Hugging Face
        self.model, self.vocab, self.model_cfg, self.collator_cfg = (
            TXModel.from_hf(
                repo_id=self.config["hf_repo_id"],
                model_size=self.config["model_size"],
                return_gene_embeddings=(self.config["emb_mode"] == "gene"),
                attn_impl=self.config["attn_impl"],
            )
        )

        self.model.to(self.device)
        self.model.eval()

        LOGGER.info(
            f"Model loaded with {self.model.n_layers} transformer layers."
        )
        LOGGER.info(
            f"Tahoe model is in 'eval' mode, on device '{self.device}' with embedding mode '{self.config['emb_mode']}' "
            f"and attention implementation '{self.config['attn_impl']}'."
        )

    def process_data(
        self,
        adata: AnnData,
        gene_names: str = "index",
        use_raw_counts: bool = True,
    ) -> DataLoader:
        """
        Processes the data for the Tahoe model and returns a DataLoader.

        Parameters
        ----------
        adata : AnnData
            The AnnData object containing the data to be processed. Tahoe uses Ensembl IDs
            to identify genes and currently supports only human genes. If the AnnData object
            already has an 'ensembl_id' column, the mapping step can be skipped.
        gene_names : str, optional, default="index"
            The column in `adata.var` that contains the gene names. If set to a value other
            than "ensembl_id", the gene symbols in that column will be mapped to Ensembl IDs
            using the 'pyensembl' package.
            - If set to "index", the index of the AnnData object will be used and mapped to Ensembl IDs.
            - If set to "ensembl_id", no mapping will occur.
        use_raw_counts : bool, optional, default=True
            Determines whether raw counts should be used.

        Returns
        -------
        DataLoader
            A PyTorch DataLoader ready for inference.
        """
        LOGGER.info("Processing data for Tahoe.")
        self.ensure_rna_data_validity(adata, gene_names, use_raw_counts)

        # Map gene symbols to Ensembl IDs if provided
        if gene_names != "ensembl_id":
            if (adata.var[gene_names].str.startswith("ENS").all()) or (
                adata.var[gene_names].str.startswith("None").any()
            ):
                message = (
                    "It seems an anndata with 'ensemble ids' and/or 'None' was passed. "
                    "Please set gene_names='ensembl_id' and remove 'None's to skip mapping."
                )
                LOGGER.error(message)
                raise ValueError(message)
            adata = map_gene_symbols_to_ensembl_ids(adata, gene_names)

            if adata.var["ensembl_id"].isnull().all():
                message = "All gene symbols could not be mapped to Ensembl IDs. Please check the input data."
                LOGGER.error(message)
                raise ValueError(message)

        gene_id_key = "ensembl_id"

        # Map genes to vocabulary
        adata.var["id_in_vocab"] = [
            self.vocab[gene] if gene in self.vocab else -1
            for gene in adata.var[gene_id_key]
        ]

        gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
        n_matched = np.sum(gene_ids_in_vocab >= 0)
        n_total = len(gene_ids_in_vocab)

        LOGGER.info(
            f"Matched {n_matched}/{n_total} genes in vocabulary of size {len(self.vocab)}."
        )

        # Filter to genes in vocabulary
        adata = adata[:, adata.var["id_in_vocab"] >= 0]

        genes = adata.var[gene_id_key].tolist()
        gene_ids = np.array([self.vocab[gene] for gene in genes], dtype=int)

        if not np.all(gene_ids >= 0):
            raise ValueError("Some genes are not in the vocabulary after filtering.")

        dataloader = loader_from_adata(
            adata=adata,
            collator_cfg=self.collator_cfg,
            vocab=self.vocab,
            batch_size=self.config["batch_size"],
            max_length=self.config["max_length"],
            gene_ids=gene_ids,
            num_workers=self.config["num_workers"],
            prefetch_factor=self.config["prefetch_factor"],
        )

        LOGGER.info("Successfully processed the data for Tahoe.")
        return dataloader

    def get_embeddings(
        self,
        dataloader: DataLoader,
        return_gene_embeddings: bool = False,
        output_attentions: bool = False,
        attn_layer: int = -1,
    ) -> Union[np.ndarray, tuple]:
        """Gets the embeddings from the Tahoe model.

        Parameters
        ----------
        dataloader : DataLoader
            The DataLoader returned from process_data().
        return_gene_embeddings : bool, optional, default=False
            Whether to return gene embeddings for each cell in addition to cell embeddings.
            Gene embeddings are returned as a list of pandas Series, one per cell, where
            each Series contains the embeddings for genes expressed in that cell.
        output_attentions : bool, optional, default=False
            Whether to return attention weights from all transformer layers.
            Note: This requires the model to be initialized with attn_impl='torch'.
            The default Flash Attention (attn_impl='flash') does not support attention
            weight extraction for efficiency reasons.
        attn_layer : int, optional, default=-1
            Which transformer layer's attention to return. Supports negative indexing
            (e.g. -1 for the last layer). Only used when output_attentions is True.

        Returns
        -------
        np.ndarray or tuple
            Depending on the combination of flags:
            - If both False: cell_embeddings (n_cells, embedding_dim)
            - If return_gene_embeddings=True only: (cell_embeddings, gene_embeddings)
            - If output_attentions=True only: (cell_embeddings, attentions)
            - If both True: (cell_embeddings, gene_embeddings, attentions)

            Where:
            - cell_embeddings: numpy array of shape (n_cells, embedding_dim)
            - gene_embeddings: list of pandas Series, one per cell. Each Series contains
              gene embeddings indexed by Ensembl IDs for genes expressed in that cell.
            - attentions: list of per-sample numpy arrays, each of shape (n_heads, seq_length, seq_length).
              Sequence lengths vary per sample based on the number of genes expressed.
        """
        LOGGER.info("Extracting embeddings from Tahoe model...")

        # Check if attention extraction is requested but not supported
        if output_attentions:
            attn_impl = self.model_cfg.get("attn_config", {}).get("attn_impl", "flash")
            if attn_impl in ["flash", "triton"]:
                raise RuntimeError(
                    f"Attention weight extraction is not supported with attn_impl='{attn_impl}'. "
                    "Flash Attention is optimized for speed and memory efficiency and does not "
                    "compute/store attention weights. To extract attention weights, initialize the model "
                    "with attn_impl='torch':\n\n"
                    "    tahoe_config = TahoeConfig(model_size='70m', attn_impl='torch')\n"
                    "    tahoe = Tahoe(configurer=tahoe_config)"
                )

        self.model.return_gene_embeddings = return_gene_embeddings

        cell_embs: List[torch.Tensor] = []
        all_attentions: List[torch.Tensor] = [] if output_attentions else None
        all_gene_embeddings: List[pd.Series] = [] if return_gene_embeddings else None

        dtype_from_string = {
            "fp32": torch.float32,
            "amp_bf16": torch.bfloat16,
            "amp_fp16": torch.float16,
            "fp16": torch.float16,
            "bf16": torch.bfloat16,
        }

        with (
            torch.no_grad(),
            torch.amp.autocast(
                enabled=True,
                dtype=dtype_from_string[self.model_cfg["precision"]],
                device_type=self.device.type,
            ),
        ):
            pbar = tqdm(total=len(dataloader), desc="Embedding cells")

            for data_dict in dataloader:
                input_gene_ids = data_dict["gene"].to(self.device)
                src_key_padding_mask = ~input_gene_ids.eq(self.collator_cfg["pad_token_id"])

                output = self.model(
                    genes=input_gene_ids,
                    values=data_dict["expr"].to(self.device),
                    gen_masks=data_dict["gen_mask"].to(self.device),
                    key_padding_mask=src_key_padding_mask,
                    drug_ids=(
                        data_dict["drug_ids"].to(self.device)
                        if "drug_ids" in data_dict
                        else None
                    ),
                    skip_decoders=True,
                    output_attentions=output_attentions,
                )

                cell_embs.append(output["cell_emb"].to("cpu").to(dtype=torch.float32))

                if output_attentions:
                    # Select the requested layer: (batch, n_heads, seq_len, seq_len)
                    layer_attn = output["attentions"][attn_layer].cpu().to(torch.float32)
                    all_attentions.append(layer_attn)

                if return_gene_embeddings:
                    # Get gene embeddings for this batch: shape (batch_size, seq_len, d_model)
                    gene_embs = output.get("gene_emb").to(torch.float32).cpu().numpy()
                    gene_ids = input_gene_ids.cpu().numpy()

                    # Create a pandas Series for each cell in the batch
                    for i in range(gene_embs.shape[0]):
                        cell_gene_dict = {}
                        for j in range(gene_embs.shape[1]):
                            gene_id = gene_ids[i, j]
                            if gene_id != self.collator_cfg["pad_token_id"]:
                                gene_name = self.vocab.index_to_token[gene_id]
                                gene_embedding = gene_embs[i, j]
                                # Normalize the gene embedding
                                gene_embedding = gene_embedding / np.linalg.norm(gene_embedding)
                                cell_gene_dict[gene_name] = gene_embedding

                        all_gene_embeddings.append(pd.Series(cell_gene_dict))

                pbar.update(1)

        # Normalize cell embeddings
        cell_array = torch.cat(cell_embs, dim=0).numpy()
        cell_array = cell_array / np.linalg.norm(
            cell_array,
            axis=1,
            keepdims=True,
        )


        # Prepare attention list if requested — one np.ndarray per sample
        if output_attentions:
            attn_list = []
            for attn in all_attentions:
                # attn shape: (batch, n_heads, seq_len, seq_len)
                attn_list.extend(attn.numpy())

        # Return based on requested outputs
        log_msg = f"Finished extracting embeddings. Cell shape: {cell_array.shape}"
        if return_gene_embeddings:
            log_msg += f", Gene embeddings: {len(all_gene_embeddings)} cells"
        if output_attentions:
            log_msg += f", Attention maps: {len(attn_list)} samples"
        LOGGER.info(log_msg)

        # Return appropriate combination
        if return_gene_embeddings and output_attentions:
            return cell_array, all_gene_embeddings, attn_list
        elif return_gene_embeddings:
            return cell_array, all_gene_embeddings
        elif output_attentions:
            return cell_array, attn_list
        else:
            return cell_array

    def get_transformer_embeddings(
        self,
        dataloader: DataLoader,
    ) -> tuple:
        """Get raw transformer embeddings before the decoder.

        This method returns the transformer output embeddings along with the gene IDs
        for each position. This is useful for perturbation experiments where you want
        to modify embeddings and then decode them to predicted expression.

        Parameters
        ----------
        dataloader : DataLoader
            The DataLoader returned from process_data().

        Returns
        -------
        tuple of (list, list)
            - transformer_embeddings: list of numpy arrays, one per cell
              Each array has shape (seq_len, embedding_dim) containing the transformer
              output embeddings for that cell's genes
            - gene_ids: list of numpy arrays, one per cell
              Each array has shape (seq_len,) containing the gene vocabulary IDs
              for that cell (pad_token_id for padding positions)

        Example
        -------
        ```python
        from helical.models.tahoe import Tahoe, TahoeConfig
        import anndata as ad

        tahoe_config = TahoeConfig(model_size="70m", batch_size=8)
        tahoe = Tahoe(configurer=tahoe_config)

        ann_data = ad.read_h5ad("anndata_file.h5ad")
        dataloader = tahoe.process_data(ann_data)

        # Get transformer embeddings and gene IDs
        transformer_embs, gene_ids = tahoe.get_transformer_embeddings(dataloader)

        # Each is a list with one entry per cell
        print(f"Number of cells: {len(transformer_embs)}")
        print(f"First cell embedding shape: {transformer_embs[0].shape}")
        print(f"First cell gene IDs shape: {gene_ids[0].shape}")

        # Modify embeddings (e.g., perturb specific genes in first cell)
        # transformer_embs[0][5, :] += 0.1  # perturb gene at position 5

        # Decode modified embeddings to predicted expression
        expr_pred = tahoe.decode_embeddings(transformer_embs, gene_ids)
        ```
        """
        LOGGER.info("Extracting transformer embeddings...")

        self.model.return_gene_embeddings = True
        all_embeddings: List[np.ndarray] = []
        all_gene_ids: List[np.ndarray] = []

        dtype_from_string = {
            "fp32": torch.float32,
            "amp_bf16": torch.bfloat16,
            "amp_fp16": torch.float16,
            "fp16": torch.float16,
            "bf16": torch.bfloat16,
        }

        with (
            torch.no_grad(),
            torch.amp.autocast(
                enabled=True,
                dtype=dtype_from_string[self.model_cfg["precision"]],
                device_type=self.device.type,
            ),
        ):
            pbar = tqdm(total=len(dataloader), desc="Extracting transformer embeddings")

            for data_dict in dataloader:
                input_gene_ids = data_dict["gene"].to(self.device)
                src_key_padding_mask = ~input_gene_ids.eq(self.collator_cfg["pad_token_id"])

                output = self.model(
                    genes=input_gene_ids,
                    values=data_dict["expr"].to(self.device),
                    gen_masks=data_dict["gen_mask"].to(self.device),
                    key_padding_mask=src_key_padding_mask,
                    drug_ids=(
                        data_dict["drug_ids"].to(self.device)
                        if "drug_ids" in data_dict
                        else None
                    ),
                    skip_decoders=True,
                )

                # Get transformer output and gene IDs from model output
                # Shape: (batch_size, seq_len, d_model)
                transformer_output = output["gene_emb"].to("cpu").to(dtype=torch.float32).numpy()
                batch_gene_ids = output["gene_ids"].to("cpu").numpy()

                # Split batch into individual cells
                for i in range(transformer_output.shape[0]):
                    all_embeddings.append(transformer_output[i])  # (seq_len, d_model)
                    all_gene_ids.append(batch_gene_ids[i])  # (seq_len,)

                pbar.update(1)

        LOGGER.info(f"Extracted transformer embeddings for {len(all_embeddings)} cells")
        return all_embeddings, all_gene_ids

    def decode_embeddings(
        self,
        gene_embeddings: List[np.ndarray],
        gene_ids: List[np.ndarray],
    ) -> List[pd.Series]:
        """Decode gene embeddings to predict expression values.

        This method takes gene-level embeddings (e.g., from the transformer) and
        uses the Tahoe expression decoder to predict gene expression values.
        The embeddings must be in the same sequence order as the original input.

        **Important**: Use `get_transformer_embeddings()` first to get embeddings
        and gene IDs, modify them if needed, then pass both to this method.

        Parameters
        ----------
        gene_embeddings : List[np.ndarray]
            List of gene embeddings, one array per cell.
            Each array has shape (seq_len, embedding_dim) containing transformer
            output embeddings in the same sequence order as the input.
        gene_ids : List[np.ndarray]
            List of gene vocabulary IDs, one array per cell.
            Each array has shape (seq_len,) corresponding to the embeddings.

        Returns
        -------
        List[pd.Series]
            List of pandas Series, one per cell. Each Series maps gene names
            (Ensembl IDs) to predicted expression values. Only includes non-padding genes.

        Example
        -------
        ```python
        from helical.models.tahoe import Tahoe, TahoeConfig
        import anndata as ad

        tahoe_config = TahoeConfig(model_size="70m", batch_size=8)
        tahoe = Tahoe(configurer=tahoe_config)

        ann_data = ad.read_h5ad("anndata_file.h5ad")
        dataloader = tahoe.process_data(ann_data)

        # Get transformer embeddings and gene IDs
        transformer_embs, gene_ids = tahoe.get_transformer_embeddings(dataloader)

        # Optional: Modify embeddings for perturbation experiments
        # Example - perturb gene at position 5 in first cell
        # transformer_embs[0][5, :] += 0.1

        # Decode embeddings to predicted expression
        expr_predictions = tahoe.decode_embeddings(transformer_embs, gene_ids)

        # Access predictions for first cell
        print(f"First cell predictions: {len(expr_predictions[0])} genes")
        for gene_name, pred_expr in list(expr_predictions[0].items())[:5]:
            print(f"  {gene_name}: {pred_expr:.4f}")
        ```

        Notes
        -----
        The decoder expects embeddings in the same format as the transformer output.
        Make sure your embeddings match the model's embedding dimension
        (512 for 70m, 1024 for 1b, 1536 for 3b).
        """
        LOGGER.info(f"Decoding embeddings for {len(gene_embeddings)} cells...")

        pad_token_id = self.collator_cfg["pad_token_id"]
        idx_to_gene = self.vocab.index_to_token

        all_predictions = []

        dtype_from_string = {
            "fp32": torch.float32,
            "amp_bf16": torch.bfloat16,
            "amp_fp16": torch.float16,
            "fp16": torch.float16,
            "bf16": torch.bfloat16,
        }

        with (
            torch.no_grad(),
            torch.amp.autocast(
                enabled=True,
                dtype=dtype_from_string[self.model_cfg["precision"]],
                device_type=self.device.type,
            ),
        ):
            for cell_emb, cell_gene_ids in zip(gene_embeddings, gene_ids):
                # Convert to tensor and add batch dimension
                # Shape: (1, seq_len, d_model)
                emb_tensor = torch.from_numpy(cell_emb).unsqueeze(0).to(torch.float32).to(self.device)

                # Pass through decoder
                decoder_output = self.model.expression_decoder(emb_tensor)
                expr_pred = decoder_output["pred"]  # (1, seq_len) or (1, seq_len, 1)

                # Remove batch dimension and convert to numpy
                expr_pred = expr_pred.squeeze(0).to("cpu").to(torch.float32).numpy()

                # Create Series mapping gene names to predictions (only non-padding)
                pred_dict = {}
                for pos, gene_id in enumerate(cell_gene_ids):
                    if gene_id != pad_token_id:
                        gene_name = idx_to_gene[gene_id]
                        pred_dict[gene_name] = float(expr_pred[pos])

                all_predictions.append(pd.Series(pred_dict))

        LOGGER.info(f"Finished decoding {len(all_predictions)} cells")
        return all_predictions

process_data(adata, gene_names='index', use_raw_counts=True)

Processes the data for the Tahoe model and returns a DataLoader.

Parameters:

Name Type Description Default
adata AnnData

The AnnData object containing the data to be processed. Tahoe uses Ensembl IDs to identify genes and currently supports only human genes. If the AnnData object already has an 'ensembl_id' column, the mapping step can be skipped.

required
gene_names str

The column in adata.var that contains the gene names. If set to a value other than "ensembl_id", the gene symbols in that column will be mapped to Ensembl IDs using the 'pyensembl' package. - If set to "index", the index of the AnnData object will be used and mapped to Ensembl IDs. - If set to "ensembl_id", no mapping will occur.

"index"
use_raw_counts bool

Determines whether raw counts should be used.

True

Returns:

Type Description
DataLoader

A PyTorch DataLoader ready for inference.

Source code in helical/models/tahoe/model.py
def process_data(
    self,
    adata: AnnData,
    gene_names: str = "index",
    use_raw_counts: bool = True,
) -> DataLoader:
    """
    Processes the data for the Tahoe model and returns a DataLoader.

    Parameters
    ----------
    adata : AnnData
        The AnnData object containing the data to be processed. Tahoe uses Ensembl IDs
        to identify genes and currently supports only human genes. If the AnnData object
        already has an 'ensembl_id' column, the mapping step can be skipped.
    gene_names : str, optional, default="index"
        The column in `adata.var` that contains the gene names. If set to a value other
        than "ensembl_id", the gene symbols in that column will be mapped to Ensembl IDs
        using the 'pyensembl' package.
        - If set to "index", the index of the AnnData object will be used and mapped to Ensembl IDs.
        - If set to "ensembl_id", no mapping will occur.
    use_raw_counts : bool, optional, default=True
        Determines whether raw counts should be used.

    Returns
    -------
    DataLoader
        A PyTorch DataLoader ready for inference.
    """
    LOGGER.info("Processing data for Tahoe.")
    self.ensure_rna_data_validity(adata, gene_names, use_raw_counts)

    # Map gene symbols to Ensembl IDs if provided
    if gene_names != "ensembl_id":
        if (adata.var[gene_names].str.startswith("ENS").all()) or (
            adata.var[gene_names].str.startswith("None").any()
        ):
            message = (
                "It seems an anndata with 'ensemble ids' and/or 'None' was passed. "
                "Please set gene_names='ensembl_id' and remove 'None's to skip mapping."
            )
            LOGGER.error(message)
            raise ValueError(message)
        adata = map_gene_symbols_to_ensembl_ids(adata, gene_names)

        if adata.var["ensembl_id"].isnull().all():
            message = "All gene symbols could not be mapped to Ensembl IDs. Please check the input data."
            LOGGER.error(message)
            raise ValueError(message)

    gene_id_key = "ensembl_id"

    # Map genes to vocabulary
    adata.var["id_in_vocab"] = [
        self.vocab[gene] if gene in self.vocab else -1
        for gene in adata.var[gene_id_key]
    ]

    gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
    n_matched = np.sum(gene_ids_in_vocab >= 0)
    n_total = len(gene_ids_in_vocab)

    LOGGER.info(
        f"Matched {n_matched}/{n_total} genes in vocabulary of size {len(self.vocab)}."
    )

    # Filter to genes in vocabulary
    adata = adata[:, adata.var["id_in_vocab"] >= 0]

    genes = adata.var[gene_id_key].tolist()
    gene_ids = np.array([self.vocab[gene] for gene in genes], dtype=int)

    if not np.all(gene_ids >= 0):
        raise ValueError("Some genes are not in the vocabulary after filtering.")

    dataloader = loader_from_adata(
        adata=adata,
        collator_cfg=self.collator_cfg,
        vocab=self.vocab,
        batch_size=self.config["batch_size"],
        max_length=self.config["max_length"],
        gene_ids=gene_ids,
        num_workers=self.config["num_workers"],
        prefetch_factor=self.config["prefetch_factor"],
    )

    LOGGER.info("Successfully processed the data for Tahoe.")
    return dataloader

get_embeddings(dataloader, return_gene_embeddings=False, output_attentions=False, attn_layer=-1)

Gets the embeddings from the Tahoe model.

Parameters:

Name Type Description Default
dataloader DataLoader

The DataLoader returned from process_data().

required
return_gene_embeddings bool

Whether to return gene embeddings for each cell in addition to cell embeddings. Gene embeddings are returned as a list of pandas Series, one per cell, where each Series contains the embeddings for genes expressed in that cell.

False
output_attentions bool

Whether to return attention weights from all transformer layers. Note: This requires the model to be initialized with attn_impl='torch'. The default Flash Attention (attn_impl='flash') does not support attention weight extraction for efficiency reasons.

False
attn_layer int

Which transformer layer's attention to return. Supports negative indexing (e.g. -1 for the last layer). Only used when output_attentions is True.

-1

Returns:

Type Description
ndarray or tuple

Depending on the combination of flags: - If both False: cell_embeddings (n_cells, embedding_dim) - If return_gene_embeddings=True only: (cell_embeddings, gene_embeddings) - If output_attentions=True only: (cell_embeddings, attentions) - If both True: (cell_embeddings, gene_embeddings, attentions)

Where: - cell_embeddings: numpy array of shape (n_cells, embedding_dim) - gene_embeddings: list of pandas Series, one per cell. Each Series contains gene embeddings indexed by Ensembl IDs for genes expressed in that cell. - attentions: list of per-sample numpy arrays, each of shape (n_heads, seq_length, seq_length). Sequence lengths vary per sample based on the number of genes expressed.

Source code in helical/models/tahoe/model.py
def get_embeddings(
    self,
    dataloader: DataLoader,
    return_gene_embeddings: bool = False,
    output_attentions: bool = False,
    attn_layer: int = -1,
) -> Union[np.ndarray, tuple]:
    """Gets the embeddings from the Tahoe model.

    Parameters
    ----------
    dataloader : DataLoader
        The DataLoader returned from process_data().
    return_gene_embeddings : bool, optional, default=False
        Whether to return gene embeddings for each cell in addition to cell embeddings.
        Gene embeddings are returned as a list of pandas Series, one per cell, where
        each Series contains the embeddings for genes expressed in that cell.
    output_attentions : bool, optional, default=False
        Whether to return attention weights from all transformer layers.
        Note: This requires the model to be initialized with attn_impl='torch'.
        The default Flash Attention (attn_impl='flash') does not support attention
        weight extraction for efficiency reasons.
    attn_layer : int, optional, default=-1
        Which transformer layer's attention to return. Supports negative indexing
        (e.g. -1 for the last layer). Only used when output_attentions is True.

    Returns
    -------
    np.ndarray or tuple
        Depending on the combination of flags:
        - If both False: cell_embeddings (n_cells, embedding_dim)
        - If return_gene_embeddings=True only: (cell_embeddings, gene_embeddings)
        - If output_attentions=True only: (cell_embeddings, attentions)
        - If both True: (cell_embeddings, gene_embeddings, attentions)

        Where:
        - cell_embeddings: numpy array of shape (n_cells, embedding_dim)
        - gene_embeddings: list of pandas Series, one per cell. Each Series contains
          gene embeddings indexed by Ensembl IDs for genes expressed in that cell.
        - attentions: list of per-sample numpy arrays, each of shape (n_heads, seq_length, seq_length).
          Sequence lengths vary per sample based on the number of genes expressed.
    """
    LOGGER.info("Extracting embeddings from Tahoe model...")

    # Check if attention extraction is requested but not supported
    if output_attentions:
        attn_impl = self.model_cfg.get("attn_config", {}).get("attn_impl", "flash")
        if attn_impl in ["flash", "triton"]:
            raise RuntimeError(
                f"Attention weight extraction is not supported with attn_impl='{attn_impl}'. "
                "Flash Attention is optimized for speed and memory efficiency and does not "
                "compute/store attention weights. To extract attention weights, initialize the model "
                "with attn_impl='torch':\n\n"
                "    tahoe_config = TahoeConfig(model_size='70m', attn_impl='torch')\n"
                "    tahoe = Tahoe(configurer=tahoe_config)"
            )

    self.model.return_gene_embeddings = return_gene_embeddings

    cell_embs: List[torch.Tensor] = []
    all_attentions: List[torch.Tensor] = [] if output_attentions else None
    all_gene_embeddings: List[pd.Series] = [] if return_gene_embeddings else None

    dtype_from_string = {
        "fp32": torch.float32,
        "amp_bf16": torch.bfloat16,
        "amp_fp16": torch.float16,
        "fp16": torch.float16,
        "bf16": torch.bfloat16,
    }

    with (
        torch.no_grad(),
        torch.amp.autocast(
            enabled=True,
            dtype=dtype_from_string[self.model_cfg["precision"]],
            device_type=self.device.type,
        ),
    ):
        pbar = tqdm(total=len(dataloader), desc="Embedding cells")

        for data_dict in dataloader:
            input_gene_ids = data_dict["gene"].to(self.device)
            src_key_padding_mask = ~input_gene_ids.eq(self.collator_cfg["pad_token_id"])

            output = self.model(
                genes=input_gene_ids,
                values=data_dict["expr"].to(self.device),
                gen_masks=data_dict["gen_mask"].to(self.device),
                key_padding_mask=src_key_padding_mask,
                drug_ids=(
                    data_dict["drug_ids"].to(self.device)
                    if "drug_ids" in data_dict
                    else None
                ),
                skip_decoders=True,
                output_attentions=output_attentions,
            )

            cell_embs.append(output["cell_emb"].to("cpu").to(dtype=torch.float32))

            if output_attentions:
                # Select the requested layer: (batch, n_heads, seq_len, seq_len)
                layer_attn = output["attentions"][attn_layer].cpu().to(torch.float32)
                all_attentions.append(layer_attn)

            if return_gene_embeddings:
                # Get gene embeddings for this batch: shape (batch_size, seq_len, d_model)
                gene_embs = output.get("gene_emb").to(torch.float32).cpu().numpy()
                gene_ids = input_gene_ids.cpu().numpy()

                # Create a pandas Series for each cell in the batch
                for i in range(gene_embs.shape[0]):
                    cell_gene_dict = {}
                    for j in range(gene_embs.shape[1]):
                        gene_id = gene_ids[i, j]
                        if gene_id != self.collator_cfg["pad_token_id"]:
                            gene_name = self.vocab.index_to_token[gene_id]
                            gene_embedding = gene_embs[i, j]
                            # Normalize the gene embedding
                            gene_embedding = gene_embedding / np.linalg.norm(gene_embedding)
                            cell_gene_dict[gene_name] = gene_embedding

                    all_gene_embeddings.append(pd.Series(cell_gene_dict))

            pbar.update(1)

    # Normalize cell embeddings
    cell_array = torch.cat(cell_embs, dim=0).numpy()
    cell_array = cell_array / np.linalg.norm(
        cell_array,
        axis=1,
        keepdims=True,
    )


    # Prepare attention list if requested — one np.ndarray per sample
    if output_attentions:
        attn_list = []
        for attn in all_attentions:
            # attn shape: (batch, n_heads, seq_len, seq_len)
            attn_list.extend(attn.numpy())

    # Return based on requested outputs
    log_msg = f"Finished extracting embeddings. Cell shape: {cell_array.shape}"
    if return_gene_embeddings:
        log_msg += f", Gene embeddings: {len(all_gene_embeddings)} cells"
    if output_attentions:
        log_msg += f", Attention maps: {len(attn_list)} samples"
    LOGGER.info(log_msg)

    # Return appropriate combination
    if return_gene_embeddings and output_attentions:
        return cell_array, all_gene_embeddings, attn_list
    elif return_gene_embeddings:
        return cell_array, all_gene_embeddings
    elif output_attentions:
        return cell_array, attn_list
    else:
        return cell_array

get_transformer_embeddings(dataloader)

Get raw transformer embeddings before the decoder.

This method returns the transformer output embeddings along with the gene IDs for each position. This is useful for perturbation experiments where you want to modify embeddings and then decode them to predicted expression.

Parameters:

Name Type Description Default
dataloader DataLoader

The DataLoader returned from process_data().

required

Returns:

Type Description
tuple of (list, list)
  • transformer_embeddings: list of numpy arrays, one per cell Each array has shape (seq_len, embedding_dim) containing the transformer output embeddings for that cell's genes
  • gene_ids: list of numpy arrays, one per cell Each array has shape (seq_len,) containing the gene vocabulary IDs for that cell (pad_token_id for padding positions)
Example
from helical.models.tahoe import Tahoe, TahoeConfig
import anndata as ad

tahoe_config = TahoeConfig(model_size="70m", batch_size=8)
tahoe = Tahoe(configurer=tahoe_config)

ann_data = ad.read_h5ad("anndata_file.h5ad")
dataloader = tahoe.process_data(ann_data)

# Get transformer embeddings and gene IDs
transformer_embs, gene_ids = tahoe.get_transformer_embeddings(dataloader)

# Each is a list with one entry per cell
print(f"Number of cells: {len(transformer_embs)}")
print(f"First cell embedding shape: {transformer_embs[0].shape}")
print(f"First cell gene IDs shape: {gene_ids[0].shape}")

# Modify embeddings (e.g., perturb specific genes in first cell)
# transformer_embs[0][5, :] += 0.1  # perturb gene at position 5

# Decode modified embeddings to predicted expression
expr_pred = tahoe.decode_embeddings(transformer_embs, gene_ids)
Source code in helical/models/tahoe/model.py
def get_transformer_embeddings(
    self,
    dataloader: DataLoader,
) -> tuple:
    """Get raw transformer embeddings before the decoder.

    This method returns the transformer output embeddings along with the gene IDs
    for each position. This is useful for perturbation experiments where you want
    to modify embeddings and then decode them to predicted expression.

    Parameters
    ----------
    dataloader : DataLoader
        The DataLoader returned from process_data().

    Returns
    -------
    tuple of (list, list)
        - transformer_embeddings: list of numpy arrays, one per cell
          Each array has shape (seq_len, embedding_dim) containing the transformer
          output embeddings for that cell's genes
        - gene_ids: list of numpy arrays, one per cell
          Each array has shape (seq_len,) containing the gene vocabulary IDs
          for that cell (pad_token_id for padding positions)

    Example
    -------
    ```python
    from helical.models.tahoe import Tahoe, TahoeConfig
    import anndata as ad

    tahoe_config = TahoeConfig(model_size="70m", batch_size=8)
    tahoe = Tahoe(configurer=tahoe_config)

    ann_data = ad.read_h5ad("anndata_file.h5ad")
    dataloader = tahoe.process_data(ann_data)

    # Get transformer embeddings and gene IDs
    transformer_embs, gene_ids = tahoe.get_transformer_embeddings(dataloader)

    # Each is a list with one entry per cell
    print(f"Number of cells: {len(transformer_embs)}")
    print(f"First cell embedding shape: {transformer_embs[0].shape}")
    print(f"First cell gene IDs shape: {gene_ids[0].shape}")

    # Modify embeddings (e.g., perturb specific genes in first cell)
    # transformer_embs[0][5, :] += 0.1  # perturb gene at position 5

    # Decode modified embeddings to predicted expression
    expr_pred = tahoe.decode_embeddings(transformer_embs, gene_ids)
    ```
    """
    LOGGER.info("Extracting transformer embeddings...")

    self.model.return_gene_embeddings = True
    all_embeddings: List[np.ndarray] = []
    all_gene_ids: List[np.ndarray] = []

    dtype_from_string = {
        "fp32": torch.float32,
        "amp_bf16": torch.bfloat16,
        "amp_fp16": torch.float16,
        "fp16": torch.float16,
        "bf16": torch.bfloat16,
    }

    with (
        torch.no_grad(),
        torch.amp.autocast(
            enabled=True,
            dtype=dtype_from_string[self.model_cfg["precision"]],
            device_type=self.device.type,
        ),
    ):
        pbar = tqdm(total=len(dataloader), desc="Extracting transformer embeddings")

        for data_dict in dataloader:
            input_gene_ids = data_dict["gene"].to(self.device)
            src_key_padding_mask = ~input_gene_ids.eq(self.collator_cfg["pad_token_id"])

            output = self.model(
                genes=input_gene_ids,
                values=data_dict["expr"].to(self.device),
                gen_masks=data_dict["gen_mask"].to(self.device),
                key_padding_mask=src_key_padding_mask,
                drug_ids=(
                    data_dict["drug_ids"].to(self.device)
                    if "drug_ids" in data_dict
                    else None
                ),
                skip_decoders=True,
            )

            # Get transformer output and gene IDs from model output
            # Shape: (batch_size, seq_len, d_model)
            transformer_output = output["gene_emb"].to("cpu").to(dtype=torch.float32).numpy()
            batch_gene_ids = output["gene_ids"].to("cpu").numpy()

            # Split batch into individual cells
            for i in range(transformer_output.shape[0]):
                all_embeddings.append(transformer_output[i])  # (seq_len, d_model)
                all_gene_ids.append(batch_gene_ids[i])  # (seq_len,)

            pbar.update(1)

    LOGGER.info(f"Extracted transformer embeddings for {len(all_embeddings)} cells")
    return all_embeddings, all_gene_ids

decode_embeddings(gene_embeddings, gene_ids)

Decode gene embeddings to predict expression values.

This method takes gene-level embeddings (e.g., from the transformer) and uses the Tahoe expression decoder to predict gene expression values. The embeddings must be in the same sequence order as the original input.

Important: Use get_transformer_embeddings() first to get embeddings and gene IDs, modify them if needed, then pass both to this method.

Parameters:

Name Type Description Default
gene_embeddings List[ndarray]

List of gene embeddings, one array per cell. Each array has shape (seq_len, embedding_dim) containing transformer output embeddings in the same sequence order as the input.

required
gene_ids List[ndarray]

List of gene vocabulary IDs, one array per cell. Each array has shape (seq_len,) corresponding to the embeddings.

required

Returns:

Type Description
List[Series]

List of pandas Series, one per cell. Each Series maps gene names (Ensembl IDs) to predicted expression values. Only includes non-padding genes.

Example
from helical.models.tahoe import Tahoe, TahoeConfig
import anndata as ad

tahoe_config = TahoeConfig(model_size="70m", batch_size=8)
tahoe = Tahoe(configurer=tahoe_config)

ann_data = ad.read_h5ad("anndata_file.h5ad")
dataloader = tahoe.process_data(ann_data)

# Get transformer embeddings and gene IDs
transformer_embs, gene_ids = tahoe.get_transformer_embeddings(dataloader)

# Optional: Modify embeddings for perturbation experiments
# Example - perturb gene at position 5 in first cell
# transformer_embs[0][5, :] += 0.1

# Decode embeddings to predicted expression
expr_predictions = tahoe.decode_embeddings(transformer_embs, gene_ids)

# Access predictions for first cell
print(f"First cell predictions: {len(expr_predictions[0])} genes")
for gene_name, pred_expr in list(expr_predictions[0].items())[:5]:
    print(f"  {gene_name}: {pred_expr:.4f}")
Notes

The decoder expects embeddings in the same format as the transformer output. Make sure your embeddings match the model's embedding dimension (512 for 70m, 1024 for 1b, 1536 for 3b).

Source code in helical/models/tahoe/model.py
def decode_embeddings(
    self,
    gene_embeddings: List[np.ndarray],
    gene_ids: List[np.ndarray],
) -> List[pd.Series]:
    """Decode gene embeddings to predict expression values.

    This method takes gene-level embeddings (e.g., from the transformer) and
    uses the Tahoe expression decoder to predict gene expression values.
    The embeddings must be in the same sequence order as the original input.

    **Important**: Use `get_transformer_embeddings()` first to get embeddings
    and gene IDs, modify them if needed, then pass both to this method.

    Parameters
    ----------
    gene_embeddings : List[np.ndarray]
        List of gene embeddings, one array per cell.
        Each array has shape (seq_len, embedding_dim) containing transformer
        output embeddings in the same sequence order as the input.
    gene_ids : List[np.ndarray]
        List of gene vocabulary IDs, one array per cell.
        Each array has shape (seq_len,) corresponding to the embeddings.

    Returns
    -------
    List[pd.Series]
        List of pandas Series, one per cell. Each Series maps gene names
        (Ensembl IDs) to predicted expression values. Only includes non-padding genes.

    Example
    -------
    ```python
    from helical.models.tahoe import Tahoe, TahoeConfig
    import anndata as ad

    tahoe_config = TahoeConfig(model_size="70m", batch_size=8)
    tahoe = Tahoe(configurer=tahoe_config)

    ann_data = ad.read_h5ad("anndata_file.h5ad")
    dataloader = tahoe.process_data(ann_data)

    # Get transformer embeddings and gene IDs
    transformer_embs, gene_ids = tahoe.get_transformer_embeddings(dataloader)

    # Optional: Modify embeddings for perturbation experiments
    # Example - perturb gene at position 5 in first cell
    # transformer_embs[0][5, :] += 0.1

    # Decode embeddings to predicted expression
    expr_predictions = tahoe.decode_embeddings(transformer_embs, gene_ids)

    # Access predictions for first cell
    print(f"First cell predictions: {len(expr_predictions[0])} genes")
    for gene_name, pred_expr in list(expr_predictions[0].items())[:5]:
        print(f"  {gene_name}: {pred_expr:.4f}")
    ```

    Notes
    -----
    The decoder expects embeddings in the same format as the transformer output.
    Make sure your embeddings match the model's embedding dimension
    (512 for 70m, 1024 for 1b, 1536 for 3b).
    """
    LOGGER.info(f"Decoding embeddings for {len(gene_embeddings)} cells...")

    pad_token_id = self.collator_cfg["pad_token_id"]
    idx_to_gene = self.vocab.index_to_token

    all_predictions = []

    dtype_from_string = {
        "fp32": torch.float32,
        "amp_bf16": torch.bfloat16,
        "amp_fp16": torch.float16,
        "fp16": torch.float16,
        "bf16": torch.bfloat16,
    }

    with (
        torch.no_grad(),
        torch.amp.autocast(
            enabled=True,
            dtype=dtype_from_string[self.model_cfg["precision"]],
            device_type=self.device.type,
        ),
    ):
        for cell_emb, cell_gene_ids in zip(gene_embeddings, gene_ids):
            # Convert to tensor and add batch dimension
            # Shape: (1, seq_len, d_model)
            emb_tensor = torch.from_numpy(cell_emb).unsqueeze(0).to(torch.float32).to(self.device)

            # Pass through decoder
            decoder_output = self.model.expression_decoder(emb_tensor)
            expr_pred = decoder_output["pred"]  # (1, seq_len) or (1, seq_len, 1)

            # Remove batch dimension and convert to numpy
            expr_pred = expr_pred.squeeze(0).to("cpu").to(torch.float32).numpy()

            # Create Series mapping gene names to predictions (only non-padding)
            pred_dict = {}
            for pos, gene_id in enumerate(cell_gene_ids):
                if gene_id != pad_token_id:
                    gene_name = idx_to_gene[gene_id]
                    pred_dict[gene_name] = float(expr_pred[pos])

            all_predictions.append(pd.Series(pred_dict))

    LOGGER.info(f"Finished decoding {len(all_predictions)} cells")
    return all_predictions