Skip to content

data

Dataset, collation, and noise utilities for training protein models.

ProteinDataset

A torch.utils.data.Dataset that holds raw protein data: sequences, per-sample observations (conditioning variables), and optional labels.

ProteinDataset(sequences, observations=None, labels=None)
  • sequences — list of amino acid strings
  • observations — dict mapping names to per-sample lists (e.g. {"structure_tokens": [...], "coordinates": [...]})
  • labels — optional (N,) or (N, n_targets) tensor

The dataset stores raw data only — all model-specific transforms (tokenization, noising, padding) happen in the collator.

Built-in collator

ProteinDataset.collator() builds a collate_fn that handles tokenization, noising, and observation preprocessing:

collate_fn = dataset.collator(
    model,                                # provides .tokenizer and .preprocess_observations
    noise_fn=uniform_mask_noise(model.tokenizer),  # masking strategy
    time_sampler=uniform_time,            # when/how much to mask
)
loader = DataLoader(dataset, batch_size=8, collate_fn=collate_fn)

Each batch produced by the collator contains:

Key Shape Description
input_ids (B, L) Tokenized, padded, optionally noised
target_ids (B, L) Tokenized, padded (clean — no noise)
observations dict Preprocessed observations ready for model.forward(**obs)
labels (B, ...) or None Per-sample targets

The collator gathers per-sample observations into list-valued dicts and passes them to model.preprocess_observations(batched). This means the model's preprocess_observations must accept batched inputs (lists of values) when used with the collator, as opposed to single observations when used with set_condition_().

Key arguments

noise_fn: (input_ids_1D, t) -> noised_input_ids_1D — the corruption strategy applied independently to each sequence. Built-in options:

  • uniform_mask_noise(tokenizer) — mask non-special positions with probability (1 - t)
  • no_noise — identity (no corruption)

time_sampler: () -> float in [0, 1] — controls how much masking to apply. Built-in options:

  • uniform_time — sample t ~ Uniform(0, 1)
  • fully_unmasked — always returns t = 1 (no masking)

rename_obs_keys: {model_kwarg: dataset_key} — for when two models use different names for the same data. One dataset, multiple collators:

# Dataset stores "structure_coords", but the model expects "coordinates"
collate_fn = dataset.collator(
    model,
    noise_fn=no_noise,
    time_sampler=fully_unmasked,
    rename_obs_keys={"coordinates": "structure_coords"},
)

Custom collators

For complex cases — like inverse folding where you need to pad both sequences and structures to the batch max length — write a custom collator instead. The training loop then passes observations directly to model.forward():

def inverse_folding_collator(tokenizer, mask_token_id):
    def collate_fn(batch):
        sequences = [s["sequence"] for s in batch]
        tokenized = tokenizer(sequences, padding=True, return_tensors="pt")
        target_ids = tokenized["input_ids"]
        B, L = target_ids.shape

        # Mask all non-special positions
        input_ids = target_ids.clone()
        input_ids[maskable_positions] = mask_token_id

        # Pad structures to match tokenized length L
        padded_struct = torch.full((B, L), STRUCTURE_PAD, dtype=torch.long)
        padded_coords = torch.zeros(B, L, 37, 3)
        for i, sample in enumerate(batch):
            seq_len = sample["structure_tokens"].shape[0]
            padded_struct[i, :seq_len] = sample["structure_tokens"]
            padded_coords[i, :seq_len] = sample["coordinates"]

        return {
            "input_ids": input_ids,
            "target_ids": target_ids,
            "structure_tokens": padded_struct,
            "coordinates": padded_coords,
        }
    return collate_fn

See the fine-tuning workflow for a complete example and the conditioning docs for how this fits into the broader conditioning model.

Noise design

noise_fn and time_sampler are intentionally separated:

  • noise_fn owns the corruption strategy (what kind of noise)
  • time_sampler owns the schedule (how much noise)

This lets you reuse the same corruption with different t distributions (e.g. uniform for training, fixed for evaluation), or swap corruption strategies while keeping the same schedule.

Both are required arguments to collator() — there are no defaults. Use the explicit sentinels no_noise + fully_unmasked when you want clean (unmasked) training data.

FASTA utilities

  • read_fasta(path) — returns list[tuple[header, sequence]]
  • aligned_sequences_to_raw(aligned_seqs) — strips gap characters (-, .) from MSA-aligned sequences

GuidanceDataset (legacy)

Deprecated

GuidanceDataset is the older dataset class. Use ProteinDataset with appropriate noise functions for new code.


API Reference

proteingen.data

Data utilities for training and evaluation.

