sampling¶
Sampling algorithms that generate sequences from GenerativeModel instances (including guided TAG/DEG models). Because TAG and DEG are GenerativeModel subclasses, all samplers work transparently with both guided and unguided models.
Samplers¶
sample¶
The main unified sampler. Unmasks positions n_parallel at a time, either in random order (default) or in an explicit order.
- Accepts
x_SPas token IDs or a list of strings (auto-tokenizes) in_order: optionallist[LongTensor]— one per sequence, giving positions to unmask in order. If None, a random permutation of masked positions is sampled.- Returns a
SamplingTrajectorydict withsequences(strings),step_log_probs,step_positions, andstep_tokens. - If the model is DEG, automatically passes position info via
model.at_position()before computing log probs
Sharp edge: orders are padded to uniform length across sequences with position 0 (BOS/CLS). At padding steps, all sequences are sampled at their designated positions — including padding. If the model's logit formatter is correctly configured, special-token positions predict only themselves, making the padding a no-op. If logits are NOT properly formatted, the token at position 0 may be mutated.
sample_ctmc_linear_interpolation¶
Euler integration for flow-matching / linear interpolation:
Each step: X_1 = ((steps_left - 1) / steps_left) * X_0 + (1 / steps_left) * exp(log_probs), then sample from the interpolated distribution.
sample_flow_matching_legacy¶
Legacy flow-matching sampler with dt and x1_temp parameters. Kept for reproducing results from the original stability guidance demo.
Helpers¶
tensor_to_string(x_SP, tokenizer)— batch decode, strips special tokens (<mask>,<cls>,<eos>)build_legacy_predictor_log_prob— builds a legacy-compatible predictor log-prob callable
Gotchas¶
- DEG +
n_parallel > 1: not yet implemented — raisesNotImplementedError. - In-place mutation:
any_order_ancestral_stepmodifiesx_SPin-place.sampleclones first. - Device handling:
sampleandsample_ctmc_linear_interpolationmove input tomodel.device.
API Reference¶
proteingen.sampling
¶
Sampling utilities for generative models.
SamplingTrajectory
¶
Bases: TypedDict
Per-step data recorded during ancestral sampling.
sequences: (S,) list of generated sequences (strings). step_log_probs: (S, n_total) — log p(sampled_token) at each sampling step. Padding entries (from order padding with position 0) will have the log-prob of re-sampling the existing BOS token (0.0 if formatted correctly). step_positions: (S, n_total) — which position was sampled at each step. Padding entries are 0 (the BOS position). step_tokens: (S, n_total) — which token was sampled at each step. step_p_y_gt_t: (S, n_total) — optional tracking of property probabilities.
Source code in src/proteingen/sampling/sampling.py
LiveSamplingPreview
¶
Terminal renderer for in-place sequence previews below tqdm.
Interface:
- update(x_SP): render current sampled sequences in-place
- close(): flush output (no-op cleanup)
Intended usage inside sampling loops:
preview = LiveSamplingPreview(model.tokenizer, enabled=True)
for step in ...:
... # update x_SP
pbar.update(1)
preview.update(x_SP)
preview.close()
Source code in src/proteingen/sampling/sampling.py
sample
¶
sample(model: GenerativeModel, x_SP: LongTensor | List[str], n_parallel: int = 1, in_order: Optional[list[LongTensor] | str] = None, live_preview: bool = True, record_p_y_gt_t: bool = False) -> SamplingTrajectory
Ancestral sampling for masked generative models.
Unmasks positions n_parallel at a time, sampling from the model's predicted
distribution at each step. If in_order is not provided, a random
permutation of masked positions is generated for each sequence.
Sharp edge: unmask orders are padded to uniform length across sequences with position 0 (typically BOS/CLS). At padding steps, ALL sequences are still sampled at their designated positions — including the padding position. If the model's logit formatter is correctly configured, special-token positions predict only themselves, making the padding a no-op. If logits are NOT properly formatted, the token at position 0 may be mutated.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
GenerativeModel
|
a GenerativeModel (or guided TAG/DEG model). |
required |
x_SP
|
LongTensor | List[str]
|
(S, P) partially masked token IDs, or list of strings. |
required |
n_parallel
|
int
|
number of positions to unmask per step. Default 1. |
1
|
in_order
|
Optional[list[LongTensor] | str]
|
controls the unmask order. Can be:
- None (default): random permutation of masked positions per sequence.
- |
None
|
live_preview
|
bool
|
if True, render in-place sequence updates below tqdm in a real terminal. Automatically disabled when stdout is not a TTY. |
True
|
Returns:
| Type | Description |
|---|---|
SamplingTrajectory
|
SamplingTrajectory with generated sequences and per-step data. |
Source code in src/proteingen/sampling/sampling.py
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 | |
generate_unmask_orders
¶
generate_unmask_orders(seq_lengths: list[int], n_orders: int, special_positions: Optional[list[set[int]]] = None, seed: Optional[int] = None) -> list[list[torch.LongTensor]]
Generate random unmask orders for a set of sequences.
Returns orders[s][k] = 1-D LongTensor of maskable position indices for sequence s, order k. Positions are listed in the order they should be unmasked (first element = first position revealed).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seq_lengths
|
list[int]
|
length of each tokenized sequence (including special tokens). |
required |
n_orders
|
int
|
how many random orders to generate per sequence. |
required |
special_positions
|
Optional[list[set[int]]]
|
per-sequence set of position indices that are NOT maskable (BOS, EOS, PAD). If None, positions 0 and L-1 are treated as special (BOS/EOS convention for ESM-family tokenizers). |
None
|
seed
|
Optional[int]
|
optional RNG seed for reproducibility. |
None
|
Source code in src/proteingen/sampling/sampling.py
mask_by_order
¶
mask_by_order(token_ids: LongTensor, order: LongTensor, mask_fraction: float, mask_token_id: int) -> torch.LongTensor
Mask positions according to a decoding order.
Positions appearing LAST in the order (the tail) are masked. Specifically,
the last ceil(mask_fraction * len(order)) positions in the order are
set to mask_token_id.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
token_ids
|
LongTensor
|
(P,) token IDs for a single sequence. |
required |
order
|
LongTensor
|
(M,) position indices in unmask order (first = first revealed). |
required |
mask_fraction
|
float
|
fraction of maskable positions to mask (0 to 1). |
required |
mask_token_id
|
int
|
token ID to use for masking. |
required |
Returns:
| Type | Description |
|---|---|
LongTensor
|
(P,) masked token IDs. The order of unmasking during generation should |
LongTensor
|
be order[n_keep:] (i.e. the masked positions, in unmask order). |
Source code in src/proteingen/sampling/sampling.py
sample_flow_matching_legacy
¶
sample_flow_matching_legacy(model: GenerativeModel, x_SP: LongTensor | List[str], dt: float = 0.01, predictor_log_prob=None, guide_temp: float = 1.0, use_tag: bool = False, x1_temp: float = 1.0, stochasticity: float = 0.0, argmax_final: bool = True, logits_postprocess: Optional[Callable[[FloatTensor, LongTensor], FloatTensor]] = None, return_string: bool = True, live_preview: bool = True) -> torch.LongTensor | list[str]
Legacy flow-matching Euler sampler (old stability demo numerics).
This reproduces the original rate-matrix integration loop and guidance-ratio update so DFM models can be compared head-to-head against old behavior.
live_preview shows in-place sequence updates below tqdm in real terminals.
Source code in src/proteingen/sampling/sampling.py
616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 | |
build_legacy_predictor_log_prob
¶
Build old-demo style predictor_log_prob closure from DFM TAG components.
Returns a closure compatible with the original flow-matching guidance loop:
- integer token input: (B, P) in generator token space
- one-hot input: (B, P, T_gen) for TAG Taylor guidance