Skip to content

Model

helical.models.c2s.Cell2Sen

Bases: HelicalBaseFoundationModel

Cell2Sen Model.

The Cell2Sen Model is a transformer/Gemma-based model that can be used to generate cell sentences from gene expression data.

Example
from helical.models.cell2sen import Cell2Sen, Cell2SenConfig
import anndata as ad

config = Cell2SenConfig(batch_size=16)
cell2sen = Cell2Sen(configurer=config)

# Process your data
dataloader = cell2sen.process_data(adata)

# Get embeddings
embeddings = cell2sen.get_embeddings(dataloader)
print("State embeddings shape:", embeddings.shape)

Parameters:

Name Type Description Default
configurer Cell2SenConfig

The model configuration. If None, uses default Cell2SenConfig.

None
Source code in helical/models/c2s/model.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
class Cell2Sen(HelicalBaseFoundationModel):
    """
    Cell2Sen Model.

    The Cell2Sen Model is a transformer/Gemma-based model that can be used to generate cell sentences from gene expression data.

    Example
    -------
    ```python
    from helical.models.cell2sen import Cell2Sen, Cell2SenConfig
    import anndata as ad

    config = Cell2SenConfig(batch_size=16)
    cell2sen = Cell2Sen(configurer=config)

    # Process your data
    dataloader = cell2sen.process_data(adata)

    # Get embeddings
    embeddings = cell2sen.get_embeddings(dataloader)
    print("State embeddings shape:", embeddings.shape)
    ```

    Parameters
    ----------
    configurer : Cell2SenConfig, optional, default=None
        The model configuration. If None, uses default Cell2SenConfig.

    """

    def __init__(self, configurer: Cell2SenConfig = None) -> None:
        super().__init__()

        if configurer is None:
            self.config = Cell2SenConfig().config
        else:
            self.config = configurer.config

        # downloader = Downloader()
        # for file in self.config["list_of_files_to_download"]:
        #     downloader.download_via_name(file)

        self.device = self.config["device"]
        if "cuda" in self.device and self.config["use_flash_attn"]:
            LOGGER.info("Using flash attention 2 for attention implementation")
            self.attn_implementation = "flash_attention_2"
        else:
            LOGGER.info("Using SDPA for attention implementation - default for CPU")
            self.attn_implementation = "sdpa"

        if self.config["dtype"] == "bfloat16":
            self.torch_dtype = torch.bfloat16
        elif self.config["dtype"] == "float32":
            self.torch_dtype = torch.float32
        else:
            raise ValueError(f"Dtype {self.config['dtype']} not supported. Please choose from 'bfloat16' or 'float32'.")

        if self.torch_dtype == torch.bfloat16 and self.device == "cpu":
            LOGGER.warning("Bfloat16 is not supported on CPU. Defaulting to 'float32' instead.")
            self.torch_dtype = torch.float32

        if self.config["use_quantization"]:
            self.bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=self.torch_dtype
            )
        else:
            self.bnb_config = None

        self.model = AutoModelForCausalLM.from_pretrained(
            self.config["hf_model_path"],
            torch_dtype=self.torch_dtype, 
            cache_dir=self.config["model_path"],
            quantization_config=self.bnb_config,
            attn_implementation=self.attn_implementation,
            device_map=self.device
            )

        self.tokenizer = AutoTokenizer.from_pretrained(self.config["hf_model_path"], cache_dir=self.config["model_path"])
        self.model.eval()

        self.batch_size = self.config['batch_size']
        self.max_new_tokens = self.config['max_new_tokens']
        self.organism = self.config['organism']
        self.perturbation_column = self.config['perturbation_column']
        self.return_fit = self.config['return_fit']
        self.max_genes = self.config['max_genes']
        self.aggregation_type = self.config["aggregation_type"]
        self.embedding_prompt_template = self.config["embedding_prompt_template"]

        LOGGER.info("Successfully loaded model")

    @staticmethod
    def _gene_ids_from_offsets(prompt, cell_sentence, offsets):
        """
        Build a per-token gene id list from character offsets.

        Tokens that fall within a gene's character span get the gene's
        index; all other tokens (prompt text, special tokens, padding)
        get None.

        Parameters
        ----------
        prompt : str
            The full prompt string.
        cell_sentence : str
            Space-separated gene names embedded in the prompt.
        offsets : list[tuple[int, int]]
            Per-token (char_start, char_end) from ``return_offsets_mapping=True``.

        Returns
        -------
        gene_ids : list[int | None]
            Gene index for each token, or None.
        """
        cs_start = prompt.find(cell_sentence)
        genes = cell_sentence.split()

        # character ranges for each gene (in prompt coordinates)
        gene_ranges = []
        pos = cs_start
        for g in genes:
            gs = prompt.index(g, pos)
            gene_ranges.append((gs, gs + len(g)))
            pos = gs + len(g)

        gene_ids = [None] * len(offsets)
        for tok_idx, (ts, te) in enumerate(offsets):
            if ts == te:
                continue
            for gi, (gs, ge) in enumerate(gene_ranges):
                if ts < ge and te > gs:
                    gene_ids[tok_idx] = gi
                    break
        return gene_ids


    @staticmethod
    def _aggregate_token_to_word_attention(attn, word_ids):
        """
        Aggregate token-level attention to word-level.
        Works with both numpy arrays (CPU) and torch tensors (GPU).
        """
        use_torch = isinstance(attn, torch.Tensor)

        # Build word_to_tokens mapping
        word_to_tokens = {}
        for tok_idx, wid in enumerate(word_ids):
            if wid is not None:
                word_to_tokens.setdefault(wid, []).append(tok_idx)

        num_words = len(word_to_tokens)
        if num_words == 0:
            LOGGER.warning("No words found in attention map. Returning empty array.")
            return torch.zeros((attn.shape[0], 0, 0), dtype=attn.dtype, device=attn.device)

        sorted_word_ids = sorted(word_to_tokens.keys())
        num_heads, seq_len, _ = attn.shape


        # GPU path - much faster
        W = torch.zeros((num_words, seq_len), dtype=attn.dtype, device=attn.device)
        V = torch.zeros((num_words, seq_len), dtype=attn.dtype, device=attn.device)

        for wi, wid in enumerate(sorted_word_ids):
            token_indices = word_to_tokens[wid]
            W[wi, token_indices] = 1.0 / len(token_indices)
            V[wi, token_indices] = 1.0

        temp = torch.einsum('wt,htk->hwk', W, attn)
        word_attn = torch.einsum('hwk,vk->hwv', temp, V)
        return word_attn.float().cpu().numpy()  # .float() converts bfloat16->float32



    def process_data(
        self, 
        adata: anndata.AnnData, 
    ):
        """
        Process anndata to create a HuggingFace Dataset with cell sentences and fit parameters.

        Parameters:
        -----------
        anndata : AnnData
            Annotated data object with gene expression
        max_genes : int, optional
            Maximum number of genes to process per cell in descending expression order
        Returns:
        --------
        dataset : Dataset
            HuggingFace Dataset with fields: cell_sentence, fit_parameters, organism, perturbations
        """

        LOGGER.info("Processing data")
        if adata.n_obs == 0:
            raise ValueError("Anndata is empty. Please provide a valid anndata object.")

        # standard log-normalization, enables accurate expression reconstruction
        anndata = adata.copy()
        sc.pp.normalize_total(anndata, target_sum=1e4)
        sc.pp.log1p(anndata, base=10)

        X = anndata.X    
        cell_sentences = []

        # Collect ranks and corresponding expression means as training data for reconstruction model
        rank_to_mean = {}  
        rank_to_count = {} 

        if self.organism is None:
            if 'organism' in anndata.uns:
                self.organism = anndata.uns['organism']
            elif 'organism' in anndata.obs.columns:
                # If organism varies per cell, use first one or most common
                self.organism = anndata.obs['organism'].iloc[0] if len(anndata.obs['organism'].unique()) == 1 else anndata.obs['organism'].mode()[0]
            elif 'species' in anndata.uns:
                self.organism = anndata.uns['species']
            elif 'species' in anndata.obs.columns:
                self.organism = anndata.obs['species'].iloc[0] if len(anndata.obs['species'].unique()) == 1 else anndata.obs['species'].mode()[0]
            else:
                self.organism = "unknown"  # Default if not found

        # Process each cell
        progress_bar = tqdm(total=X.shape[0], desc="Processing cells")
        for cell_idx in range(X.shape[0]):

            row = X[cell_idx]

            if issparse(row):
                gene_indices = row.indices
                expr_values = row.data
            else:
                # Dense fallback (rare)
                gene_indices = np.where(row > 0)[0]
                expr_values = row[gene_indices]

            if len(expr_values) == 0:
                LOGGER.warning(f"No genes expressed above zero in cell {cell_idx}. Using empty sentence.")
                cell_sentences.append("")
                progress_bar.update(1)
                continue

            gene_names = anndata.var_names.values[gene_indices]
            # Sort by expression descending
            ranked = np.argsort(expr_values)[::-1]
            expr_values = expr_values[ranked]
            gene_names = gene_names[ranked]

            # Cut at max_genes if desired
            if self.max_genes:
                if len(gene_names) > self.max_genes:
                    gene_names = gene_names[:self.max_genes]
                    expr_values = expr_values[:self.max_genes]

            if self.return_fit:
                ranks = np.arange(1, len(gene_names) + 1)
                for rank, expr in zip(ranks, expr_values):
                    r = int(rank)

                    if r not in rank_to_mean:
                        # first time seeing this rank
                        rank_to_mean[r] = expr
                        rank_to_count[r] = 1
                    else:
                        # online mean update
                        count = rank_to_count[r] + 1
                        old_mean = rank_to_mean[r]
                        new_mean = old_mean + (expr - old_mean) / count

                        rank_to_mean[r] = new_mean
                        rank_to_count[r] = count


            cell_sentence = " ".join(gene_names)           
            cell_sentences.append(cell_sentence)
            progress_bar.update(1)


        if self.return_fit:
            log_ranks_to_fit = np.log10(list(rank_to_mean.keys()))
            expr_to_fit = np.array(list(rank_to_mean.values()))

            # Fit linear model to predict log-normalized expression from log rank: expr(g) = slope * log(rank(g)) = intercept
            model = LinearRegression()
            model.fit(log_ranks_to_fit.reshape(-1, 1), np.array(expr_to_fit))
            slope, intercept = model.coef_[0], model.intercept_
            r_squared = model.score(log_ranks_to_fit.reshape(-1, 1), expr_to_fit)

            fit_parameters = {"slope": float(slope), "intercept": float(intercept), "r_squared": float(r_squared)}

        else:
            fit_parameters = None

        progress_bar.close()

        if self.perturbation_column is not None:
            perturbations = anndata.obs[self.perturbation_column].values.tolist()
            if len(perturbations) != len(cell_sentences):
                raise ValueError(f"Number of perturbations ({len(perturbations)}) does not match number of cells ({len(cell_sentences)})")
        else:
            perturbations = [None] * len(cell_sentences)

        dataset = Dataset.from_dict({
            'cell_sentence': cell_sentences,
            'fit_parameters': [fit_parameters] * len(cell_sentences),
            'organism': [self.organism] * len(cell_sentences),
            'perturbations': perturbations
        })

        LOGGER.info("Successfully processed data")

        return dataset

    def get_embeddings(
        self,
        dataset: Dataset,
        output_attentions: bool = False,
        emb_layer: int = -1,
        ):
        """
        Extract embeddings from cell sentences in a HuggingFace Dataset using the last hidden layer of Gemma.

        Parameters:
        -----------
        dataset : Dataset
            HuggingFace Dataset with 'cell_sentence' and 'organism' fields

        output_attentions : bool, optional
            Whether to output the attention maps from the model. If set to True, the attention maps will be returned along with the embeddings.
            If set to False, only the embeddings will be returned. **Note**: This will increase the memory usage of the model significantly, so use it only if you need the attention maps.

        emb_layer : int, optional
            Which layer to extract attention from (default: -1, i.e. last layer).
            Only used when output_attentions=True.
            Only one layer of attention can be returned at a time.

        Returns:
        --------
        embeddings : np.ndarray
            Embeddings of shape (num_sentences, hidden_size)
        attn_list : list, optional
            If output_attentions=True, a list of gene-level attention arrays,
            one per sample, each of shape (num_heads, num_genes, num_genes).
        gene_names_list : list, optional
            If output_attentions=True, a list of gene name lists, one list per sample,
            e.g. [['geneA', 'geneB', ...], ['geneX', 'geneY', ...], ...]. This is used to attention values to specific genes.
        """

        LOGGER.info("Extracting embeddings from dataset")

        if output_attentions:
            # SDPA and FlashAttention do not support returning attention maps;
            # override to eager on the model config so all layers use it.
            self.model.config._attn_implementation = "eager"

        sentences_list = dataset['cell_sentence']
        organisms_list = dataset['organism']

        all_embeddings = []
        all_attentions = [[]]  # Single list for the one layer we process

        progress_bar = tqdm(total=len(sentences_list), desc="Processing embeddings")
        for i in range(0, len(sentences_list), self.batch_size):
            batch_sentences = sentences_list[i:i + self.batch_size]
            batch_organisms = organisms_list[i:i + self.batch_size]

            if self.embedding_prompt_template is None:
                prompts = [
                    EMBEDDING_PROMPT.format(organism=org, cell_sentence=cs)
                    for org, cs in zip(batch_organisms, batch_sentences)
                ]
            else:
                prompts = [
                    self.embedding_prompt_template.format(organism=org, cell_sentence=cs)
                    for org, cs in zip(batch_organisms, batch_sentences)
                ]

            inputs = self.tokenizer(
                prompts,
                return_tensors="pt",
                padding=True,
                truncation=False,
                return_offsets_mapping=output_attentions,
                # truncation=True,
                # max_length=max_length
            )
            # offset_mapping is not a tensor; grab it before .to(device)
            if output_attentions:
                batch_offsets = inputs.pop("offset_mapping")
            inputs = inputs.to(self.device)

            with torch.no_grad():
                outputs = self.model(
                    **inputs,
                    output_hidden_states=True,
                    output_attentions=output_attentions
                )
                last_hidden = outputs.hidden_states[-1]               # (B, L, H)
                attention_mask = inputs['attention_mask'].float()    # (B, L)

                if self.aggregation_type == 'mean_pool':
                    # mean pooling over non-padding tokens
                    masked_hidden = last_hidden * attention_mask.unsqueeze(-1)   # (B, L, H)
                    sum_embeddings = masked_hidden.sum(dim=1)                    # (B, H)
                    sum_mask = attention_mask.sum(dim=1, keepdim=True).clamp(min=1e-9)
                    batch_embeddings = sum_embeddings / sum_mask                 # (B, H)

                elif self.aggregation_type == 'last_token':
                    # index of last non-padding token
                    last_idx = (attention_mask.sum(dim=1) - 1).long()            # (B,)

                    # gather token representations
                    batch_embeddings = last_hidden[
                        torch.arange(last_hidden.size(0), device=last_hidden.device),
                        last_idx
                    ]  # (B, H)

                else:   
                    raise ValueError("Invalid aggregation type. Use 'mean_pool' or 'last_token'.")                                             

                if output_attentions:
                    # outputs.attentions is a tuple of tensors, one per layer
                    # Each tensor has shape (batch_size, num_heads, seq_length, seq_length)
                    # Only process the selected layer (emb_layer)


                    batch_size_actual = inputs['input_ids'].shape[0]
                    attn = outputs.attentions[emb_layer] #allow returning only one layer of attention 
                    word_attns = []
                    for b in range(batch_size_actual):
                        offsets_b = batch_offsets[b].tolist()
                        gene_ids = self._gene_ids_from_offsets(
                            prompts[b], batch_sentences[b], offsets_b
                        )
                        word_attns.append(
                            self._aggregate_token_to_word_attention(attn[b], gene_ids)
                        )
                    all_attentions[0].append(word_attns)
                del outputs

            all_embeddings.append(batch_embeddings.float().cpu().numpy())
            progress_bar.update(len(batch_sentences))
        progress_bar.close()
        LOGGER.info("Successfully extracted embeddings")

        if output_attentions:
            # Restore the original attention implementation
            self.model.config._attn_implementation = self.attn_implementation

            # Flatten per-batch lists for the selected layer
            stacked_attentions = [
                [arr for batch_list in all_attentions[0] for arr in batch_list]
            ]
            # Return only the selected layer as a flat list (like Geneformer)
            attn_list = stacked_attentions[0]
            # Gene names per sample from cell sentences
            gene_names_list = [sentence.split() for sentence in sentences_list]
            return np.concatenate(all_embeddings, axis=0), attn_list, gene_names_list
        else:
            return np.concatenate(all_embeddings, axis=0)

    def get_perturbations(
        self, 
        dataset: Dataset, 
        perturbations_list: list[str] = None, 
        ):
        """
        Generate perturbed cell sentences using the model.

        Parameters:
        -----------
        dataset : Dataset
            HuggingFace Dataset with 'cell_sentence' and 'perturbations' fields

        perturbations_list : list[str], optional
            List of perturbations to apply to the cells. If None, uses the perturbations from the dataset.
            If provided, overrides the perturbations in the dataset. E.g. ["pert1", "pert2", "pert3", ...]

        Returns:
        --------
        perturbed_dataset : Dataset
            HuggingFace Dataset with 'cell_sentence' and 'perturbations' fields and a new column 'perturbed_cell_sentence'

        perturbed_sentences : list
            List of perturbed cell sentences (strings)
        """

        LOGGER.info("Generating perturbed cell sentences")

        sentences_list = dataset['cell_sentence']
        organisms_list = dataset['organism']
        if perturbations_list is None:
            perturbations_list = dataset['perturbations']
        else:
            if len(perturbations_list) != len(sentences_list):
                raise ValueError(f"perturbations_list length ({len(perturbations_list)}) must match dataset length ({len(sentences_list)})")

        # Handle None perturbations - skip those entries or use empty string
        valid_indices = [i for i, p in enumerate(perturbations_list) if p is not None]
        if len(valid_indices) == 0:
                raise ValueError("No valid perturbations found in dataset. All perturbations are None.")

        valid_sentences = [sentences_list[i] for i in valid_indices]
        valid_perturbations = [perturbations_list[i] for i in valid_indices]
        valid_organisms = [organisms_list[i] for i in valid_indices]    
        all_perturbed = []
        # Process in batches
        progress_bar = tqdm(total=len(valid_sentences), desc="Processing valid perturbations")
        for i in range(0, len(valid_sentences), self.batch_size):
            batch_cells = valid_sentences[i:i + self.batch_size]
            batch_perturbs = valid_perturbations[i:i + self.batch_size]
            batch_organisms = valid_organisms[i:i + self.batch_size]

            prompts = [
                PERTURBATION_PROMPT.format(
                    organism=org,
                    perturbation=pert,  # Changed from perturbation_in_words
                    cell_sentence=cs
                )
                for org, pert, cs in zip(batch_organisms, batch_perturbs, batch_cells)
            ]

            inputs = self.tokenizer(
                prompts,
                return_tensors="pt",
                padding=True,
                truncation=False,
                # truncation=True,
                # max_length=max_length
            ).to(self.device)

            with torch.no_grad():
                if self.config["use_quantization"]:
                    # Disable torch.compile entirely for quantized models
                    # suppress_errors alone isn't sufficient - we need to disable compilation
                    original_disable = torch._dynamo.config.disable
                    original_suppress = torch._dynamo.config.suppress_errors
                    torch._dynamo.config.disable = True
                    torch._dynamo.config.suppress_errors = True
                    try:
                        outputs = self.model.generate(
                            **inputs,
                            max_new_tokens=self.max_new_tokens,
                            do_sample=False,
                            pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
                        )
                    finally:
                        torch._dynamo.config.disable = original_disable
                        torch._dynamo.config.suppress_errors = original_suppress
                else:
                    outputs = self.model.generate(
                        **inputs,
                        max_new_tokens=self.max_new_tokens,
                        do_sample=False,
                        pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
                    )

            input_lengths = inputs['attention_mask'].sum(dim=1)
            batch_perturbed = []

            for j, output in enumerate(outputs):
                # Extract only the generated tokens (skip the prompt)
                input_length = input_lengths[j].item()
                generated_tokens = output[input_length:]  # Only generated part
                decoded = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
                batch_perturbed.append(decoded.strip())

            all_perturbed.extend(batch_perturbed)
            progress_bar.update(len(batch_cells))
        progress_bar.close()
        # Create result list with None for entries without perturbations
        perturbed_sentences = [None] * len(sentences_list)
        for idx, perturbed in zip(valid_indices, all_perturbed):
            perturbed_sentences[idx] = perturbed

        dataset = dataset.add_column('perturbed_cell_sentence', perturbed_sentences)

        LOGGER.info("Successfully generated perturbed cell sentences")

        return dataset, perturbed_sentences

