Skip to content

sampling

Sampling algorithms that generate sequences from GenerativeModel instances (including guided TAG/DEG models). Because TAG and DEG are GenerativeModel subclasses, all samplers work transparently with both guided and unguided models.

Samplers

sample

The main unified sampler. Unmasks positions n_parallel at a time, either in random order (default) or in an explicit order.

sample(model, x_SP, n_parallel=1, in_order=None) -> SamplingTrajectory
  • Accepts x_SP as token IDs or a list of strings (auto-tokenizes)
  • in_order: optional list[LongTensor] — one per sequence, giving positions to unmask in order. If None, a random permutation of masked positions is sampled.
  • Returns a SamplingTrajectory dict with sequences (strings), step_log_probs, step_positions, and step_tokens.
  • If the model is DEG, automatically passes position info via model.at_position() before computing log probs

Sharp edge: orders are padded to uniform length across sequences with position 0 (BOS/CLS). At padding steps, all sequences are sampled at their designated positions — including padding. If the model's logit formatter is correctly configured, special-token positions predict only themselves, making the padding a no-op. If logits are NOT properly formatted, the token at position 0 may be mutated.

sample_ctmc_linear_interpolation

Euler integration for flow-matching / linear interpolation:

sample_ctmc_linear_interpolation(model, x_SP, n_steps, return_string=True)

Each step: X_1 = ((steps_left - 1) / steps_left) * X_0 + (1 / steps_left) * exp(log_probs), then sample from the interpolated distribution.

sample_flow_matching_legacy

Legacy flow-matching sampler with dt and x1_temp parameters. Kept for reproducing results from the original stability guidance demo.

Helpers

  • tensor_to_string(x_SP, tokenizer) — batch decode, strips special tokens (<mask>, <cls>, <eos>)
  • build_legacy_predictor_log_prob — builds a legacy-compatible predictor log-prob callable

Gotchas

  • DEG + n_parallel > 1: not yet implemented — raises NotImplementedError.
  • In-place mutation: any_order_ancestral_step modifies x_SP in-place. sample clones first.
  • Device handling: sample and sample_ctmc_linear_interpolation move input to model.device.

API Reference

proteingen.sampling

Sampling utilities for generative models.

SamplingTrajectory

Bases: TypedDict

Per-step data recorded during ancestral sampling.

sequences: (S,) list of generated sequences (strings). step_log_probs: (S, n_total) — log p(sampled_token) at each sampling step. Padding entries (from order padding with position 0) will have the log-prob of re-sampling the existing BOS token (0.0 if formatted correctly). step_positions: (S, n_total) — which position was sampled at each step. Padding entries are 0 (the BOS position). step_tokens: (S, n_total) — which token was sampled at each step. step_p_y_gt_t: (S, n_total) — optional tracking of property probabilities.

Source code in src/proteingen/sampling/sampling.py
class SamplingTrajectory(TypedDict):
    """Per-step data recorded during ancestral sampling.

    sequences: (S,) list of generated sequences (strings).
    step_log_probs: (S, n_total) — log p(sampled_token) at each sampling step.
        Padding entries (from order padding with position 0) will have the
        log-prob of re-sampling the existing BOS token (0.0 if formatted correctly).
    step_positions: (S, n_total) — which position was sampled at each step.
        Padding entries are 0 (the BOS position).
    step_tokens: (S, n_total) — which token was sampled at each step.
    step_p_y_gt_t: (S, n_total) — optional tracking of property probabilities.
    """

    sequences: list[str]
    step_log_probs: torch.Tensor
    step_positions: torch.Tensor
    step_tokens: torch.Tensor
    step_p_y_gt_t: torch.Tensor | None

LiveSamplingPreview

Terminal renderer for in-place sequence previews below tqdm.

Interface: - update(x_SP): render current sampled sequences in-place - close(): flush output (no-op cleanup)

Intended usage inside sampling loops:

preview = LiveSamplingPreview(model.tokenizer, enabled=True)
for step in ...:
    ...  # update x_SP
    pbar.update(1)
    preview.update(x_SP)
