Skip to content

guide

This module implements TAG (Taylor-Approximate Guidance) and DEG (Discrete Enumeration Guidance) — two algorithms for combining a generative model with a predictive model via Bayes' rule. It also contains the GuidanceProjection abstraction for handling cross-tokenizer mapping.

The guidance equation

Both algorithms implement the same principle:

$$ p_\text{guided}(x_t | x_{<t}) \propto p_\text{gen}(x_t | x_{<t}) \cdot p_\text{pred}(\text{target} | x)^\gamma $$

Since TAG and DEG are GenerativeModel subclasses, they produce guided log-probs that plug directly into any sampler — no special handling needed.

TAG (Taylor-Approximate Guidance)

Uses first-order Taylor expansion of the predictive model's log-prob to compute guidance deltas efficiently.

TAG(gen_model, pred_model, use_clean_classifier=False, projection=None)

Forward pass

logp_gen = gen_model.get_log_probs(seq_SP)           # generative log probs
prepared = projection.prepare(seq_SP, logp_gen, ...)  # build predictor input
grad = pred_model.grad_log_prob(prepared.seq_pred_SP) # predictor gradient (∂log p / ∂OHE)
delta = projection.grad_to_gen_delta(grad, prepared)  # project gradient to gen logit space
return logp_gen + delta / pred_model.temp             # Bayes' rule combination

Temperature as guidance strength

TAG does not have a separate guidance_scale parameter. Instead:

  • Predictor temperature controls guidance strength: lower temp → steeper log_softmax → larger gradient magnitude → stronger guidance
  • Generator temperature controls prior flatness: higher temp → flatter prior → more room for guidance to steer
  • Internally, TAG computes the gradient at temp=1, then divides by the predictor's temperature as a linear multiplier

Temperature tuning in practice

ESMC's prior at well-determined positions (e.g. conserved glycine with log prob ≈ 0.0) is nearly impossible to override at temp=1. Raising gen temp to 2–3 flattens the prior, giving guidance room to steer. Combined with predictor temp equivalent to guidance scale 10–20, this produces significant improvements.

TrpB benchmark: Unguided mean fitness = 0.48 → DEG (scale=20, temp=3) mean = 0.62, fraction above 0.7 from 0.5% to 32.5% (N=200, 10 runs).

When to use TAG

TAG is fast (one backward pass per sampling step) but requires reliable gradients through the predictive model. It works well for:

  • Small predictive models (OneHotMLP, EmbeddingMLP)
  • LoRA-adapted backbones where gradients flow through adapted layers
  • Gaussian binary logits (differentiable through mean and variance)

Gradient vanishes on <mask> tokens

Predictors trained on real AA embeddings produce vanishing gradients (~10⁶× attenuation) when evaluated on <mask> inputs through a frozen transformer. Fix: set use_clean_classifier=True to fill mask positions with the generative model's argmax before computing predictor gradients.

DEG (Discrete Enumeration Guidance)

Evaluates the predictor at all vocabulary tokens for a given position and reweights.

DEG(gen_model, pred_model, argmax_masked_positions=False)

Position management

DEG requires position info via a context manager:

with deg.at_position(positions_to_score_S):
    log_probs = deg.get_log_probs(seq_SP)

positions_to_score_S is a list of length B — one position index per sequence (or None to skip). The sample function handles this automatically.

When to use DEG

DEG is more robust than TAG when gradients are unreliable. It only needs correct rankings from the predictor, not accurate gradient magnitudes:

  • Frozen-LM probes — TAG gradients through 30-layer frozen transformers are unreliable. DEG gives better guidance.
  • Point estimate predictors with steep sigmoids — large k in point_estimate_binary_logits saturates gradients. DEG sidesteps this.

The tradeoff: DEG requires vocab_size forward passes per position per step (vs. one backward pass for TAG).

Limitation

DEG with n_parallel > 1 (unmasking multiple positions simultaneously) is not yet implemented.

GuidanceProjection

When the generative and predictive models use different tokenizers (e.g. ESM with 33 tokens vs. PMPNN with 21 tokens), GuidanceProjection handles the mapping.

LinearGuidanceProjection

The default projection. Builds a fixed linear map between token spaces:

M[T_gen, K_pred]  — each gen token's row is its predictor OHE representation
delta(t) = g · (M[t] - M[baseline])   — first-order Taylor delta per gen token

Where g is the predictor gradient, M[t] is the predictor OHE for gen token t, and baseline is the current/argmax token.

