Skip to content

models

Concrete model implementations that subclass the core abstractions. Each model wraps an external library (ESM, ProteinMPNN) into ProteinGen's unified interface.

Tokenization landscape

Three tokenizer ecosystems coexist in the library — cross-tokenizer mapping is handled by GuidanceProjection (see guide):

Tokenizer Vocab size Special tokens Used by
ESM (EsmSequenceTokenizer) 33 <cls>=0, <pad>=1, <eos>=2, <unk>=3, <mask>=32 ESMC, ESM3
MPNN (MPNNTokenizer) 21 (or 22 with mask) UNK(X)=20, optional <mask>=21 StabilityPMPNN
Simple 20-AA 21 pad=20 Custom predictors

API Reference

proteingen.models

Backward-compatible model namespace.

New code should prefer proteingen.modeling. This module keeps old import styles working, including:

  • from proteingen.models import ESMC
  • from proteingen.models import esmc (module)
  • from proteingen.models.esm import ESMC

DPLM2

Bases: GenerativeModelWithEmbedding

DPLM-2 discrete diffusion protein language model.

Wraps ByteDance's DPLM-2 (multimodal diffusion protein LM) as a GenerativeModelWithEmbedding for use with proteingen's sampling, guidance, and probe infrastructure.

Currently supports sequence-only mode: input is [, AA..., ]. Structure conditioning (joint sequence + structure token input) is not yet implemented.

Available checkpoints
  • "airkingbd/dplm2_150m" — 150M params, 640d, 30 layers
  • "airkingbd/dplm2_650m" — 650M params, 1280d, 33 layers (default)
  • "airkingbd/dplm2_3b" — 3B params, 2560d, 36 layers
Tensor index legend

S: batch index P: position index T: token/vocab dimension (OUTPUT_DIM = vocab_size = 8229) D: embedding dimension (EMB_DIM = hidden_size)

Example::

model = DPLM2("airkingbd/dplm2_650m")
log_probs = model.get_log_probs_from_string(["ACDEF"])
Source code in src/proteingen/modeling/models/dplm2/__init__.py
class DPLM2(GenerativeModelWithEmbedding):
    """DPLM-2 discrete diffusion protein language model.

    Wraps ByteDance's DPLM-2 (multimodal diffusion protein LM) as a
    GenerativeModelWithEmbedding for use with proteingen's sampling,
    guidance, and probe infrastructure.

    Currently supports sequence-only mode: input is [<cls_aa>, AA..., <eos_aa>].
    Structure conditioning (joint sequence + structure token input) is not yet
    implemented.

    Available checkpoints:
        - ``"airkingbd/dplm2_150m"`` — 150M params, 640d, 30 layers
        - ``"airkingbd/dplm2_650m"`` — 650M params, 1280d, 33 layers (default)
        - ``"airkingbd/dplm2_3b"``   — 3B params, 2560d, 36 layers

    Tensor index legend:
        S: batch index
        P: position index
        T: token/vocab dimension (OUTPUT_DIM = vocab_size = 8229)
        D: embedding dimension (EMB_DIM = hidden_size)

    Example::

        model = DPLM2("airkingbd/dplm2_650m")
        log_probs = model.get_log_probs_from_string(["ACDEF"])
    """

    # Token dropout rate hardcoded in ESM training — needed to match forward path
    _MASK_RATIO_TRAIN = 0.15 * 0.8  # = 0.12

    def __init__(self, checkpoint: str = "airkingbd/dplm2_650m"):
        self._checkpoint = checkpoint

        config = AutoConfig.from_pretrained(checkpoint)
        # DPLM2 was trained with untied embedding/decoder weights.
        # The HF config incorrectly says tie_word_embeddings=True.
        config.tie_word_embeddings = False

        tokenizer = DPLM2Tokenizer(checkpoint, vocab_size=config.vocab_size)
        model = AutoModelForMaskedLM.from_pretrained(checkpoint, config=config).eval()

        self.OUTPUT_DIM = config.vocab_size
        self.EMB_DIM = config.hidden_size

        logit_formatter = MaskedModelLogitFormatter(tokenizer, self.OUTPUT_DIM)

        super().__init__(
            model=model, tokenizer=tokenizer, logit_formatter=logit_formatter
        )

    def _save_args(self) -> dict:
        return {"checkpoint": self._checkpoint}

    def differentiable_embedding(
        self, ohe_seq_SPT: torch.FloatTensor
    ) -> torch.FloatTensor:
        """OHE (or soft distribution) → deep embeddings through the transformer.

        Replicates the forward path through EsmEmbeddings + EsmEncoder:
        1. Soft word embedding lookup via matmul
        2. Token dropout: zero mask positions, then rescale (ESM's mask-ratio
           compensation, runs in both train and eval mode)
        3. Attention mask application
        4. Transformer encoder with rotary attention

        For soft distributions (TAG guidance), mask tokens should not appear
        in the input — the zeroing step uses argmax to identify mask positions,
        which is non-differentiable at those positions.
        """
        # Soft word embedding lookup
        x = ohe_seq_SPT @ self.model.esm.embeddings.word_embeddings.weight  # (S, P, D)

        # Token dropout: zero out mask positions, then compensate by rescaling.
        # ESM always applies this, even in eval mode.
        # For non-masked inputs: scale = (1 - 0.12) / (1 - 0) = 0.88
        pseudo_ids = ohe_seq_SPT.argmax(-1)
        attention_mask = pseudo_ids.ne(self.tokenizer.pad_token_id)
        is_mask = pseudo_ids == self.tokenizer.mask_token_id
        x = x.masked_fill(is_mask.unsqueeze(-1), 0.0)
        src_lengths = attention_mask.sum(-1)
        mask_ratio_observed = is_mask.sum(-1).float() / src_lengths
        x = (
            x * (1 - self._MASK_RATIO_TRAIN) / (1 - mask_ratio_observed)[:, None, None]
        ).to(x.dtype)

        # Zero out padding positions (same as EsmEmbeddings.forward)
        x = (x * attention_mask.unsqueeze(-1)).to(x.dtype)

        # Build extended attention mask for the encoder
        extended_mask = self.model.esm.get_extended_attention_mask(
            attention_mask, pseudo_ids.shape
        )
        head_mask = self.model.esm.get_head_mask(
            None, self.model.config.num_hidden_layers
        )

        # Run through transformer encoder
        encoder_out = self.model.esm.encoder(
            x,
            attention_mask=extended_mask,
            head_mask=head_mask,
        )
        return encoder_out.last_hidden_state

    def embedding_to_outputs(
        self, embedding_SPD: torch.FloatTensor
    ) -> torch.FloatTensor:
        """Deep embeddings → logits via the LM head."""
        return self.model.lm_head(embedding_SPD)

    def format_raw_to_logits(
        self, raw_output, seq_SP: torch.LongTensor, **kwargs
    ) -> torch.FloatTensor:
        """Extract logits from MaskedLMOutput and apply logit formatting."""
        logits_SPT = raw_output.logits.float()
        return self.logit_formatter(logits_SPT, seq_SP)