preview.close()
Source code in src/proteingen/sampling/sampling.py
class LiveSamplingPreview:
    """Terminal renderer for in-place sequence previews below tqdm.

    Interface:
    - ``update(x_SP)``: render current sampled sequences in-place
    - ``close()``: flush output (no-op cleanup)

    Intended usage inside sampling loops:

    ```python
    preview = LiveSamplingPreview(model.tokenizer, enabled=True)
    for step in ...:
        ...  # update x_SP
        pbar.update(1)
        preview.update(x_SP)
    preview.close()
    ```
    """

    def __init__(
        self,
        tokenizer,
        *,
        enabled: bool,
        reserve_lines: int = 2,
        stream: Optional[TextIO] = None,
    ) -> None:
        self.tokenizer = tokenizer
        self._reserve_lines = reserve_lines
        self._stream = sys.stdout if stream is None else stream
        self._enabled = bool(
            enabled
            and hasattr(self._stream, "isatty")
            and self._stream.isatty()
        )
        self._line_count = 0

    @property
    def enabled(self) -> bool:
        return self._enabled

    def update(self, x_SP: torch.LongTensor) -> None:
        if not self._enabled:
            return

        if self._line_count > 0:
            self._stream.write(f"\x1b[{self._line_count}A")
            self._stream.flush()

        terminal_size = shutil.get_terminal_size(fallback=(80, 24))
        max_lines = max(terminal_size.lines - self._reserve_lines, 1)
        max_width = max(terminal_size.columns, 1)
        preview_lines = _build_live_preview_lines(
            x_SP=x_SP,
            tokenizer=self.tokenizer,
            max_lines=max_lines,
            max_width=max_width,
        )
        self._line_count = _render_live_preview(
            preview_lines,
            previous_line_count=self._line_count,
            stream=self._stream,
        )

    def close(self) -> None:
        if self._enabled:
            self._stream.flush()

    def __enter__(self) -> "LiveSamplingPreview":
        return self

    def __exit__(self, exc_type, exc, tb) -> None:
        self.close()

sample

sample(model: GenerativeModel, x_SP: LongTensor | List[str], n_parallel: int = 1, in_order: Optional[list[LongTensor] | str] = None, live_preview: bool = True, record_p_y_gt_t: bool = False) -> SamplingTrajectory

Ancestral sampling for masked generative models.

Unmasks positions n_parallel at a time, sampling from the model's predicted distribution at each step. If in_order is not provided, a random permutation of masked positions is generated for each sequence.

Sharp edge: unmask orders are padded to uniform length across sequences with position 0 (typically BOS/CLS). At padding steps, ALL sequences are still sampled at their designated positions — including the padding position. If the model's logit formatter is correctly configured, special-token positions predict only themselves, making the padding a no-op. If logits are NOT properly formatted, the token at position 0 may be mutated.

Parameters:

Name Type Description Default
model GenerativeModel

a GenerativeModel (or guided TAG/DEG model).

required
x_SP LongTensor | List[str]

(S, P) partially masked token IDs, or list of strings.

required
n_parallel int

number of positions to unmask per step. Default 1.

1
in_order Optional[list[LongTensor] | str]

controls the unmask order. Can be: - None (default): random permutation of masked positions per sequence. - "left_to_right": masked positions in ascending index order. - list[LongTensor]: one tensor per sequence giving explicit positions to unmask (first element = first revealed).

None
live_preview bool

if True, render in-place sequence updates below tqdm in a real terminal. Automatically disabled when stdout is not a TTY.

True

Returns:

Type Description
SamplingTrajectory

SamplingTrajectory with generated sequences and per-step data.