process_data(adata)

Process anndata to create a HuggingFace Dataset with cell sentences and fit parameters.

Parameters:

anndata : AnnData Annotated data object with gene expression max_genes : int, optional Maximum number of genes to process per cell in descending expression order

Returns:

dataset : Dataset HuggingFace Dataset with fields: cell_sentence, fit_parameters, organism, perturbations

Source code in helical/models/c2s/model.py
def process_data(
    self, 
    adata: anndata.AnnData, 
):
    """
    Process anndata to create a HuggingFace Dataset with cell sentences and fit parameters.

    Parameters:
    -----------
    anndata : AnnData
        Annotated data object with gene expression
    max_genes : int, optional
        Maximum number of genes to process per cell in descending expression order
    Returns:
    --------
    dataset : Dataset
        HuggingFace Dataset with fields: cell_sentence, fit_parameters, organism, perturbations
    """

    LOGGER.info("Processing data")
    if adata.n_obs == 0:
        raise ValueError("Anndata is empty. Please provide a valid anndata object.")

    # standard log-normalization, enables accurate expression reconstruction
    anndata = adata.copy()
    sc.pp.normalize_total(anndata, target_sum=1e4)
    sc.pp.log1p(anndata, base=10)

    X = anndata.X    
    cell_sentences = []

    # Collect ranks and corresponding expression means as training data for reconstruction model
    rank_to_mean = {}  
    rank_to_count = {} 

    if self.organism is None:
        if 'organism' in anndata.uns:
            self.organism = anndata.uns['organism']
        elif 'organism' in anndata.obs.columns:
            # If organism varies per cell, use first one or most common
            self.organism = anndata.obs['organism'].iloc[0] if len(anndata.obs['organism'].unique()) == 1 else anndata.obs['organism'].mode()[0]
        elif 'species' in anndata.uns:
            self.organism = anndata.uns['species']
        elif 'species' in anndata.obs.columns:
            self.organism = anndata.obs['species'].iloc[0] if len(anndata.obs['species'].unique()) == 1 else anndata.obs['species'].mode()[0]
        else:
            self.organism = "unknown"  # Default if not found

    # Process each cell
    progress_bar = tqdm(total=X.shape[0], desc="Processing cells")
    for cell_idx in range(X.shape[0]):

        row = X[cell_idx]

        if issparse(row):
            gene_indices = row.indices
            expr_values = row.data
        else:
            # Dense fallback (rare)
            gene_indices = np.where(row > 0)[0]
            expr_values = row[gene_indices]

        if len(expr_values) == 0:
            LOGGER.warning(f"No genes expressed above zero in cell {cell_idx}. Using empty sentence.")
            cell_sentences.append("")
            progress_bar.update(1)
            continue

        gene_names = anndata.var_names.values[gene_indices]
        # Sort by expression descending
        ranked = np.argsort(expr_values)[::-1]
        expr_values = expr_values[ranked]
        gene_names = gene_names[ranked]

        # Cut at max_genes if desired
        if self.max_genes:
            if len(gene_names) > self.max_genes:
                gene_names = gene_names[:self.max_genes]
                expr_values = expr_values[:self.max_genes]

        if self.return_fit:
            ranks = np.arange(1, len(gene_names) + 1)
            for rank, expr in zip(ranks, expr_values):
                r = int(rank)

                if r not in rank_to_mean:
                    # first time seeing this rank
                    rank_to_mean[r] = expr
                    rank_to_count[r] = 1
                else:
                    # online mean update
                    count = rank_to_count[r] + 1
                    old_mean = rank_to_mean[r]
                    new_mean = old_mean + (expr - old_mean) / count

                    rank_to_mean[r] = new_mean
                    rank_to_count[r] = count


        cell_sentence = " ".join(gene_names)           
        cell_sentences.append(cell_sentence)
        progress_bar.update(1)


    if self.return_fit:
        log_ranks_to_fit = np.log10(list(rank_to_mean.keys()))
        expr_to_fit = np.array(list(rank_to_mean.values()))

        # Fit linear model to predict log-normalized expression from log rank: expr(g) = slope * log(rank(g)) = intercept
        model = LinearRegression()
        model.fit(log_ranks_to_fit.reshape(-1, 1), np.array(expr_to_fit))
        slope, intercept = model.coef_[0], model.intercept_
        r_squared = model.score(log_ranks_to_fit.reshape(-1, 1), expr_to_fit)

        fit_parameters = {"slope": float(slope), "intercept": float(intercept), "r_squared": float(r_squared)}

    else:
        fit_parameters = None

    progress_bar.close()

    if self.perturbation_column is not None:
        perturbations = anndata.obs[self.perturbation_column].values.tolist()
        if len(perturbations) != len(cell_sentences):
            raise ValueError(f"Number of perturbations ({len(perturbations)}) does not match number of cells ({len(cell_sentences)})")
    else:
        perturbations = [None] * len(cell_sentences)

    dataset = Dataset.from_dict({
        'cell_sentence': cell_sentences,
        'fit_parameters': [fit_parameters] * len(cell_sentences),
        'organism': [self.organism] * len(cell_sentences),
        'perturbations': perturbations
    })

    LOGGER.info("Successfully processed data")

    return dataset