differentiable_embedding
differentiable_embedding(ohe_seq_SPT: FloatTensor) -> torch.FloatTensor

OHE (or soft distribution) → deep embeddings through the transformer.

Replicates the forward path through EsmEmbeddings + EsmEncoder: 1. Soft word embedding lookup via matmul 2. Token dropout: zero mask positions, then rescale (ESM's mask-ratio compensation, runs in both train and eval mode) 3. Attention mask application 4. Transformer encoder with rotary attention

For soft distributions (TAG guidance), mask tokens should not appear in the input — the zeroing step uses argmax to identify mask positions, which is non-differentiable at those positions.

Source code in src/proteingen/modeling/models/dplm2/__init__.py
def differentiable_embedding(
    self, ohe_seq_SPT: torch.FloatTensor
) -> torch.FloatTensor:
    """OHE (or soft distribution) → deep embeddings through the transformer.

    Replicates the forward path through EsmEmbeddings + EsmEncoder:
    1. Soft word embedding lookup via matmul
    2. Token dropout: zero mask positions, then rescale (ESM's mask-ratio
       compensation, runs in both train and eval mode)
    3. Attention mask application
    4. Transformer encoder with rotary attention

    For soft distributions (TAG guidance), mask tokens should not appear
    in the input — the zeroing step uses argmax to identify mask positions,
    which is non-differentiable at those positions.
    """
    # Soft word embedding lookup
    x = ohe_seq_SPT @ self.model.esm.embeddings.word_embeddings.weight  # (S, P, D)

    # Token dropout: zero out mask positions, then compensate by rescaling.
    # ESM always applies this, even in eval mode.
    # For non-masked inputs: scale = (1 - 0.12) / (1 - 0) = 0.88
    pseudo_ids = ohe_seq_SPT.argmax(-1)
    attention_mask = pseudo_ids.ne(self.tokenizer.pad_token_id)
    is_mask = pseudo_ids == self.tokenizer.mask_token_id
    x = x.masked_fill(is_mask.unsqueeze(-1), 0.0)
    src_lengths = attention_mask.sum(-1)
    mask_ratio_observed = is_mask.sum(-1).float() / src_lengths
    x = (
        x * (1 - self._MASK_RATIO_TRAIN) / (1 - mask_ratio_observed)[:, None, None]
    ).to(x.dtype)

    # Zero out padding positions (same as EsmEmbeddings.forward)
    x = (x * attention_mask.unsqueeze(-1)).to(x.dtype)

    # Build extended attention mask for the encoder
    extended_mask = self.model.esm.get_extended_attention_mask(
        attention_mask, pseudo_ids.shape
    )
    head_mask = self.model.esm.get_head_mask(
        None, self.model.config.num_hidden_layers
    )

    # Run through transformer encoder
    encoder_out = self.model.esm.encoder(
        x,
        attention_mask=extended_mask,
        head_mask=head_mask,
    )
    return encoder_out.last_hidden_state
embedding_to_outputs
embedding_to_outputs(embedding_SPD: FloatTensor) -> torch.FloatTensor

Deep embeddings → logits via the LM head.

Source code in src/proteingen/modeling/models/dplm2/__init__.py
def embedding_to_outputs(
    self, embedding_SPD: torch.FloatTensor
) -> torch.FloatTensor:
    """Deep embeddings → logits via the LM head."""
    return self.model.lm_head(embedding_SPD)
format_raw_to_logits
format_raw_to_logits(raw_output, seq_SP: LongTensor, **kwargs) -> torch.FloatTensor

Extract logits from MaskedLMOutput and apply logit formatting.

Source code in src/proteingen/modeling/models/dplm2/__init__.py
def format_raw_to_logits(
    self, raw_output, seq_SP: torch.LongTensor, **kwargs
) -> torch.FloatTensor:
    """Extract logits from MaskedLMOutput and apply logit formatting."""
    logits_SPT = raw_output.logits.float()
    return self.logit_formatter(logits_SPT, seq_SP)

DPLM2Tokenizer

Tokenizer for DPLM2's extended vocabulary (AA + structure tokens).

Wraps HuggingFace's EsmTokenizer with DPLM2-specific special token assignments. The DPLM2 vocabulary has 3 regions: - Tokens 0-32: amino acid tokens + AA special tokens - Tokens 33-8228: structure tokens + struct special tokens - Token IDs >= vocab_size: generic HF special tokens (excluded)

Key special tokens
  • 0: (BOS for amino acids)
  • 1:
  • 2: (EOS for amino acids)
  • 32: (mask for amino acid diffusion)
Source code in src/proteingen/modeling/models/dplm2/__init__.py
class DPLM2Tokenizer:
    """Tokenizer for DPLM2's extended vocabulary (AA + structure tokens).

    Wraps HuggingFace's EsmTokenizer with DPLM2-specific special token assignments.
    The DPLM2 vocabulary has 3 regions:
        - Tokens 0-32: amino acid tokens + AA special tokens
        - Tokens 33-8228: structure tokens + struct special tokens
        - Token IDs >= vocab_size: generic HF special tokens (excluded)

    Key special tokens:
        - 0: <cls_aa>  (BOS for amino acids)
        - 1: <pad>
        - 2: <eos_aa>  (EOS for amino acids)
        - 32: <mask_aa> (mask for amino acid diffusion)
    """

    # Special token IDs in the DPLM2 vocabulary
    _AA_SPECIAL = {0, 1, 2, 3, 32}  # cls_aa, pad, eos_aa, unk_aa, mask_aa
    _STRUCT_SPECIAL = {
        33,
        34,
        35,
        8228,
    }  # cls_struct, eos_struct, unk_struct, mask_struct
    _NON_STANDARD_AA = {24, 25, 26, 27, 28}  # X, B, U, Z, O
    _OTHER = {29, 30, 31}  # '.', '-', '<null_1>'

    def __init__(self, checkpoint: str, vocab_size: int):
        self._tok = EsmTokenizer.from_pretrained(checkpoint)
        # Filter vocab to model's actual vocab_size (excludes generic HF special tokens)
        full_vocab = self._tok.get_vocab()
        self._vocab = {k: v for k, v in full_vocab.items() if v < vocab_size}
        self._vocab_size = vocab_size
        self._id_to_tok = {v: k for k, v in self._vocab.items()}

    @property
    def vocab_size(self) -> int:
        return self._vocab_size

    @property
    def vocab(self) -> dict[str, int]:
        return dict(self._vocab)

    @property
    def mask_token_id(self) -> int:
        return 32  # <mask_aa>

    @property
    def cls_token_id(self) -> int:
        return 0  # <cls_aa>

    @property
    def eos_token_id(self) -> int:
        return 2  # <eos_aa>

    @property
    def pad_token_id(self) -> int:
        return 1  # <pad>

    @property
    def added_tokens_decoder(self) -> dict[int, str]:
        all_special = (
            self._AA_SPECIAL
            | self._STRUCT_SPECIAL
            | self._NON_STANDARD_AA
            | self._OTHER
        )
        return {i: self._id_to_tok[i] for i in all_special if i in self._id_to_tok}

    @property
    def all_special_ids(self) -> list[int]:
        return sorted(self._AA_SPECIAL | self._STRUCT_SPECIAL)

    def encode(self, sequence: str, add_special_tokens: bool = True) -> list[int]:
        """Encode an amino acid sequence to token IDs."""
        token_ids = [self._vocab.get(c, 3) for c in sequence]  # 3 = unk_aa
        if add_special_tokens:
            token_ids = [self.cls_token_id] + token_ids + [self.eos_token_id]
        return token_ids

    def decode(self, token_ids: list[int] | torch.Tensor) -> str:
        """Decode token IDs back to an amino acid sequence (skipping special tokens)."""
        if isinstance(token_ids, torch.Tensor):
            token_ids = token_ids.tolist()
        skip = self._AA_SPECIAL | self._STRUCT_SPECIAL
        return "".join(self._id_to_tok.get(i, "X") for i in token_ids if i not in skip)

    def __call__(
        self,
        sequences: str | list[str],
        padding: bool = False,
        return_tensors: str | None = None,
    ) -> dict[str, list | torch.Tensor]:
        if isinstance(sequences, str):
            sequences = [sequences]
        encoded = [self.encode(seq) for seq in sequences]
        if padding:
            max_len = max(len(e) for e in encoded)
            encoded = [e + [self.pad_token_id] * (max_len - len(e)) for e in encoded]
        result = {"input_ids": encoded}
        if return_tensors == "pt":
            result["input_ids"] = torch.tensor(result["input_ids"], dtype=torch.long)
        return result
encode
encode(sequence: str, add_special_tokens: bool = True) -> list[int]

Encode an amino acid sequence to token IDs.

Source code in src/proteingen/modeling/models/dplm2/__init__.py
def encode(self, sequence: str, add_special_tokens: bool = True) -> list[int]:
    """Encode an amino acid sequence to token IDs."""
    token_ids = [self._vocab.get(c, 3) for c in sequence]  # 3 = unk_aa
    if add_special_tokens:
        token_ids = [self.cls_token_id] + token_ids + [self.eos_token_id]
    return token_ids
decode
decode(token_ids: list[int] | Tensor) -> str

Decode token IDs back to an amino acid sequence (skipping special tokens).

Source code in src/proteingen/modeling/models/dplm2/__init__.py
def decode(self, token_ids: list[int] | torch.Tensor) -> str:
    """Decode token IDs back to an amino acid sequence (skipping special tokens)."""
    if isinstance(token_ids, torch.Tensor):
        token_ids = token_ids.tolist()
    skip = self._AA_SPECIAL | self._STRUCT_SPECIAL
    return "".join(self._id_to_tok.get(i, "X") for i in token_ids if i not in skip)

ESM3

Bases: GenerativeModelWithEmbedding

ESM3 masked language model as a GenerativeModelWithEmbedding.

Non-sequence tracks (structure, ss8, sasa, function, residue) default to padding values. Only the sequence embedding is differentiable.

Structure conditioning: pass atom37 coordinates via set_condition_() or conditioned_on(). The structure VQ-VAE encodes them once; the resulting structure tokens and coordinates are used in both the differentiable embedding path and the full forward path.

Example::

model = ESM3()
coords_RAX, wt_seq = pdb_to_atom37_and_seq("1abc.pdb")
with model.conditioned_on({"coords_RAX": coords_RAX}):
    log_probs = model.get_log_probs(seq_SP)
Tensor Index Legend

S: sequence index in batch P: position index in sequence T: token/vocab dimension (OUTPUT_DIM = 64) D: embedding dimension (EMB_DIM = 1536)

Source code in src/proteingen/modeling/models/esm/esm3.py
class ESM3(GenerativeModelWithEmbedding):
    """ESM3 masked language model as a GenerativeModelWithEmbedding.

    Non-sequence tracks (structure, ss8, sasa, function, residue) default to
    padding values. Only the sequence embedding is differentiable.

    Structure conditioning: pass atom37 coordinates via ``set_condition_()``
    or ``conditioned_on()``. The structure VQ-VAE encodes them once; the
    resulting structure tokens and coordinates are used in both the
    differentiable embedding path and the full forward path.

    Example::

        model = ESM3()
        coords_RAX, wt_seq = pdb_to_atom37_and_seq("1abc.pdb")
        with model.conditioned_on({"coords_RAX": coords_RAX}):
            log_probs = model.get_log_probs(seq_SP)

    Tensor Index Legend:
        S: sequence index in batch
        P: position index in sequence
        T: token/vocab dimension (OUTPUT_DIM = 64)
        D: embedding dimension (EMB_DIM = 1536)
    """

    OUTPUT_DIM = 64

    def __init__(self, esm3_checkpoint: str = "esm3-open"):
        self._esm3_checkpoint = esm3_checkpoint
        tokenizer = EsmSequenceTokenizer()
        logit_formatter = MaskedModelLogitFormatter(tokenizer, ESM3.OUTPUT_DIM)
        esm3 = (
            _ESM3.from_pretrained(esm3_checkpoint, device=torch.device("cpu"))
            .float()
            .eval()
        )
        self.EMB_DIM = esm3.encoder.sequence_embed.weight.shape[1]  # 1536
        super().__init__(
            model=esm3, tokenizer=tokenizer, logit_formatter=logit_formatter
        )

    def _save_args(self) -> dict:
        return {"esm3_checkpoint": self._esm3_checkpoint}

    def preprocess_observations(self, observations: dict) -> dict:
        """Encode structure once via VQ-VAE (expensive).

        Args:
            observations: {"coords_RAX": (L, 37, 3) tensor or np.array}

        Returns:
            {"structure_tokens": (L+2,), "coordinates": (L+2, 37, 3)} with BOS/EOS padding.
        """
        coords = observations["coords_RAX"]
        if isinstance(coords, np.ndarray):
            coords = torch.from_numpy(coords).float()
        protein = ESMProtein(coordinates=coords)
        with torch.no_grad():
            encoded = self.model.encode(protein)
        return {
            "structure_tokens": encoded.structure,  # (L+2,) with BOS/EOS
            "coordinates": encoded.coordinates,  # (L+2, 37, 3) with BOS/EOS
        }

    def collate_observations(
        self, seq_SP: torch.LongTensor, observations: dict
    ) -> dict:
        """Tile cached structure to match batch size."""
        B = seq_SP.shape[0]
        return {
            "structure_tokens": observations["structure_tokens"]
            .unsqueeze(0)
            .expand(B, -1),
            "coordinates": observations["coordinates"]
            .unsqueeze(0)
            .expand(B, -1, -1, -1),
        }

    def _non_sequence_embedding(
        self, seq_SP: torch.LongTensor, structure_tokens: torch.LongTensor | None = None
    ) -> torch.Tensor:
        """Non-sequence track embeddings.

        If structure_tokens is provided (from conditioning), uses them directly.
        Otherwise defaults to STRUCTURE_MASK_TOKEN with special-position overrides.
        """
        C = ESM3_CONSTANTS
        B, L = seq_SP.shape
        device = seq_SP.device

        if structure_tokens is None:
            # Default: mask everywhere, override at special sequence positions
            structure_tokens = torch.full(
                (B, L), C.STRUCTURE_MASK_TOKEN, dtype=torch.long, device=device
            )
            structure_tokens.masked_fill_(
                seq_SP == C.SEQUENCE_BOS_TOKEN, C.STRUCTURE_BOS_TOKEN
            )
            structure_tokens.masked_fill_(
                seq_SP == C.SEQUENCE_PAD_TOKEN, C.STRUCTURE_PAD_TOKEN
            )
            structure_tokens.masked_fill_(
                seq_SP == C.SEQUENCE_EOS_TOKEN, C.STRUCTURE_EOS_TOKEN
            )
            structure_tokens.masked_fill_(
                seq_SP == C.SEQUENCE_CHAINBREAK_TOKEN, C.STRUCTURE_CHAINBREAK_TOKEN
            )

        ss8_tokens = torch.full(
            (B, L), C.SS8_PAD_TOKEN, dtype=torch.long, device=device
        )
        sasa_tokens = torch.full(
            (B, L), C.SASA_PAD_TOKEN, dtype=torch.long, device=device
        )
        average_plddt = torch.ones((B, L), dtype=torch.float, device=device)
        per_res_plddt = torch.zeros((B, L), dtype=torch.float, device=device)
        function_tokens = torch.full(
            (B, L, 8), C.INTERPRO_PAD_TOKEN, dtype=torch.long, device=device
        )
        residue_annotation_tokens = torch.full(
            (B, L, 16), C.RESIDUE_PAD_TOKEN, dtype=torch.long, device=device
        )

        rbf_16_fn = partial(rbf, v_min=0.0, v_max=1.0, n_bins=16)
        enc = self.model.encoder

        plddt_embed = enc.plddt_projection(rbf_16_fn(average_plddt))
        structure_per_res_plddt = enc.structure_per_res_plddt_projection(
            rbf_16_fn(per_res_plddt)
        )
        structure_embed = enc.structure_tokens_embed(structure_tokens)
        ss8_embed = enc.ss8_embed(ss8_tokens)
        sasa_embed = enc.sasa_embed(sasa_tokens)
        function_embed = torch.cat(
            [fn(t) for fn, t in zip(enc.function_embed, function_tokens.unbind(-1))], -1
        )
        residue_embed = enc.residue_embed(
            einops.rearrange(residue_annotation_tokens, "B L N -> (B L) N")
        )
        residue_embed = einops.rearrange(residue_embed, "(B L) D -> B L D", B=B, L=L)

        return (
            plddt_embed
            + structure_per_res_plddt
            + structure_embed
            + ss8_embed
            + sasa_embed
            + function_embed
            + residue_embed
        )

    def differentiable_embedding(
        self, ohe_seq_SPT: torch.FloatTensor
    ) -> torch.FloatTensor:
        if ohe_seq_SPT.shape[-1] < self.OUTPUT_DIM:
            ohe_seq_SPT = F.pad(
                ohe_seq_SPT, (0, self.OUTPUT_DIM - ohe_seq_SPT.shape[-1])
            )

        B, L, _ = ohe_seq_SPT.shape
        device = ohe_seq_SPT.device
        seq_SP = ohe_seq_SPT.argmax(-1)

        # Differentiable sequence embedding
        seq_embed_SPD = ohe_seq_SPT @ self.model.encoder.sequence_embed.weight

        # Non-sequence track embeddings — use structure conditioning if available
        structure_tokens = None
        coordinates = None
        if self.observations is not None:
            obs = self.collate_observations(seq_SP, self.observations)
            structure_tokens = obs["structure_tokens"].to(device)
            coordinates = obs["coordinates"].to(device)

        with torch.no_grad():
            non_seq_embed_SPD = self._non_sequence_embedding(seq_SP, structure_tokens)

        x_SPD = seq_embed_SPD + non_seq_embed_SPD

        # Build affine for geometric attention
        sequence_id = seq_SP != self.tokenizer.pad_token_id
        if coordinates is not None:
            # structure_coords is (B, L, 37, 3) — ESM3 transformer expects (B, L, 3, 3)
            affine, affine_mask = build_affine3d_from_coordinates(
                coordinates[..., :3, :]
            )
        else:
            coords = torch.full((B, L, 3, 3), float("nan"), device=device)
            affine, affine_mask = build_affine3d_from_coordinates(coords)

        x_SPD, _, _ = self.model.transformer(
            x_SPD, sequence_id=sequence_id, affine=affine, affine_mask=affine_mask
        )
        return x_SPD

    def embedding_to_outputs(
        self, embedding_SPD: torch.FloatTensor
    ) -> torch.FloatTensor:
        return self.model.output_heads.sequence_head(embedding_SPD)

    def forward(self, seq_SP: torch.LongTensor, **kwargs) -> torch.FloatTensor:
        # ESM3's forward uses keyword-only args; pass through conditioning if provided
        fwd_kwargs = {"sequence_tokens": seq_SP}
        if "structure_tokens" in kwargs:
            fwd_kwargs["structure_tokens"] = kwargs["structure_tokens"]
        if "coordinates" in kwargs:
            fwd_kwargs["structure_coords"] = kwargs["coordinates"]
        return self.model(**fwd_kwargs)

    def format_raw_to_logits(
        self, raw_output, seq_SP: torch.LongTensor, **kwargs
    ) -> torch.FloatTensor:
        logits_SPT = raw_output.sequence_logits.float()
        return self.logit_formatter(logits_SPT, seq_SP)
preprocess_observations
preprocess_observations(observations: dict) -> dict

Encode structure once via VQ-VAE (expensive).

Parameters:

Name Type Description Default
observations dict

{"coords_RAX": (L, 37, 3) tensor or np.array}

required

Returns:

Type Description
dict

{"structure_tokens": (L+2,), "coordinates": (L+2, 37, 3)} with BOS/EOS padding.

Source code in src/proteingen/modeling/models/esm/esm3.py
def preprocess_observations(self, observations: dict) -> dict:
    """Encode structure once via VQ-VAE (expensive).

    Args:
        observations: {"coords_RAX": (L, 37, 3) tensor or np.array}

    Returns:
        {"structure_tokens": (L+2,), "coordinates": (L+2, 37, 3)} with BOS/EOS padding.
    """
    coords = observations["coords_RAX"]
    if isinstance(coords, np.ndarray):
        coords = torch.from_numpy(coords).float()
    protein = ESMProtein(coordinates=coords)
    with torch.no_grad():
        encoded = self.model.encode(protein)
    return {
        "structure_tokens": encoded.structure,  # (L+2,) with BOS/EOS
        "coordinates": encoded.coordinates,  # (L+2, 37, 3) with BOS/EOS
    }
collate_observations
collate_observations(seq_SP: LongTensor, observations: dict) -> dict

Tile cached structure to match batch size.

Source code in src/proteingen/modeling/models/esm/esm3.py
def collate_observations(
    self, seq_SP: torch.LongTensor, observations: dict
) -> dict:
    """Tile cached structure to match batch size."""
    B = seq_SP.shape[0]
    return {
        "structure_tokens": observations["structure_tokens"]
        .unsqueeze(0)
        .expand(B, -1),
        "coordinates": observations["coordinates"]
        .unsqueeze(0)
        .expand(B, -1, -1, -1),
    }

ESM3IF

Bases: ESM3

Deprecated: use ESM3 with set_condition_() instead.

This thin subclass exists only for backwards compatibility. It issues a deprecation warning on construction and delegates everything to ESM3.

Source code in src/proteingen/modeling/models/esm/esm3if.py
class ESM3IF(ESM3):
    """Deprecated: use ``ESM3`` with ``set_condition_()`` instead.

    This thin subclass exists only for backwards compatibility. It issues a
    deprecation warning on construction and delegates everything to ``ESM3``.
    """

    def __init__(self, esm3_checkpoint: str = "esm3-open", **kwargs):
        warnings.warn(
            "ESM3IF is deprecated — use ESM3 with set_condition_() or "
            "conditioned_on() for structure-conditioned generation. "
            "ESM3IF will be removed in a future release.",
            DeprecationWarning,
            stacklevel=2,
        )
        super().__init__(esm3_checkpoint=esm3_checkpoint, **kwargs)

ESMC

Bases: GenerativeModelWithEmbedding

ESM-C masked language model as a GenerativeModelWithEmbedding.

Tensor Index Legend

S: sequence index in batch P: position index in sequence T: token/vocab dimension (OUTPUT_DIM = 64) D: embedding dimension (EMB_DIM = 960 for 300m, 1152 for 600m)

Source code in src/proteingen/modeling/models/esm/esmc.py
class ESMC(GenerativeModelWithEmbedding):
    """ESM-C masked language model as a GenerativeModelWithEmbedding.

    Tensor Index Legend:
        S: sequence index in batch
        P: position index in sequence
        T: token/vocab dimension (OUTPUT_DIM = 64)
        D: embedding dimension (EMB_DIM = 960 for 300m, 1152 for 600m)
    """

    OUTPUT_DIM = 64

    def __init__(self, esmc_checkpoint: str = "esmc_300m"):
        self._esmc_checkpoint = esmc_checkpoint
        tokenizer = EsmSequenceTokenizer()
        logit_formatter = MaskedModelLogitFormatter(tokenizer, ESMC.OUTPUT_DIM)
        esmc = _ESMC.from_pretrained(esmc_checkpoint, device=torch.device("cpu")).eval()
        self.EMB_DIM = esmc.embed.weight.shape[1]  # 960 for 300m, 1152 for 600m
        super().__init__(
            model=esmc, tokenizer=tokenizer, logit_formatter=logit_formatter
        )

    def _save_args(self) -> dict:
        return {"esmc_checkpoint": self._esmc_checkpoint}

    def differentiable_embedding(
        self, ohe_seq_SPT: torch.FloatTensor
    ) -> torch.FloatTensor:
        if ohe_seq_SPT.shape[-1] < self.OUTPUT_DIM:
            ohe_seq_SPT = F.pad(
                ohe_seq_SPT, (0, self.OUTPUT_DIM - ohe_seq_SPT.shape[-1])
            )
        x_SPD = ohe_seq_SPT @ self.model.embed.weight
        sequence_id = ohe_seq_SPT.argmax(-1) != self.tokenizer.pad_token_id
        x_SPD, _, _ = self.model.transformer(x_SPD, sequence_id=sequence_id)
        return x_SPD

    def embedding_to_outputs(
        self, embedding_SPD: torch.FloatTensor
    ) -> torch.FloatTensor:
        return self.model.sequence_head(embedding_SPD)

    def format_raw_to_logits(
        self, raw_output, seq_SP: torch.LongTensor, **kwargs
    ) -> torch.FloatTensor:
        logits_SPT = raw_output.sequence_logits.float()
        return self.logit_formatter(logits_SPT, seq_SP)

ESMForgeAPI

Bases: GenerativeModel

ESM model accessed via the EvolutionaryScale Forge API.

Wraps a Forge inference client to provide the same get_log_probs interface as the local ESM wrappers. No local weights needed — inference happens remotely.

Automatically selects the right client (ESM3 vs ESMC) based on the model name. Structure conditioning is supported for ESM3 models only.

Limitations vs local models
  • No gradient access (no embed, no TAG guidance)
  • No LoRA / fine-tuning
  • No checkpointing
  • Batched inference loops sequentially over the API

Example::

import os
model = ESMForgeAPI("esmc-6b-2024-12", token=os.environ["FORGE_TOKEN"])
log_probs = model.get_log_probs_from_string(["ACDEF"])
Tensor Index Legend

S: sequence index in batch P: position index in sequence T: token/vocab dimension (OUTPUT_DIM = 64)

Source code in src/proteingen/modeling/models/esm/esm_api.py
class ESMForgeAPI(GenerativeModel):
    """ESM model accessed via the EvolutionaryScale Forge API.

    Wraps a Forge inference client to provide the same ``get_log_probs``
    interface as the local ESM wrappers. No local weights needed — inference
    happens remotely.

    Automatically selects the right client (ESM3 vs ESMC) based on the
    model name. Structure conditioning is supported for ESM3 models only.

    Limitations vs local models:
        - No gradient access (no ``embed``, no TAG guidance)
        - No LoRA / fine-tuning
        - No checkpointing
        - Batched inference loops sequentially over the API

    Example::

        import os
        model = ESMForgeAPI("esmc-6b-2024-12", token=os.environ["FORGE_TOKEN"])
        log_probs = model.get_log_probs_from_string(["ACDEF"])

    Tensor Index Legend:
        S: sequence index in batch
        P: position index in sequence
        T: token/vocab dimension (OUTPUT_DIM = 64)
    """

    OUTPUT_DIM = 64

    def __init__(
        self,
        model_name: str,
        token: str,
        url: str = "https://forge.evolutionaryscale.ai",
    ):
        tokenizer = EsmSequenceTokenizer()
        logit_formatter = MaskedModelLogitFormatter(tokenizer, self.OUTPUT_DIM)
        # GenerativeModel expects an nn.Module for self.model; use an empty module
        # since inference is remote. We override forward() to use self.client.
        super().__init__(
            model=nn.Module(), tokenizer=tokenizer, logit_formatter=logit_formatter
        )
        self._model_name = model_name
        self._is_esm3 = model_name.startswith("esm3")

        if self._is_esm3:
            from esm.sdk.forge import ESM3ForgeInferenceClient

            self.client = ESM3ForgeInferenceClient(
                model=model_name, url=url, token=token
            )
        else:
            from esm.sdk.forge import ESMCForgeInferenceClient

            self.client = ESMCForgeInferenceClient(
                model=model_name, url=url, token=token
            )

    @property
    def device(self) -> torch.device:
        return torch.device("cpu")

    def _call_logits(
        self, seq_1d: torch.LongTensor, **tensor_kwargs
    ) -> torch.FloatTensor:
        """Call the Forge logits endpoint for a single (unpadded) sequence.

        Args:
            seq_1d: 1-D token IDs (L,) including CLS/EOS, no padding.
            **tensor_kwargs: Extra ESMProteinTensor fields (e.g. structure, coordinates).

        Returns:
            Logits tensor of shape (L, OUTPUT_DIM).
        """
        input_tensor = ESMProteinTensor(sequence=seq_1d, **tensor_kwargs)
        output = self.client.logits(input_tensor, LogitsConfig(sequence=True))
        if isinstance(output, ESMProteinError):
            raise RuntimeError(
                f"Forge API error ({output.error_code}): {output.error_msg}"
            )
        assert output.logits is not None and output.logits.sequence is not None
        logits = output.logits.sequence.float()
        if logits.dim() == 3 and logits.shape[0] == 1:
            logits = logits.squeeze(0)
        return logits

    def forward(self, seq_SP: torch.LongTensor, **kwargs) -> torch.FloatTensor:
        """Forward pass via the Forge API.

        Loops over the batch dimension, stripping padding from each sequence
        before sending to the API, then re-pads results to the max length.

        Returns:
            Logits tensor of shape (S, P, OUTPUT_DIM).
        """
        S, P = seq_SP.shape
        pad_id = self.tokenizer.pad_token_id

        all_logits = []
        for i in range(S):
            seq = seq_SP[i]
            non_pad = seq != pad_id
            seq_unpadded = seq[non_pad]

            tensor_kwargs = {}
            if "structure_tokens" in kwargs:
                tensor_kwargs["structure"] = kwargs["structure_tokens"][i]
            if "coordinates" in kwargs:
                tensor_kwargs["coordinates"] = kwargs["coordinates"][i]

            logits = self._call_logits(seq_unpadded, **tensor_kwargs)
            L_actual = logits.shape[0]

            if L_actual < P:
                padding = torch.zeros(
                    P - L_actual, logits.shape[-1], dtype=logits.dtype
                )
                logits = torch.cat([logits, padding], dim=0)
            all_logits.append(logits)

        return torch.stack(all_logits)

    # ── Structure conditioning (ESM3 only) ───────────────────────────────

    def preprocess_observations(self, observations: dict) -> dict:
        """Encode structure via the Forge API (remote VQ-VAE).

        Only supported for ESM3 models.

        Args:
            observations: {"coords_RAX": (L, 37, 3) tensor or np.array}

        Returns:
            {"structure_tokens": (L+2,), "coordinates": (L+2, 37, 3)} with BOS/EOS.
        """
        assert self._is_esm3, (
            f"Structure conditioning not supported for ESMC model '{self._model_name}'"
        )
        coords = observations["coords_RAX"]
        if isinstance(coords, torch.Tensor):
            coords = coords.float()
        elif isinstance(coords, np.ndarray):
            coords = torch.from_numpy(coords).float()
        protein = ESMProtein(coordinates=coords)
        encoded = self.client.encode(protein)
        if isinstance(encoded, ESMProteinError):
            raise RuntimeError(
                f"Forge encode error ({encoded.error_code}): {encoded.error_msg}"
            )
        return {
            "structure_tokens": encoded.structure,
            "coordinates": encoded.coordinates,
        }

    def collate_observations(
        self, seq_SP: torch.LongTensor, observations: dict
    ) -> dict:
        B = seq_SP.shape[0]
        return {
            "structure_tokens": observations["structure_tokens"]
            .unsqueeze(0)
            .expand(B, -1),
            "coordinates": observations["coordinates"]
            .unsqueeze(0)
            .expand(B, -1, -1, -1),
        }

    # ── Unsupported operations ───────────────────────────────────────────

    def _save_args(self) -> dict:
        raise NotImplementedError("API models don't support checkpointing")

    def apply_lora(self, **kwargs) -> None:
        raise NotImplementedError("LoRA is not supported for API models")

    def save_lora(self, path) -> None:
        raise NotImplementedError("LoRA is not supported for API models")

    def load_lora(self, path) -> None:
        raise NotImplementedError("LoRA is not supported for API models")
forward
forward(seq_SP: LongTensor, **kwargs) -> torch.FloatTensor

Forward pass via the Forge API.

Loops over the batch dimension, stripping padding from each sequence before sending to the API, then re-pads results to the max length.

Returns:

Type Description
FloatTensor

Logits tensor of shape (S, P, OUTPUT_DIM).

Source code in src/proteingen/modeling/models/esm/esm_api.py
def forward(self, seq_SP: torch.LongTensor, **kwargs) -> torch.FloatTensor:
    """Forward pass via the Forge API.

    Loops over the batch dimension, stripping padding from each sequence
    before sending to the API, then re-pads results to the max length.

    Returns:
        Logits tensor of shape (S, P, OUTPUT_DIM).
    """
    S, P = seq_SP.shape
    pad_id = self.tokenizer.pad_token_id

    all_logits = []
    for i in range(S):
        seq = seq_SP[i]
        non_pad = seq != pad_id
        seq_unpadded = seq[non_pad]

        tensor_kwargs = {}
        if "structure_tokens" in kwargs:
            tensor_kwargs["structure"] = kwargs["structure_tokens"][i]
        if "coordinates" in kwargs:
            tensor_kwargs["coordinates"] = kwargs["coordinates"][i]

        logits = self._call_logits(seq_unpadded, **tensor_kwargs)
        L_actual = logits.shape[0]

        if L_actual < P:
            padding = torch.zeros(
                P - L_actual, logits.shape[-1], dtype=logits.dtype
            )
            logits = torch.cat([logits, padding], dim=0)
        all_logits.append(logits)

    return torch.stack(all_logits)
preprocess_observations
preprocess_observations(observations: dict) -> dict

Encode structure via the Forge API (remote VQ-VAE).

Only supported for ESM3 models.

Parameters:

Name Type Description Default
observations dict

{"coords_RAX": (L, 37, 3) tensor or np.array}

required

Returns:

Type Description
dict

{"structure_tokens": (L+2,), "coordinates": (L+2, 37, 3)} with BOS/EOS.

Source code in src/proteingen/modeling/models/esm/esm_api.py
def preprocess_observations(self, observations: dict) -> dict:
    """Encode structure via the Forge API (remote VQ-VAE).

    Only supported for ESM3 models.

    Args:
        observations: {"coords_RAX": (L, 37, 3) tensor or np.array}

    Returns:
        {"structure_tokens": (L+2,), "coordinates": (L+2, 37, 3)} with BOS/EOS.
    """
    assert self._is_esm3, (
        f"Structure conditioning not supported for ESMC model '{self._model_name}'"
    )
    coords = observations["coords_RAX"]
    if isinstance(coords, torch.Tensor):
        coords = coords.float()
    elif isinstance(coords, np.ndarray):
        coords = torch.from_numpy(coords).float()
    protein = ESMProtein(coordinates=coords)
    encoded = self.client.encode(protein)
    if isinstance(encoded, ESMProteinError):
        raise RuntimeError(
            f"Forge encode error ({encoded.error_code}): {encoded.error_msg}"
        )
    return {
        "structure_tokens": encoded.structure,
        "coordinates": encoded.coordinates,
    }

Frame2seq

Bases: GenerativeModelWithEmbedding

Frame2seq structure-conditioned inverse folding model.

Loads Frame2seq's bundled checkpoint ensemble and exposes it through proteingen's GenerativeModelWithEmbedding interface.

Conditioning is required and must be set with:

model.set_condition_({"pdb_path": "1abc.pdb", "chain_id": "A"})

Tensor Index Legend

B: batch index P: residue position index A: atom index (N, CA, C, CB, O) U: Frame2seq sequence dim (21 = 20 AAs + X) T: proteingen tokenizer dim (22 = U + ) D: single-model embedding dim (128) E: concatenated ensemble embedding dim (D * n_models)

Source code in src/proteingen/modeling/models/frame2seq/frame2seq.py
class Frame2seq(GenerativeModelWithEmbedding):
    """Frame2seq structure-conditioned inverse folding model.

    Loads Frame2seq's bundled checkpoint ensemble and exposes it through
    proteingen's GenerativeModelWithEmbedding interface.

    Conditioning is required and must be set with:

    ``model.set_condition_({"pdb_path": "1abc.pdb", "chain_id": "A"})``

    Tensor Index Legend:
        B: batch index
        P: residue position index
        A: atom index (N, CA, C, CB, O)
        U: Frame2seq sequence dim (21 = 20 AAs + X)
        T: proteingen tokenizer dim (22 = U + <mask>)
        D: single-model embedding dim (128)
        E: concatenated ensemble embedding dim (D * n_models)
    """

    OUTPUT_DIM = 22
    SEQUENCE_DIM = 21

    def __init__(self, checkpoint_paths: list[str | Path] | None = None):
        self._checkpoint_paths = [
            str(path) for path in self._resolve_checkpoint_paths(checkpoint_paths)
        ]

        models = [
            _Frame2seqModel.load_from_checkpoint(path, map_location="cpu").eval()
            for path in self._checkpoint_paths
        ]
        first_model = models[0]

        for model in models[1:]:
            assert model.single_dim == first_model.single_dim
            assert model.sequence_dim == first_model.sequence_dim

        self._n_models = len(models)
        self._single_emb_dim = first_model.single_dim
        self.EMB_DIM = self._single_emb_dim * self._n_models

        tokenizer = Frame2seqTokenizer(include_mask_token=True)
        unk_token_id = tokenizer.vocab["X"]
        assert tokenizer.mask_token_id is not None
        mask_token_id = tokenizer.mask_token_id
        logit_formatter = _Frame2seqLogitFormatter(unk_token_id, mask_token_id)

        super().__init__(
            model=_Frame2seqEnsemble(models),
            tokenizer=tokenizer,
            logit_formatter=logit_formatter,
        )

    @staticmethod
    def _resolve_checkpoint_paths(
        checkpoint_paths: list[str | Path] | None,
    ) -> list[Path]:
        if checkpoint_paths is not None:
            resolved_paths = [Path(path) for path in checkpoint_paths]
        else:
            package_dir = Path(_frame2seq_init_file).resolve().parent
            checkpoint_dir = package_dir / "trained_models"
            resolved_paths = sorted(checkpoint_dir.glob("*.ckpt"))

        if len(resolved_paths) == 0:
            raise ValueError("No Frame2seq checkpoint files were found.")

        for path in resolved_paths:
            if not path.exists():
                raise ValueError(f"Frame2seq checkpoint does not exist: {path}")

        return resolved_paths

    @staticmethod
    def condition_from_pdb(
        pdb_path: str | Path, chain_id: str
    ) -> Frame2seqConditioning:
        return {"pdb_path": str(pdb_path), "chain_id": chain_id}

    def _save_args(self) -> dict:
        return {"checkpoint_paths": self._checkpoint_paths}

    def preprocess_observations(
        self, observations: Frame2seqConditioning
    ) -> _Frame2seqCachedConditioning:
        seq_mask_BP, native_seq_BP, X_BPA3 = get_inference_inputs(
            str(observations["pdb_path"]), observations["chain_id"]
        )

        X_LA3 = X_BPA3.squeeze(0)  # X_LA3 [P, A, 3] - chain atom coordinates
        seq_mask_L = seq_mask_BP.squeeze(
            0
        ).bool()  # seq_mask_L [P] - valid residue mask
        native_seq_L = native_seq_BP.squeeze(
            0
        ).long()  # native_seq_L [P] - native Frame2seq token ids

        return {
            "X_LA3": X_LA3,
            "seq_mask_L": seq_mask_L,
            "native_seq_L": native_seq_L,
        }

    def _to_frame2seq_vocab(self, ohe_seq_BPT: torch.FloatTensor) -> torch.FloatTensor:
        input_ohe_BPU = ohe_seq_BPT[
            :, :, : self.SEQUENCE_DIM
        ].clone()  # input_ohe_BPU [B, P, U] - AA+X channels
        extra_mass_BP = ohe_seq_BPT[:, :, self.SEQUENCE_DIM :].sum(
            dim=-1
        )  # extra_mass_BP [B, P] - total probability on extra channels
        input_ohe_BPU[:, :, 20] = (
            input_ohe_BPU[:, :, 20] + extra_mass_BP
        )  # input_ohe_BPU [B, P, U] - fold extra mass into unknown token
        return input_ohe_BPU

    def _single_model_embedding(
        self,
        model: _Frame2seqModel,
        X_BPA3: torch.FloatTensor,
        seq_mask_BP: torch.BoolTensor,
        input_ohe_BPU: torch.FloatTensor,
    ) -> torch.FloatTensor:
        rigid_BPR = Rigid.from_3_points(
            X_BPA3[:, :, 0], X_BPA3[:, :, 1], X_BPA3[:, :, 2]
        )  # rigid_BPR [B, P] - residue rigid frames from backbone atoms

        s_init_BPF, in_s_BPD = make_s_init(
            model, X_BPA3, input_ohe_BPU, seq_mask_BP
        )  # s_init_BPF [B, P, single_dim+6], in_s_BPD [B, P, D]
        s_BPD = model.sequence_to_single(
            s_init_BPF
        )  # s_BPD [B, P, D] - geometry+position to single representation
        s_BPD = s_BPD + model.input_sequence_layer_norm(
            in_s_BPD
        )  # s_BPD [B, P, D] - add masked input-sequence embedding

        z_init_BPPF = make_z_init(
            model, X_BPA3
        )  # z_init_BPPF [B, P, P, pair_dim] - initial pairwise features
        z_BPPD = model.edge_to_pair(
            z_init_BPPF
        )  # z_BPPD [B, P, P, D] - pair representation

        seq_mask_long_BP = (
            seq_mask_BP.long()
        )  # seq_mask_long_BP [B, P] - IPA mask as int
        attn_drop_rate = 0.2 if model.training else 0.0

        for layer in model.layers:
            ipa, ipa_dropout, layer_norm_ipa, *transit_layers, edge_transition = layer

            ipa_update_BPD = ipa(
                s_BPD,
                z_BPPD,
                rigid_BPR,
                seq_mask_long_BP,
                attn_drop_rate=attn_drop_rate,
            )  # ipa_update_BPD [B, P, D] - IPA single-state update
            s_BPD = s_BPD + ipa_update_BPD  # s_BPD [B, P, D]
            s_BPD = ipa_dropout(s_BPD)  # s_BPD [B, P, D]
            s_BPD = layer_norm_ipa(s_BPD)  # s_BPD [B, P, D]

            if model.st_mod_tsit_factor > 1:
                pre_transit = transit_layers[0]
                transition = transit_layers[1]
                post_transit = transit_layers[2]

                s_BPD = pre_transit(
                    s_BPD
                )  # s_BPD [B, P, D*factor] - transition expansion
                s_BPD = transition(s_BPD)  # s_BPD [B, P, D*factor] - transition block
                s_BPD = post_transit(s_BPD)  # s_BPD [B, P, D] - transition projection
            else:
                transition = transit_layers[0]
                s_BPD = transition(s_BPD)  # s_BPD [B, P, D] - transition block

            if edge_transition is not None:
                z_BPPD = edge_transition(
                    s_BPD, z_BPPD
                )  # z_BPPD [B, P, P, D] - pair update conditioned on singles

        return s_BPD

    def _differentiable_embedding_with_observations(
        self,
        ohe_seq_BPT: torch.FloatTensor,
        observations: _Frame2seqCachedConditioning | dict[str, torch.Tensor],
    ) -> torch.FloatTensor:
        B, P, _ = ohe_seq_BPT.shape
        device = ohe_seq_BPT.device

        input_ohe_BPU = self._to_frame2seq_vocab(
            ohe_seq_BPT
        )  # input_ohe_BPU [B, P, U] - Frame2seq AA+X input distribution

        if observations["X_LA3"].dim() == 3:
            dummy_batch_B = torch.zeros(
                B, dtype=torch.long, device=device
            )  # dummy_batch_B [B] - batch-size carrier for default collator
            obs = self.collate_observations(dummy_batch_B, observations)
        else:
            obs = observations

        X_BPA3 = obs["X_LA3"].to(
            device=device
        )  # X_BPA3 [B, P, A, 3] - conditioned coordinates
        seq_mask_BP = obs["seq_mask_L"].to(
            device=device, dtype=torch.bool
        )  # seq_mask_BP [B, P] - conditioned valid-position mask

        assert X_BPA3.shape[:2] == (B, P)
        assert seq_mask_BP.shape == (B, P)

        embeddings_by_model = []
        for model in self.model.models:
            emb_BPD = self._single_model_embedding(
                model=model,
                X_BPA3=X_BPA3,
                seq_mask_BP=seq_mask_BP,
                input_ohe_BPU=input_ohe_BPU,
            )  # emb_BPD [B, P, D] - one ensemble member single representation
            embeddings_by_model.append(emb_BPD)

        emb_BPE = torch.cat(
            embeddings_by_model, dim=-1
        )  # emb_BPE [B, P, E] - concatenated ensemble embeddings
        return emb_BPE

    def differentiable_embedding(
        self, ohe_seq_SPT: torch.FloatTensor
    ) -> torch.FloatTensor:
        if self.observations is None:
            raise ValueError(
                "Frame2seq requires structure conditioning. "
                "Call set_condition_() or use conditioned_on() first."
            )
        return self._differentiable_embedding_with_observations(
            ohe_seq_BPT=ohe_seq_SPT,
            observations=self.observations,
        )

    def forward(self, seq_SP: torch.LongTensor, **kwargs) -> torch.FloatTensor:
        ohe_BPT = F.one_hot(seq_SP, num_classes=self.tokenizer.vocab_size).float()
        if kwargs:
            emb_BPE = self._differentiable_embedding_with_observations(
                ohe_seq_BPT=ohe_BPT,
                observations=kwargs,
            )
        else:
            emb_BPE = self.differentiable_embedding(ohe_BPT)
        return self.embedding_to_outputs(emb_BPE)

    @staticmethod
    def _pad_logits(logits_BPU: torch.FloatTensor) -> torch.FloatTensor:
        pad_mask_BP1 = torch.full(
            (*logits_BPU.shape[:-1], 1),
            float("-inf"),
            device=logits_BPU.device,
            dtype=logits_BPU.dtype,
        )  # pad_mask_BP1 [B, P, 1] - mask-token output column
        logits_BPT = torch.cat(
            [logits_BPU, pad_mask_BP1], dim=-1
        )  # logits_BPT [B, P, T] - logits with extra mask-token column
        return logits_BPT

    def embedding_to_outputs(
        self, embedding_SPD: torch.FloatTensor
    ) -> torch.FloatTensor:
        embedding_chunks = torch.split(
            embedding_SPD, self._single_emb_dim, dim=-1
        )  # embedding_chunks list[[B, P, D]] - one chunk per ensemble member
        assert len(embedding_chunks) == self._n_models

        logits_sum_BPU = torch.zeros(
            embedding_SPD.shape[0],
            embedding_SPD.shape[1],
            self.SEQUENCE_DIM,
            device=embedding_SPD.device,
            dtype=embedding_SPD.dtype,
        )  # logits_sum_BPU [B, P, U] - running sum over ensemble logits

        for model, emb_chunk_BPD in zip(self.model.models, embedding_chunks):
            logits_member_BPU = model.single_to_sequence(
                emb_chunk_BPD
            )  # logits_member_BPU [B, P, U] - one model's sequence logits
            logits_sum_BPU = (
                logits_sum_BPU + logits_member_BPU
            )  # logits_sum_BPU [B, P, U]

        logits_mean_BPU = logits_sum_BPU / self._n_models  # logits_mean_BPU [B, P, U]
        return self._pad_logits(logits_mean_BPU)

    def format_raw_to_logits(
        self,
        raw_output: torch.FloatTensor,
        seq_SP: torch.LongTensor,
        **kwargs,
    ) -> torch.FloatTensor:
        return self.logit_formatter(raw_output.float(), seq_SP)