probability_model¶
ProbabilityModel is the root of the entire class hierarchy — both generative and predictive models inherit from it. It's an nn.Module + ABC that provides temperature scaling, observation conditioning, log-probability computation, and checkpointing.
The get_log_probs pipeline¶
Every call to get_log_probs follows the same chain:
collate_observations(x_B, self.observations)
→ forward(x_B, **obs)
→ format_raw_to_logits(raw, x_B, **obs)
→ log_softmax(logits / temp)
When no observations are set, forward and format_raw_to_logits are called without keyword arguments.
Abstract methods¶
Subclasses must implement two methods:
| Method | Signature | Notes |
|---|---|---|
forward |
(x_B, **kwargs) → Any |
Returns raw output — can be a non-tensor type (e.g. ESM returns a dataclass). |
format_raw_to_logits |
(raw_output, x_B, **kwargs) → FloatTensor |
Extracts a logit tensor suitable for log_softmax. Receives the full context via **kwargs. |
Overridable defaults¶
| Method | Default behavior | Override when… |
|---|---|---|
preprocess_observations(obs) |
Pass-through | You have expensive one-time ops (e.g. ESM3's VQ-VAE structure encoding). |
collate_observations(x_B, obs) |
Tile each tensor to match batch size | You have non-tensor observations or need selective expansion. |
Conditioning¶
Conditioning attaches observations (e.g. structure coordinates) that persist across get_log_probs calls. There are two distinct conditioning modes depending on whether you're doing inference or training.
Inference: single conditioning, tiled to batch¶
For sampling and evaluation, you typically have one conditioning input (e.g. a single backbone structure) shared across all sequences in the batch. Use set_condition_() or conditioned_on():
model.set_condition_(obs) # in-place, returns None
model = model.set_condition(obs) # returns self (chainable)
with model.conditioned_on(obs): # context manager, reverts on exit
log_probs = model.get_log_probs(x)
The flow is:
set_condition_()callspreprocess_observations(obs)once and caches the result inself.observations- Each
get_log_probs()call runscollate_observations(x_B, self.observations)to tile the cached observations to match the batch size - The tiled observations are passed as
**kwargstoforward()andformat_raw_to_logits()
The preprocess_observations → collate_observations split is a performance optimization: expensive preprocessing (running a VQ-VAE encoder) happens once when the condition is set, while cheap per-batch collation (tiling tensors to batch size) happens on every forward pass.
# Example: ESM3 inverse folding — one structure, many sequence samples
model = ESM3().cuda()
model.set_condition_({"coords_RAX": backbone_coords}) # VQ-VAE runs once
# Every get_log_probs call tiles the cached structure tokens to batch size
for step in range(n_steps):
log_probs = model.get_log_probs(batch_of_sequences)
Conditioning is mutable state
set_condition_() modifies self.observations in place. The conditioned_on() context manager handles revert, but be careful with concurrent usage.
Training: per-sample conditioning via collator¶
For training, each sequence in the batch typically has its own conditioning input (e.g. each protein has a different predicted structure). This pattern bypasses set_condition_() / get_log_probs() entirely — the collator prepares per-sample observations and the training loop calls model.forward() directly.
The flow is:
ProteinDatasetstores per-sequence observations (e.g. structure tokens, coordinates)- The collator batches observations from individual samples into batched tensors
- The training loop passes observations directly to
model(input_ids, **observations)
# Example: inverse folding training — each sequence has its own structure
for batch in loader:
input_ids = batch["input_ids"].to(device)
target_ids = batch["target_ids"].to(device)
struct_tokens = batch["structure_tokens"].to(device)
coords = batch["coordinates"].to(device)
# Call forward directly with per-sample observations
raw = model(input_ids, structure_tokens=struct_tokens, coordinates=coords)
logits = model.format_raw_to_logits(
raw, input_ids, structure_tokens=struct_tokens, coordinates=coords
)
loss = F.cross_entropy(logits[masked], target_ids[masked])
There are two ways to build the collator:
ProteinDataset.collator() — for sequence-only training or when preprocess_observations handles batched inputs. The built-in collator gathers per-sample observations into list-valued dicts and calls model.preprocess_observations(batched). See data for details.
Custom collator — for complex cases like inverse folding where you need to pad both sequences and structures to the batch max length. The fine-tuning workflow shows a complete example.
Why not use set_condition_() for training?
set_condition_() caches a single observation and tiles it to the batch — it assumes every sample in the batch shares the same conditioning. During training, each sample has different conditioning, so you pass observations directly through the collator → forward() path instead.
Summary: which pattern to use¶
| Scenario | Pattern | Observations flow |
|---|---|---|
| Sampling / evaluation | set_condition_() or conditioned_on() |
One obs → preprocess (once) → collate (tile to batch) → get_log_probs |
| Training (sequence-only) | ProteinDataset.collator() |
No observations needed |
| Training (per-sample conditioning) | Custom collator or ProteinDataset.collator() |
Per-sample obs → collator batches → model.forward(**obs) directly |
Checkpointing¶
Save and restore models with their constructor arguments:
Subclasses participate via:
_save_args() → dict— return JSON-serializable constructor kwargs (raisesNotImplementedErrorby default)- Override
save()to write additional state (weights, LoRA adapters), then callsuper().save() - Override
from_checkpoint()to load additional state after construction
For example, GenerativeModel.save writes a lora_adapter/ directory if LoRA is present, and LinearProbe.save writes head.pt plus delegates to embed_model.save().
Gotchas¶
- The
deviceproperty usesnext(self.parameters()).device— it will fail on models with no parameters (all current subclasses have parameters). forwardreturnsAny, not just tensors. This is intentional — ESM models return dataclass outputs. Theformat_raw_to_logitsstep is where you extract the tensor.- The default
collate_observationsassumes all observation values are tensors or scalars. If you store non-tensor observations (lists, strings), override this method.
API Reference¶
proteingen.modeling.probability_model
¶
Base class for models that produce log probability distributions.
ProbabilityModel
¶
Bases: Module, ABC
Base class for models that produce log probability distributions.
Subclasses implement:
1. preprocess_observations to turn conditioning variables input set
through, e.g. set_condition into collated input tensors that match the
forward function kwargs
2. forward to return logits based on the conditioning information
3. format_raw_to_logits``` to convert the forward output into something
that can be safely softmaxxed inget_log_probs().format_raw_to_logits``
typically includes output masking, coarse-graining classes, turning ensemble
predictions into probabilities, etc.
The default get_log_probs applies log_softmax along the last
dimension, which is appropriate for class-valued models. Models with
other output types (real-valued, ensemble, etc.) should override it.
Checkpointing
Subclasses that support save/load implement _save_args() returning
a JSON-serializable dict of constructor kwargs. The base class provides
save(path) and from_checkpoint(path) that serialize/deserialize
the constructor args and rebuild the object. Subclasses add their own
state (weights, adapters, etc.) on top by overriding save/from_checkpoint
and calling super().
Source code in src/proteingen/modeling/probability_model.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | |
preprocess_observations
¶
Transform raw observations into cached form.
Called once when set_condition_() or conditioned_on() is
invoked. Override for expensive operations (e.g. encoding structure)
that should not be repeated every forward pass.
Default: pass through.
Source code in src/proteingen/modeling/probability_model.py
collate_observations
¶
Collate observations for input to the forward function.
Default: tile each observation tensor to match batch size. Override when you need custom collation (e.g. only expanding certain keys, or handling non-tensor observations).
Source code in src/proteingen/modeling/probability_model.py
set_condition_
¶
set_condition
¶
Preprocess and cache observations, returning self for chaining.
conditioned_on
¶
Context manager that temporarily sets observations, reverting on exit.
Source code in src/proteingen/modeling/probability_model.py
forward
abstractmethod
¶
format_raw_to_logits
abstractmethod
¶
Convert raw forward output to logits suitable for log_softmax.
Examples: output masking (generative models), coarse-graining classes, turning ensemble predictions into per-class logits, etc.
Source code in src/proteingen/modeling/probability_model.py
with_temp
¶
Context manager to temporarily change the temperature.
get_log_probs
¶
Return temperature-scaled log probabilities.
Input is some batched tensor with otherwise arbitrary dimensions.
Default implementation: log_softmax(forward(x, **kwargs) / temp).
Source code in src/proteingen/modeling/probability_model.py
save
¶
Save model to a directory. Writes config.json with constructor args.
Subclasses override to add their own state (weights, adapters, etc.)
and should call super().save(path) first.
Source code in src/proteingen/modeling/probability_model.py
from_checkpoint
classmethod
¶
Load model from a directory. Reads config.json and calls cls(**args).
Subclasses override to load additional state (weights, adapters, etc.)
and should call super().from_checkpoint(path) to get the base object.