predictive_modeling
This module contains the PredictiveModel ABC, template model classes (LinearProbe, OneHotMLP, EmbeddingMLP, PairwiseLinearModel), binary logit helper functions, and PCA embedding initialization.
PredictiveModel
An ABC extending ProbabilityModel for models that answer "what is log p(target | sequence)?". Uses a binary logit pattern: format_raw_to_logits must return (B, 2) logits [false_logit, true_logit], and get_log_probs extracts the true_logit after log-softmax.
Abstract methods
| Method |
Signature |
Notes |
forward |
(ohe_seq_SPT, **kwargs) → Any |
Takes OHE features (not token IDs). Must be differentiable for TAG. |
format_raw_to_logits |
(raw_output, seq_SPT, **kwargs) → FloatTensor(B, 2) |
Must return binary logits. Use the binary logit helper functions below. |
The get_log_probs pipeline
PredictiveModel overrides the base pipeline to add OHE creation and binary extraction:
get_log_probs(seq_SP)
→ tokens_to_ohe(seq_SP) → ohe_SPK (requires_grad=True, stashed as self._ohe)
→ forward(ohe) → raw output
→ format_raw_to_logits(raw, ohe) → (B, 2) binary logits
→ log_softmax(logits / temp) → (B, 2)
→ [:, 1] → (B,) log p(target | x)
Target management
A target must be set before calling get_log_probs:
model.set_target_(True) # in-place
model = model.set_target(True) # chainable
with model.with_target(True): # context manager
log_prob = model.get_log_probs(x)
Gradient access (for TAG)
grad = model.grad_log_prob(seq_SP) # ∂log p(target|x) / ∂OHE, shape (B, P, K)
This runs get_log_probs with gradient tracking, backprops through the binary logits, and returns self._ohe.grad. This is the core computation that TAG uses to compute guidance deltas.
Token OHE basis
- Default
token_ohe_basis() returns torch.eye(vocab_size) — each token maps to its own one-hot vector (K = T)
- Override when tokens should map to a reduced space (e.g. the stability predictor maps
<mask> to an all-zero OHE row to preserve original PMPNN masking semantics)
grad_log_prob returns gradients in feature space (B, P, K), which is not always (B, P, vocab_size)
Binary logit functions
Standalone functions for converting raw predictions to (B, 2) binary logits. Call these from your format_raw_to_logits:
| Function |
Input |
Use case |
categorical_binary_logits(logits_BC, target_class) |
Multi-class logits |
Classification (target class vs rest) |
binary_logits(logit_B, target) |
Single logit |
Binary classification |
point_estimate_binary_logits(pred_B, threshold, k) |
Scalar prediction |
Thresholding a regression output |
gaussian_binary_logits(mu_B, log_var_B, threshold) |
Gaussian params |
P(Y > threshold) via CDF |
Steep sigmoid saturates gradients
point_estimate_binary_logits uses a sigmoid with steepness parameter k. Large values (k=100) make sigmoid(k*(pred - threshold)) ≈ 1, driving gradients to zero. Use k=5–10, or prefer DEG over TAG when gradients are unreliable.
Gaussian logits are TAG-friendly
gaussian_binary_logits computes P(Y > threshold) via the CDF and is differentiable through both the mean and variance. This makes it naturally compatible with TAG gradient-based guidance.
Template models
All template classes are ABCs — you implement format_raw_to_logits using the binary logit functions above.
LinearProbe
Frozen GenerativeModelWithEmbedding + trainable nn.Linear head.
LinearProbe(embed_model, output_dim, pooling_fn=None, freeze_embed_model=True)
- Default pooling: mean over non-padding positions
pooling_fn takes two args: (embeddings_SPD, seq_SP) — the token IDs are needed for masking special tokens during pooling
- Set
freeze_embed_model=False when using LoRA on the embed_model so PEFT's freeze/unfreeze state is preserved
precompute_embeddings(sequences, batch_size, device) caches pooled embeddings for fast training
EmbeddingMLP
Learnable embeddings + MLP. The embedding lookup uses ohe @ self.embed.weight, which is differentiable for TAG gradient flow.
padding_idx defaults to tokenizer.pad_token_id
- Supports PCA initialization from pretrained models (see below)
OneHotMLP
Flattened one-hot encoding → MLP. Takes sequence_length as a required constructor argument.
PairwiseLinearModel
Linear model on single + pairwise OHE features. Computes pairwise outer products and takes the upper triangle. Quadratic in sequence_length × vocab_size.
PCA embedding initialization
EmbeddingMLP supports post-construction initialization from PCA of pretrained embeddings:
model.init_embed_from_pretrained_pca(
source=esmc_model, # GenerativeModelWithEmbedding
source_vocab=esmc_tokenizer.vocab,
target_vocab=model.tokenizer.vocab,
)
Key details:
- Token matching by string key — shared tokens are the intersection of vocabulary keys
- PCA is computed only over shared tokens (special tokens excluded from centering/SVD)
- Automatically zeroes the padding row after copy
- This is a post-construction method (not a constructor parameter) to avoid redundant shape arguments
Variance capture
ESMC's 960-dim embeddings for 20 amino acids have effective rank 19 (after centering). The first 20 PCs capture ~100% of the variance. Small embedding dims (8–32) still capture the most important AA similarity structure.
API Reference
proteingen.modeling.predictive_modeling
PredictiveModel
Bases: ProbabilityModel, ABC
Base class for predictive models used in guidance.
Predictive models answer: "what is log p(target_event | sequence)?"
The target event is set via set_target_() or the with_target() context
manager. forward() returns raw predictions (class logits, regression
values, etc.). format_raw_to_logits() converts those to binary logits
(B, 2): [false_logit, true_logit]. The inherited
ProbabilityModel.get_log_probs applies temperature-scaled log_softmax,
and this class's override takes [:, 1] to return the scalar
log p(target=True | x).
Pipeline::
get_log_probs(seq_SP) — creates OHE, stashes self._ohe
↓
forward(ohe_seq) → raw output (class logits, regression value, ...)
↓
format_raw_to_logits(raw) → (B, 2) binary logits [false, true]
↓
ProbabilityModel.get_log_probs: log_softmax(logits / temp) → (B, 2)
↓
PredictiveModel.get_log_probs: [:, 1] → (B,) log p(target | x)
For TAG guidance, use grad_log_prob(seq_SP) which runs the pipeline,
backprops, and returns gradients w.r.t. the model's OHE feature space.
Source code in src/proteingen/modeling/predictive_modeling.py
| class PredictiveModel(ProbabilityModel, ABC):
"""Base class for predictive models used in guidance.
Predictive models answer: "what is log p(target_event | sequence)?"
The target event is set via ``set_target_()`` or the ``with_target()`` context
manager. ``forward()`` returns raw predictions (class logits, regression
values, etc.). ``format_raw_to_logits()`` converts those to binary logits
``(B, 2)``: ``[false_logit, true_logit]``. The inherited
``ProbabilityModel.get_log_probs`` applies temperature-scaled log_softmax,
and this class's override takes ``[:, 1]`` to return the scalar
log p(target=True | x).
Pipeline::
get_log_probs(seq_SP) — creates OHE, stashes self._ohe
↓
forward(ohe_seq) → raw output (class logits, regression value, ...)
↓
format_raw_to_logits(raw) → (B, 2) binary logits [false, true]
↓
ProbabilityModel.get_log_probs: log_softmax(logits / temp) → (B, 2)
↓
PredictiveModel.get_log_probs: [:, 1] → (B,) log p(target | x)
For TAG guidance, use ``grad_log_prob(seq_SP)`` which runs the pipeline,
backprops, and returns gradients w.r.t. the model's OHE feature space.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__()
self.tokenizer = tokenizer
self.target = None
self._ohe = None
def set_target_(self, target):
"""Set the target event in-place."""
self.target = target
def set_target(self, target):
"""Set the target event, returning self for chaining."""
self.set_target_(target)
return self
@contextmanager
def with_target(self, target_spec):
"""Context manager: temporarily set target, revert on exit."""
old = self.target
self.set_target_(target_spec)
try:
yield self
finally:
self.target = old
def forward(self, ohe_seq_SPT: torch.FloatTensor, **kwargs) -> Any:
"""Forward pass to produce raw predictions. To be implemented by subclasses. Must take OHE as input.
OHE dimension is defined by ``tokens_to_ohe`` / ``token_ohe_basis``."""
...
# TODO[pi] add the decorator which gives this instance-variable semantics
def token_ohe_basis(self) -> torch.FloatTensor:
"""Return token-id → OHE feature matrix (T, K).
Default is identity: each token id maps to its one-hot basis vector.
Override when token ids should map to a reduced OHE space (e.g. an
explicit mask token represented as an all-zero feature vector).
"""
vocab_size = self.tokenizer.vocab_size
return torch.eye(vocab_size, dtype=torch.float32)
def tokens_to_ohe(self, seq_SP: torch.LongTensor) -> torch.FloatTensor:
"""Map token IDs to model OHE features using ``token_ohe_basis``."""
basis_TK = self.token_ohe_basis().to(seq_SP.device)
return basis_TK[seq_SP.long()]
# TODO[pi] The relationship between target and format_raw_to_logits is not
# obvious from the interface. A user subclassing OneHotMLP has to just *know*
# that self.target exists, what type it should be, and that they need to use
# it inside format_raw_to_logits. The binary logit functions help but still
# require the user to pass self.target manually. Consider ways to make this
# more discoverable — e.g. requiring subclasses to declare expected target
# type, or having format_raw_to_logits receive target as an explicit argument
# instead of reading it off self.
@abstractmethod
def format_raw_to_logits(
self, raw_output: Any, seq_SPT: torch.FloatTensor, **kwargs
) -> torch.FloatTensor:
"""Convert raw predictions → binary logits (B, 2): [false_logit, true_logit].
Uses ``self.target`` to determine what event is being evaluated.
The parent's ``get_log_probs`` applies ``log_softmax(logits / temp)``
on top.
"""
...
def get_log_probs(self, seq_SP: torch.LongTensor) -> torch.FloatTensor:
"""Return log p(target=True | x), scalar per sequence.
Creates model OHE features from token IDs, stashes them as ``self._ohe``
(for gradient access via ``grad_log_prob``), then runs the parent
pipeline (forward → format_raw_to_logits → log_softmax) to get
(B, 2) and returns [:, 1].
"""
assert self.target is not None, (
"Target not set. Call set_target_() or use with_target() context manager."
)
ohe_seq_SPT = self.tokens_to_ohe(seq_SP).float()
ohe_seq_SPT.requires_grad_(True)
self._ohe = ohe_seq_SPT
log_probs_B2 = super().get_log_probs(ohe_seq_SPT) # (B, 2)
assert log_probs_B2.shape[1] == 2, (
f"Expected binary logits (B, 2) from format_raw_to_logits, got shape {log_probs_B2.shape}"
)
return log_probs_B2[:, 1] # (B,)
def predict(self, seq_SP: torch.LongTensor) -> Any:
"""Get raw model predictions from token IDs.
Creates OHE and calls forward — returns whatever forward returns
(scalar predictions, class logits, etc.) without binary logit
conversion. Use for training (e.g. MSE loss) and scoring.
"""
ohe = self.tokens_to_ohe(seq_SP).float()
return self.forward(ohe)
def grad_log_prob(self, seq_SP: torch.LongTensor) -> torch.FloatTensor:
"""Return gradient of log p(target | x) w.r.t. model OHE features.
Runs ``get_log_probs`` (which creates and stashes the OHE),
backprops, and returns ``self._ohe.grad`` of shape (B, P, K).
This is the gradient signal TAG adds to generative model logits.
"""
with torch.enable_grad():
log_p = self.get_log_probs(seq_SP)
log_p.sum().backward()
return self._ohe.grad
|
set_target_
Set the target event in-place.
Source code in src/proteingen/modeling/predictive_modeling.py
| def set_target_(self, target):
"""Set the target event in-place."""
self.target = target
|
set_target
Set the target event, returning self for chaining.
Source code in src/proteingen/modeling/predictive_modeling.py
| def set_target(self, target):
"""Set the target event, returning self for chaining."""
self.set_target_(target)
return self
|
with_target
Context manager: temporarily set target, revert on exit.
Source code in src/proteingen/modeling/predictive_modeling.py
| @contextmanager
def with_target(self, target_spec):
"""Context manager: temporarily set target, revert on exit."""
old = self.target
self.set_target_(target_spec)
try:
yield self
finally:
self.target = old
|
forward
forward(ohe_seq_SPT: FloatTensor, **kwargs) -> Any
Forward pass to produce raw predictions. To be implemented by subclasses. Must take OHE as input.
OHE dimension is defined by tokens_to_ohe / token_ohe_basis.
Source code in src/proteingen/modeling/predictive_modeling.py
| def forward(self, ohe_seq_SPT: torch.FloatTensor, **kwargs) -> Any:
"""Forward pass to produce raw predictions. To be implemented by subclasses. Must take OHE as input.
OHE dimension is defined by ``tokens_to_ohe`` / ``token_ohe_basis``."""
...
|
token_ohe_basis
token_ohe_basis() -> torch.FloatTensor
Return token-id → OHE feature matrix (T, K).
Default is identity: each token id maps to its one-hot basis vector.
Override when token ids should map to a reduced OHE space (e.g. an
explicit mask token represented as an all-zero feature vector).
Source code in src/proteingen/modeling/predictive_modeling.py
| def token_ohe_basis(self) -> torch.FloatTensor:
"""Return token-id → OHE feature matrix (T, K).
Default is identity: each token id maps to its one-hot basis vector.
Override when token ids should map to a reduced OHE space (e.g. an
explicit mask token represented as an all-zero feature vector).
"""
vocab_size = self.tokenizer.vocab_size
return torch.eye(vocab_size, dtype=torch.float32)
|
tokens_to_ohe
tokens_to_ohe(seq_SP: LongTensor) -> torch.FloatTensor
Map token IDs to model OHE features using token_ohe_basis.
Source code in src/proteingen/modeling/predictive_modeling.py
| def tokens_to_ohe(self, seq_SP: torch.LongTensor) -> torch.FloatTensor:
"""Map token IDs to model OHE features using ``token_ohe_basis``."""
basis_TK = self.token_ohe_basis().to(seq_SP.device)
return basis_TK[seq_SP.long()]
|
format_raw_to_logits(raw_output: Any, seq_SPT: FloatTensor, **kwargs) -> torch.FloatTensor
Convert raw predictions → binary logits (B, 2): [false_logit, true_logit].
Uses self.target to determine what event is being evaluated.
The parent's get_log_probs applies log_softmax(logits / temp)
on top.
Source code in src/proteingen/modeling/predictive_modeling.py
| @abstractmethod
def format_raw_to_logits(
self, raw_output: Any, seq_SPT: torch.FloatTensor, **kwargs
) -> torch.FloatTensor:
"""Convert raw predictions → binary logits (B, 2): [false_logit, true_logit].
Uses ``self.target`` to determine what event is being evaluated.
The parent's ``get_log_probs`` applies ``log_softmax(logits / temp)``
on top.
"""
...
|
get_log_probs
get_log_probs(seq_SP: LongTensor) -> torch.FloatTensor
Return log p(target=True | x), scalar per sequence.
Creates model OHE features from token IDs, stashes them as self._ohe
(for gradient access via grad_log_prob), then runs the parent
pipeline (forward → format_raw_to_logits → log_softmax) to get
(B, 2) and returns [:, 1].
Source code in src/proteingen/modeling/predictive_modeling.py
| def get_log_probs(self, seq_SP: torch.LongTensor) -> torch.FloatTensor:
"""Return log p(target=True | x), scalar per sequence.
Creates model OHE features from token IDs, stashes them as ``self._ohe``
(for gradient access via ``grad_log_prob``), then runs the parent
pipeline (forward → format_raw_to_logits → log_softmax) to get
(B, 2) and returns [:, 1].
"""
assert self.target is not None, (
"Target not set. Call set_target_() or use with_target() context manager."
)
ohe_seq_SPT = self.tokens_to_ohe(seq_SP).float()
ohe_seq_SPT.requires_grad_(True)
self._ohe = ohe_seq_SPT
log_probs_B2 = super().get_log_probs(ohe_seq_SPT) # (B, 2)
assert log_probs_B2.shape[1] == 2, (
f"Expected binary logits (B, 2) from format_raw_to_logits, got shape {log_probs_B2.shape}"
)
return log_probs_B2[:, 1] # (B,)
|
predict
predict(seq_SP: LongTensor) -> Any
Get raw model predictions from token IDs.
Creates OHE and calls forward — returns whatever forward returns
(scalar predictions, class logits, etc.) without binary logit
conversion. Use for training (e.g. MSE loss) and scoring.
Source code in src/proteingen/modeling/predictive_modeling.py
| def predict(self, seq_SP: torch.LongTensor) -> Any:
"""Get raw model predictions from token IDs.
Creates OHE and calls forward — returns whatever forward returns
(scalar predictions, class logits, etc.) without binary logit
conversion. Use for training (e.g. MSE loss) and scoring.
"""
ohe = self.tokens_to_ohe(seq_SP).float()
return self.forward(ohe)
|
grad_log_prob
grad_log_prob(seq_SP: LongTensor) -> torch.FloatTensor
Return gradient of log p(target | x) w.r.t. model OHE features.
Runs get_log_probs (which creates and stashes the OHE),
backprops, and returns self._ohe.grad of shape (B, P, K).
This is the gradient signal TAG adds to generative model logits.
Source code in src/proteingen/modeling/predictive_modeling.py
| def grad_log_prob(self, seq_SP: torch.LongTensor) -> torch.FloatTensor:
"""Return gradient of log p(target | x) w.r.t. model OHE features.
Runs ``get_log_probs`` (which creates and stashes the OHE),
backprops, and returns ``self._ohe.grad`` of shape (B, P, K).
This is the gradient signal TAG adds to generative model logits.
"""
with torch.enable_grad():
log_p = self.get_log_probs(seq_SP)
log_p.sum().backward()
return self._ohe.grad
|
LinearProbe
Bases: PredictiveModel, ABC
Linear probe on top of pre-computed embeddings.
Tensor Dimension Labels
I: batch index
D: embedding dimension
O: output dimension
Source code in src/proteingen/modeling/predictive_modeling.py
| class LinearProbe(PredictiveModel, ABC):
"""
Linear probe on top of pre-computed embeddings.
Tensor Dimension Labels:
I: batch index
D: embedding dimension
O: output dimension
"""
def __init__(
self,
embed_model: GenerativeModelWithEmbedding,
output_dim: int,
pooling_fn: Optional[callable] = None,
freeze_embed_model: bool = True,
):
super().__init__(tokenizer=embed_model.tokenizer)
self.embed_model = embed_model
self.embedding_dim = embed_model.EMB_DIM
self.output_dim = output_dim
self.w = nn.Linear(self.embedding_dim, self.output_dim)
def _mean_pool_non_padding(emb_SPD, seq_SP):
mask = (seq_SP != self.tokenizer.pad_token_id).unsqueeze(-1).float()
return (emb_SPD * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
self.pooling_fn = pooling_fn or _mean_pool_non_padding
if freeze_embed_model:
for p in self.embed_model.parameters():
p.requires_grad = False
def forward(self, ohe_seq_SPT: torch.FloatTensor) -> torch.FloatTensor:
"""Full forward: differentiable embed → pool → linear."""
emb_SPD = self.embed_model.differentiable_embedding(ohe_seq_SPT)
seq_SP = ohe_seq_SPT.argmax(dim=-1)
pooled_SD = self.pooling_fn(emb_SPD, seq_SP)
return self.w(pooled_SD)
@torch.no_grad()
def precompute_embeddings(
self,
sequences: list[str],
batch_size: int,
device: torch.device,
) -> torch.Tensor:
"""Pre-compute pooled embeddings for training. Shape (N, EMB_DIM)."""
self.embed_model.to(device)
tokenizer = self.embed_model.tokenizer
all_embeddings = []
n = len(sequences)
for start in range(0, n, batch_size):
batch_seqs = sequences[start : start + batch_size]
token_ids = tokenizer(batch_seqs, padding=True, return_tensors="pt")[
"input_ids"
].to(device)
emb_SPD = self.embed_model.embed(token_ids)
pooled = self.pooling_fn(emb_SPD, token_ids)
all_embeddings.append(pooled.cpu())
if (start // batch_size) % 50 == 0:
print(f" Embedded {min(start + batch_size, n):>6d} / {n}")
return torch.cat(all_embeddings, dim=0)
# ── Checkpointing ────────────────────────────────────────────────────
def save(self, path: str | Path) -> None:
"""Save probe to a directory: config.json, head.pt, and embed_model/.
Subclasses must implement ``_save_args()`` returning constructor kwargs.
The embed_model is saved to a subdirectory via its own ``save()`` method.
"""
path = Path(path)
super().save(path)
torch.save(self.w.state_dict(), path / "head.pt")
self.embed_model.save(path / "embed_model")
@classmethod
def from_checkpoint(cls, path: str | Path) -> "LinearProbe":
"""Load probe from a directory.
Reconstructs the object from config.json (calls ``cls(**args)``),
loads the LoRA adapter onto the embed_model if present, and loads
the head weights.
"""
path = Path(path)
obj = super().from_checkpoint(path)
embed_path = path / "embed_model"
if (embed_path / "lora_adapter").exists():
obj.embed_model.load_lora(embed_path / "lora_adapter")
obj.w.load_state_dict(torch.load(path / "head.pt", weights_only=True))
return obj
|
forward
forward(ohe_seq_SPT: FloatTensor) -> torch.FloatTensor
Full forward: differentiable embed → pool → linear.
Source code in src/proteingen/modeling/predictive_modeling.py
| def forward(self, ohe_seq_SPT: torch.FloatTensor) -> torch.FloatTensor:
"""Full forward: differentiable embed → pool → linear."""
emb_SPD = self.embed_model.differentiable_embedding(ohe_seq_SPT)
seq_SP = ohe_seq_SPT.argmax(dim=-1)
pooled_SD = self.pooling_fn(emb_SPD, seq_SP)
return self.w(pooled_SD)
|
precompute_embeddings
precompute_embeddings(sequences: list[str], batch_size: int, device: device) -> torch.Tensor
Pre-compute pooled embeddings for training. Shape (N, EMB_DIM).
Source code in src/proteingen/modeling/predictive_modeling.py
| @torch.no_grad()
def precompute_embeddings(
self,
sequences: list[str],
batch_size: int,
device: torch.device,
) -> torch.Tensor:
"""Pre-compute pooled embeddings for training. Shape (N, EMB_DIM)."""
self.embed_model.to(device)
tokenizer = self.embed_model.tokenizer
all_embeddings = []
n = len(sequences)
for start in range(0, n, batch_size):
batch_seqs = sequences[start : start + batch_size]
token_ids = tokenizer(batch_seqs, padding=True, return_tensors="pt")[
"input_ids"
].to(device)
emb_SPD = self.embed_model.embed(token_ids)
pooled = self.pooling_fn(emb_SPD, token_ids)
all_embeddings.append(pooled.cpu())
if (start // batch_size) % 50 == 0:
print(f" Embedded {min(start + batch_size, n):>6d} / {n}")
return torch.cat(all_embeddings, dim=0)
|
save
save(path: str | Path) -> None
Save probe to a directory: config.json, head.pt, and embed_model/.
Subclasses must implement _save_args() returning constructor kwargs.
The embed_model is saved to a subdirectory via its own save() method.
Source code in src/proteingen/modeling/predictive_modeling.py
| def save(self, path: str | Path) -> None:
"""Save probe to a directory: config.json, head.pt, and embed_model/.
Subclasses must implement ``_save_args()`` returning constructor kwargs.
The embed_model is saved to a subdirectory via its own ``save()`` method.
"""
path = Path(path)
super().save(path)
torch.save(self.w.state_dict(), path / "head.pt")
self.embed_model.save(path / "embed_model")
|
from_checkpoint
classmethod
from_checkpoint(path: str | Path) -> LinearProbe
Load probe from a directory.
Reconstructs the object from config.json (calls cls(**args)),
loads the LoRA adapter onto the embed_model if present, and loads
the head weights.
Source code in src/proteingen/modeling/predictive_modeling.py
| @classmethod
def from_checkpoint(cls, path: str | Path) -> "LinearProbe":
"""Load probe from a directory.
Reconstructs the object from config.json (calls ``cls(**args)``),
loads the LoRA adapter onto the embed_model if present, and loads
the head weights.
"""
path = Path(path)
obj = super().from_checkpoint(path)
embed_path = path / "embed_model"
if (embed_path / "lora_adapter").exists():
obj.embed_model.load_lora(embed_path / "lora_adapter")
obj.w.load_state_dict(torch.load(path / "head.pt", weights_only=True))
return obj
|
OneHotMLP
Bases: PredictiveModel, ABC
MLP operating on one-hot encoded sequences.
Receives OHE input from PredictiveModel.get_log_probs, flattens across
all positions, and passes through an MLP. Subclasses implement
format_raw_to_logits to convert the MLP output to binary logits.
Tensor Dimension Labels
S: batch (sample) index
P: position in sequence
T: token dimension (one-hot / vocab size)
O: output dimension
Source code in src/proteingen/modeling/predictive_modeling.py
| class OneHotMLP(PredictiveModel, ABC):
"""MLP operating on one-hot encoded sequences.
Receives OHE input from PredictiveModel.get_log_probs, flattens across
all positions, and passes through an MLP. Subclasses implement
format_raw_to_logits to convert the MLP output to binary logits.
Tensor Dimension Labels:
S: batch (sample) index
P: position in sequence
T: token dimension (one-hot / vocab size)
O: output dimension
"""
def __init__(
self,
tokenizer,
sequence_length: int,
model_dim: int,
n_layers: int,
output_dim: int,
dropout: float = 0.0,
):
super().__init__(tokenizer)
self.vocab_size = tokenizer.vocab_size
self.sequence_length = sequence_length
self.model_dim = model_dim
self.n_layers = n_layers
self.output_dim = output_dim
self.dropout = dropout
layers: list[nn.Module] = [
nn.Linear(self.sequence_length * self.vocab_size, self.model_dim)
]
for _ in range(n_layers - 1):
layers.append(nn.ReLU())
if dropout > 0:
layers.append(nn.Dropout(dropout))
layers.append(nn.Linear(self.model_dim, self.model_dim))
layers.append(nn.ReLU())
if dropout > 0:
layers.append(nn.Dropout(dropout))
layers.append(nn.Linear(self.model_dim, self.output_dim))
self.layers = nn.Sequential(*layers)
def forward(self, ohe_seq_SPT: torch.FloatTensor) -> torch.FloatTensor:
x_SPxT = ohe_seq_SPT.reshape(ohe_seq_SPT.size(0), -1)
return self.layers(x_SPxT)
|
PairwiseLinearModel
Bases: PredictiveModel, ABC
Linear model that uses single and pairwise mutation features
encoded as one-hot vectors.
Receives OHE input from PredictiveModel.get_log_probs, flattens across
all positions, and passes through an MLP. Subclasses implement
format_raw_to_logits to convert the MLP output to binary logits.
Tensor Dimension Labels
S: batch (sample) index
P: position in sequence
T: token dimension (one-hot / vocab size)
O: output dimension
Source code in src/proteingen/modeling/predictive_modeling.py
| class PairwiseLinearModel(PredictiveModel, ABC):
"""Linear model that uses single and pairwise mutation features
encoded as one-hot vectors.
Receives OHE input from PredictiveModel.get_log_probs, flattens across
all positions, and passes through an MLP. Subclasses implement
format_raw_to_logits to convert the MLP output to binary logits.
Tensor Dimension Labels:
S: batch (sample) index
P: position in sequence
T: token dimension (one-hot / vocab size)
O: output dimension
"""
def __init__(
self,
tokenizer,
sequence_length: int,
output_dim: int,
):
super().__init__(tokenizer)
self.vocab_size = tokenizer.vocab_size
self.sequence_length = sequence_length
self.n_ohe_features = self.sequence_length * self.vocab_size + 1
self.n_pairwise_features = torch.triu_indices(
self.n_ohe_features, self.n_ohe_features
).shape[1]
self.linear_layer = nn.Linear(self.n_pairwise_features, output_dim)
def forward(self, ohe_seq_SPT: torch.FloatTensor) -> torch.FloatTensor:
x_SPxT = ohe_seq_SPT.reshape(ohe_seq_SPT.size(0), -1)
ones_S1 = torch.ones_like(x_SPxT[:, :1])
x_Sf = torch.cat([ones_S1, x_SPxT], dim=-1)
pairwise_features_Sff = torch.einsum("si,sj->sij", x_Sf, x_Sf)
idx = torch.triu_indices(
self.n_ohe_features, self.n_ohe_features, device=ohe_seq_SPT.device
)
x_SF = pairwise_features_Sff[:, idx[0], idx[1]]
y_SO = self.linear_layer(x_SF)
return y_SO
|
EmbeddingMLP
Bases: PredictiveModel, ABC
MLP operating on learned token embeddings.
Receives OHE input from PredictiveModel.get_log_probs, multiplies through
a learned embedding matrix (ohe @ embed.weight), flattens across
all positions, and passes through an MLP. The matmul embedding lookup
is differentiable, enabling TAG gradient flow.
Use :meth:init_embed_from_pretrained_pca to initialize the embedding
layer from a pretrained model's embeddings (e.g. ESMC), compressed to
embed_dim principal components with automatic cross-tokenizer mapping.
Tensor Dimension Labels
S: batch (sample) index
P: position in sequence
E: embedding dimension
O: output dimension
Source code in src/proteingen/modeling/predictive_modeling.py
| class EmbeddingMLP(PredictiveModel, ABC):
"""MLP operating on learned token embeddings.
Receives OHE input from PredictiveModel.get_log_probs, multiplies through
a learned embedding matrix (``ohe @ embed.weight``), flattens across
all positions, and passes through an MLP. The matmul embedding lookup
is differentiable, enabling TAG gradient flow.
Use :meth:`init_embed_from_pretrained_pca` to initialize the embedding
layer from a pretrained model's embeddings (e.g. ESMC), compressed to
``embed_dim`` principal components with automatic cross-tokenizer mapping.
Tensor Dimension Labels:
S: batch (sample) index
P: position in sequence
E: embedding dimension
O: output dimension
"""
def __init__(
self,
tokenizer,
sequence_length: int,
embed_dim: int,
model_dim: int,
n_layers: int,
output_dim: int,
padding_idx: Optional[int] = None,
dropout: float = 0.0,
):
super().__init__(tokenizer)
self.vocab_size = tokenizer.vocab_size
self.sequence_length = sequence_length
self.embed_dim = embed_dim
self.model_dim = model_dim
self.n_layers = n_layers
self.output_dim = output_dim
self.padding_idx = (
padding_idx if padding_idx is not None else tokenizer.pad_token_id
)
self.dropout = dropout
self.embed = nn.Embedding(
self.vocab_size,
self.embed_dim,
padding_idx=self.padding_idx,
)
layers: list[nn.Module] = [
nn.Linear(self.sequence_length * self.embed_dim, self.model_dim)
]
for _ in range(n_layers - 1):
layers.append(nn.ReLU())
if dropout > 0:
layers.append(nn.Dropout(dropout))
layers.append(nn.Linear(self.model_dim, self.model_dim))
layers.append(nn.ReLU())
if dropout > 0:
layers.append(nn.Dropout(dropout))
layers.append(nn.Linear(self.model_dim, self.output_dim))
self.layers = nn.Sequential(*layers)
def init_embed_from_pretrained_pca(
self,
source: nn.Embedding,
source_vocab: dict[str, int],
target_vocab: dict[str, int],
) -> None:
"""Initialize embedding layer from PCA of a pretrained model's embeddings.
Finds tokens shared between source and target vocabularies, runs PCA
on their pretrained embeddings, and copies the first ``self.embed_dim``
principal components into this model's embedding layer at the correct
target indices. Unmatched rows and the padding row are zeroed.
Args:
source: Pretrained embedding layer (e.g. ``esmc_model.embed``).
source_vocab: Token string → index mapping for the pretrained model
(e.g. ``esm_tokenizer.vocab``).
target_vocab: Token string → index mapping for this model
(e.g. ``mpnn_tokenizer.vocab``).
"""
weights = pca_embed_init(
pretrained_weights=source.weight.detach(),
pretrained_vocab=source_vocab,
target_vocab=target_vocab,
n_components=self.embed_dim,
target_vocab_size=self.vocab_size,
)
self.embed.weight.data.copy_(weights)
if self.padding_idx is not None:
self.embed.weight.data[self.padding_idx].zero_()
def forward(self, ohe_seq_SPT: torch.FloatTensor) -> torch.FloatTensor:
x_SPE = ohe_seq_SPT @ self.embed.weight # differentiable embedding lookup
x_SPxE = x_SPE.reshape(x_SPE.size(0), -1)
return self.layers(x_SPxE)
|
init_embed_from_pretrained_pca
init_embed_from_pretrained_pca(source: Embedding, source_vocab: dict[str, int], target_vocab: dict[str, int]) -> None
Initialize embedding layer from PCA of a pretrained model's embeddings.
Finds tokens shared between source and target vocabularies, runs PCA
on their pretrained embeddings, and copies the first self.embed_dim
principal components into this model's embedding layer at the correct
target indices. Unmatched rows and the padding row are zeroed.
Parameters:
| Name |
Type |
Description |
Default |
source
|
Embedding
|
Pretrained embedding layer (e.g. esmc_model.embed).
|
required
|
source_vocab
|
dict[str, int]
|
Token string → index mapping for the pretrained model
(e.g. esm_tokenizer.vocab).
|
required
|
target_vocab
|
dict[str, int]
|
Token string → index mapping for this model
(e.g. mpnn_tokenizer.vocab).
|
required
|
Source code in src/proteingen/modeling/predictive_modeling.py
| def init_embed_from_pretrained_pca(
self,
source: nn.Embedding,
source_vocab: dict[str, int],
target_vocab: dict[str, int],
) -> None:
"""Initialize embedding layer from PCA of a pretrained model's embeddings.
Finds tokens shared between source and target vocabularies, runs PCA
on their pretrained embeddings, and copies the first ``self.embed_dim``
principal components into this model's embedding layer at the correct
target indices. Unmatched rows and the padding row are zeroed.
Args:
source: Pretrained embedding layer (e.g. ``esmc_model.embed``).
source_vocab: Token string → index mapping for the pretrained model
(e.g. ``esm_tokenizer.vocab``).
target_vocab: Token string → index mapping for this model
(e.g. ``mpnn_tokenizer.vocab``).
"""
weights = pca_embed_init(
pretrained_weights=source.weight.detach(),
pretrained_vocab=source_vocab,
target_vocab=target_vocab,
n_components=self.embed_dim,
target_vocab_size=self.vocab_size,
)
self.embed.weight.data.copy_(weights)
if self.padding_idx is not None:
self.embed.weight.data[self.padding_idx].zero_()
|
categorical_binary_logits
categorical_binary_logits(logits_BC: FloatTensor, target_class: int) -> torch.FloatTensor
Multi-class logits (B, C) → binary logits (B, 2) for a target class.
true_logit = logits[:, target_class]
false_logit = logsumexp(logits for non-target classes)
Source code in src/proteingen/modeling/predictive_modeling.py
| def categorical_binary_logits(
logits_BC: torch.FloatTensor, target_class: int
) -> torch.FloatTensor:
"""Multi-class logits (B, C) → binary logits (B, 2) for a target class.
true_logit = logits[:, target_class]
false_logit = logsumexp(logits for non-target classes)
"""
target_logit_B = logits_BC[:, target_class]
C = logits_BC.shape[-1]
mask = torch.ones(C, dtype=torch.bool, device=logits_BC.device)
mask[target_class] = False
false_logit_B = torch.logsumexp(logits_BC[:, mask], dim=-1)
return torch.stack([false_logit_B, target_logit_B], dim=-1) # (B, 2)
|
binary_logits
binary_logits(logit_B: FloatTensor, target: bool = True) -> torch.FloatTensor
Single logit → binary logits (B, 2) via sigmoid(x) = softmax([0, x])[1].
If target is False, swaps the logits so [:, 1] gives P(negative).
Source code in src/proteingen/modeling/predictive_modeling.py
| def binary_logits(logit_B: torch.FloatTensor, target: bool = True) -> torch.FloatTensor:
"""Single logit → binary logits (B, 2) via sigmoid(x) = softmax([0, x])[1].
If target is False, swaps the logits so [:, 1] gives P(negative).
"""
logit_B = logit_B.reshape(-1)
zero_B = torch.zeros_like(logit_B)
if target:
return torch.stack([zero_B, logit_B], dim=-1) # (B, 2)
else:
return torch.stack([logit_B, zero_B], dim=-1) # (B, 2)
|
point_estimate_binary_logits
point_estimate_binary_logits(pred_B: FloatTensor, threshold: float, k: float = 100.0) -> torch.FloatTensor
Scalar prediction → binary logits (B, 2) via steep sigmoid.
sigmoid(k * (pred - threshold)) approximates a step function.
Gradients unstable through the threshold though — use DEG, not TAG.
Source code in src/proteingen/modeling/predictive_modeling.py
| def point_estimate_binary_logits(
pred_B: torch.FloatTensor, threshold: float, k: float = 100.0
) -> torch.FloatTensor:
"""Scalar prediction → binary logits (B, 2) via steep sigmoid.
sigmoid(k * (pred - threshold)) approximates a step function.
Gradients unstable through the threshold though — use DEG, not TAG.
"""
pred_B = pred_B.reshape(-1)
logit_B = k * (pred_B - threshold)
zero_B = torch.zeros_like(logit_B)
return torch.stack([zero_B, logit_B], dim=-1) # (B, 2)
|
gaussian_binary_logits
gaussian_binary_logits(mu_B: FloatTensor, log_var_B: FloatTensor, threshold: float) -> torch.FloatTensor
Gaussian (mean, log_var) → binary logits (B, 2) via CDF log-odds.
P(Y > threshold) = Phi((mu - threshold) / sigma).
If you want P(Y < threshold) for your application, just swap the order of the logits in your
format_raw_to_logits implementation.
Differentiable through both mean and variance — works with TAG.
Source code in src/proteingen/modeling/predictive_modeling.py
| def gaussian_binary_logits(
mu_B: torch.FloatTensor, log_var_B: torch.FloatTensor, threshold: float
) -> torch.FloatTensor:
"""Gaussian (mean, log_var) → binary logits (B, 2) via CDF log-odds.
P(Y > threshold) = Phi((mu - threshold) / sigma).
If you want P(Y < threshold) for your application, just swap the order of the logits in your
format_raw_to_logits implementation.
Differentiable through both mean and variance — works with TAG.
"""
sigma_B = (log_var_B / 2).exp()
z_B = (threshold - mu_B) / sigma_B
log_p_above = torch.special.log_ndtr(-z_B)
log_p_below = torch.special.log_ndtr(z_B)
return torch.stack([log_p_below, log_p_above], dim=-1) # (B, 2)
|
pca_embed_init
pca_embed_init(pretrained_weights: Tensor, pretrained_vocab: dict[str, int], target_vocab: dict[str, int], n_components: int, target_vocab_size: Optional[int] = None) -> torch.Tensor
Project pretrained embeddings onto their top principal components, mapped to target vocab indices.
Finds shared tokens between the pretrained and target vocabularies,
computes PCA over the pretrained embeddings of those shared tokens,
and returns the projected embeddings arranged in the target vocabulary's
index order.
Parameters:
| Name |
Type |
Description |
Default |
pretrained_weights
|
Tensor
|
Embedding weight matrix, shape (V_pretrained, D_pretrained).
|
required
|
pretrained_vocab
|
dict[str, int]
|
Token string → index mapping for the pretrained model.
|
required
|
target_vocab
|
dict[str, int]
|
Token string → index mapping for the target model.
|
required
|
n_components
|
int
|
Number of principal components to keep.
|
required
|
target_vocab_size
|
Optional[int]
|
Total number of rows in the output tensor. Must be
large enough to contain all indices in target_vocab. When
None (default), inferred as max(target_vocab.values()) + 1.
Set this when the target embedding table has extra slots (e.g. a
padding index) that are not listed in target_vocab.
|
None
|
Returns:
| Type |
Description |
Tensor
|
Tensor of shape (target_vocab_size, n_components) with PCA-projected
|
Tensor
|
embeddings at the correct target indices. Rows for tokens not found in
|
Tensor
|
the pretrained vocabulary (or extra slots) are zero.
|
Raises:
| Type |
Description |
ValueError
|
If no shared tokens exist, n_components exceeds the
number of shared tokens or the pretrained embedding dimension,
or target_vocab_size is too small for the target vocab.
|
Source code in src/proteingen/modeling/predictive_modeling.py
| def pca_embed_init(
pretrained_weights: torch.Tensor,
pretrained_vocab: dict[str, int],
target_vocab: dict[str, int],
n_components: int,
target_vocab_size: Optional[int] = None,
) -> torch.Tensor:
"""Project pretrained embeddings onto their top principal components, mapped to target vocab indices.
Finds shared tokens between the pretrained and target vocabularies,
computes PCA over the pretrained embeddings of those shared tokens,
and returns the projected embeddings arranged in the target vocabulary's
index order.
Args:
pretrained_weights: Embedding weight matrix, shape (V_pretrained, D_pretrained).
pretrained_vocab: Token string → index mapping for the pretrained model.
target_vocab: Token string → index mapping for the target model.
n_components: Number of principal components to keep.
target_vocab_size: Total number of rows in the output tensor. Must be
large enough to contain all indices in ``target_vocab``. When
``None`` (default), inferred as ``max(target_vocab.values()) + 1``.
Set this when the target embedding table has extra slots (e.g. a
padding index) that are not listed in ``target_vocab``.
Returns:
Tensor of shape (target_vocab_size, n_components) with PCA-projected
embeddings at the correct target indices. Rows for tokens not found in
the pretrained vocabulary (or extra slots) are zero.
Raises:
ValueError: If no shared tokens exist, n_components exceeds the
number of shared tokens or the pretrained embedding dimension,
or target_vocab_size is too small for the target vocab.
"""
shared_tokens = sorted(set(pretrained_vocab.keys()) & set(target_vocab.keys()))
if len(shared_tokens) == 0:
raise ValueError("No shared tokens between pretrained and target vocabularies")
if n_components > len(shared_tokens):
raise ValueError(
f"n_components ({n_components}) exceeds number of shared tokens ({len(shared_tokens)})"
)
if n_components > pretrained_weights.shape[1]:
raise ValueError(
f"n_components ({n_components}) exceeds pretrained embedding dim ({pretrained_weights.shape[1]})"
)
min_target_size = max(target_vocab.values()) + 1
if target_vocab_size is None:
target_vocab_size = min_target_size
elif target_vocab_size < min_target_size:
raise ValueError(
f"target_vocab_size ({target_vocab_size}) is too small to contain "
f"all target vocab indices (max index = {min_target_size - 1})"
)
# Extract pretrained embeddings for shared tokens
pretrained_indices = [pretrained_vocab[t] for t in shared_tokens]
shared_embeddings = pretrained_weights[pretrained_indices].float() # (n_shared, D)
# PCA: center, SVD, project onto top-k components
mean = shared_embeddings.mean(dim=0)
centered = shared_embeddings - mean
_, _, Vt = torch.linalg.svd(centered, full_matrices=False)
projected = centered @ Vt[:n_components].T # (n_shared, n_components)
# Place projections at the correct target indices
output = torch.zeros(target_vocab_size, n_components)
for i, token in enumerate(shared_tokens):
output[target_vocab[token]] = projected[i]
return output
|