get_embeddings(dataset, output_attentions=False, emb_layer=-1)

Extract embeddings from cell sentences in a HuggingFace Dataset using the last hidden layer of Gemma.

Parameters:

dataset : Dataset HuggingFace Dataset with 'cell_sentence' and 'organism' fields

output_attentions : bool, optional Whether to output the attention maps from the model. If set to True, the attention maps will be returned along with the embeddings. If set to False, only the embeddings will be returned. Note: This will increase the memory usage of the model significantly, so use it only if you need the attention maps.

emb_layer : int, optional Which layer to extract attention from (default: -1, i.e. last layer). Only used when output_attentions=True. Only one layer of attention can be returned at a time.

Returns:

embeddings : np.ndarray Embeddings of shape (num_sentences, hidden_size) attn_list : list, optional If output_attentions=True, a list of gene-level attention arrays, one per sample, each of shape (num_heads, num_genes, num_genes). gene_names_list : list, optional If output_attentions=True, a list of gene name lists, one list per sample, e.g. [['geneA', 'geneB', ...], ['geneX', 'geneY', ...], ...]. This is used to attention values to specific genes.

Source code in helical/models/c2s/model.py
def get_embeddings(
    self,
    dataset: Dataset,
    output_attentions: bool = False,
    emb_layer: int = -1,
    ):
    """
    Extract embeddings from cell sentences in a HuggingFace Dataset using the last hidden layer of Gemma.

    Parameters:
    -----------
    dataset : Dataset
        HuggingFace Dataset with 'cell_sentence' and 'organism' fields

    output_attentions : bool, optional
        Whether to output the attention maps from the model. If set to True, the attention maps will be returned along with the embeddings.
        If set to False, only the embeddings will be returned. **Note**: This will increase the memory usage of the model significantly, so use it only if you need the attention maps.

    emb_layer : int, optional
        Which layer to extract attention from (default: -1, i.e. last layer).
        Only used when output_attentions=True.
        Only one layer of attention can be returned at a time.

    Returns:
    --------
    embeddings : np.ndarray
        Embeddings of shape (num_sentences, hidden_size)
    attn_list : list, optional
        If output_attentions=True, a list of gene-level attention arrays,
        one per sample, each of shape (num_heads, num_genes, num_genes).
    gene_names_list : list, optional
        If output_attentions=True, a list of gene name lists, one list per sample,
        e.g. [['geneA', 'geneB', ...], ['geneX', 'geneY', ...], ...]. This is used to attention values to specific genes.
    """

    LOGGER.info("Extracting embeddings from dataset")

    if output_attentions:
        # SDPA and FlashAttention do not support returning attention maps;
        # override to eager on the model config so all layers use it.
        self.model.config._attn_implementation = "eager"

    sentences_list = dataset['cell_sentence']
    organisms_list = dataset['organism']

    all_embeddings = []
    all_attentions = [[]]  # Single list for the one layer we process

    progress_bar = tqdm(total=len(sentences_list), desc="Processing embeddings")
    for i in range(0, len(sentences_list), self.batch_size):
        batch_sentences = sentences_list[i:i + self.batch_size]
        batch_organisms = organisms_list[i:i + self.batch_size]

        if self.embedding_prompt_template is None:
            prompts = [
                EMBEDDING_PROMPT.format(organism=org, cell_sentence=cs)
                for org, cs in zip(batch_organisms, batch_sentences)
            ]
        else:
            prompts = [
                self.embedding_prompt_template.format(organism=org, cell_sentence=cs)
                for org, cs in zip(batch_organisms, batch_sentences)
            ]

        inputs = self.tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=False,
            return_offsets_mapping=output_attentions,
            # truncation=True,
            # max_length=max_length
        )
        # offset_mapping is not a tensor; grab it before .to(device)
        if output_attentions:
            batch_offsets = inputs.pop("offset_mapping")
        inputs = inputs.to(self.device)

        with torch.no_grad():
            outputs = self.model(
                **inputs,
                output_hidden_states=True,
                output_attentions=output_attentions
            )
            last_hidden = outputs.hidden_states[-1]               # (B, L, H)
            attention_mask = inputs['attention_mask'].float()    # (B, L)

            if self.aggregation_type == 'mean_pool':
                # mean pooling over non-padding tokens
                masked_hidden = last_hidden * attention_mask.unsqueeze(-1)   # (B, L, H)
                sum_embeddings = masked_hidden.sum(dim=1)                    # (B, H)
                sum_mask = attention_mask.sum(dim=1, keepdim=True).clamp(min=1e-9)
                batch_embeddings = sum_embeddings / sum_mask                 # (B, H)

            elif self.aggregation_type == 'last_token':
                # index of last non-padding token
                last_idx = (attention_mask.sum(dim=1) - 1).long()            # (B,)

                # gather token representations
                batch_embeddings = last_hidden[
                    torch.arange(last_hidden.size(0), device=last_hidden.device),
                    last_idx
                ]  # (B, H)

            else:   
                raise ValueError("Invalid aggregation type. Use 'mean_pool' or 'last_token'.")                                             

            if output_attentions:
                # outputs.attentions is a tuple of tensors, one per layer
                # Each tensor has shape (batch_size, num_heads, seq_length, seq_length)
                # Only process the selected layer (emb_layer)


                batch_size_actual = inputs['input_ids'].shape[0]
                attn = outputs.attentions[emb_layer] #allow returning only one layer of attention 
                word_attns = []
                for b in range(batch_size_actual):
                    offsets_b = batch_offsets[b].tolist()
                    gene_ids = self._gene_ids_from_offsets(
                        prompts[b], batch_sentences[b], offsets_b
                    )
                    word_attns.append(
                        self._aggregate_token_to_word_attention(attn[b], gene_ids)
                    )
                all_attentions[0].append(word_attns)
            del outputs

        all_embeddings.append(batch_embeddings.float().cpu().numpy())
        progress_bar.update(len(batch_sentences))
    progress_bar.close()
    LOGGER.info("Successfully extracted embeddings")

    if output_attentions:
        # Restore the original attention implementation
        self.model.config._attn_implementation = self.attn_implementation

        # Flatten per-batch lists for the selected layer
        stacked_attentions = [
            [arr for batch_list in all_attentions[0] for arr in batch_list]
        ]
        # Return only the selected layer as a flat list (like Geneformer)
        attn_list = stacked_attentions[0]
        # Gene names per sample from cell sentences
        gene_names_list = [sentence.split() for sentence in sentences_list]
        return np.concatenate(all_embeddings, axis=0), attn_list, gene_names_list
    else:
        return np.concatenate(all_embeddings, axis=0)

