Source code for molecular_simulations.analysis.ipSAE

# ruff: noqa: N801, N999
"""Interface prediction Score from Aligned Errors (ipSAE) module.

This module computes interaction prediction scores from pLDDT and PAE data,
adapted from https://doi.org/10.1101/2025.02.10.637595. Supports outputs
from structure prediction tools like Boltz and AlphaFold.
"""

from itertools import permutations
from pathlib import Path
from typing import Any

import numpy as np
import polars as pl
from scipy.spatial.distance import cdist

PathLike = Path | str
OptPath = Path | str | None

[docs] class ipSAE: """Compute interaction prediction Score from Aligned Errors. Computes various model quality scores including pDockQ, pDockQ2, LIS, ipTM, and ipSAE for structure predictions. Attributes: parser: ModelParser instance for structure file. plddt_file: Path to pLDDT data file. pae_file: Path to PAE data file. path: Output directory path. scores: Polars DataFrame of computed scores after run(). Args: structure_file: Path to PDB/CIF model file. plddt_file: Path to pLDDT numpy file (.npz with 'plddt' key). pae_file: Path to PAE numpy file (.npz with 'pae' key). out_path: Output directory path. If None, uses parent directory of plddt_file. Example: >>> scorer = ipSAE("model.pdb", "plddt.npz", "pae.npz") >>> scorer.run() >>> print(scorer.scores) """
[docs] def __init__( self, structure_file: PathLike, plddt_file: PathLike, pae_file: PathLike, out_path: OptPath = None, skip_chains: set[str] | list[str] | None = None, ): """Initialize the ipSAE scorer. Args: structure_file: Path to structure file. plddt_file: Path to pLDDT data file. pae_file: Path to PAE data file. out_path: Output directory path. skip_chains: Chain IDs to exclude from scoring (e.g. glycan or ligand chains). Their pLDDT/PAE tokens are inferred from remaining tokens and dropped before scoring. """ self.parser = ModelParser(structure_file, skip_chains=skip_chains) self.plddt_file = Path(plddt_file) self.pae_file = Path(pae_file) self.path = Path(out_path) if out_path is not None else self.plddt_file.parent self.path.mkdir(exist_ok=True)
[docs] def parse_structure_file(self) -> None: """Parse the structure file and extract relevant details. Runs the parser to read the structure file and classifies chains as protein or nucleic acid. """ self.parser.parse_structure_file() self.parser.classify_chains() self.coordinates = np.vstack([res['coor'] for res in self.parser.residues]) if self.parser.cb_residues: self.cb_coordinates = np.vstack( [res['coor'] for res in self.parser.cb_residues] ) else: self.cb_coordinates = self.coordinates
[docs] def prepare_scorer(self) -> None: """Initialize the ScoreCalculator for computing scores. Creates a ScoreCalculator instance with chain information extracted from the parsed structure. """ chains = np.array(self.parser.chains) chain_types = self.parser.chain_types self.scorer = ScoreCalculator( chains=chains, chain_pair_type=chain_types, n_residues=len(self.parser.residues), )
[docs] def run(self) -> None: """Execute the complete ipSAE scoring workflow. Parses structure, computes distogram, loads pLDDT and PAE data, runs the scorer, and saves results. """ self.parse_structure_file() distances = cdist(self.cb_coordinates, self.cb_coordinates) pLDDT = self.load_pLDDT_file() PAE = self.load_PAE_file() n_residues = len(self.parser.residues) if pLDDT.shape[0] != n_residues or PAE.shape[0] != n_residues: self.parser.build_protein_token_indices(pLDDT.shape[0]) indices = np.array(self.parser.protein_token_indices, dtype=int) pLDDT = pLDDT[indices] PAE = PAE[np.ix_(indices, indices)] self.prepare_scorer() self.scorer.compute_scores(distances, pLDDT, PAE) self.scores = self.scorer.scores self.save_scores()
[docs] def save_scores(self) -> None: """Save scores DataFrame to a Parquet file.""" self.scores.write_parquet(self.path / 'ipSAE_scores.parquet')
[docs] def load_pLDDT_file(self) -> np.ndarray: """Load and scale pLDDT data. Returns: pLDDT array scaled to 0-100 range. """ data = np.load(self.plddt_file) pLDDT_arr = data['plddt'] if np.max(pLDDT_arr) <= 1.0: pLDDT_arr = pLDDT_arr * 100.0 return pLDDT_arr
[docs] def load_PAE_file(self) -> np.ndarray: """Load PAE data from file. Returns: PAE array from the 'pae' key in the npz file. """ data = np.load(self.pae_file)['pae'] return data
[docs] class ScoreCalculator: """Calculate model quality scores from structure predictions. Computes pDockQ, pDockQ2, LIS, ipTM, and ipSAE scores for all chain pairs in a structure. Attributes: chains: Array of chain IDs for each residue. unique_chains: Unique chain IDs in the structure. chain_pair_type: Dictionary mapping chain ID to type. n_res: Array of residue types. permuted: List of all chain pairs to evaluate. scores: DataFrame of computed scores after compute_scores(). Args: chains: Array of chain IDs. chain_pair_type: Dictionary mapping chain ID to chain type ('protein' or 'nucleic_acid'). n_residues: Number of residues per chain. pdockq_cutoff: Distance cutoff for pDockQ in Angstroms. Defaults to 8.0. pae_cutoff: PAE cutoff for ipSAE in Angstroms. Defaults to 12.0. Example: >>> calc = ScoreCalculator(chains, chain_types, n_residues) >>> calc.compute_scores(distances, plddt, pae) >>> print(calc.scores) """
[docs] def __init__( self, chains: np.ndarray, chain_pair_type: dict[str, str], n_residues: int, pdockq_cutoff: float = 8.0, pae_cutoff: float = 12.0, ): """Initialize the ScoreCalculator. Args: chains: Array of chain IDs. chain_pair_type: Chain ID to type mapping. n_residues: Residue type array. pdockq_cutoff: pDockQ distance cutoff. pae_cutoff: PAE cutoff. """ self.chains = chains self.unique_chains = np.unique(chains) self.chain_pair_type = chain_pair_type self.n_res = n_residues self.pDockQ_cutoff = pdockq_cutoff self.PAE_cutoff = pae_cutoff self.permute_chains()
[docs] def compute_scores( self, distances: np.ndarray, pLDDT: np.ndarray, PAE: np.ndarray ) -> None: """Compute all scores for all chain pairs. Calculates pDockQ, pDockQ2, LIS, ipTM, and ipSAE scores for each permutation of chain pairs. Args: distances: Pairwise distance matrix between all residues. pLDDT: Per-residue pLDDT values (0-100 scale). PAE: Predicted aligned error matrix. """ self.distances = distances self.pLDDT = pLDDT self.PAE = PAE results = [] for chain1, chain2 in self.permuted: pDockQ, pDockQ2 = self.compute_pDockQ_scores(chain1, chain2) LIS = self.compute_LIS(chain1, chain2) ipTM, ipSAE = self.compute_ipTM_ipSAE(chain1, chain2) results.append([chain1, chain2, pDockQ, pDockQ2, LIS, ipTM, ipSAE]) self.df = pl.DataFrame( np.array(results), schema={ 'chain1': str, 'chain2': str, 'pDockQ': float, 'pDockQ2': float, 'LIS': float, 'ipTM': float, 'ipSAE': float, }, ) self.get_max_values()
[docs] def compute_pDockQ_scores(self, chain1: str, chain2: str) -> tuple[float, float]: """Compute pDockQ and pDockQ2 scores for a chain pair. pDockQ depends solely on pLDDT, while pDockQ2 depends on both pLDDT and PAE. Args: chain1: First chain identifier. chain2: Second chain identifier. Returns: Tuple of (pDockQ, pDockQ2) scores. """ mask_c1 = self.chains == chain1 mask_c2 = self.chains == chain2 dist_sub = self.distances[np.ix_(mask_c1, mask_c2)] contact = dist_sub <= self.pDockQ_cutoff n_pairs = np.sum(contact) if n_pairs == 0: return 0.0, 0.0 c1_idx = np.where(mask_c1)[0] c2_idx = np.where(mask_c2)[0] c1_contact = c1_idx[contact.any(axis=1)] c2_contact = c2_idx[contact.any(axis=0)] residues = np.concatenate([c1_contact, c2_contact]) mean_pLDDT = self.pLDDT[residues].mean() x = mean_pLDDT * np.log10(n_pairs) pDockQ = self.pDockQ_score(x) pae_sub = self.PAE[np.ix_(mask_c1, mask_c2)] pae_contacts = pae_sub[contact] pae_ptm = self.compute_pTM(pae_contacts, 10.0) mean_pTM = pae_ptm.mean() x = mean_pLDDT * mean_pTM pDockQ2 = self.pDockQ2_score(x) return pDockQ, pDockQ2
[docs] def compute_LIS(self, chain1: str, chain2: str) -> float: """Compute Local Interaction Score (LIS) for a chain pair. LIS is based on a subset of the predicted aligned error using a cutoff of 12 Å. Values range in (0, 1] where 1 indicates perfect accuracy. Adapted from: https://doi.org/10.1101/2024.02.19.580970 Args: chain1: First chain identifier. chain2: Second chain identifier. Returns: LIS value for the chain pair. """ mask = (self.chains[:, None] == chain1) & (self.chains[None, :] == chain2) selected_pae = self.PAE[mask] LIS = 0.0 if selected_pae.size: valid_pae = selected_pae[selected_pae < self.PAE_cutoff] if valid_pae.size: scores = (self.PAE_cutoff - valid_pae) / self.PAE_cutoff LIS = np.mean(scores) return LIS
[docs] def compute_ipTM_ipSAE(self, chain1: str, chain2: str) -> tuple[Any, Any]: """Compute ipTM and ipSAE scores for a chain pair. ipTM uses d0 based on total chain pair length and averages over all chain2 residues. ipSAE uses a per-residue d0 based on the count of chain2 residues with PAE below the cutoff for each aligned residue in chain1, averaging only over those valid residues. Args: chain1: First chain identifier (aligned chain). chain2: Second chain identifier (scored chain). Returns: Tuple of (ipTM, ipSAE) scores. """ pair_type = 'protein' if ( self.chain_pair_type[chain1] == 'nucleic_acid' or self.chain_pair_type[chain2] == 'nucleic_acid' ): pair_type = 'nucleic_acid' mask_c1 = self.chains == chain1 mask_c2 = self.chains == chain2 # ipTM: d0 from total chain pair length L = np.sum(mask_c1) + np.sum(mask_c2) d0_chain = self.compute_d0(L, pair_type) pae_sub = self.PAE[np.ix_(mask_c1, mask_c2)] # ipTM: mean pTM over all chain2 residues per chain1 residue ptm_sub_chain = self.compute_pTM(pae_sub, d0_chain) ipTM_byres = np.array([0.0]) if mask_c2.any(): ipTM_byres = np.mean(ptm_sub_chain, axis=1) # ipSAE: PAE filtering + per-residue d0 valid_mask = pae_sub < self.PAE_cutoff n_valid_per_row = valid_mask.sum(axis=1) d0_per_res = self.compute_d0_array(n_valid_per_row, pair_type) ptm_sub_res = 1.0 / (1 + (pae_sub / d0_per_res[:, None]) ** 2) ipSAE_byres = np.zeros(mask_c1.sum()) rows_with_valid = n_valid_per_row > 0 if rows_with_valid.any(): masked_ptm = np.where(valid_mask, ptm_sub_res, 0.0) ipSAE_byres[rows_with_valid] = ( masked_ptm[rows_with_valid].sum(axis=1) / n_valid_per_row[rows_with_valid] ) ipTM = np.max(ipTM_byres) ipSAE = np.max(ipSAE_byres) if ipSAE_byres.size > 0 else 0.0 return ipTM, ipSAE
[docs] def get_max_values(self) -> None: """Extract maximum scores for undirected chain pairs. Because some scores like ipSAE are asymmetric (A->B != B->A), takes the maximum score for either direction as the undirected score. """ self.scores = ( self.df.with_columns( pl.when(pl.col('chain1') < pl.col('chain2')) .then(pl.concat_str(['chain1', 'chain2'], separator='_')) .otherwise(pl.concat_str(['chain2', 'chain1'], separator='_')) .alias('pair_key') ) .sort('ipSAE', descending=True) .unique(subset=['pair_key'], keep='first') .drop('pair_key') )
[docs] def permute_chains(self) -> None: """Generate all permutations of chain pairs. Creates all unique ordered pairs of chains, excluding self-pairs. """ self.permuted = list(permutations(self.unique_chains, 2))
[docs] @staticmethod def pDockQ_score(x: float) -> float: """Compute pDockQ score. Formula: pDockQ = 0.724 / (1 + exp(-0.052 * (x - 152.611))) + 0.018 Reference: https://doi.org/10.1038/s41467-022-28865-w Args: x: Mean pLDDT scaled by log10 of the number of residue pairs meeting pLDDT and distance cutoffs. Returns: pDockQ score. """ return 0.724 / (1 + np.exp(-0.052 * (x - 152.611))) + 0.018
[docs] @staticmethod def pDockQ2_score(x: float) -> float: """Compute pDockQ2 score. Formula: pDockQ2 = 1.31 / (1 + exp(-0.075 * (x - 84.733))) + 0.005 Reference: https://doi.org/10.1093/bioinformatics/btad424 Args: x: Mean pLDDT scaled by mean PAE score. Returns: pDockQ2 score. """ return 1.31 / (1 + np.exp(-0.075 * (x - 84.733))) + 0.005
[docs] @staticmethod def compute_pTM(x: np.ndarray, d0: float) -> np.ndarray: """Compute pTM score. Formula: pTM = 1.0 / (1 + (x / d0)^2) Args: x: PAE values (scalar or array). d0: Distance parameter from compute_d0. Returns: pTM score(s). """ return 1.0 / (1 + (x / d0) ** 2)
[docs] @staticmethod def compute_d0(L: int, pair_type: str) -> float: """Compute d0 parameter for pTM calculation. Formula: d0 = max(min_value, 1.24 * (L - 15)^(1/3) - 1.8) Args: L: Sequence length (minimum 27). pair_type: 'protein' or 'nucleic_acid'. Returns: d0 parameter value. """ L = max(27, L) min_value = 1.0 if pair_type == 'nucleic_acid': min_value = 2.0 return max(min_value, 1.24 * (L - 15) ** (1 / 3) - 1.8)
[docs] @staticmethod def compute_d0_array(L: np.ndarray, pair_type: str) -> np.ndarray: """Compute d0 parameter for an array of sequence lengths. Vectorized version of compute_d0 for per-residue d0 calculation used in ipSAE scoring. Args: L: Array of sequence lengths. pair_type: 'protein' or 'nucleic_acid'. Returns: Array of d0 parameter values. """ L = np.asarray(L, dtype=float) L = np.maximum(26.0, L) min_value = 1.0 if pair_type == 'nucleic_acid': min_value = 2.0 return np.maximum(min_value, 1.24 * (L - 15) ** (1 / 3) - 1.8)
[docs] class ModelParser: """Parse structure files to extract residue and atom information. Handles both PDB and CIF format files, extracting C-alpha, C-beta, and nucleic acid backbone atom coordinates. Attributes: structure: Path to the structure file. protein_token_indices: pLDDT/PAE indices for kept-chain anchor tokens; populated by :meth:`build_protein_token_indices`. residues: List of dictionaries containing residue information. chains: List of chain IDs for each residue. chain_types: Dictionary mapping chain ID to type after classify_chains(). Args: structure: Path to PDB or CIF file. Example: >>> parser = ModelParser("model.pdb") >>> parser.parse_structure_file() >>> parser.classify_chains() """
[docs] def __init__( self, structure: PathLike, skip_chains: set[str] | list[str] | None = None, ): """Initialize the ModelParser. Args: structure: Path to PDB or CIF file. skip_chains: Chain IDs to exclude (e.g. glycan/ligand chains whose per-atom token counts in the pLDDT/PAE arrays aren't recoverable from the structure file alone). """ self.structure = Path(structure) self.skip_chains = set(skip_chains) if skip_chains else set() self.residues = [] self.cb_residues = [] self.chains = [] self.chain_order: list[str] = [] self.chain_residue_counts: dict[str, int] = {} self.protein_token_indices: list[int] = []
[docs] def parse_structure_file(self) -> None: """Parse the structure file and extract atom/residue data. Atoms in ``skip_chains`` are ignored. Chain order is preserved so that :meth:`build_protein_token_indices` can later infer the token span of skipped chains by arithmetic. """ if self.structure.suffix == '.pdb': line_parser = self.parse_pdb_line else: line_parser = self.parse_cif_line field_num = 0 fields = {} with open(self.structure) as f: lines = f.readlines() chain_residues: dict[str, list[dict[str, Any]]] = {} chain_cb: dict[str, list[dict[str, Any]]] = {} seen_chains: set[str] = set() for line in lines: if line.startswith('_atom_site.'): _, field_name = line.strip().split('.') fields[field_name] = field_num field_num += 1 if line.startswith(('ATOM', 'HETATM')): atom = line_parser(line, fields, allow_missing_seq_id=True) if atom is None: continue cid = atom['chain_id'] if cid not in seen_chains: seen_chains.add(cid) self.chain_order.append(cid) chain_residues[cid] = [] chain_cb[cid] = [] if cid in self.skip_chains: continue name = atom['atom_name'] is_standard = atom['res'] in self.STANDARD_RESIDUES if is_standard and (name == 'CA' or 'C1' in name): chain_residues[cid].append(atom) if is_standard and ( name == 'CB' or 'C3' in name or (atom['res'] == 'GLY' and name == 'CA') ): chain_cb[cid].append(atom) for cid in self.chain_order: if cid in self.skip_chains: continue residues = chain_residues[cid] self.residues.extend(residues) self.cb_residues.extend(chain_cb[cid]) self.chains.extend([cid] * len(residues)) self.chain_residue_counts[cid] = len(residues)
[docs] def build_protein_token_indices(self, total_tokens: int) -> None: """Derive pLDDT/PAE indices for kept-chain anchor tokens. Each kept chain contributes one token per residue; any run of consecutive skipped chains occupies the token span left over between kept chains. Solvable when skipped chains form at most one contiguous block in ``chain_order`` — the default layout emitted by Chai, Boltz, and AlphaFold (polymers first, then non-polymer chains). Args: total_tokens: Length of the pLDDT array (== PAE dim). Raises: ValueError: If skipped chains appear in multiple non-contiguous runs, making per-run sizes ambiguous. """ runs: list[tuple[bool, list[str]]] = [] for cid in self.chain_order: is_kept = cid not in self.skip_chains if runs and runs[-1][0] == is_kept: runs[-1][1].append(cid) else: runs.append((is_kept, [cid])) skipped_runs = [r for r in runs if not r[0]] total_kept_tokens = sum(self.chain_residue_counts.values()) skipped_span = total_tokens - total_kept_tokens if len(skipped_runs) > 1: raise ValueError( f'Cannot infer token layout: skipped chains appear in ' f'{len(skipped_runs)} non-contiguous runs. Reorder input ' 'so non-polymer chains are grouped.' ) if skipped_span < 0: raise ValueError( f'Kept residue count ({total_kept_tokens}) exceeds ' f'total tokens ({total_tokens}); structure and score ' 'files are inconsistent.' ) indices: list[int] = [] offset = 0 for is_kept, chain_ids in runs: if is_kept: for cid in chain_ids: n = self.chain_residue_counts[cid] indices.extend(range(offset, offset + n)) offset += n else: offset += skipped_span self.protein_token_indices = indices
[docs] def classify_chains(self) -> None: """Classify chains as protein or nucleic acid. Reads through residue data to assign chain identity based on whether nucleic acid residues are detected. """ self.residue_types = np.array([res['res'] for res in self.residues]) unique_chains = np.unique(self.chains) chains_array = np.array(self.chains) self.chain_types = {chain: 'protein' for chain in unique_chains} for chain in unique_chains: indices = np.where(chains_array == chain)[0] chain_residues = self.residue_types[indices] if set(chain_residues) & self.NUCLEIC_ACIDS: self.chain_types[chain] = 'nucleic_acid'
NUCLEIC_ACIDS = frozenset(['DA', 'DC', 'DT', 'DG', 'A', 'C', 'U', 'G']) STANDARD_RESIDUES = frozenset([ 'ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'DA', 'DC', 'DT', 'DG', 'A', 'C', 'U', 'G', ])
[docs] @staticmethod def parse_pdb_line(line: str, *args, **kwargs) -> dict[str, Any]: """Parse a single line of a PDB file. Args: line: Line from the PDB file. *args: Unused, for API compatibility with parse_cif_line. Returns: Dictionary with atom/residue information. """ atom_num = line[6:11].strip() atom_name = line[12:16].strip() residue_name = line[17:20].strip() chain_id = line[21] residue_id = line[22:26].strip() x = line[30:38].strip() y = line[38:46].strip() z = line[46:54].strip() return ModelParser.package_line( atom_num, atom_name, residue_name, chain_id, residue_id, x, y, z )
[docs] @staticmethod def parse_cif_line( line: str, fields: dict[str, int], allow_missing_seq_id: bool = False, ) -> dict[str, Any] | None: """Parse a single line of a CIF file. Args: line: Line from the CIF file. fields: Dictionary mapping field names to column indices. allow_missing_seq_id: If True, fall back to auth_seq_id when label_seq_id is '.'. Used for non-polymer residues (ligands, glycans) which lack a label_seq_id. Returns: Dictionary with atom/residue information, or None if residue_id is missing and fallback is disabled. """ _split = line.split() atom_num = _split[fields['id']] atom_name = _split[fields['label_atom_id']] residue_name = _split[fields['label_comp_id']] chain_id = _split[fields['label_asym_id']] residue_id = _split[fields['label_seq_id']] x = _split[fields['Cartn_x']] y = _split[fields['Cartn_y']] z = _split[fields['Cartn_z']] if residue_id == '.': if not allow_missing_seq_id or 'auth_seq_id' not in fields: return None residue_id = _split[fields['auth_seq_id']] return ModelParser.package_line( atom_num, atom_name, residue_name, chain_id, residue_id, x, y, z )
[docs] @staticmethod def package_line( atom_num: str, atom_name: str, residue_name: str, chain_id: str, residue_id: str, x: str, y: str, z: str, ) -> dict[str, Any]: """Package parsed line data into a dictionary. Args: atom_num: Atom index. atom_name: Atom name (e.g., 'CA', 'CB'). residue_name: Residue name (e.g., 'ALA'). chain_id: Chain identifier. residue_id: Residue sequence number. x: X coordinate as string. y: Y coordinate as string. z: Z coordinate as string. Returns: Dictionary containing parsed atom/residue data. """ return { 'atom_num': int(atom_num), 'atom_name': atom_name, 'coor': np.array([x, y, z], dtype=float), 'res': residue_name, 'chain_id': chain_id, 'resid': int(residue_id), }