ProteinDataset

Bases: Dataset

Raw protein data: sequences, observations (conditioning variables), and labels.

Stores raw data only — all model-specific transforms (tokenization, observation preprocessing, noising, padding) happen in the collator returned by :meth:collator.

Parameters:

Name Type Description Default
sequences list[str]

Amino acid strings.

required
observations Optional[dict[str, list[Any]]]

Dict mapping names to per-sample lists (e.g. structures, temperatures). Each value must be indexable with the same length as sequences.

None
labels Optional[Tensor]

Per-sample targets. Shape (N,) or (N, n_targets).

None
Source code in src/proteingen/data/data.py
class ProteinDataset(Dataset):
    """Raw protein data: sequences, observations (conditioning variables), and labels.

    Stores raw data only — all model-specific transforms (tokenization,
    observation preprocessing, noising, padding) happen in the collator
    returned by :meth:`collator`.

    Args:
        sequences: Amino acid strings.
        observations: Dict mapping names to per-sample lists (e.g. structures,
            temperatures). Each value must be indexable with the same length
            as ``sequences``.
        labels: Per-sample targets. Shape ``(N,)`` or ``(N, n_targets)``.
    """

    def __init__(
        self,
        sequences: list[str],
        observations: Optional[dict[str, list[Any]]] = None,
        labels: Optional[torch.Tensor] = None,
    ):
        self.sequences = sequences
        self.observations = observations or {}
        self.labels = labels

        n = len(sequences)
        for key, vals in self.observations.items():
            if len(vals) != n:
                raise ValueError(
                    f"Observation '{key}' has {len(vals)} items, expected {n}"
                )
        if self.labels is not None and len(self.labels) != n:
            raise ValueError(f"Labels has {len(self.labels)} items, expected {n}")

    def __len__(self) -> int:
        return len(self.sequences)

    def __getitem__(self, idx: int) -> dict[str, Any]:
        item: dict[str, Any] = {"sequence": self.sequences[idx]}
        if self.observations:
            item["observations"] = {k: v[idx] for k, v in self.observations.items()}
        if self.labels is not None:
            item["labels"] = self.labels[idx]
        return item

    def collator(
        self,
        model: Any,  # ProbabilityModel — avoids circular import
        noise_fn: NoiseFn,
        time_sampler: TimeSampler,
        rename_obs_keys: Optional[dict[str, str]] = None,
    ) -> Callable[[list[dict[str, Any]]], dict[str, Any]]:
        """Build a collate_fn that tokenizes, noises, and preprocesses per batch.

        Args:
            model: Provides ``.tokenizer`` and ``.preprocess_observations`` (batched).
            noise_fn: ``(input_ids_1D, t) -> noised_input_ids_1D``. Use
                :func:`uniform_mask_noise` for MLM training or :func:`no_noise`
                for clean inputs.
            time_sampler: ``() -> float`` in ``[0, 1]``. Use :func:`uniform_time`
                or :func:`fully_unmasked`.
            rename_obs_keys: ``{model_kwarg: dataset_key}`` for renaming
                observation keys. If ``None``, dataset keys are passed through
                as model kwargs.

        Returns:
            A collate_fn producing dicts with:
                - ``input_ids``:    ``(B, L)`` — tokenized, optionally noised
                - ``target_ids``:   ``(B, L)`` — tokenized original sequences
                - ``observations``: dict ready for ``model.forward(**obs)``
                - ``labels``:       ``(B, ...)`` or ``None``
        """
        tokenizer = model.tokenizer

        def collate_fn(items: list[dict[str, Any]]) -> dict[str, Any]:
            sequences = [item["sequence"] for item in items]

            # Tokenize + pad to longest in batch
            encoded = tokenizer(sequences, padding=True, return_tensors="pt")
            target_ids = encoded["input_ids"].clone()
            input_ids = encoded["input_ids"]

            # Noise each sequence independently
            for i in range(input_ids.size(0)):
                t = time_sampler()
                input_ids[i] = noise_fn(input_ids[i], t)

            # Preprocess observations (batched)
            observations: dict[str, Any] = {}
            if "observations" in items[0]:
                if rename_obs_keys is not None:
                    # {model_kwarg: dataset_key} → {model_kwarg: [values...]}
                    batched = {
                        model_key: [item["observations"][dataset_key] for item in items]
                        for model_key, dataset_key in rename_obs_keys.items()
                    }
                else:
                    # Pass through dataset keys as-is
                    batched = {
                        key: [item["observations"][key] for item in items]
                        for key in items[0]["observations"]
                    }
                observations = model.preprocess_observations(batched)

            # Labels
            labels = None
            if "labels" in items[0]:
                labels = torch.stack([item["labels"] for item in items])

            return {
                "input_ids": input_ids,
                "target_ids": target_ids,
                "observations": observations,
                "labels": labels,
            }

        return collate_fn