Source code in src/proteingen/sampling/sampling.py
@torch.no_grad()
def sample(
    model: GenerativeModel,
    x_SP: torch.LongTensor | List[str],
    n_parallel: int = 1,
    in_order: Optional[list[torch.LongTensor] | str] = None,
    live_preview: bool = True,
    record_p_y_gt_t: bool = False,
) -> SamplingTrajectory:
    """Ancestral sampling for masked generative models.

    Unmasks positions n_parallel at a time, sampling from the model's predicted
    distribution at each step. If ``in_order`` is not provided, a random
    permutation of masked positions is generated for each sequence.

    **Sharp edge**: unmask orders are padded to uniform length across sequences
    with position 0 (typically BOS/CLS). At padding steps, ALL sequences are
    still sampled at their designated positions — including the padding position.
    If the model's logit formatter is correctly configured, special-token
    positions predict only themselves, making the padding a no-op. If logits
    are NOT properly formatted, the token at position 0 may be mutated.

    Args:
        model: a GenerativeModel (or guided TAG/DEG model).
        x_SP: (S, P) partially masked token IDs, or list of strings.
        n_parallel: number of positions to unmask per step. Default 1.
        in_order: controls the unmask order. Can be:
            - None (default): random permutation of masked positions per sequence.
            - ``"left_to_right"``: masked positions in ascending index order.
            - ``list[LongTensor]``: one tensor per sequence giving explicit
              positions to unmask (first element = first revealed).
        live_preview: if True, render in-place sequence updates below tqdm in a
            real terminal. Automatically disabled when stdout is not a TTY.

    Returns:
        SamplingTrajectory with generated sequences and per-step data.
    """
    mask_token_id = model.tokenizer.mask_token_id

    if isinstance(x_SP, list):
        x_SP = model.tokenizer(x_SP, padding=True, return_tensors="pt")["input_ids"]
    x_SP = x_SP.clone().to(model.device)
    S, P = x_SP.shape

    # Build unmask orders
    if in_order is None:
        in_order = []
        for s in range(S):
            masked = (x_SP[s] == mask_token_id).nonzero().flatten()
            in_order.append(masked[torch.randperm(len(masked))])
    elif in_order == "left_to_right":
        in_order = []
        for s in range(S):
            in_order.append((x_SP[s] == mask_token_id).nonzero().flatten())

    assert len(in_order) == S, f"Expected {S} orders, got {len(in_order)}"

    # Pad orders to uniform length with position 0, then chunk by n_parallel
    max_positions = max((len(o) for o in in_order), default=0)
    n_steps = math.ceil(max_positions / n_parallel) if max_positions > 0 else 0
    padded_len = n_steps * n_parallel

    # order_flat: (S, padded_len) — padded with position 0
    order_flat = torch.zeros(S, padded_len, dtype=torch.long)
    for s, order in enumerate(in_order):
        order_flat[s, : len(order)] = order

    # (S, n_steps, n_parallel) — positions to unmask at each step
    if n_steps > 0:
        order_steps = order_flat.reshape(S, n_steps, n_parallel).to(model.device)
    else:
        order_steps = order_flat.reshape(S, 0, n_parallel).to(model.device)

    # Trajectory storage — flat (S, padded_len)
    step_log_probs = torch.full((S, padded_len), float("nan"))
    step_positions = order_flat.clone()
    step_tokens = torch.full((S, padded_len), -1, dtype=torch.long)
    step_p_y_gt_t = torch.full((S, padded_len), float("nan")) if record_p_y_gt_t else None

    preview = LiveSamplingPreview(model.tokenizer, enabled=live_preview)

    with preview, tqdm(total=n_steps) as pbar:
        for step in range(n_steps):
            positions = order_steps[:, step, :]  # (S, n_parallel)

            # DEG needs position info before computing log probs
            if hasattr(model, "at_position"):
                if n_parallel > 1:
                    raise NotImplementedError("DEG with n_parallel > 1 not implemented")
                positions_per_seq: List[Optional[int]] = [
                    pos[0].item() for pos in positions
                ]
                with model.at_position(positions_per_seq):
                    log_probs_SPT = model.get_log_probs(x_SP)
            else:
                log_probs_SPT = model.get_log_probs(x_SP)

            T = log_probs_SPT.size(-1)
            probs_SPT = torch.exp(log_probs_SPT)

            # Gather probs at selected positions
            pos_expanded = positions.unsqueeze(-1).expand(S, n_parallel, T)  # (S, n_parallel, T)
            probs_at_pos = probs_SPT.gather(1, pos_expanded)  # (S, n_parallel, T)

            # Sample tokens
            tokens = torch.multinomial(
                probs_at_pos.reshape(S * n_parallel, T), num_samples=1
            ).reshape(S, n_parallel)  # (S, n_parallel)

            # Update sequences
            x_SP.scatter_(1, positions, tokens)

            # Record trajectory
            flat_start = step * n_parallel
            flat_end = flat_start + n_parallel
            log_probs_at_pos = log_probs_SPT.gather(1, pos_expanded)  # (S, n_parallel, T)
            token_log_probs = log_probs_at_pos.gather(
                2, tokens.unsqueeze(-1)
            ).squeeze(-1)  # (S, n_parallel)
            step_log_probs[:, flat_start:flat_end] = token_log_probs.cpu()
            step_tokens[:, flat_start:flat_end] = tokens.cpu()

            if record_p_y_gt_t and hasattr(model, "pred_model"):
                pred_log_probs = model.pred_model.get_log_probs(x_SP)
                p_y_gt_t = torch.exp(pred_log_probs).cpu()
                if step_p_y_gt_t is not None:
                    step_p_y_gt_t[:, flat_start:flat_end] = p_y_gt_t.unsqueeze(-1).expand(S, n_parallel)

            pbar.update(1)

            preview.update(x_SP)

    assert (x_SP == mask_token_id).sum() == 0, "Some positions remain masked"
    sequences = tensor_to_string(x_SP, model.tokenizer)

    return SamplingTrajectory(
        sequences=sequences,
        step_log_probs=step_log_probs,
        step_positions=step_positions,
        step_tokens=step_tokens,
        step_p_y_gt_t=step_p_y_gt_t,
    )

