Skip to content

probability_model

ProbabilityModel is the root of the entire class hierarchy — both generative and predictive models inherit from it. It's an nn.Module + ABC that provides temperature scaling, observation conditioning, log-probability computation, and checkpointing.

The get_log_probs pipeline

Every call to get_log_probs follows the same chain:

collate_observations(x_B, self.observations)
    → forward(x_B, **obs)
        → format_raw_to_logits(raw, x_B, **obs)
            → log_softmax(logits / temp)

When no observations are set, forward and format_raw_to_logits are called without keyword arguments.

Abstract methods

Subclasses must implement two methods:

Method Signature Notes
forward (x_B, **kwargs) → Any Returns raw output — can be a non-tensor type (e.g. ESM returns a dataclass).
format_raw_to_logits (raw_output, x_B, **kwargs) → FloatTensor Extracts a logit tensor suitable for log_softmax. Receives the full context via **kwargs.

Overridable defaults

Method Default behavior Override when…
preprocess_observations(obs) Pass-through You have expensive one-time ops (e.g. ESM3's VQ-VAE structure encoding).
collate_observations(x_B, obs) Tile each tensor to match batch size You have non-tensor observations or need selective expansion.

Conditioning

Conditioning attaches observations (e.g. structure coordinates) that persist across get_log_probs calls. There are two distinct conditioning modes depending on whether you're doing inference or training.

Inference: single conditioning, tiled to batch

For sampling and evaluation, you typically have one conditioning input (e.g. a single backbone structure) shared across all sequences in the batch. Use set_condition_() or conditioned_on():

model.set_condition_(obs)              # in-place, returns None
model = model.set_condition(obs)       # returns self (chainable)
with model.conditioned_on(obs):        # context manager, reverts on exit
    log_probs = model.get_log_probs(x)

The flow is:

  1. set_condition_() calls preprocess_observations(obs) once and caches the result in self.observations
  2. Each get_log_probs() call runs collate_observations(x_B, self.observations) to tile the cached observations to match the batch size
  3. The tiled observations are passed as **kwargs to forward() and format_raw_to_logits()

The preprocess_observationscollate_observations split is a performance optimization: expensive preprocessing (running a VQ-VAE encoder) happens once when the condition is set, while cheap per-batch collation (tiling tensors to batch size) happens on every forward pass.

# Example: ESM3 inverse folding — one structure, many sequence samples
model = ESM3().cuda()
model.set_condition_({"coords_RAX": backbone_coords})  # VQ-VAE runs once

# Every get_log_probs call tiles the cached structure tokens to batch size
for step in range(n_steps):
    log_probs = model.get_log_probs(batch_of_sequences)

Conditioning is mutable state

set_condition_() modifies self.observations in place. The conditioned_on() context manager handles revert, but be careful with concurrent usage.

Training: per-sample conditioning via collator

For training, each sequence in the batch typically has its own conditioning input (e.g. each protein has a different predicted structure). This pattern bypasses set_condition_() / get_log_probs() entirely — the collator prepares per-sample observations and the training loop calls model.forward() directly.

The flow is:

  1. ProteinDataset stores per-sequence observations (e.g. structure tokens, coordinates)
  2. The collator batches observations from individual samples into batched tensors
  3. The training loop passes observations directly to model(input_ids, **observations)
# Example: inverse folding training — each sequence has its own structure
for batch in loader:
    input_ids = batch["input_ids"].to(device)
    target_ids = batch["target_ids"].to(device)
    struct_tokens = batch["structure_tokens"].to(device)
    coords = batch["coordinates"].to(device)

    # Call forward directly with per-sample observations
    raw = model(input_ids, structure_tokens=struct_tokens, coordinates=coords)
    logits = model.format_raw_to_logits(
        raw, input_ids, structure_tokens=struct_tokens, coordinates=coords
    )
    loss = F.cross_entropy(logits[masked], target_ids[masked])

There are two ways to build the collator:

ProteinDataset.collator() — for sequence-only training or when preprocess_observations handles batched inputs. The built-in collator gathers per-sample observations into list-valued dicts and calls model.preprocess_observations(batched). See data for details.

Custom collator — for complex cases like inverse folding where you need to pad both sequences and structures to the batch max length. The fine-tuning workflow shows a complete example.

Why not use set_condition_() for training?

set_condition_() caches a single observation and tiles it to the batch — it assumes every sample in the batch shares the same conditioning. During training, each sample has different conditioning, so you pass observations directly through the collator → forward() path instead.

Summary: which pattern to use

Scenario Pattern Observations flow
Sampling / evaluation set_condition_() or conditioned_on() One obs → preprocess (once) → collate (tile to batch) → get_log_probs
Training (sequence-only) ProteinDataset.collator() No observations needed
Training (per-sample conditioning) Custom collator or ProteinDataset.collator() Per-sample obs → collator batches → model.forward(**obs) directly

Checkpointing

Save and restore models with their constructor arguments:

model.save("checkpoints/my_model")
restored = MyModel.from_checkpoint("checkpoints/my_model")

Subclasses participate via:

  • _save_args() → dict — return JSON-serializable constructor kwargs (raises NotImplementedError by default)
  • Override save() to write additional state (weights, LoRA adapters), then call super().save()
  • Override from_checkpoint() to load additional state after construction

For example, GenerativeModel.save writes a lora_adapter/ directory if LoRA is present, and LinearProbe.save writes head.pt plus delegates to embed_model.save().

Gotchas

  • The device property uses next(self.parameters()).device — it will fail on models with no parameters (all current subclasses have parameters).
  • forward returns Any, not just tensors. This is intentional — ESM models return dataclass outputs. The format_raw_to_logits step is where you extract the tensor.
  • The default collate_observations assumes all observation values are tensors or scalars. If you store non-tensor observations (lists, strings), override this method.

API Reference

proteingen.modeling.probability_model

Base class for models that produce log probability distributions.

ProbabilityModel

Bases: Module, ABC

Base class for models that produce log probability distributions.

Subclasses implement: 1. preprocess_observations to turn conditioning variables input set through, e.g. set_condition into collated input tensors that match the forward function kwargs 2. forward to return logits based on the conditioning information 3. format_raw_to_logits``` to convert the forward output into something that can be safely softmaxxed inget_log_probs().format_raw_to_logits`` typically includes output masking, coarse-graining classes, turning ensemble predictions into probabilities, etc.

The default get_log_probs applies log_softmax along the last dimension, which is appropriate for class-valued models. Models with other output types (real-valued, ensemble, etc.) should override it.

Checkpointing

Subclasses that support save/load implement _save_args() returning a JSON-serializable dict of constructor kwargs. The base class provides save(path) and from_checkpoint(path) that serialize/deserialize the constructor args and rebuild the object. Subclasses add their own state (weights, adapters, etc.) on top by overriding save/from_checkpoint and calling super().

Source code in src/proteingen/modeling/probability_model.py
class ProbabilityModel(nn.Module, ABC):
    """Base class for models that produce log probability distributions.

    Subclasses implement:
    1. ``preprocess_observations`` to turn conditioning variables input set
    through, e.g. set_condition into collated input tensors that match the
    forward function kwargs
    2. ``forward`` to return logits based on the conditioning information
    3. ``format_raw_to_logits``` to convert the forward output into something
    that can be safely softmaxxed in ``get_log_probs()``. ``format_raw_to_logits``
    typically includes output masking, coarse-graining classes, turning ensemble
    predictions into probabilities, etc.

    The default ``get_log_probs`` applies ``log_softmax`` along the last
    dimension, which is appropriate for class-valued models. Models with
    other output types (real-valued, ensemble, etc.) should override it.

    Checkpointing:
        Subclasses that support save/load implement ``_save_args()`` returning
        a JSON-serializable dict of constructor kwargs. The base class provides
        ``save(path)`` and ``from_checkpoint(path)`` that serialize/deserialize
        the constructor args and rebuild the object. Subclasses add their own
        state (weights, adapters, etc.) on top by overriding save/from_checkpoint
        and calling super().
    """

    ConditioningInputType: TypedDict
    RawOutputType: TypedDict

    def __init__(self):
        super().__init__()
        self.temp = 1.0
        self.observations: Optional[ProbabilityModel.ConditioningInputType] = None

    @property
    def device(self):
        return next(self.parameters()).device

    def preprocess_observations(
        self, observations: ProbabilityModel.ConditioningInputType
    ) -> Dict[str, Any]:
        """Transform raw observations into cached form.

        Called once when ``set_condition_()`` or ``conditioned_on()`` is
        invoked. Override for expensive operations (e.g. encoding structure)
        that should not be repeated every forward pass.

        Default: pass through.
        """
        return observations

    def collate_observations(
        self, x_B: torch.Tensor, observations: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Collate observations for input to the forward function.

        Default: tile each observation tensor to match batch size.
        Override when you need custom collation (e.g. only expanding
        certain keys, or handling non-tensor observations).
        """
        batch_size = x_B.size(0)
        return {
            k: v.unsqueeze(0).expand(batch_size, *v.shape)
            if isinstance(v, torch.Tensor) and v.dim() > 0
            else v
            for k, v in observations.items()
        }

    def set_condition_(self, observations: Dict[str, Any]):
        """Preprocess and cache observations in-place."""
        self.observations = self.preprocess_observations(observations)

    def set_condition(self, observations: Dict[str, Any]):
        """Preprocess and cache observations, returning self for chaining."""
        self.set_condition_(observations)
        return self

    @contextmanager
    def conditioned_on(self, observations: Dict[str, Any]):
        """Context manager that temporarily sets observations, reverting on exit."""
        pre_context_obs = self.observations
        try:
            yield self.set_condition(observations)
        finally:
            self.observations = pre_context_obs

    @abstractmethod
    def forward(self, x_B: torch.Tensor, **kwargs) -> Any:
        """Return logits from batched input."""

    @abstractmethod
    def format_raw_to_logits(
        self, raw_forward_output: Any, x_B: torch.Tensor, **kwargs
    ) -> torch.FloatTensor:
        """Convert raw forward output to logits suitable for log_softmax.

        Examples: output masking (generative models), coarse-graining classes,
        turning ensemble predictions into per-class logits, etc.
        """
        ...

    def set_temp_(self, temp: float):
        self.temp = temp

    def set_temp(self, temp: float):
        self.set_temp_(temp)
        return self

    @contextmanager
    def with_temp(self, temp: float):
        """Context manager to temporarily change the temperature."""
        pre_context_temp = self.temp
        try:
            yield self.set_temp(temp)
        finally:
            self.temp = pre_context_temp

    def get_log_probs(self, x_B: torch.Tensor) -> torch.FloatTensor:
        """Return temperature-scaled log probabilities.
        Input is some batched tensor with otherwise arbitrary dimensions.
        Default implementation: ``log_softmax(forward(x, **kwargs) / temp)``.
        """
        assert self.temp > 0, f"Temperature must be positive, got {self.temp}"
        if self.observations is not None:
            obs = self.collate_observations(x_B, self.observations)
            raw_output = self.forward(x_B, **obs)
            log_probs = self.format_raw_to_logits(raw_output, x_B, **obs)
        else:
            raw_output = self.forward(x_B)
            log_probs = self.format_raw_to_logits(raw_output, x_B)
        return F.log_softmax(log_probs / self.temp, dim=-1)

    # ── Checkpointing ────────────────────────────────────────────────────

    def _save_args(self) -> dict:
        """Return constructor kwargs as a JSON-serializable dict.

        Override in subclasses that support checkpointing. The dict must
        contain everything needed to reconstruct the object via ``cls(**args)``.
        """
        raise NotImplementedError(
            f"{type(self).__name__} must implement _save_args() for checkpointing"
        )

    def save(self, path: str | Path) -> None:
        """Save model to a directory. Writes config.json with constructor args.

        Subclasses override to add their own state (weights, adapters, etc.)
        and should call ``super().save(path)`` first.
        """
        path = Path(path)
        path.mkdir(parents=True, exist_ok=True)
        with open(path / "config.json", "w") as f:
            json.dump(self._save_args(), f, indent=2)

    @classmethod
    def from_checkpoint(cls, path: str | Path) -> "ProbabilityModel":
        """Load model from a directory. Reads config.json and calls ``cls(**args)``.

        Subclasses override to load additional state (weights, adapters, etc.)
        and should call ``super().from_checkpoint(path)`` to get the base object.
        """
        path = Path(path)
        with open(path / "config.json") as f:
            config = json.load(f)
        return cls(**config)
preprocess_observations
preprocess_observations(observations: ConditioningInputType) -> Dict[str, Any]

Transform raw observations into cached form.

Called once when set_condition_() or conditioned_on() is invoked. Override for expensive operations (e.g. encoding structure) that should not be repeated every forward pass.

Default: pass through.

Source code in src/proteingen/modeling/probability_model.py
def preprocess_observations(
    self, observations: ProbabilityModel.ConditioningInputType
) -> Dict[str, Any]:
    """Transform raw observations into cached form.

    Called once when ``set_condition_()`` or ``conditioned_on()`` is
    invoked. Override for expensive operations (e.g. encoding structure)
    that should not be repeated every forward pass.

    Default: pass through.
    """
    return observations
collate_observations
collate_observations(x_B: Tensor, observations: Dict[str, Any]) -> Dict[str, Any]

Collate observations for input to the forward function.

Default: tile each observation tensor to match batch size. Override when you need custom collation (e.g. only expanding certain keys, or handling non-tensor observations).

Source code in src/proteingen/modeling/probability_model.py
def collate_observations(
    self, x_B: torch.Tensor, observations: Dict[str, Any]
) -> Dict[str, Any]:
    """Collate observations for input to the forward function.

    Default: tile each observation tensor to match batch size.
    Override when you need custom collation (e.g. only expanding
    certain keys, or handling non-tensor observations).
    """
    batch_size = x_B.size(0)
    return {
        k: v.unsqueeze(0).expand(batch_size, *v.shape)
        if isinstance(v, torch.Tensor) and v.dim() > 0
        else v
        for k, v in observations.items()
    }
set_condition_
set_condition_(observations: Dict[str, Any])

Preprocess and cache observations in-place.

Source code in src/proteingen/modeling/probability_model.py
def set_condition_(self, observations: Dict[str, Any]):
    """Preprocess and cache observations in-place."""
    self.observations = self.preprocess_observations(observations)
set_condition
set_condition(observations: Dict[str, Any])

Preprocess and cache observations, returning self for chaining.

Source code in src/proteingen/modeling/probability_model.py
def set_condition(self, observations: Dict[str, Any]):
    """Preprocess and cache observations, returning self for chaining."""
    self.set_condition_(observations)
    return self
conditioned_on
conditioned_on(observations: Dict[str, Any])

Context manager that temporarily sets observations, reverting on exit.

Source code in src/proteingen/modeling/probability_model.py
@contextmanager
def conditioned_on(self, observations: Dict[str, Any]):
    """Context manager that temporarily sets observations, reverting on exit."""
    pre_context_obs = self.observations
    try:
        yield self.set_condition(observations)
    finally:
        self.observations = pre_context_obs
forward abstractmethod
forward(x_B: Tensor, **kwargs) -> Any

Return logits from batched input.

Source code in src/proteingen/modeling/probability_model.py
@abstractmethod
def forward(self, x_B: torch.Tensor, **kwargs) -> Any:
    """Return logits from batched input."""
format_raw_to_logits abstractmethod
format_raw_to_logits(raw_forward_output: Any, x_B: Tensor, **kwargs) -> torch.FloatTensor

Convert raw forward output to logits suitable for log_softmax.

Examples: output masking (generative models), coarse-graining classes, turning ensemble predictions into per-class logits, etc.

Source code in src/proteingen/modeling/probability_model.py
@abstractmethod
def format_raw_to_logits(
    self, raw_forward_output: Any, x_B: torch.Tensor, **kwargs
) -> torch.FloatTensor:
    """Convert raw forward output to logits suitable for log_softmax.

    Examples: output masking (generative models), coarse-graining classes,
    turning ensemble predictions into per-class logits, etc.
    """
    ...
with_temp
with_temp(temp: float)

Context manager to temporarily change the temperature.

Source code in src/proteingen/modeling/probability_model.py
@contextmanager
def with_temp(self, temp: float):
    """Context manager to temporarily change the temperature."""
    pre_context_temp = self.temp
    try:
        yield self.set_temp(temp)
    finally:
        self.temp = pre_context_temp
get_log_probs
get_log_probs(x_B: Tensor) -> torch.FloatTensor

Return temperature-scaled log probabilities. Input is some batched tensor with otherwise arbitrary dimensions. Default implementation: log_softmax(forward(x, **kwargs) / temp).

Source code in src/proteingen/modeling/probability_model.py
def get_log_probs(self, x_B: torch.Tensor) -> torch.FloatTensor:
    """Return temperature-scaled log probabilities.
    Input is some batched tensor with otherwise arbitrary dimensions.
    Default implementation: ``log_softmax(forward(x, **kwargs) / temp)``.
    """
    assert self.temp > 0, f"Temperature must be positive, got {self.temp}"
    if self.observations is not None:
        obs = self.collate_observations(x_B, self.observations)
        raw_output = self.forward(x_B, **obs)
        log_probs = self.format_raw_to_logits(raw_output, x_B, **obs)
    else:
        raw_output = self.forward(x_B)
        log_probs = self.format_raw_to_logits(raw_output, x_B)
    return F.log_softmax(log_probs / self.temp, dim=-1)
save
save(path: str | Path) -> None

Save model to a directory. Writes config.json with constructor args.

Subclasses override to add their own state (weights, adapters, etc.) and should call super().save(path) first.

Source code in src/proteingen/modeling/probability_model.py
def save(self, path: str | Path) -> None:
    """Save model to a directory. Writes config.json with constructor args.

    Subclasses override to add their own state (weights, adapters, etc.)
    and should call ``super().save(path)`` first.
    """
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)
    with open(path / "config.json", "w") as f:
        json.dump(self._save_args(), f, indent=2)
from_checkpoint classmethod
from_checkpoint(path: str | Path) -> 'ProbabilityModel'

Load model from a directory. Reads config.json and calls cls(**args).

Subclasses override to load additional state (weights, adapters, etc.) and should call super().from_checkpoint(path) to get the base object.

Source code in src/proteingen/modeling/probability_model.py
@classmethod
def from_checkpoint(cls, path: str | Path) -> "ProbabilityModel":
    """Load model from a directory. Reads config.json and calls ``cls(**args)``.

    Subclasses override to load additional state (weights, adapters, etc.)
    and should call ``super().from_checkpoint(path)`` to get the base object.
    """
    path = Path(path)
    with open(path / "config.json") as f:
        config = json.load(f)
    return cls(**config)