collator
collator(model: Any, noise_fn: NoiseFn, time_sampler: TimeSampler, rename_obs_keys: Optional[dict[str, str]] = None) -> Callable[[list[dict[str, Any]]], dict[str, Any]]

Build a collate_fn that tokenizes, noises, and preprocesses per batch.

Parameters:

Name Type Description Default
model Any

Provides .tokenizer and .preprocess_observations (batched).

required
noise_fn NoiseFn

(input_ids_1D, t) -> noised_input_ids_1D. Use :func:uniform_mask_noise for MLM training or :func:no_noise for clean inputs.

required
time_sampler TimeSampler

() -> float in [0, 1]. Use :func:uniform_time or :func:fully_unmasked.

required
rename_obs_keys Optional[dict[str, str]]

{model_kwarg: dataset_key} for renaming observation keys. If None, dataset keys are passed through as model kwargs.

None

Returns:

Type Description
Callable[[list[dict[str, Any]]], dict[str, Any]]

A collate_fn producing dicts with: - input_ids: (B, L) — tokenized, optionally noised - target_ids: (B, L) — tokenized original sequences - observations: dict ready for model.forward(**obs) - labels: (B, ...) or None

Source code in src/proteingen/data/data.py
def collator(
    self,
    model: Any,  # ProbabilityModel — avoids circular import
    noise_fn: NoiseFn,
    time_sampler: TimeSampler,
    rename_obs_keys: Optional[dict[str, str]] = None,
) -> Callable[[list[dict[str, Any]]], dict[str, Any]]:
    """Build a collate_fn that tokenizes, noises, and preprocesses per batch.

    Args:
        model: Provides ``.tokenizer`` and ``.preprocess_observations`` (batched).
        noise_fn: ``(input_ids_1D, t) -> noised_input_ids_1D``. Use
            :func:`uniform_mask_noise` for MLM training or :func:`no_noise`
            for clean inputs.
        time_sampler: ``() -> float`` in ``[0, 1]``. Use :func:`uniform_time`
            or :func:`fully_unmasked`.
        rename_obs_keys: ``{model_kwarg: dataset_key}`` for renaming
            observation keys. If ``None``, dataset keys are passed through
            as model kwargs.

    Returns:
        A collate_fn producing dicts with:
            - ``input_ids``:    ``(B, L)`` — tokenized, optionally noised
            - ``target_ids``:   ``(B, L)`` — tokenized original sequences
            - ``observations``: dict ready for ``model.forward(**obs)``
            - ``labels``:       ``(B, ...)`` or ``None``
    """
    tokenizer = model.tokenizer

    def collate_fn(items: list[dict[str, Any]]) -> dict[str, Any]:
        sequences = [item["sequence"] for item in items]

        # Tokenize + pad to longest in batch
        encoded = tokenizer(sequences, padding=True, return_tensors="pt")
        target_ids = encoded["input_ids"].clone()
        input_ids = encoded["input_ids"]

        # Noise each sequence independently
        for i in range(input_ids.size(0)):
            t = time_sampler()
            input_ids[i] = noise_fn(input_ids[i], t)

        # Preprocess observations (batched)
        observations: dict[str, Any] = {}
        if "observations" in items[0]:
            if rename_obs_keys is not None:
                # {model_kwarg: dataset_key} → {model_kwarg: [values...]}
                batched = {
                    model_key: [item["observations"][dataset_key] for item in items]
                    for model_key, dataset_key in rename_obs_keys.items()
                }
            else:
                # Pass through dataset keys as-is
                batched = {
                    key: [item["observations"][key] for item in items]
                    for key in items[0]["observations"]
                }
            observations = model.preprocess_observations(batched)

        # Labels
        labels = None
        if "labels" in items[0]:
            labels = torch.stack([item["labels"] for item in items])

        return {
            "input_ids": input_ids,
            "target_ids": target_ids,
            "observations": observations,
            "labels": labels,
        }

    return collate_fn

PDBStructure dataclass

Parsed PDB with per-residue chain and sequence info.

The atom_array is kept so model-specific code can re-encode coordinates with the appropriate atom layout (e.g. MPNN vs ESM).

