Skip to content

Config

helical.models.transcriptformer.TranscriptFormerConfig

TranscriptFormerConfig constructor.

Source code in helical/models/transcriptformer/transcriptformer_config.py
class TranscriptFormerConfig:
    """
    TranscriptFormerConfig constructor.

    Parameters
    ----------
        model_name: Literal["tf_sapiens", "tf_metazoa", "tf_exemplar"] = "tf_metazoa"
            The name of the model to use.
        batch_size: int = 8
            The number of samples to process in each batch.
        emb_mode: Literal["gene", "cell"] = "cell"
            The mode to use for the embeddings.
        output_keys: List[Literal["gene_llh", "llh"]] = ["gene_llh"]
            The keys to output.
        obs_keys: List[str] = ["all"]
            The keys to include in the output.
        data_files: List[str] = [None]
            Path to input AnnData file(s)
        output_path: str = "./inference_results"
            Directory where results will be saved
        load_checkpoint: str = None
            Path to model weights file (automatically set by inference.py)
        pretrained_embedding: str = None
            Path to pretrained embeddings for out-of-distribution species
        gene_col_name: str = "ensembl_id"
            Column name in AnnData.var containing gene names which will be mapped to ensembl ids. If index is set, .var_names will be used.
        clip_counts: int = 30
            Maximum count value (higher values will be clipped)
        filter_to_vocabs: bool = True
            Whether to filter genes to only those in the vocabulary
        filter_outliers: float = 0.0
            Standard deviation threshold for filtering outlier cells (0.0 = no filtering)
        normalize_to_scale: float = 0
            Scale factor for count normalization (0 = no normalization)
        sort_genes: bool = False
            Whether to sort the genes.
        randomize_genes: bool = False
            Whether to randomize the genes.
        min_expressed_genes: int = 0
            Minimum number of expressed genes required per cell

    """

    def __init__(
        self,
        model_name: Literal["tf_sapiens", "tf_metazoa", "tf_exemplar"] = "tf_sapiens",
        batch_size: int = 8,
        emb_mode: Literal["gene", "cell"] = "cell",
        output_keys: List[Literal["gene_llh", "llh"]] = [
            "llh",
        ],
        obs_keys: List[str] = ["all"],
        data_files: List[str] = [None],
        output_path: str = "./inference_results",
        load_checkpoint: str = None,
        pretrained_embedding: str = None,
        gene_col_name: str = "index",
        clip_counts: int = 30,
        filter_to_vocabs: bool = True,
        filter_outliers: float = 0.0,
        normalize_to_scale: float = 0,
        sort_genes: bool = False,
        randomize_genes: bool = False,
        min_expressed_genes: int = 0,
    ):

        inference_config: dict = {
            "batch_size": batch_size,
            "output_keys": output_keys,
            "obs_keys": obs_keys,
            "data_files": data_files,
            "output_path": output_path,
            "load_checkpoint": load_checkpoint,
            "device": "cuda",
            "pretrained_embedding": pretrained_embedding,
            "emb_mode": emb_mode,
        }

        data_config: dict = {
            "gene_col_name": gene_col_name,
            "clip_counts": clip_counts,
            "filter_to_vocabs": filter_to_vocabs,
            "filter_outliers": filter_outliers,
            "normalize_to_scale": normalize_to_scale,
            "sort_genes": sort_genes,
            "randomize_genes": randomize_genes,
            "min_expressed_genes": min_expressed_genes,
        }

        self.config = OmegaConf.create(
            {
                "model": {
                    "inference_config": inference_config,
                    "data_config": data_config,
                }
            }
        )

        if model_name not in ["tf_sapiens", "tf_metazoa", "tf_exemplar"]:
            raise ValueError(
                f"Model name {model_name} not supported. Only tf_sapiens, tf_metazoa, and tf_exemplar are supported."
            )

        if model_name == "tf_sapiens":
            self.list_of_files_to_download = [
                "transcriptformer/tf_sapiens/config.json",
                "transcriptformer/tf_sapiens/model_weights.pt",
                "transcriptformer/tf_sapiens/vocabs/assay_vocab.json",
                "transcriptformer/tf_sapiens/vocabs/homo_sapiens_gene.h5",
            ]
        elif model_name == "tf_metazoa":
            self.list_of_files_to_download = [
                "transcriptformer/tf_metazoa/config.json",
                "transcriptformer/tf_metazoa/model_weights.pt",
                "transcriptformer/tf_metazoa/vocabs/assay_vocab.json",
                "transcriptformer/tf_metazoa/vocabs/drosophila_melanogaster_gene.h5",
                "transcriptformer/tf_metazoa/vocabs/lytechinus_variegatus_gene.h5",
                "transcriptformer/tf_metazoa/vocabs/plasmodium_falciparum_gene.h5",
                "transcriptformer/tf_metazoa/vocabs/xenopus_laevis_gene.h5",
                "transcriptformer/tf_metazoa/vocabs/caenorhabditis_elegans_gene.h5",
                "transcriptformer/tf_metazoa/vocabs/gallus_gallus_gene.h5",
                "transcriptformer/tf_metazoa/vocabs/mus_musculus_gene.h5",
                "transcriptformer/tf_metazoa/vocabs/saccharomyces_cerevisiae_gene.h5",
                "transcriptformer/tf_metazoa/vocabs/danio_rerio_gene.h5",
                "transcriptformer/tf_metazoa/vocabs/oryctolagus_cuniculus_gene.h5",
                "transcriptformer/tf_metazoa/vocabs/spongilla_lacustris_gene.h5",
                "transcriptformer/tf_metazoa/vocabs/homo_sapiens_gene.h5",
            ]
        elif model_name == "tf_exemplar":
            self.list_of_files_to_download = [
                "transcriptformer/tf_exemplar/config.json",
                "transcriptformer/tf_exemplar/model_weights.pt",
                "transcriptformer/tf_exemplar/vocabs/assay_vocab.json",
                "transcriptformer/tf_exemplar/vocabs/danio_rerio_gene.h5",
                "transcriptformer/tf_exemplar/vocabs/drosophila_melanogaster_gene.h5",
                "transcriptformer/tf_exemplar/vocabs/homo_sapiens_gene.h5",
                "transcriptformer/tf_exemplar/vocabs/mus_musculus_gene.h5",
                "transcriptformer/tf_exemplar/vocabs/caenorhabditis_elegans_gene.h5",
            ]

        self.model_name = model_name