LinearGuidanceProjection(
    tokenizer_gen, tokenizer_pred,
    pred_token_ohe_basis_TK,       # from pred_model.token_ohe_basis()
    fallback_pred_token_id=None,    # for unmapped gen tokens
    strip_prefix=None,              # auto-detect CLS stripping
    strip_suffix=None,              # auto-detect EOS stripping
)

Key details:

  • Token matching by string key between gen and pred vocabularies
  • Unmapped tokens fall back to fallback_pred_token_id (defaults to pred's mask or unk token)
  • CLS/EOS stripping is auto-detected when gen has CLS/EOS but pred doesn't (e.g. ESM → MPNN)

If projection is not provided to TAG, it auto-creates a LinearGuidanceProjection from the two models' tokenizers.


API Reference

proteingen.modeling.guide

PreparedGuidanceInput dataclass

Inputs needed to convert predictor-space gradients into gen-logit deltas.

Source code in src/proteingen/modeling/guide.py
@dataclass
class PreparedGuidanceInput:
    """Inputs needed to convert predictor-space gradients into gen-logit deltas."""

    seq_pred_SP: torch.LongTensor
    pred_pos_to_gen_pos_P: torch.LongTensor
    baseline_gen_SP: torch.LongTensor
    gen_length: int

GuidanceProjection

Bases: Module, ABC

Maps between predictor OHE-gradient space and generative model logit space.

Source code in src/proteingen/modeling/guide.py
class GuidanceProjection(nn.Module, ABC):
    """Maps between predictor OHE-gradient space and generative model logit space."""

    @abstractmethod
    def prepare(
        self,
        seq_gen_SP: torch.LongTensor,
        logp_gen_SPT: torch.FloatTensor,
        *,
        use_clean_classifier: bool,
        n_samples: int = 1,
    ) -> PreparedGuidanceInput:
        """Build predictor input tokens and Taylor baseline in gen token space."""
        ...

    @abstractmethod
    def grad_to_gen_delta(
        self,
        grad_pred_SPK: torch.FloatTensor,
        prepared: PreparedGuidanceInput,
        *,
        gen_output_dim: int,
    ) -> torch.FloatTensor:
        """Project predictor-space gradients into generator logit space."""
        ...
prepare abstractmethod
prepare(seq_gen_SP: LongTensor, logp_gen_SPT: FloatTensor, *, use_clean_classifier: bool, n_samples: int = 1) -> PreparedGuidanceInput

Build predictor input tokens and Taylor baseline in gen token space.

Source code in src/proteingen/modeling/guide.py
@abstractmethod
def prepare(
    self,
    seq_gen_SP: torch.LongTensor,
    logp_gen_SPT: torch.FloatTensor,
    *,
    use_clean_classifier: bool,
    n_samples: int = 1,
) -> PreparedGuidanceInput:
    """Build predictor input tokens and Taylor baseline in gen token space."""
    ...
grad_to_gen_delta abstractmethod
grad_to_gen_delta(grad_pred_SPK: FloatTensor, prepared: PreparedGuidanceInput, *, gen_output_dim: int) -> torch.FloatTensor

Project predictor-space gradients into generator logit space.

Source code in src/proteingen/modeling/guide.py
@abstractmethod
def grad_to_gen_delta(
    self,
    grad_pred_SPK: torch.FloatTensor,
    prepared: PreparedGuidanceInput,
    *,
    gen_output_dim: int,
) -> torch.FloatTensor:
    """Project predictor-space gradients into generator logit space."""
    ...

LinearGuidanceProjection

Bases: GuidanceProjection

Linear token-space projection for TAG.

Uses a fixed map M[T_gen, K_pred] where each gen token's row is the predictor-OHE representation of that token. Given predictor gradient g, TAG's first-order term at each position is:

``delta(t) = g · (M[t] - M[baseline])``

where baseline is the current token (or argmax-filled token when using clean-classifier mode).

Source code in src/proteingen/modeling/guide.py
class LinearGuidanceProjection(GuidanceProjection):
    """Linear token-space projection for TAG.

    Uses a fixed map ``M[T_gen, K_pred]`` where each gen token's row is the
    predictor-OHE representation of that token. Given predictor gradient ``g``,
    TAG's first-order term at each position is:

        ``delta(t) = g · (M[t] - M[baseline])``

    where ``baseline`` is the current token (or argmax-filled token when using
    clean-classifier mode).
    """

    def __init__(
        self,
        tokenizer_gen: PreTrainedTokenizerBase,
        tokenizer_pred: PreTrainedTokenizerBase,
        pred_token_ohe_basis_TK: torch.FloatTensor,
        fallback_pred_token_id: Optional[int] = None,
        strip_prefix: Optional[int] = None,
        strip_suffix: Optional[int] = None,
    ):
        super().__init__()
        self.tokenizer_gen = tokenizer_gen
        self.tokenizer_pred = tokenizer_pred

        if pred_token_ohe_basis_TK.ndim != 2:
            raise ValueError(
                "pred_token_ohe_basis_TK must have shape (pred_vocab_size, pred_ohe_dim)"
            )
        if pred_token_ohe_basis_TK.shape[0] != tokenizer_pred.vocab_size:
            raise ValueError(
                "pred_token_ohe_basis_TK first dimension must match predictor tokenizer vocab_size"
            )

        fallback_id = fallback_pred_token_id
        if fallback_id is None:
            fallback_id = getattr(tokenizer_pred, "mask_token_id", None)
        if fallback_id is None:
            fallback_id = getattr(tokenizer_pred, "unk_token_id", None)

        src_vocab = tokenizer_gen.vocab
        tgt_vocab = tokenizer_pred.vocab
        gen_to_pred_idx_T = torch.full(
            (tokenizer_gen.vocab_size,),
            fill_value=-1,
            dtype=torch.long,
        )
        for tok, i_gen in src_vocab.items():
            if tok in tgt_vocab:
                gen_to_pred_idx_T[i_gen] = tgt_vocab[tok]

        unmapped_mask = gen_to_pred_idx_T < 0
        if unmapped_mask.any():
            if fallback_id is None:
                idx_to_tok = {idx: tok for tok, idx in src_vocab.items()}
                missing = [
                    idx_to_tok.get(i, f"<idx:{i}>")
                    for i in torch.where(unmapped_mask)[0].tolist()[:10]
                ]
                raise ValueError(
                    "No fallback predictor token available for unmapped generator tokens. "
                    f"Example unmapped tokens: {missing}"
                )
            gen_to_pred_idx_T[unmapped_mask] = int(fallback_id)

        self.register_buffer("gen_to_pred_idx_T", gen_to_pred_idx_T)
        self.register_buffer("pred_token_ohe_basis_TK", pred_token_ohe_basis_TK.float())
        self.register_buffer(
            "gen_to_pred_ohe_TK",
            self.pred_token_ohe_basis_TK[self.gen_to_pred_idx_T],
        )

        if strip_prefix is None:
            src_has_cls = getattr(tokenizer_gen, "cls_token_id", None) is not None
            tgt_has_cls = getattr(tokenizer_pred, "cls_token_id", None) is not None
            strip_prefix = 1 if (src_has_cls and not tgt_has_cls) else 0
        if strip_suffix is None:
            src_has_eos = getattr(tokenizer_gen, "eos_token_id", None) is not None
            tgt_has_eos = getattr(tokenizer_pred, "eos_token_id", None) is not None
            strip_suffix = 1 if (src_has_eos and not tgt_has_eos) else 0

        self._strip_prefix = int(strip_prefix)
        self._strip_suffix = int(strip_suffix)

    def _pred_window(self, seq_gen_SP: torch.LongTensor) -> tuple[int, int]:
        start = self._strip_prefix
        end = seq_gen_SP.size(1) - self._strip_suffix
        if end < start:
            raise ValueError(
                f"Invalid strip configuration: start={start}, end={end}, sequence_length={seq_gen_SP.size(1)}"
            )
        return start, end

    def prepare(
        self,
        seq_gen_SP: torch.LongTensor,
        logp_gen_SPT: torch.FloatTensor,
        *,
        use_clean_classifier: bool,
        n_samples: int = 1,
    ) -> PreparedGuidanceInput:
        seq_for_grad_SP = seq_gen_SP
        if use_clean_classifier:
            if n_samples > 1:
                raise NotImplementedError("n_samples > 1 not implemented for TAG yet")
            seq_for_grad_SP = _fill_masked_with_argmax(
                seq_gen_SP,
                logp_gen_SPT,
                getattr(self.tokenizer_gen, "mask_token_id", None),
                self.tokenizer_gen.vocab_size,
            )

        start, end = self._pred_window(seq_gen_SP)
        seq_for_pred_SP = seq_for_grad_SP[:, start:end]
        seq_pred_SP = self.gen_to_pred_idx_T[seq_for_pred_SP]
        pred_pos_to_gen_pos_P = torch.arange(
            start, end, device=seq_gen_SP.device, dtype=torch.long
        )
        return PreparedGuidanceInput(
            seq_pred_SP=seq_pred_SP,
            pred_pos_to_gen_pos_P=pred_pos_to_gen_pos_P,
            baseline_gen_SP=seq_for_pred_SP,
            gen_length=seq_gen_SP.size(1),
        )

    def grad_to_gen_delta(
        self,
        grad_pred_SPK: torch.FloatTensor,
        prepared: PreparedGuidanceInput,
        *,
        gen_output_dim: int,
    ) -> torch.FloatTensor:
        pred_ohe_dim = self.gen_to_pred_ohe_TK.shape[1]
        if grad_pred_SPK.shape[-1] != pred_ohe_dim:
            raise ValueError(
                f"Predictor grad dim ({grad_pred_SPK.shape[-1]}) does not match projection pred_ohe_dim ({pred_ohe_dim})"
            )

        gen_vocab_size = self.tokenizer_gen.vocab_size
        if gen_output_dim < gen_vocab_size:
            raise ValueError(
                f"gen_output_dim ({gen_output_dim}) must be >= generator vocab_size ({gen_vocab_size})"
            )

        # score_gen[s, p, t] = g[s, p, :] · M[t, :]
        score_gen_vocab_SPT = torch.einsum(
            "spk,tk->spt", grad_pred_SPK, self.gen_to_pred_ohe_TK
        )

        # TAG Taylor delta: g · (M_t - M_baseline)
        baseline_score_SP = score_gen_vocab_SPT.gather(
            dim=-1, index=prepared.baseline_gen_SP.unsqueeze(-1)
        ).squeeze(-1)
        delta_vocab_SPT = score_gen_vocab_SPT - baseline_score_SP.unsqueeze(-1)

        S = grad_pred_SPK.size(0)
        delta_gen_SPT = grad_pred_SPK.new_zeros(
            (S, prepared.gen_length, gen_output_dim)
        )
        delta_gen_SPT[:, prepared.pred_pos_to_gen_pos_P, :gen_vocab_size] = (
            delta_vocab_SPT
        )
        return delta_gen_SPT

TAG

Bases: GenerativeModel

Token-level Autoregressive Guidance.

Combines a generative model with a predictive model via Bayes' rule. Uses gradients through the predictive model's OHE to shift transition logits.

Guidance projection is handled by a GuidanceProjection object, keeping TAG's core update rule focused on Bayes composition in gen-logit space.

Source code in src/proteingen/modeling/guide.py
class TAG(GenerativeModel):
    """Token-level Autoregressive Guidance.

    Combines a generative model with a predictive model via Bayes' rule.
    Uses gradients through the predictive model's OHE to shift transition logits.

    Guidance projection is handled by a ``GuidanceProjection`` object, keeping
    TAG's core update rule focused on Bayes composition in gen-logit space.
    """

    def __init__(
        self,
        gen_model: GenerativeModel,
        pred_model: PredictiveModel,
        use_clean_classifier: bool = False,
        projection: Optional[GuidanceProjection] = None,
        n_fill_samples: int = 1,
    ):
        super().__init__(
            model=gen_model.model,
            tokenizer=gen_model.tokenizer,
            logit_formatter=gen_model.logit_formatter,
        )
        self.gen_model = gen_model
        self.pred_model = pred_model
        self.argmax_masked_positions = use_clean_classifier
        self.n_fill_samples = n_fill_samples
        if projection is None:
            projection = LinearGuidanceProjection(
                tokenizer_gen=gen_model.tokenizer,
                tokenizer_pred=pred_model.tokenizer,
                pred_token_ohe_basis_TK=pred_model.token_ohe_basis().detach(),
            )
        self.projection = projection.to(self.gen_model.device)

    def forward(self, seq_SP: torch.LongTensor):
        if self.gen_model.device != self.pred_model.device:
            raise ValueError(
                "TAG requires gen_model and pred_model to be on the same device"
            )
        logp_xtilde_g_x_SPT = self.gen_model.get_log_probs(seq_SP)
        prepared = self.projection.prepare(
            seq_SP,
            logp_xtilde_g_x_SPT,
            use_clean_classifier=self.argmax_masked_positions,
            n_samples=self.n_fill_samples,
        )
        # Compute gradient at temp=1 for natural gradient shape (no sigmoid
        # saturation), then use the predictor's temperature purely as a linear
        # guidance strength multiplier on the Taylor delta.
        guidance_temp = self.pred_model.temp
        self.pred_model.set_temp_(1.0)
        grad_pred_SPK = self.pred_model.grad_log_prob(prepared.seq_pred_SP)
        self.pred_model.set_temp_(guidance_temp)
        delta_gen_SPT = self.projection.grad_to_gen_delta(
            grad_pred_SPK,
            prepared,
            gen_output_dim=logp_xtilde_g_x_SPT.shape[-1],
        )
        return logp_xtilde_g_x_SPT + delta_gen_SPT / guidance_temp

DEG

Bases: GenerativeModel

Source code in src/proteingen/modeling/guide.py
class DEG(GenerativeModel):
    # TAG and DEG are basically ways to efficiently compute the vector p(y|x_const)
    # Therefore, we don't make the predictive models deal with that kind of query
    def __init__(
        self,
        gen_model: GenerativeModel,
        pred_model: PredictiveModel,
        argmax_masked_positions: bool = False,
        n_fill_samples: int = 1,
    ):
        super().__init__(
            model=gen_model.model,
            tokenizer=gen_model.tokenizer,
            logit_formatter=gen_model.logit_formatter,
        )
        # main stipulation here is that the predictive model has to take OHE as input
        self.gen_model = gen_model
        self.pred_model = pred_model
        self.argmax_masked_positions = argmax_masked_positions
        self.n_fill_samples = n_fill_samples
        self.positions_to_score_S = None

    @contextmanager
    def at_position(self, positions_to_score_S: List[int]):
        """
        positions_to_score_S is a list that gives the index at each sequence to try to sample
        If a sequence does not need to be sampled, pass None for that index in the list
        """
        old = self.positions_to_score_S
        self.positions_to_score_S = positions_to_score_S
        try:
            yield self
        finally:
            self.positions_to_score_S = old

    def forward(self, seq_SP: torch.LongTensor):
        if self.positions_to_score_S is None:
            raise ValueError(
                "Need to call ``self.at_position(positions_to_score_S)`` to provide the position to score for each sequence"
            )
        logp_xtilde_g_x_SPT = self.gen_model.get_log_probs(seq_SP)
        logp_y_g_xtilde_SPT = torch.zeros_like(logp_xtilde_g_x_SPT)
        n_tok = self.tokenizer.vocab_size
        mask_token_id = self.gen_model.tokenizer.mask_token_id
        for s, p in enumerate(self.positions_to_score_S):
            if p is None:
                continue
            base = seq_SP[s].clone()
            if self.argmax_masked_positions:
                base_M = _fill_masked_with_argmax(
                    base.unsqueeze(0),
                    logp_xtilde_g_x_SPT[s].unsqueeze(0),
                    mask_token_id,
                    n_tok,
                    n_samples=self.n_fill_samples,
                ) # [n_samples, P]
            else:
                base_M = base.unsqueeze(0)

            n_samples = base_M.shape[0]
            # Create [n_samples, n_tok, P] by expanding base_M
            seq_MXP = base_M.unsqueeze(1).repeat(1, n_tok, 1)
            # Set the current position p to all possible tokens
            seq_MXP[:, :, p] = torch.arange(n_tok, device=seq_SP.device).unsqueeze(0).expand(n_samples, -1)
            # Flatten to [n_samples * n_tok, P]
            seq_flat = seq_MXP.view(-1, seq_SP.size(1))

            with torch.no_grad():
                logp_y_g_xtilde_flat = self.pred_model.get_log_probs(seq_flat)
            # Reshape back to [n_samples, n_tok]
            logp_y_g_xtilde_MX = logp_y_g_xtilde_flat.view(n_samples, n_tok)
            # Log-mean-exp over the samples: log( (1/M) sum_m exp(logp_m) )
            logp_y_g_xtilde_X = torch.logsumexp(logp_y_g_xtilde_MX, dim=0) - torch.log(torch.tensor(n_samples, dtype=torch.float32, device=seq_SP.device))

            logp_y_g_xtilde_SPT[s, p, :n_tok] = logp_y_g_xtilde_X
            # Don't need to take care of making the others -inf since the logit_formatter will take care of the invalid ones (including the invalid ones we tested lol)
        return logp_y_g_xtilde_SPT + logp_xtilde_g_x_SPT
at_position
at_position(positions_to_score_S: List[int])

positions_to_score_S is a list that gives the index at each sequence to try to sample If a sequence does not need to be sampled, pass None for that index in the list

Source code in src/proteingen/modeling/guide.py
@contextmanager
def at_position(self, positions_to_score_S: List[int]):
    """
    positions_to_score_S is a list that gives the index at each sequence to try to sample
    If a sequence does not need to be sampled, pass None for that index in the list
    """
    old = self.positions_to_score_S
    self.positions_to_score_S = positions_to_score_S
    try:
        yield self
    finally:
        self.positions_to_score_S = old