ESM3¶
ESM3-open (1.4B parameters). Supports optional structure conditioning via atom37 coordinates.
- Output dim: 64
- Embedding dim: 1536
- LoRA support: yes — use
lr ≤ 1e-4to avoid mode collapse
Structure conditioning (inverse folding)¶
ESM3 accepts atom37-format coordinates as conditioning input, enabling structure-conditioned sequence generation (inverse folding). For inference, the set_condition_() method runs the VQ-VAE structure encoder once (expensive), then all subsequent calls use cached structure tokens:
from proteingen.models import ESM3
from proteingen.sampling import sample_ctmc_linear_interpolation
model = ESM3("esm3-open").cuda()
coords = ... # atom37 format, shape (L, 37, 3)
# Set structure conditioning — VQ-VAE encodes coordinates once
model.set_condition_({"coords_RAX": coords})
# Sample sequences conditioned on the structure
init_tokens = tokenizer(["<mask>" * L], return_tensors="pt")["input_ids"].cuda()
sequences = sample_ctmc_linear_interpolation(model, init_tokens, n_steps=100)
# Or use a context manager (reverts conditioning on exit)
with model.conditioned_on({"coords_RAX": coords}):
log_probs = model.get_log_probs(seq)
For training with per-sample structures (e.g. fine-tuning on a family of AF3-predicted structures), pass observations directly through the collator instead of using set_condition_(). See the conditioning docs and the fine-tuning workflow for the full pattern.
Structure conditioning is length-locked
set_condition_() preprocesses coordinates to fixed-length structure tokens (L+2 with BOS/EOS). All subsequent calls must use sequences of exactly that length, or you get a shape mismatch. For training with variable-length sequences, use per-sample conditioning via the collator.
Lazy VQ-VAE loading and parameter freezing
ESM3's structure encoder (_structure_encoder) is loaded lazily on first set_condition_() call. If you froze parameters before conditioning (e.g. via apply_lora()), the encoder's ~30M parameters won't be frozen. Always set up conditioning before freezing, or re-freeze after.
Memory considerations¶
ESM3's geometric attention computes pairwise (L×L) tensors. For a sequence of length 297:
batch_size ≥ 32OOMs on 48GB GPUbatch_size = 16with bfloat16 AMP is optimal (~32s/epoch for LoRA training)
LoRA specifics¶
| Rank (r) | Trainable params | % of 1.4B |
|---|---|---|
| 8 | ~9.8M | 0.69% |
| 4 | ~4.9M | 0.35% |
| 2 | ~2.5M | 0.18% |