data¶
Dataset, collation, and noise utilities for training protein models.
ProteinDataset¶
A torch.utils.data.Dataset that holds raw protein data: sequences, per-sample observations (conditioning variables), and optional labels.
sequences— list of amino acid stringsobservations— dict mapping names to per-sample lists (e.g.{"structure_tokens": [...], "coordinates": [...]})labels— optional(N,)or(N, n_targets)tensor
The dataset stores raw data only — all model-specific transforms (tokenization, noising, padding) happen in the collator.
Built-in collator¶
ProteinDataset.collator() builds a collate_fn that handles tokenization, noising, and observation preprocessing:
collate_fn = dataset.collator(
model, # provides .tokenizer and .preprocess_observations
noise_fn=uniform_mask_noise(model.tokenizer), # masking strategy
time_sampler=uniform_time, # when/how much to mask
)
loader = DataLoader(dataset, batch_size=8, collate_fn=collate_fn)
Each batch produced by the collator contains:
| Key | Shape | Description |
|---|---|---|
input_ids |
(B, L) |
Tokenized, padded, optionally noised |
target_ids |
(B, L) |
Tokenized, padded (clean — no noise) |
observations |
dict |
Preprocessed observations ready for model.forward(**obs) |
labels |
(B, ...) or None |
Per-sample targets |
The collator gathers per-sample observations into list-valued dicts and passes them to model.preprocess_observations(batched). This means the model's preprocess_observations must accept batched inputs (lists of values) when used with the collator, as opposed to single observations when used with set_condition_().
Key arguments¶
noise_fn: (input_ids_1D, t) -> noised_input_ids_1D — the corruption strategy applied independently to each sequence. Built-in options:
uniform_mask_noise(tokenizer)— mask non-special positions with probability(1 - t)no_noise— identity (no corruption)
time_sampler: () -> float in [0, 1] — controls how much masking to apply. Built-in options:
uniform_time— samplet ~ Uniform(0, 1)fully_unmasked— always returnst = 1(no masking)
rename_obs_keys: {model_kwarg: dataset_key} — for when two models use different names for the same data. One dataset, multiple collators:
# Dataset stores "structure_coords", but the model expects "coordinates"
collate_fn = dataset.collator(
model,
noise_fn=no_noise,
time_sampler=fully_unmasked,
rename_obs_keys={"coordinates": "structure_coords"},
)
Custom collators¶
For complex cases — like inverse folding where you need to pad both sequences and structures to the batch max length — write a custom collator instead. The training loop then passes observations directly to model.forward():
def inverse_folding_collator(tokenizer, mask_token_id):
def collate_fn(batch):
sequences = [s["sequence"] for s in batch]
tokenized = tokenizer(sequences, padding=True, return_tensors="pt")
target_ids = tokenized["input_ids"]
B, L = target_ids.shape
# Mask all non-special positions
input_ids = target_ids.clone()
input_ids[maskable_positions] = mask_token_id
# Pad structures to match tokenized length L
padded_struct = torch.full((B, L), STRUCTURE_PAD, dtype=torch.long)
padded_coords = torch.zeros(B, L, 37, 3)
for i, sample in enumerate(batch):
seq_len = sample["structure_tokens"].shape[0]
padded_struct[i, :seq_len] = sample["structure_tokens"]
padded_coords[i, :seq_len] = sample["coordinates"]
return {
"input_ids": input_ids,
"target_ids": target_ids,
"structure_tokens": padded_struct,
"coordinates": padded_coords,
}
return collate_fn
See the fine-tuning workflow for a complete example and the conditioning docs for how this fits into the broader conditioning model.
Noise design¶
noise_fn and time_sampler are intentionally separated:
noise_fnowns the corruption strategy (what kind of noise)time_samplerowns the schedule (how much noise)
This lets you reuse the same corruption with different t distributions (e.g. uniform for training, fixed for evaluation), or swap corruption strategies while keeping the same schedule.
Both are required arguments to collator() — there are no defaults. Use the explicit sentinels no_noise + fully_unmasked when you want clean (unmasked) training data.
FASTA utilities¶
read_fasta(path)— returnslist[tuple[header, sequence]]aligned_sequences_to_raw(aligned_seqs)— strips gap characters (-,.) from MSA-aligned sequences
GuidanceDataset (legacy)¶
Deprecated
GuidanceDataset is the older dataset class. Use ProteinDataset with appropriate noise functions for new code.
API Reference¶
proteingen.data
¶
Data utilities for training and evaluation.
ProteinDataset
¶
Bases: Dataset
Raw protein data: sequences, observations (conditioning variables), and labels.
Stores raw data only — all model-specific transforms (tokenization,
observation preprocessing, noising, padding) happen in the collator
returned by :meth:collator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sequences
|
list[str]
|
Amino acid strings. |
required |
observations
|
Optional[dict[str, list[Any]]]
|
Dict mapping names to per-sample lists (e.g. structures,
temperatures). Each value must be indexable with the same length
as |
None
|
labels
|
Optional[Tensor]
|
Per-sample targets. Shape |
None
|
Source code in src/proteingen/data/data.py
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | |
collator
¶
collator(model: Any, noise_fn: NoiseFn, time_sampler: TimeSampler, rename_obs_keys: Optional[dict[str, str]] = None) -> Callable[[list[dict[str, Any]]], dict[str, Any]]
Build a collate_fn that tokenizes, noises, and preprocesses per batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Any
|
Provides |
required |
noise_fn
|
NoiseFn
|
|
required |
time_sampler
|
TimeSampler
|
|
required |
rename_obs_keys
|
Optional[dict[str, str]]
|
|
None
|
Returns:
| Type | Description |
|---|---|
Callable[[list[dict[str, Any]]], dict[str, Any]]
|
A collate_fn producing dicts with:
- |
Source code in src/proteingen/data/data.py
PDBStructure
dataclass
¶
Parsed PDB with per-residue chain and sequence info.
The atom_array is kept so model-specific code can re-encode coordinates with the appropriate atom layout (e.g. MPNN vs ESM).
Source code in src/proteingen/data/structure.py
uniform_mask_noise
¶
Mask non-special positions independently with probability (1 - t).
At t=1 nothing is masked; at t=0 everything maskable is masked.
Uses tokenizer.all_special_ids to determine which positions to leave alone
(CLS, EOS, PAD, MASK, etc.), so the logic is tokenizer-agnostic.
Source code in src/proteingen/data/data.py
no_noise
¶
fully_unmasked
¶
uniform_time
¶
read_fasta
¶
Read a FASTA file, returning (header, sequence) pairs.
Concatenates multi-line sequences. Does not modify sequences (gaps, lowercase, etc. are preserved).
Source code in src/proteingen/data/data.py
aligned_sequences_to_raw
¶
Strip gap characters from aligned sequences to get raw AA strings.
Removes - and . characters used in MSA formats.
Source code in src/proteingen/data/data.py
load_pdb
¶
Parse a PDB file into a PDBStructure.
If the file is not present locally, tries to infer a PDB id from
pdb_path and downloads it from RCSB into data/pdbs at repo root.
Assumes a single biological assembly. Handles multi-chain structures.
Source code in src/proteingen/data/structure.py
cif_to_atom37
¶
Convert a CIF structure file to atom37 coordinates (L, 37, 3).
The CIF is converted through a temporary PDB because ProteinChain currently
consumes PDB input.
Source code in src/proteingen/data/folding.py
af3_result_cif_path
¶
af3_result_cif_path(result_output_dir: str, result_name: str, *, container_output_root: str = '/app/af_output', host_output_root: str | Path = '/data/af3_server_output') -> Path
Map AF3 server output dir (container path) to host CIF path.
Source code in src/proteingen/data/folding.py
fold_sequence_and_download_cif
¶
Fold one sequence with AF3 server and download the resulting CIF file.
Source code in src/proteingen/data/folding.py
fold_sequence_to_atom37
¶
fold_sequence_to_atom37(client: Any, sequence: str, name: str, *, container_output_root: str = '/app/af_output', host_output_root: str | Path = '/data/af3_server_output', chain_id: str = 'A')
Fold one sequence with AF3 server and return (result, coords_atom37).