Design Philosophy¶
ProteinGen is organized around four modules — Data, Models, Sampling, and Evaluation — that mirror the stages of a library design pipeline. The Models module contains the core abstractions: a small number of composable classes that mirror the math of guided generation. Understanding these base classes is all you need to use the library.
The core abstractions (Models module)¶
ProbabilityModel¶
All models in ProteinGen are subclasses of ProbabilityModel — an nn.Module that produces log-probability distributions. This shared base class provides:
- Temperature — scale the sharpness of distributions
- Conditioning — attach observations (e.g., structure coordinates) that the model conditions on
- Log-prob pipeline —
get_log_probs(x)chains collation → forward → logit formatting → log-softmax
Two abstract methods must be implemented by subclasses:
forward(x, **kwargs)— run the model, return raw output (can be any type)format_raw_to_logits(raw_output, x, **kwargs)— extract a float tensor of logits from raw output
Temperature¶
All three styles — in-place, chained, and context-managed — are available:
model.set_temp_(2.0) # in-place mutation
model = model.set_temp(2.0) # returns self (chainable)
with model.with_temp(2.0): # reverts when exiting
log_probs = model.get_log_probs(x)
Conditioning¶
Attach observations that persist across calls:
# In-place
model.set_condition_({"coords_RAX": coords})
# Context manager (reverts on exit)
with model.conditioned_on({"coords_RAX": coords}):
log_probs = model.get_log_probs(x)
preprocess_observations runs once when conditions are set (e.g., encoding a structure). collate_observations tiles observations to match batch size at inference time.
Checkpointing¶
Save and restore models with their constructor arguments:
Subclasses implement _save_args() to return JSON-serializable constructor kwargs.
GenerativeModel¶
A concrete ProbabilityModel subclass that wraps any nn.Module generative model via composition:
from proteingen import GenerativeModel, MaskedModelLogitFormatter
model = GenerativeModel(
model=my_nn_module,
tokenizer=my_tokenizer,
logit_formatter=MaskedModelLogitFormatter(my_tokenizer, output_dim=64),
)
forwarddelegates toself.model(seq, **kwargs)format_raw_to_logitsapplies the logit formatter- Override
format_raw_to_logitswhen the wrapped model returns non-tensor output (e.g., ESM dataclasses)
LoRA support is built in: apply_lora(), save_lora(), load_lora().
GenerativeModelWithEmbedding¶
An ABC extending GenerativeModel for models that support differentiable embedding extraction. Subclasses implement:
differentiable_embedding(ohe) → embeddings— OHE through embedding layer + transformerembedding_to_outputs(embeddings) → raw_output— embeddings through the output head
This enables LinearProbe to extract and cache embeddings, and provides a differentiable path from one-hot inputs through the full model (needed for TAG gradients).
PredictiveModel¶
An ABC extending ProbabilityModel for models that answer "what is log p(target | sequence)?". Uses a binary logit pattern: format_raw_to_logits returns (B, 2) logits [false_logit, true_logit], and get_log_probs extracts the true_logit after log-softmax.
class MyPredictor(PredictiveModel):
def forward(self, ohe, **kwargs):
return self.mlp(ohe.flatten(1)) # raw scalar predictions
def format_raw_to_logits(self, raw, ohe, **kwargs):
return point_estimate_binary_logits(raw.squeeze(-1), threshold=0.7, k=10)
Four layers¶
A predictive model integration decomposes into four separable layers. Understanding this decomposition makes it clear what you're building vs reusing:
- Raw Predictor — the original pretrained model (architecture + weights), ported with minimal changes. Not proteingen-specific.
- Binary Logit Function — converts raw output to
(B, 2)binary logits. Independent of the model — the same predictor could use different functions. The library providesbinary_logits,categorical_binary_logits,point_estimate_binary_logits, andgaussian_binary_logits. - Template Model Class (optional) — a reusable architecture pattern (e.g.
LinearProbe,EmbeddingMLP). If the predictor's architecture generalizes, add a template. If it's one-off, subclassPredictiveModeldirectly. - PredictiveModel Subclass — thin glue wiring 1–3 together with conditioning, tokenizer, and OHE basis. If the other layers are well-designed, this should be mostly boilerplate.
See the contributing guide for details on each layer.
Target management¶
model.set_target_(True) # in-place
with model.with_target(True): # context manager
log_prob = model.get_log_probs(x) # log p(target=True | x)
Gradient access (for TAG)¶
Template subclasses¶
LinearProbe— frozenGenerativeModelWithEmbedding+nn.LinearheadEmbeddingMLP— learnable embeddings + MLP, with PCA initialization from pretrained modelsOneHotMLP— flattened one-hot + MLP
All are ABCs — you implement format_raw_to_logits using the binary logit functions listed above.
On-the-fly Conditional Models¶
TAG (Taylor-Approximate Guidance) and DEG (Discrete Enumeration Guidance) combine a generative model with a predictive model using Bayes' rule:
$$ p_\text{guided}(x_t | x_{<t}) \propto p_\text{gen}(x_t | x_{<t}) \cdot p_\text{pred}(\text{target} | x)^\gamma $$
Both are GenerativeModel subclasses — they produce guided log-probs that can be passed directly to any sampler.
- TAG uses first-order Taylor expansion of the predictive model's log-prob. Works well when gradients are reliable.
- DEG enumerates all 20 amino acids at each position and reweights. More robust for frozen-LM probes where gradients through the transformer are unreliable.
GuidanceProjection handles cross-tokenizer mapping when the predictive and generative models use different vocabularies.
Sampling¶
sample generates sequences by unmasking positions one (or n_parallel) at a time, using model.get_log_probs at each step. With no in_order argument, positions are unmasked in random order:
from proteingen import sample
from proteingen.models import ESMC
model = ESMC().cuda()
sequences = sample(model, ["<mask>" * 100] * 8)["sequences"]
Linear interpolation sampler¶
sample_ctmc_linear_interpolation generates sequences by interpolating between the current token distribution and the model's predicted distribution over a fixed number of steps. At each step $i$ of $N$ total:
$$ p_\text{next}(x) = \frac{N - i - 1}{N - i} \cdot \mathbb{1}[x = x_\text{current}] + \frac{1}{N - i} \cdot p_\text{model}(x) $$
Tokens are resampled from this mixture at every position simultaneously, so the distribution gradually shifts from the initial state (fully masked) to the model's predicted distribution. Unlike ancestral sampling which unmasks one position at a time, linear interpolation updates all positions in parallel at each step.
from proteingen.sampling import sample_ctmc_linear_interpolation
from proteingen.models import ESMC
model = ESMC().cuda()
sequences = sample_ctmc_linear_interpolation(model, ["<mask>" * 100] * 8, n_steps=50)
Flow-matching Euler sampler¶
sample_flow_matching_legacy integrates a rate matrix using Euler steps, following the continuous-time flow-matching framework. At each time step, the model predicts an $x_1$ distribution, and a rate matrix $R_t$ is constructed such that masked positions transition toward the predicted distribution at a rate proportional to $1/(1-t)$. With optional stochasticity, unmasked positions can also remask.
When a predictive model is provided, guidance is applied by reweighting the rate matrix with likelihood ratios — either via enumeration (DEG-style) or first-order Taylor approximation (TAG-style).
from proteingen.sampling import sample_flow_matching_legacy
from proteingen.models import ESMC
model = ESMC().cuda()
sequences = sample_flow_matching_legacy(model, ["<mask>" * 100] * 8, dt=0.01)
Key parameters:
dt— step size (default 0.01, i.e. 100 steps)x1_temp— temperature applied to the model's $x_1$ predictionstochasticity— controls remasking rate (0 = deterministic flow, >0 = stochastic)argmax_final— if True, remaining masked positions are filled with argmax at $t=1$predictor_log_prob— optional guidance function (usebuild_legacy_predictor_log_probto construct from a TAG model)
Composition¶
The key design insight: because TAG, DEG, and all models share the ProbabilityModel interface, they compose naturally. You can layer multiple guidance signals, swap generative backbones, or mix sampling strategies without changing code for any other part of your pipeline.