Source code in src/proteingen/data/structure.py
@dataclass
class PDBStructure:
    """Parsed PDB with per-residue chain and sequence info.

    The atom_array is kept so model-specific code can re-encode
    coordinates with the appropriate atom layout (e.g. MPNN vs ESM).
    """

    atom_array: bts.AtomArray  # full biotite atom array
    chain_ids: np.ndarray  # (L,) per-residue chain ID strings, e.g. ['A','A','B']
    sequence: str  # full sequence across all chains

uniform_mask_noise

uniform_mask_noise(tokenizer: Any) -> NoiseFn

Mask non-special positions independently with probability (1 - t).

At t=1 nothing is masked; at t=0 everything maskable is masked. Uses tokenizer.all_special_ids to determine which positions to leave alone (CLS, EOS, PAD, MASK, etc.), so the logic is tokenizer-agnostic.

Source code in src/proteingen/data/data.py
def uniform_mask_noise(tokenizer: Any) -> NoiseFn:
    """Mask non-special positions independently with probability (1 - t).

    At t=1 nothing is masked; at t=0 everything maskable is masked.
    Uses ``tokenizer.all_special_ids`` to determine which positions to leave alone
    (CLS, EOS, PAD, MASK, etc.), so the logic is tokenizer-agnostic.
    """
    mask_id = tokenizer.vocab["<mask>"]
    special = torch.tensor(tokenizer.all_special_ids)

    def noise(input_ids: torch.LongTensor, t: float) -> torch.LongTensor:
        maskable = ~torch.isin(input_ids, special)
        to_mask = maskable & (torch.rand(input_ids.shape) > t)
        out = input_ids.clone()
        out[to_mask] = mask_id
        return out

    return noise

no_noise

no_noise(input_ids: LongTensor, t: float) -> torch.LongTensor

Identity noise function — returns input unchanged.

Source code in src/proteingen/data/data.py
def no_noise(input_ids: torch.LongTensor, t: float) -> torch.LongTensor:
    """Identity noise function — returns input unchanged."""
    return input_ids

fully_unmasked

fully_unmasked() -> float

Time sampler that always returns t=1 (no masking).

Source code in src/proteingen/data/data.py
def fully_unmasked() -> float:
    """Time sampler that always returns t=1 (no masking)."""
    return 1.0

uniform_time

uniform_time() -> float

Time sampler that returns t ~ Uniform(0, 1).

Source code in src/proteingen/data/data.py
def uniform_time() -> float:
    """Time sampler that returns t ~ Uniform(0, 1)."""
    return random.random()

read_fasta

read_fasta(path: str) -> list[tuple[str, str]]

Read a FASTA file, returning (header, sequence) pairs.

Concatenates multi-line sequences. Does not modify sequences (gaps, lowercase, etc. are preserved).

Source code in src/proteingen/data/data.py
def read_fasta(path: str) -> list[tuple[str, str]]:
    """Read a FASTA file, returning (header, sequence) pairs.

    Concatenates multi-line sequences. Does not modify sequences
    (gaps, lowercase, etc. are preserved).
    """
    entries: list[tuple[str, str]] = []
    header = ""
    seq_parts: list[str] = []
    with open(path) as f:
        for line in f:
            line = line.rstrip("\n")
            if line.startswith(">"):
                if header or seq_parts:
                    entries.append((header, "".join(seq_parts)))
                header = line[1:]
                seq_parts = []
            else:
                seq_parts.append(line)
    if header or seq_parts:
        entries.append((header, "".join(seq_parts)))
    return entries

aligned_sequences_to_raw

aligned_sequences_to_raw(aligned_sequences: list[str]) -> list[str]

Strip gap characters from aligned sequences to get raw AA strings.

Removes - and . characters used in MSA formats.

Source code in src/proteingen/data/data.py
def aligned_sequences_to_raw(aligned_sequences: list[str]) -> list[str]:
    """Strip gap characters from aligned sequences to get raw AA strings.

    Removes ``-`` and ``.`` characters used in MSA formats.
    """
    return [seq.replace("-", "").replace(".", "") for seq in aligned_sequences]

load_pdb

load_pdb(pdb_path: Path | str, cache_dir: Path | str | None = None) -> PDBStructure

Parse a PDB file into a PDBStructure.

If the file is not present locally, tries to infer a PDB id from pdb_path and downloads it from RCSB into data/pdbs at repo root.

Assumes a single biological assembly. Handles multi-chain structures.