generate_unmask_orders

generate_unmask_orders(seq_lengths: list[int], n_orders: int, special_positions: Optional[list[set[int]]] = None, seed: Optional[int] = None) -> list[list[torch.LongTensor]]

Generate random unmask orders for a set of sequences.

Returns orders[s][k] = 1-D LongTensor of maskable position indices for sequence s, order k. Positions are listed in the order they should be unmasked (first element = first position revealed).

Parameters:

Name Type Description Default
seq_lengths list[int]

length of each tokenized sequence (including special tokens).

required
n_orders int

how many random orders to generate per sequence.

required
special_positions Optional[list[set[int]]]

per-sequence set of position indices that are NOT maskable (BOS, EOS, PAD). If None, positions 0 and L-1 are treated as special (BOS/EOS convention for ESM-family tokenizers).

None
seed Optional[int]

optional RNG seed for reproducibility.

None
Source code in src/proteingen/sampling/sampling.py
def generate_unmask_orders(
    seq_lengths: list[int],
    n_orders: int,
    special_positions: Optional[list[set[int]]] = None,
    seed: Optional[int] = None,
) -> list[list[torch.LongTensor]]:
    """Generate random unmask orders for a set of sequences.

    Returns orders[s][k] = 1-D LongTensor of maskable position indices for
    sequence s, order k. Positions are listed in the order they should be
    unmasked (first element = first position revealed).

    Args:
        seq_lengths: length of each tokenized sequence (including special tokens).
        n_orders: how many random orders to generate per sequence.
        special_positions: per-sequence set of position indices that are NOT
            maskable (BOS, EOS, PAD). If None, positions 0 and L-1 are treated
            as special (BOS/EOS convention for ESM-family tokenizers).
        seed: optional RNG seed for reproducibility.
    """
    rng = torch.Generator()
    if seed is not None:
        rng.manual_seed(seed)

    orders: list[list[torch.LongTensor]] = []
    for s, L in enumerate(seq_lengths):
        if special_positions is not None:
            maskable = sorted(set(range(L)) - special_positions[s])
        else:
            # Default: skip first and last (BOS/EOS)
            maskable = list(range(1, L - 1))
        maskable_t = torch.LongTensor(maskable)
        seq_orders = []
        for _ in range(n_orders):
            perm = torch.randperm(len(maskable_t), generator=rng)
            seq_orders.append(maskable_t[perm])
        orders.append(seq_orders)
    return orders

mask_by_order

mask_by_order(token_ids: LongTensor, order: LongTensor, mask_fraction: float, mask_token_id: int) -> torch.LongTensor

Mask positions according to a decoding order.