get_perturbations(dataset, perturbations_list=None)

Generate perturbed cell sentences using the model.

Parameters:

dataset : Dataset HuggingFace Dataset with 'cell_sentence' and 'perturbations' fields

perturbations_list : list[str], optional List of perturbations to apply to the cells. If None, uses the perturbations from the dataset. If provided, overrides the perturbations in the dataset. E.g. ["pert1", "pert2", "pert3", ...]

Returns:

perturbed_dataset : Dataset HuggingFace Dataset with 'cell_sentence' and 'perturbations' fields and a new column 'perturbed_cell_sentence'

perturbed_sentences : list List of perturbed cell sentences (strings)

Source code in helical/models/c2s/model.py
def get_perturbations(
    self, 
    dataset: Dataset, 
    perturbations_list: list[str] = None, 
    ):
    """
    Generate perturbed cell sentences using the model.

    Parameters:
    -----------
    dataset : Dataset
        HuggingFace Dataset with 'cell_sentence' and 'perturbations' fields

    perturbations_list : list[str], optional
        List of perturbations to apply to the cells. If None, uses the perturbations from the dataset.
        If provided, overrides the perturbations in the dataset. E.g. ["pert1", "pert2", "pert3", ...]

    Returns:
    --------
    perturbed_dataset : Dataset
        HuggingFace Dataset with 'cell_sentence' and 'perturbations' fields and a new column 'perturbed_cell_sentence'

    perturbed_sentences : list
        List of perturbed cell sentences (strings)
    """

    LOGGER.info("Generating perturbed cell sentences")

    sentences_list = dataset['cell_sentence']
    organisms_list = dataset['organism']
    if perturbations_list is None:
        perturbations_list = dataset['perturbations']
    else:
        if len(perturbations_list) != len(sentences_list):
            raise ValueError(f"perturbations_list length ({len(perturbations_list)}) must match dataset length ({len(sentences_list)})")

    # Handle None perturbations - skip those entries or use empty string
    valid_indices = [i for i, p in enumerate(perturbations_list) if p is not None]
    if len(valid_indices) == 0:
            raise ValueError("No valid perturbations found in dataset. All perturbations are None.")

    valid_sentences = [sentences_list[i] for i in valid_indices]
    valid_perturbations = [perturbations_list[i] for i in valid_indices]
    valid_organisms = [organisms_list[i] for i in valid_indices]    
    all_perturbed = []
    # Process in batches
    progress_bar = tqdm(total=len(valid_sentences), desc="Processing valid perturbations")
    for i in range(0, len(valid_sentences), self.batch_size):
        batch_cells = valid_sentences[i:i + self.batch_size]
        batch_perturbs = valid_perturbations[i:i + self.batch_size]
        batch_organisms = valid_organisms[i:i + self.batch_size]

        prompts = [
            PERTURBATION_PROMPT.format(
                organism=org,
                perturbation=pert,  # Changed from perturbation_in_words
                cell_sentence=cs
            )
            for org, pert, cs in zip(batch_organisms, batch_perturbs, batch_cells)
        ]

        inputs = self.tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=False,
            # truncation=True,
            # max_length=max_length
        ).to(self.device)

        with torch.no_grad():
            if self.config["use_quantization"]:
                # Disable torch.compile entirely for quantized models
                # suppress_errors alone isn't sufficient - we need to disable compilation
                original_disable = torch._dynamo.config.disable
                original_suppress = torch._dynamo.config.suppress_errors
                torch._dynamo.config.disable = True
                torch._dynamo.config.suppress_errors = True
                try:
                    outputs = self.model.generate(
                        **inputs,
                        max_new_tokens=self.max_new_tokens,
                        do_sample=False,
                        pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
                    )
                finally:
                    torch._dynamo.config.disable = original_disable
                    torch._dynamo.config.suppress_errors = original_suppress
            else:
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=self.max_new_tokens,
                    do_sample=False,
                    pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
                )

        input_lengths = inputs['attention_mask'].sum(dim=1)
        batch_perturbed = []

        for j, output in enumerate(outputs):
            # Extract only the generated tokens (skip the prompt)
            input_length = input_lengths[j].item()
            generated_tokens = output[input_length:]  # Only generated part
            decoded = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
            batch_perturbed.append(decoded.strip())

        all_perturbed.extend(batch_perturbed)
        progress_bar.update(len(batch_cells))
    progress_bar.close()
    # Create result list with None for entries without perturbations
    perturbed_sentences = [None] * len(sentences_list)
    for idx, perturbed in zip(valid_indices, all_perturbed):
        perturbed_sentences[idx] = perturbed

    dataset = dataset.add_column('perturbed_cell_sentence', perturbed_sentences)

    LOGGER.info("Successfully generated perturbed cell sentences")

    return dataset, perturbed_sentences