Source code in src/proteingen/data/structure.py
def load_pdb(
    pdb_path: Path | str,
    cache_dir: Path | str | None = None,
) -> PDBStructure:
    """Parse a PDB file into a PDBStructure.

    If the file is not present locally, tries to infer a PDB id from
    ``pdb_path`` and downloads it from RCSB into ``data/pdbs`` at repo root.

    Assumes a single biological assembly. Handles multi-chain structures.
    """
    resolved_path = _resolve_pdb_path(pdb_path, cache_dir=cache_dir)
    parsed = aio.parse(str(resolved_path))
    atom_array = parsed["assemblies"]["1"][0]

    residue_starts = bts.get_residue_starts(atom_array)
    chain_ids = atom_array.chain_id[residue_starts]  # (L,) str
    sequence = bts.to_sequence(atom_array)[0][0]

    return PDBStructure(
        atom_array=atom_array,
        chain_ids=chain_ids,
        sequence=sequence,
    )

cif_to_atom37

cif_to_atom37(cif_path: str | Path, chain_id: str = 'A') -> torch.Tensor

Convert a CIF structure file to atom37 coordinates (L, 37, 3).

The CIF is converted through a temporary PDB because ProteinChain currently consumes PDB input.

Source code in src/proteingen/data/folding.py
def cif_to_atom37(cif_path: str | Path, chain_id: str = "A") -> torch.Tensor:
    """Convert a CIF structure file to atom37 coordinates ``(L, 37, 3)``.

    The CIF is converted through a temporary PDB because ``ProteinChain`` currently
    consumes PDB input.
    """
    cif_path = Path(cif_path)
    f = pdbx.CIFFile.read(str(cif_path))
    atoms = pdbx.get_structure(f, model=1, extra_fields=["b_factor", "occupancy"])

    with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as tmp:
        pdb_file = pdb_io.PDBFile()
        pdb_file.set_structure(atoms)
        pdb_file.write(tmp.name)
        tmp_path = Path(tmp.name)

    try:
        protein_chain = ProteinChain.from_pdb(str(tmp_path), chain_id=chain_id)
        return torch.from_numpy(protein_chain.atom37_positions).float()
    finally:
        tmp_path.unlink(missing_ok=True)

af3_result_cif_path

af3_result_cif_path(result_output_dir: str, result_name: str, *, container_output_root: str = '/app/af_output', host_output_root: str | Path = '/data/af3_server_output') -> Path

Map AF3 server output dir (container path) to host CIF path.

Source code in src/proteingen/data/folding.py
def af3_result_cif_path(
    result_output_dir: str,
    result_name: str,
    *,
    container_output_root: str = "/app/af_output",
    host_output_root: str | Path = "/data/af3_server_output",
) -> Path:
    """Map AF3 server output dir (container path) to host CIF path."""
    output_dir = str(result_output_dir)
    host_root = Path(host_output_root)
    if output_dir.startswith(container_output_root):
        suffix = output_dir[len(container_output_root) :].lstrip("/")
        mapped_dir = host_root / suffix
    else:
        mapped_dir = Path(output_dir)
    return mapped_dir / f"{result_name}_model.cif"

fold_sequence_and_download_cif

fold_sequence_and_download_cif(client: Any, sequence: str, name: str, cif_path: str | Path)

Fold one sequence with AF3 server and download the resulting CIF file.

Source code in src/proteingen/data/folding.py
def fold_sequence_and_download_cif(
    client: Any,
    sequence: str,
    name: str,
    cif_path: str | Path,
):
    """Fold one sequence with AF3 server and download the resulting CIF file."""
    result = client.fold(sequence=sequence, name=name)
    cif_path = Path(cif_path)
    cif_path.parent.mkdir(parents=True, exist_ok=True)
    client.download_cif(result.job_id, cif_path)
    return result, cif_path

fold_sequence_to_atom37

fold_sequence_to_atom37(client: Any, sequence: str, name: str, *, container_output_root: str = '/app/af_output', host_output_root: str | Path = '/data/af3_server_output', chain_id: str = 'A')

Fold one sequence with AF3 server and return (result, coords_atom37).

Source code in src/proteingen/data/folding.py
def fold_sequence_to_atom37(
    client: Any,
    sequence: str,
    name: str,
    *,
    container_output_root: str = "/app/af_output",
    host_output_root: str | Path = "/data/af3_server_output",
    chain_id: str = "A",
):
    """Fold one sequence with AF3 server and return ``(result, coords_atom37)``."""
    result = client.fold(sequence=sequence, name=name)
    cif_path = af3_result_cif_path(
        result_output_dir=result.output_dir,
        result_name=result.name,
        container_output_root=container_output_root,
        host_output_root=host_output_root,
    )
    coords = cif_to_atom37(cif_path, chain_id=chain_id)
    return result, coords