Positions appearing LAST in the order (the tail) are masked. Specifically, the last ceil(mask_fraction * len(order)) positions in the order are set to mask_token_id.

Parameters:

Name Type Description Default
token_ids LongTensor

(P,) token IDs for a single sequence.

required
order LongTensor

(M,) position indices in unmask order (first = first revealed).

required
mask_fraction float

fraction of maskable positions to mask (0 to 1).

required
mask_token_id int

token ID to use for masking.

required

Returns:

Type Description
LongTensor

(P,) masked token IDs. The order of unmasking during generation should

LongTensor

be order[n_keep:] (i.e. the masked positions, in unmask order).

Source code in src/proteingen/sampling/sampling.py
def mask_by_order(
    token_ids: torch.LongTensor,
    order: torch.LongTensor,
    mask_fraction: float,
    mask_token_id: int,
) -> torch.LongTensor:
    """Mask positions according to a decoding order.

    Positions appearing LAST in the order (the tail) are masked. Specifically,
    the last ``ceil(mask_fraction * len(order))`` positions in the order are
    set to mask_token_id.

    Args:
        token_ids: (P,) token IDs for a single sequence.
        order: (M,) position indices in unmask order (first = first revealed).
        mask_fraction: fraction of maskable positions to mask (0 to 1).
        mask_token_id: token ID to use for masking.

    Returns:
        (P,) masked token IDs. The order of unmasking during generation should
        be order[n_keep:] (i.e. the masked positions, in unmask order).
    """
    n_maskable = len(order)
    n_to_mask = math.ceil(mask_fraction * n_maskable)
    n_keep = n_maskable - n_to_mask

    masked = token_ids.clone()
    positions_to_mask = order[n_keep:]
    masked[positions_to_mask] = mask_token_id
    return masked

sample_flow_matching_legacy

sample_flow_matching_legacy(model: GenerativeModel, x_SP: LongTensor | List[str], dt: float = 0.01, predictor_log_prob=None, guide_temp: float = 1.0, use_tag: bool = False, x1_temp: float = 1.0, stochasticity: float = 0.0, argmax_final: bool = True, logits_postprocess: Optional[Callable[[FloatTensor, LongTensor], FloatTensor]] = None, return_string: bool = True, live_preview: bool = True) -> torch.LongTensor | list[str]

Legacy flow-matching Euler sampler (old stability demo numerics).

This reproduces the original rate-matrix integration loop and guidance-ratio update so DFM models can be compared head-to-head against old behavior.

live_preview shows in-place sequence updates below tqdm in real terminals.

