Skip to content

Training Predictors

Train oracle and noisy predictive models from assay-labeled data for use in guided generation.

Two models, two roles

Guided generation with ProteinGuide requires two predictive models:

Model Training data Input type Role
Oracle All available data (all rounds) Clean sequences Evaluation only — scores final generated sequences
Noisy predictor Current round data Randomly masked sequences Used during sampling — must handle partial sequences

The oracle answers "how good is this fully-designed sequence?". The noisy predictor answers "given what I can see so far, is this sequence likely to be good?" — which is what the sampler needs at each unmasking step.

Training the oracle

The oracle is trained on clean (fully unmasked) sequences with standard supervised learning. It should be as accurate as possible — use all available data including later experimental rounds.

oracle = MyPredictor(tokenizer=gen_model.tokenizer, ...)
# Standard training loop: MSE or cross-entropy on clean sequences
for batch in oracle_loader:
    pred = oracle.forward(batch["ohe"])
    loss = F.mse_loss(pred, batch["labels"])
    loss.backward()
    optimizer.step()

Architecture choice: Use whatever works best on your data. OneHotMLP is a strong default for small datasets (< 5k sequences). LinearProbe on frozen ESMC embeddings works well for larger datasets.

Training the noisy predictor

The noisy predictor is trained identically to the oracle except that input sequences are randomly masked at each training step. This makes it robust to the partially-masked sequences it sees during guided generation.

noisy_predictor = MyPredictor(tokenizer=gen_model.tokenizer, ...)
for batch in train_loader:
    tokens = batch["tokens"].clone()

    # Random masking: mask fraction ~ Uniform(0, 1) each step
    t = torch.rand(1).item()
    mask = torch.rand(tokens.shape) < t
    mask[:, 0] = False   # preserve BOS
    mask[:, -1] = False  # preserve EOS
    tokens[mask] = mask_token_id

    ohe = F.one_hot(tokens, vocab_size).float()
    pred = noisy_predictor.forward(ohe)
    loss = F.mse_loss(pred, batch["labels"])
    loss.backward()
    optimizer.step()

Key detail: The masking distribution during training should match the sampling schedule. If you use uniform_mask_noise for generation, train with uniform t. Validate on clean sequences (no masking) — this is the regime at the end of generation where the predictor's accuracy matters most.

Validating predictor–oracle agreement

Before using the noisy predictor for guidance, check that it agrees with the oracle on clean sequences:

from scipy.stats import spearmanr

oracle_scores = oracle.predict(val_sequences)
predictor_scores = noisy_predictor.predict(val_sequences)
rho, _ = spearmanr(oracle_scores, predictor_scores)
print(f"Agreement: ρ = {rho:.3f}")

If agreement is low (ρ < 0.5): The predictor can't be trusted during generation. Consider:

  • More training data
  • Simpler architecture (less overfitting)
  • Different masking schedule
  • Collecting more experimental data before attempting guided generation

See the PbrR walkthrough for a complete implementation with oracle–predictor agreement plots.

Choosing format_raw_to_logits

The binary logit function determines how raw predictions become guidance signals:

Function Best for TAG-compatible?
point_estimate_binary_logits Simple thresholding Only with small k (5–10)
gaussian_binary_logits Uncertainty-aware predictions Yes — differentiable through mean and variance
binary_logits Direct classification Yes
categorical_binary_logits Multi-class (one vs rest) Yes

If using TAG guidance, prefer gaussian_binary_logits — it provides smooth gradients. For DEG guidance, point_estimate_binary_logits works fine regardless of k since DEG only needs rankings.