Source code in src/proteingen/sampling/sampling.py
@torch.no_grad()
def sample_flow_matching_legacy(
    model: GenerativeModel,
    x_SP: torch.LongTensor | List[str],
    dt: float = 0.01,
    predictor_log_prob=None,
    guide_temp: float = 1.0,
    use_tag: bool = False,
    x1_temp: float = 1.0,
    stochasticity: float = 0.0,
    argmax_final: bool = True,
    logits_postprocess: Optional[
        Callable[[torch.FloatTensor, torch.LongTensor], torch.FloatTensor]
    ] = None,
    return_string: bool = True,
    live_preview: bool = True,
) -> torch.LongTensor | list[str]:
    """Legacy flow-matching Euler sampler (old stability demo numerics).

    This reproduces the original rate-matrix integration loop and guidance-ratio
    update so DFM models can be compared head-to-head against old behavior.

    live_preview shows in-place sequence updates below tqdm in real terminals.
    """
    if isinstance(x_SP, list):
        x_SP = model.tokenizer(x_SP, padding=True, return_tensors="pt")["input_ids"]
    x_device = x_SP.device
    xt = x_SP.to(model.device)

    S = model.tokenizer.vocab_size
    mask_idx = model.tokenizer.mask_token_id
    if mask_idx is None:
        raise ValueError("sample_flow_matching_legacy requires tokenizer.mask_token_id")

    t = 0.0
    n_steps = int(1.0 / dt)
    mask_one_hot = torch.zeros((S,), device=xt.device)
    mask_one_hot[mask_idx] = 1.0

    preview = LiveSamplingPreview(model.tokenizer, enabled=live_preview)

    with preview, tqdm(total=n_steps) as pbar:
        for _ in range(n_steps):
            logits = _formatted_logits(model, xt)[..., :S]
            if logits_postprocess is not None:
                logits = logits_postprocess(logits, xt)
            pt_x1_probs = F.softmax(logits / x1_temp, dim=-1)

            xt_is_mask = (xt == mask_idx).view(*xt.shape, 1).float()
            R_t = xt_is_mask * pt_x1_probs * ((1 + stochasticity * t) / (1 - t))
            remask_rates = (1 - xt_is_mask) * mask_one_hot.view(1, 1, -1) * stochasticity
            R_t += remask_rates

            if predictor_log_prob is not None:
                R_t = _legacy_get_guided_rates(
                    predictor_log_prob,
                    xt,
                    t,
                    R_t,
                    S,
                    use_tag=use_tag,
                    guide_temp=guide_temp,
                )

            R_t.scatter_(-1, xt[:, :, None], 0.0)
            R_t.scatter_(-1, xt[:, :, None], (-R_t.sum(dim=-1, keepdim=True)))

            step_probs = (R_t * dt).clamp(min=0.0, max=1.0)
            step_probs.scatter_(-1, xt[:, :, None], 0.0)
            step_probs.scatter_(
                -1,
                xt[:, :, None],
                (1.0 - torch.sum(step_probs, dim=-1, keepdim=True)).clamp(min=0.0),
            )
            step_probs = torch.clamp(step_probs, min=0.0, max=1.0)

            xt = torch.distributions.Categorical(step_probs).sample()

            pbar.update(1)
            preview.update(xt)

            t += dt
            if t > 1.0:
                break

        if argmax_final:
            xt_is_mask = (xt == mask_idx).view(*xt.shape).float()
            logits = _formatted_logits(model, xt)[..., :S]
            if logits_postprocess is not None:
                logits = logits_postprocess(logits, xt)
            xt = (torch.argmax(logits, dim=-1) * xt_is_mask + xt * (1 - xt_is_mask)).long()
            preview.update(xt)

    if return_string:
        return tensor_to_string(xt, model.tokenizer)
    return xt.to(x_device)

build_legacy_predictor_log_prob

build_legacy_predictor_log_prob(tag_model)

Build old-demo style predictor_log_prob closure from DFM TAG components.

Returns a closure compatible with the original flow-matching guidance loop: - integer token input: (B, P) in generator token space - one-hot input: (B, P, T_gen) for TAG Taylor guidance

Source code in src/proteingen/sampling/sampling.py
def build_legacy_predictor_log_prob(tag_model):
    """Build old-demo style predictor_log_prob closure from DFM TAG components.

    Returns a closure compatible with the original flow-matching guidance loop:
    - integer token input: ``(B, P)`` in generator token space
    - one-hot input: ``(B, P, T_gen)`` for TAG Taylor guidance
    """
    from ..modeling.guide import LinearGuidanceProjection

    projection = tag_model.projection
    if not isinstance(projection, LinearGuidanceProjection):
        raise ValueError(
            "Legacy predictor_log_prob currently supports LinearGuidanceProjection only"
        )
    pred_model = tag_model.pred_model
    gen_vocab_size = projection.tokenizer_gen.vocab_size
    start = projection._strip_prefix
    suffix = projection._strip_suffix

    def _pred_window(x):
        end = x.shape[1] - suffix
        if end < start:
            raise ValueError(
                f"Invalid strip window: start={start}, end={end}, length={x.shape[1]}"
            )
        return x[:, start:end]

    def predictor_log_prob(xt, t, **kwargs):
        if xt.is_floating_point():
            xt_inner = _pred_window(xt)
            xt_inner = xt_inner[..., :gen_vocab_size]
            pred_ohe = xt_inner @ projection.gen_to_pred_ohe_TK.to(xt.device)
            return _predictive_log_prob_from_ohe(pred_model, pred_ohe)

        xt = xt.long()
        xt_inner = _pred_window(xt)
        pred_tokens = projection.gen_to_pred_idx_T[xt_inner]
        return pred_model.get_log_probs(pred_tokens)

    return predictor_log_prob