# ruff: noqa: RUF012
"""
ConstantPH - Constant pH simulation using AMBER prmtop/inpcrd files directly.
This implementation preserves all force field parameters from the AMBER topology,
including custom ligand parameters and Lipid21 modular lipids.
Key features:
- Uses AmberPrmtopFile for explicit solvent simulation (preserves all parameters)
- Uses ParmEd to create implicit solvent system WITH lipids and ligands
- MC energy evaluations include protein-lipid and protein-ligand interactions
- Only water and ions are stripped from the implicit system
"""
import contextlib
from collections import defaultdict
from collections.abc import Sequence
from copy import deepcopy
import numpy as np
import parmed as pmd
from openmm import Context, GBSAOBCForce, NonbondedForce, System
from openmm.app import (
OBC2,
AmberInpcrdFile,
AmberPrmtopFile,
ForceField,
GBn2,
HBonds,
Modeller,
NoCutoff,
Simulation,
element,
)
from openmm.app.forcefield import NonbondedGenerator
from openmm.app.internal import compiled # ty: ignore[unresolved-import]
from openmm.unit import (
MOLAR_GAS_CONSTANT_R,
elementary_charge,
is_quantity,
kelvin,
kilojoules_per_mole,
nanometers, # ty: ignore[unresolved-import]
)
[docs]
class ResidueState:
"""Stores parameters for a particular protonation state of a residue."""
[docs]
def __init__(
self,
residueIndex,
atomIndices,
particleParameters,
exceptionParameters,
numHydrogens,
):
self.residueIndex = residueIndex
self.atomIndices = atomIndices # {atom_name: atom_index}
self.particleParameters = (
particleParameters # {force_index: {atom_name: params}}
)
self.exceptionParameters = (
exceptionParameters # {force_index: {(res, a1, a2): params}}
)
self.numHydrogens = numHydrogens
[docs]
class ResidueTitration:
"""Manages titration states for a single residue."""
[docs]
def __init__(self, variants, referenceEnergies):
self.variants = variants
self.referenceEnergies = referenceEnergies
self.explicitStates = []
self.implicitStates = []
self.explicitHydrogenIndices = []
self.protonatedIndex = -1
self.currentIndex = -1
[docs]
class ConstantPH:
"""
Constant pH simulation using AMBER topology files directly.
This class enables constant pH molecular dynamics while preserving all force
field parameters from AMBER prmtop files, including custom ligand parameters
and Lipid21 modular lipids.
The approach:
1. Use AmberPrmtopFile.createSystem() for the explicit solvent simulation
2. Use ParmEd to create implicit solvent system WITH lipids and ligands
3. MC energy evaluations include protein-lipid and protein-ligand interactions
4. Only water and ions are stripped from the implicit system
5. Use OpenMM ForceField only for building protonation state parameters
Parameters
----------
prmtop_file : str or Path
Path to AMBER prmtop file
inpcrd_file : str or Path
Path to AMBER inpcrd file
pH : float or list
The pH value(s) for simulation. If a list, simulated tempering is used.
residueVariants : dict
Maps residue indices to lists of variant names.
Example: {10: ['ASP', 'ASH'], 15: ['GLU', 'GLH']}
referenceEnergies : dict
Maps residue indices to lists of reference energies (kJ/mol).
Example: {10: [0.0, 5.2], 15: [0.0, 4.8]}
relaxationSteps : int
Steps to relax solvent after accepting a protonation state change.
explicitArgs : dict
Arguments for createSystem() for explicit solvent.
implicitArgs : dict
Arguments for ParmEd createSystem() for implicit solvent.
Supports: implicitSolvent (OBC2, GBn2), solventDielectric, soluteDielectric
integrator : openmm.Integrator
Integrator for the main simulation.
relaxationIntegrator : openmm.Integrator
Integrator for solvent relaxation (frozen solute).
implicitForceField : openmm.app.ForceField, optional
ForceField for building protonation state parameters (protein only).
Defaults to amber14 + GBn2.
gbModel : str, optional
GB model for implicit solvent: 'OBC2' or 'GBn2'. Default: 'GBn2'
weights : list, optional
Simulated tempering weights. None = auto-determine via Wang-Landau.
platform : openmm.Platform, optional
Platform for simulation. None = auto-select.
properties : dict, optional
Platform-specific properties.
"""
# Standard residues that can be parameterized by OpenMM ForceField
PROTEIN_RESIDUES = {
'ALA',
'ARG',
'ASN',
'ASP',
'ASH',
'CYS',
'CYM',
'CYX',
'GLN',
'GLU',
'GLH',
'GLY',
'HIS',
'HID',
'HIE',
'HIP',
'ILE',
'LEU',
'LYS',
'LYN',
'MET',
'PHE',
'PRO',
'SER',
'THR',
'TRP',
'TYR',
'VAL',
'ACE',
'NME',
'NHE',
}
# Ion elements to exclude from implicit system
ION_ELEMENTS = (
element.cesium,
element.potassium,
element.lithium,
element.sodium,
element.rubidium,
element.chlorine,
element.bromine,
element.fluorine,
element.iodine,
)
# Water/ion residue names to strip (lipids and ligands are KEPT)
WATER_ION_NAMES = {
'HOH',
'WAT',
'Na+',
'Cl-',
'NA',
'CL',
'K+',
'K',
'SOD',
'CLA',
'POT',
'OPC',
'TIP3',
'SPC',
}
# Titratable hydrogen differences: maps (deprotonated, protonated) variant pairs
# to the number of hydrogens LOST when going from protonated to deprotonated
TITRATION_H_DIFF = {
# Aspartate: ASH (protonated) -> ASP (deprotonated), loses 1 H
('ASP', 'ASH'): -1,
('ASH', 'ASP'): 1,
# Glutamate: GLH (protonated) -> GLU (deprotonated), loses 1 H
('GLU', 'GLH'): -1,
('GLH', 'GLU'): 1,
# Histidine: HIP (doubly protonated) -> HID/HIE (singly protonated)
('HID', 'HIP'): -1,
('HIP', 'HID'): 1,
('HIE', 'HIP'): -1,
('HIP', 'HIE'): 1,
('HID', 'HIE'): 0,
('HIE', 'HID'): 0, # Same H count, different position
# Lysine: LYS (protonated) -> LYN (neutral), loses 1 H
('LYN', 'LYS'): -1,
('LYS', 'LYN'): 1,
# Cysteine: CYS (protonated thiol) -> CYM (deprotonated thiolate)
('CYM', 'CYS'): -1,
('CYS', 'CYM'): 1,
# Tyrosine: TYR (protonated) -> TYD (deprotonated phenolate)
('TYD', 'TYR'): -1,
('TYR', 'TYD'): 1,
}
[docs]
def __init__(
self,
prmtop_file,
inpcrd_file,
pH,
residueVariants,
referenceEnergies,
relaxationSteps,
explicitArgs,
implicitArgs,
integrator,
relaxationIntegrator,
implicitForceField=None,
excludeResidues=None,
gbModel='GBn2',
weights=None,
platform=None,
properties=None,
):
# Store file paths for ParmEd
self.prmtop_file = str(prmtop_file)
self.inpcrd_file = str(inpcrd_file)
# Load AMBER topology and coordinates
print('Loading AMBER topology...')
self.prmtop = AmberPrmtopFile(self.prmtop_file)
self.inpcrd = AmberInpcrdFile(self.inpcrd_file)
# Store parameters
self._explicitArgs = explicitArgs
self._implicitArgs = implicitArgs
self.relaxationSteps = relaxationSteps
self.gbModel = gbModel
# Set up pH
if not isinstance(pH, Sequence):
pH = [pH]
self.setPH(pH, weights)
self.currentPHIndex = 0
# excludeResidues is no longer used - we keep lipids and ligands!
# Only water and ions are stripped
if excludeResidues is not None:
print(' Note: excludeResidues parameter is deprecated.')
print(' Lipids and ligands are now INCLUDED in MC evaluations.')
# Set up implicit ForceField (for building protonation state params only)
if implicitForceField is None:
self.implicitForceField = ForceField('amber14-all.xml', 'implicit/gbn2.xml')
else:
self.implicitForceField = implicitForceField
# Initialize titration tracking
self.titrations = {}
for resIndex, variants in residueVariants.items():
energies = list(referenceEnergies[resIndex])
self.titrations[resIndex] = ResidueTitration(variants, energies)
# Build implicit system with lipids/ligands using ParmEd
print('Building implicit solvent system (includes lipids and ligands)...')
self._buildImplicitSystemWithParmEd()
# Build protein-only topology for protonation state parameters
print('Building protein-only topology for protonation states...')
self._buildProteinOnlyTopology(residueVariants)
# Build protonation states for each titratable residue
print('Computing protonation state parameters...')
self._buildProtonationStates(residueVariants)
# Map protonation states to implicit system (fix force indices)
print('Mapping protonation states to implicit system...')
self._mapStatesToImplicitSystem()
# Create the explicit system from AMBER topology
print('Creating explicit solvent system from AMBER topology...')
self._buildExplicitSystem(
integrator, relaxationIntegrator, platform, properties
)
# Map protonation states to explicit system
print('Mapping protonation states to explicit system...')
self._mapStatesToExplicitSystem()
print('ConstantPHAmber initialization complete.')
def _buildImplicitSystemWithParmEd(self):
"""
Build implicit solvent system using ParmEd, keeping lipids and ligands.
This method:
1. Loads the AMBER topology with ParmEd
2. Strips only water and ions (keeps protein, lipids, ligands)
3. Creates an implicit solvent System with GB
The resulting system preserves all AMBER parameters for lipids and ligands,
enabling accurate MC energy evaluations that include protein-membrane and
protein-ligand interactions.
"""
# Load structure with ParmEd
parm = pmd.load_file(self.prmtop_file, self.inpcrd_file)
# Select atoms to keep (negate water/ion selection)
# Use ParmEd's slice syntax to create a subset
keep_indices = []
for residue in parm.residues:
# Check if residue is water or ion
is_water_ion = residue.name in self.WATER_ION_NAMES
if not is_water_ion: # noqa: SIM102
# Also check for single-atom ions by element
if (len(residue.atoms) == 1 and
residue.atoms[0].element in [11, 17, 19, 35, 37, 55]): # Na, Cl, K, Br, Rb, Cs
is_water_ion = True
if not is_water_ion:
for atom in residue.atoms:
keep_indices.append(atom.idx)
# Create stripped structure by selecting atoms
stripped_parm = parm[keep_indices]
# Build residue index mapping (implicit <-> explicit)
self.implicitToExplicitResidueMap = []
self.explicitToImplicitResidueMap = {}
# Track which explicit residues were kept
explicit_residues = list(self.prmtop.topology.residues())
implicit_idx = 0
for explicit_idx, res in enumerate(explicit_residues):
# Check if this residue was stripped
if res.name not in self.WATER_ION_NAMES:
# Check for ions by element
atoms = list(res.atoms())
if len(atoms) == 1 and atoms[0].element in self.ION_ELEMENTS:
continue # Skip ions
self.implicitToExplicitResidueMap.append(explicit_idx)
self.explicitToImplicitResidueMap[explicit_idx] = implicit_idx
implicit_idx += 1
# Store the stripped ParmEd structure
self._strippedParm = stripped_parm
# Determine GB model
if self.gbModel == 'GBn2':
implicitSolvent = GBn2
elif self.gbModel == 'OBC2':
implicitSolvent = OBC2
else:
raise ValueError(f'Unknown GB model: {self.gbModel}. Use GBn2 or OBC2.')
# Create implicit system with GB using ParmEd
# This preserves all AMBER bonded parameters
solventDielectric = self._implicitArgs.get('solventDielectric', 78.5)
soluteDielectric = self._implicitArgs.get('soluteDielectric', 1.0)
self.implicitSystem = stripped_parm.createSystem(
nonbondedMethod=NoCutoff,
constraints=HBonds,
implicitSolvent=implicitSolvent,
soluteDielectric=soluteDielectric,
solventDielectric=solventDielectric,
removeCMMotion=False,
)
# Store topology and positions
self.implicitTopology = stripped_parm.topology
self.implicitPositions = stripped_parm.positions
# Count what we kept
n_protein = sum(
1 for r in stripped_parm.residues if r.name in self.PROTEIN_RESIDUES
)
n_lipid = sum(
1
for r in stripped_parm.residues
if r.name in {'PA', 'PC', 'PE', 'OL', 'GL'}
)
n_other = len(stripped_parm.residues) - n_protein - n_lipid
print(' Stripped water and ions only')
print(
f' Implicit system: {len(stripped_parm.residues)} residues, '
f'{len(stripped_parm.atoms)} atoms'
)
print(f' Protein: {n_protein} residues')
print(f' Lipids: {n_lipid} residues (PA/PC/PE/OL/GL)')
print(f' Other (ligands, etc.): {n_other} residues')
print(f' GB model: {self.gbModel}')
def _buildProteinOnlyTopology(self, residueVariants):
"""
Build a protein-only topology for computing protonation state parameters.
This is separate from the implicit system (which includes lipids/ligands).
We need protein-only to use OpenMM ForceField for variant parameterization.
"""
topology = self.prmtop.topology
positions = self.inpcrd.positions
# Identify non-protein residues to remove
residuesToRemove = []
self.proteinToExplicitResidueMap = []
self.explicitToProteinResidueMap = {}
removedCount = 0
for residue in topology.residues():
isProtein = residue.name in self.PROTEIN_RESIDUES
if not isProtein:
residuesToRemove.append(residue)
removedCount += 1
else:
proteinIndex = residue.index - removedCount
self.proteinToExplicitResidueMap.append(residue.index)
self.explicitToProteinResidueMap[residue.index] = proteinIndex
# Create protein-only topology using Modeller
modeller = Modeller(topology, positions)
modeller.delete(residuesToRemove)
self.proteinTopology = modeller.topology
self.proteinPositions = modeller.positions
print(
f' Protein-only topology: {self.proteinTopology.getNumResidues()} residues, '
f'{self.proteinTopology.getNumAtoms()} atoms'
)
def _buildProtonationStates(self, residueVariants):
"""
Build ResidueState objects for each protonation state.
Uses the protein-only topology with OpenMM ForceField to build
protonation state parameters. These parameters (charges, etc.) will
be applied to both the implicit system (with lipids/ligands) and
the explicit system.
"""
# We need to iterate through each variant index
variantIndex = 0
finished = False
# Use protein-only topology for ForceField parameterization
proteinVariants = [None] * self.proteinTopology.getNumResidues()
while not finished:
finished = True
# Set variants for this iteration
for proteinIndex, explicitIndex in enumerate(
self.proteinToExplicitResidueMap
):
if explicitIndex in residueVariants:
variants = residueVariants[explicitIndex]
if variantIndex < len(variants):
finished = False
proteinVariants[proteinIndex] = variants[variantIndex]
if finished:
break
# Build states using protein-only topology
proteinStates = self._findResidueStates(
self.proteinTopology,
self.proteinPositions,
self.implicitForceField,
proteinVariants,
self._implicitArgs,
)
# Add to ResidueTitration objects
for proteinState in proteinStates:
# Map protein residue index to explicit
proteinResIndex = proteinState.residueIndex
explicitResIndex = self.proteinToExplicitResidueMap[proteinResIndex]
if explicitResIndex in self.titrations:
titration = self.titrations[explicitResIndex]
if variantIndex < len(titration.variants):
# Update residue index to explicit system index
proteinState.residueIndex = explicitResIndex
titration.implicitStates.append(proteinState)
variantIndex += 1
# Identify the fully protonated state and fix numHydrogens for each titration
# This is necessary because AMBER prmtops already have hydrogens, so addHydrogens
# may not correctly adjust the topology for different variants
for resIndex, titration in self.titrations.items():
variants = titration.variants
# Find the protonated state based on variant name conventions
protonatedIdx = self._identifyProtonatedState(variants)
titration.protonatedIndex = protonatedIdx
# Get baseline hydrogen count from the protonated state
baselineH = titration.implicitStates[protonatedIdx].numHydrogens
# Update numHydrogens for each state based on variant differences
protonatedVariant = variants[protonatedIdx]
for i, (state, variant) in enumerate(
zip(titration.implicitStates, variants, strict=True)
):
if i == protonatedIdx:
continue
# Calculate H difference from protonated state
key = (variant, protonatedVariant)
if key in self.TITRATION_H_DIFF:
hDiff = self.TITRATION_H_DIFF[key]
state.numHydrogens = baselineH + hDiff
if (
state.numHydrogens
!= titration.implicitStates[protonatedIdx].numHydrogens
):
print(
f' Res {resIndex}: {variant} has {state.numHydrogens} H '
f'(vs {baselineH} for {protonatedVariant})'
)
titration.currentIndex = titration.protonatedIndex
def _identifyProtonatedState(self, variants):
"""Identify which variant index corresponds to the fully protonated state.
For standard titratable residues:
- ASH > ASP (ASH is protonated aspartate)
- GLH > GLU (GLH is protonated glutamate)
- HIP > HID, HIE (HIP is doubly protonated histidine)
- LYS > LYN (LYS is protonated lysine)
- CYS > CYM (CYS has the thiol proton)
- TYR > TYD (TYR has the phenolic proton)
"""
# Protonated forms (higher proton count)
PROTONATED_FORMS = {'ASH', 'GLH', 'HIP', 'LYS', 'CYS', 'TYR'}
# Find the variant with the most protons
for i, variant in enumerate(variants):
if variant in PROTONATED_FORMS:
return i
# If no protonated form found, check for intermediate forms
# For histidine: HID and HIE are equally protonated (singly)
for i, variant in enumerate(variants):
if variant in {'HID', 'HIE'}:
return i
# Fallback: assume first variant is protonated
# (This handles custom residue types)
print(
f' Warning: Could not identify protonated state for variants {variants}, '
f'assuming index 0'
)
return 0
def _mapStatesToImplicitSystem(self):
"""Map protonation state force indices from ForceField system to ParmEd implicit system.
This method:
1. Remaps force indices from ForceField-created system to ParmEd implicit system
2. Updates atom indices to use implicit system coordinates
3. Ensures all states have consistent atoms with ghost hydrogens
"""
# Find NonbondedForce and GBSAOBCForce indices in the ParmEd implicit system
implicitNBForceIdx = None
implicitGBForceIdx = None
for fi, force in enumerate(self.implicitSystem.getForces()):
if isinstance(force, NonbondedForce):
implicitNBForceIdx = fi
elif isinstance(force, GBSAOBCForce):
implicitGBForceIdx = fi
if implicitNBForceIdx is None:
raise RuntimeError('No NonbondedForce found in implicit system')
implicitNBForce = self.implicitSystem.getForce(implicitNBForceIdx)
implicitGBForce = (
self.implicitSystem.getForce(implicitGBForceIdx)
if implicitGBForceIdx is not None
else None
)
# Build a reference system to identify force types from ForceField
# We'll use the force field to create a temporary system and check force types
tempSystem = self.implicitForceField.createSystem(
self.proteinTopology, **self._implicitArgs
)
ffForceTypes = {}
for fi, force in enumerate(tempSystem.getForces()):
if isinstance(force, NonbondedForce):
ffForceTypes[fi] = 'NB'
elif isinstance(force, GBSAOBCForce):
ffForceTypes[fi] = 'GB'
for resIndex, titration in self.titrations.items():
for state in titration.implicitStates:
# Map ForceField force indices to implicit system force indices
ffNBForceIdx = None
ffGBForceIdx = None
for fi in state.particleParameters:
if fi in ffForceTypes:
if ffForceTypes[fi] == 'NB':
ffNBForceIdx = fi
elif ffForceTypes[fi] == 'GB':
ffGBForceIdx = fi
# Remap particle parameters to use implicit system force indices
newParticleParams = {}
if (
ffNBForceIdx is not None
and ffNBForceIdx in state.particleParameters
):
newParticleParams[implicitNBForceIdx] = state.particleParameters[
ffNBForceIdx
]
if (
ffGBForceIdx is not None
and implicitGBForceIdx is not None
and ffGBForceIdx in state.particleParameters
):
newParticleParams[implicitGBForceIdx] = state.particleParameters[
ffGBForceIdx
]
# Remap exception parameters
newExceptionParams = {}
if (
ffNBForceIdx is not None
and ffNBForceIdx in state.exceptionParameters
):
newExceptionParams[implicitNBForceIdx] = state.exceptionParameters[
ffNBForceIdx
]
# Update the state
state.particleParameters = newParticleParams
state.exceptionParameters = newExceptionParams
# Update atom indices to use implicit system indices
implicitResIdx = self.explicitToImplicitResidueMap.get(resIndex)
if implicitResIdx is not None:
implicitResidues = list(self.implicitTopology.residues())
if implicitResIdx < len(implicitResidues):
state.atomIndices = {
atom.name: atom.index
for atom in implicitResidues[implicitResIdx].atoms()
}
state.residueIndex = implicitResIdx
# Second pass: ensure all implicit states have consistent atoms with ghost hydrogens
protonated = titration.protonatedIndex
protonatedState = titration.implicitStates[protonated]
protonatedNBParams = protonatedState.particleParameters.get(
implicitNBForceIdx, {}
)
protonatedGBParams = (
protonatedState.particleParameters.get(implicitGBForceIdx, {})
if implicitGBForceIdx
else {}
)
protonatedExceptionParams = protonatedState.exceptionParameters.get(
implicitNBForceIdx, {}
)
for i, state in enumerate(titration.implicitStates):
if i == protonated:
continue
# Ensure NB parameters include all atoms from protonated state
stateNBParams = state.particleParameters.get(implicitNBForceIdx, {})
for atomName in protonatedNBParams:
if atomName not in stateNBParams:
originalParams = protonatedNBParams[atomName]
zeroParams = self._get_zero_parameters(
originalParams, implicitNBForce
)
stateNBParams[atomName] = zeroParams
state.particleParameters[implicitNBForceIdx] = stateNBParams
# Ensure GB parameters include all atoms from protonated state
if implicitGBForceIdx is not None and implicitGBForce is not None:
stateGBParams = state.particleParameters.get(implicitGBForceIdx, {})
for atomName in protonatedGBParams:
if atomName not in stateGBParams:
originalParams = protonatedGBParams[atomName]
zeroParams = self._get_zero_parameters(
originalParams, implicitGBForce
)
stateGBParams[atomName] = zeroParams
state.particleParameters[implicitGBForceIdx] = stateGBParams
# Handle exceptions for ghost atoms
stateExceptionParams = state.exceptionParameters.get(
implicitNBForceIdx, {}
)
for key in protonatedExceptionParams:
if key not in stateExceptionParams:
originalParams = protonatedExceptionParams[key]
stateExceptionParams[key] = (
0.0 * elementary_charge**2,
*originalParams[1:]
)
state.exceptionParameters[implicitNBForceIdx] = stateExceptionParams
def _findResidueStates(self, topology, positions, forcefield, variants, ffargs):
"""Build ResidueState objects for residues with specified variants."""
modeller = Modeller(topology, positions)
modeller.addHydrogens(forcefield=forcefield, variants=variants)
system = forcefield.createSystem(modeller.topology, **ffargs)
atoms = list(modeller.topology.atoms())
residues = list(modeller.topology.residues())
states = []
for residue, variant in zip(residues, variants, strict=True):
if variant is not None:
atomIndices = {atom.name: atom.index for atom in residue.atoms()}
particleParameters = {}
exceptionParameters = {}
for i, force in enumerate(system.getForces()):
with contextlib.suppress(Exception):
particleParameters[i] = {
atom.name: force.getParticleParameters(atom.index)
for atom in residue.atoms()
}
if isinstance(force, NonbondedForce):
exceptionParameters[i] = {}
for j in range(force.getNumExceptions()):
p1, p2, chargeProd, sigma, epsilon = (
force.getExceptionParameters(j)
)
atom1 = atoms[p1]
atom2 = atoms[p2]
if atom1.residue == residue and atom2.residue == residue:
exceptionParameters[i][
(residue.index, atom1.name, atom2.name)
] = (chargeProd, sigma, epsilon)
numHydrogens = sum(
1 for atom in residue.atoms() if atom.element == element.hydrogen
)
states.append(
ResidueState(
residue.index,
atomIndices,
particleParameters,
exceptionParameters,
numHydrogens,
)
)
return states
def _buildExplicitSystem(
self, integrator, relaxationIntegrator, platform, properties
):
"""Create the explicit solvent system from AMBER topology."""
# Create system preserving all AMBER parameters
self.explicitSystem = self.prmtop.createSystem(**self._explicitArgs)
self.explicitTopology = self.prmtop.topology
explicitPositions = self.inpcrd.positions
# Create relaxation system (frozen non-solvent)
relaxationSystem = deepcopy(self.explicitSystem)
for residue in self.explicitTopology.residues():
isWater = residue.name == 'HOH'
isIon = (
len(residue) == 1
and next(iter(residue.atoms())).element in self.ION_ELEMENTS
)
if not isWater and not isIon:
for atom in residue.atoms():
relaxationSystem.setParticleMass(atom.index, 0.0)
# Note: implicitSystem is already created by _buildImplicitSystemWithParmEd()
# It includes lipids and ligands with all AMBER parameters preserved
# Create simulation and contexts
self.simulation = Simulation(
self.explicitTopology,
self.explicitSystem,
deepcopy(integrator),
platform,
properties,
)
actualPlatform = self.simulation.context.getPlatform()
if properties is None:
self.implicitContext = Context(
self.implicitSystem, deepcopy(integrator), actualPlatform
)
self.relaxationContext = Context(
relaxationSystem, deepcopy(relaxationIntegrator), actualPlatform
)
else:
self.implicitContext = Context(
self.implicitSystem, deepcopy(integrator), actualPlatform, properties
)
self.relaxationContext = Context(
relaxationSystem,
deepcopy(relaxationIntegrator),
actualPlatform,
properties,
)
# Set positions
self.simulation.context.setPositions(explicitPositions)
self.relaxationContext.setPositions(explicitPositions)
# Set implicit positions (stripped system)
self.implicitContext.setPositions(self.implicitPositions)
# Record atom index mapping for copying positions
self._buildAtomIndexMapping()
# Record exception indices
self.explicitExceptionIndex = self._findExceptionIndices(
self.explicitSystem, self.explicitTopology
)
self.implicitExceptionIndex = self._findExceptionIndices(
self.implicitSystem, self.implicitTopology
)
self.explicitInterResidue14 = self._findInterResidue14(
self.explicitSystem, self.explicitTopology
)
self.implicitInterResidue14 = self._findInterResidue14(
self.implicitSystem, self.implicitTopology
)
# Record 1-4 scale factors (AMBER uses 1/1.2 = 0.8333 for Coulomb)
self.explicit14Scale = self._find14Scale(self.explicitSystem)
# For implicit system from ParmEd, use AMBER default
self.implicit14Scale = 1.0 / 1.2 # AMBER default Coulomb 1-4 scale
def _buildAtomIndexMapping(self):
"""Build mapping from implicit atom indices to explicit atom indices.
This maps each atom in the implicit (stripped) system to its corresponding
atom in the explicit system, enabling position copying between systems.
"""
numImplicitAtoms = self.implicitSystem.getNumParticles()
implicitAtomIndex = np.zeros(numImplicitAtoms, dtype=np.int64)
explicitResidues = list(self.explicitTopology.residues())
# ParmEd stripped structure
implicitResidues = self._strippedParm.residues
# Track atom offset for implicit system
implicitAtomOffset = 0
for implicitIndex, explicitIndex in enumerate(
self.implicitToExplicitResidueMap
):
explicitRes = explicitResidues[explicitIndex]
implicitRes = implicitResidues[implicitIndex]
# Build atom name -> index map for explicit residue
explicitAtoms = {atom.name: atom.index for atom in explicitRes.atoms()}
# Map implicit atoms to explicit atoms using enumeration
# This is more robust than using ParmEd's global atom.idx
for localIdx, atom in enumerate(implicitRes.atoms):
implicitIdx = implicitAtomOffset + localIdx
if atom.name in explicitAtoms:
implicitAtomIndex[implicitIdx] = explicitAtoms[atom.name]
else:
# Fallback: try to match by position in residue
atomList = list(explicitRes.atoms())
if localIdx < len(atomList):
implicitAtomIndex[implicitIdx] = atomList[localIdx].index
implicitAtomOffset += len(implicitRes.atoms)
self.implicitAtomIndex = implicitAtomIndex
def _mapStatesToExplicitSystem(self):
"""Map protonation state parameters from implicit to explicit system.
This method:
1. Creates explicit states for each protonation state
2. Ensures all states have the same atom indices (using ghost atoms with zero charge)
3. Tracks hydrogen indices for multi-site titrations
"""
explicitResidues = list(self.explicitTopology.residues())
# Find force indices in explicit system
explicitNBForceIdx = None
for fi, force in enumerate(self.explicitSystem.getForces()):
if isinstance(force, NonbondedForce):
explicitNBForceIdx = fi
break
if explicitNBForceIdx is None:
raise RuntimeError('No NonbondedForce found in explicit system')
explicitNBForce = self.explicitSystem.getForce(explicitNBForceIdx)
for resIndex, titration in self.titrations.items():
protonated = titration.protonatedIndex
# Get atom indices from the explicit topology
explicitAtomIndices = {
atom.name: atom.index for atom in explicitResidues[resIndex].atoms()
}
# First pass: build all explicit states with their original parameters
for implicitState in titration.implicitStates:
# Get implicit NonbondedForce index
implicitNBForceIdx = None
for fi in implicitState.particleParameters:
force = self.implicitSystem.getForce(fi)
if isinstance(force, NonbondedForce):
implicitNBForceIdx = fi
break
# Build explicit state by mapping parameters
explicitParticleParams = {explicitNBForceIdx: {}}
explicitExceptionParams = {explicitNBForceIdx: {}}
# Map particle parameters
if (
implicitNBForceIdx is not None
and implicitNBForceIdx in implicitState.particleParameters
):
for atomName, params in implicitState.particleParameters[
implicitNBForceIdx
].items():
if atomName in explicitAtomIndices:
explicitParticleParams[explicitNBForceIdx][atomName] = (
params
)
# Map exception parameters
if (
implicitNBForceIdx is not None
and implicitNBForceIdx in implicitState.exceptionParameters
):
for key, params in implicitState.exceptionParameters[
implicitNBForceIdx
].items():
# Convert key from implicit to explicit residue index
newKey = (resIndex, key[1], key[2])
explicitExceptionParams[explicitNBForceIdx][newKey] = params
explicitState = ResidueState(
resIndex,
explicitAtomIndices,
explicitParticleParams,
explicitExceptionParams,
implicitState.numHydrogens,
)
titration.explicitStates.append(explicitState)
# Second pass: ensure all states have consistent atoms with ghost hydrogens
# Get parameters from the fully protonated state (which has all atoms)
protonatedState = titration.explicitStates[protonated]
protonatedParams = protonatedState.particleParameters.get(
explicitNBForceIdx, {}
)
protonatedExceptionParams = protonatedState.exceptionParameters.get(
explicitNBForceIdx, {}
)
for i, state in enumerate(titration.explicitStates):
if i == protonated:
continue
# For each atom in the protonated state, ensure this state has it too
stateParams = state.particleParameters.get(explicitNBForceIdx, {})
stateExceptionParams = state.exceptionParameters.get(
explicitNBForceIdx, {}
)
for atomName in protonatedParams:
if atomName not in stateParams:
# This atom doesn't exist in this protonation state
# Use zero parameters (ghost atom)
originalParams = protonatedParams[atomName]
zeroParams = self._get_zero_parameters(
originalParams, explicitNBForce
)
stateParams[atomName] = zeroParams
# Track this as a titratable hydrogen for multi-site moves
atomIdx = explicitAtomIndices.get(atomName)
if atomIdx is not None:
# Check if it's a hydrogen
for atom in explicitResidues[resIndex].atoms():
if (
atom.name == atomName
and atom.element == element.hydrogen
):
if atomIdx not in titration.explicitHydrogenIndices:
titration.explicitHydrogenIndices.append(
atomIdx
)
break
# Handle exceptions for ghost atoms
for key in protonatedExceptionParams:
if key not in stateExceptionParams:
# Zero out the charge product for this exception
originalParams = protonatedExceptionParams[key]
stateExceptionParams[key] = (0.0 * elementary_charge**2, *originalParams[1:])
# Update the state's parameters
state.particleParameters[explicitNBForceIdx] = stateParams
state.exceptionParameters[explicitNBForceIdx] = stateExceptionParams
def _findExceptionIndices(self, system, topology):
"""Map (residue, atom1, atom2) -> exception index in NonbondedForce."""
indices = {}
atoms = list(topology.atoms())
for force in system.getForces():
if isinstance(force, NonbondedForce):
for i in range(force.getNumExceptions()):
p1, p2, _chargeProd, _sigma, _epsilon = force.getExceptionParameters(i)
atom1 = atoms[p1]
atom2 = atoms[p2]
if atom1.residue == atom2.residue:
indices[(atom1.residue.index, atom1.name, atom2.name)] = i
indices[(atom1.residue.index, atom2.name, atom1.name)] = i
return indices
def _findInterResidue14(self, system, topology):
"""Find 1-4 exceptions that span residues for each titratable residue."""
indices = defaultdict(list)
atoms = list(topology.atoms())
for force in system.getForces():
if isinstance(force, NonbondedForce):
for i in range(force.getNumExceptions()):
p1, p2, chargeProd, _sigma, _epsilon = force.getExceptionParameters(i)
atom1 = atoms[p1]
atom2 = atoms[p2]
if (
atom1.residue != atom2.residue
and chargeProd.value_in_unit(elementary_charge**2) != 0.0
):
indices[atom1.residue.index].append(i)
indices[atom2.residue.index].append(i)
return indices
def _find14Scale(self, obj):
"""Find 1-4 Coulomb scale factor from ForceField or System."""
if isinstance(obj, ForceField):
for generator in obj.getGenerators():
if isinstance(generator, NonbondedGenerator):
return generator.coulomb14scale
elif isinstance(obj, System):
# AMBER default is 1/1.2 = 0.8333
return 1.0 / 1.2
return 1.0
[docs]
def setPH(self, pH, weights=None):
"""Set the pH value(s) for simulation."""
self.pH = pH
if weights is None:
self._weights = [0.0] * len(pH)
self._updateWeights = True
self._weightUpdateFactor = 1.0
self._histogram = [0] * len(pH)
self._hasMadeTransition = False
else:
self._weights = weights
self._updateWeights = False
@property
def weights(self):
"""Current simulated tempering weights."""
return [x - self._weights[0] for x in self._weights]
[docs]
def printTitrationState(self):
"""Print current state of all titrations for debugging."""
print(f'Current pH: {self.pH[self.currentPHIndex]:.2f}')
print(f'Titratable residues: {len(self.titrations)}')
for resIndex, titration in self.titrations.items():
current = titration.currentIndex
protonated = titration.protonatedIndex
print(
f' Res {resIndex}: currentState={current}, protonatedIdx={protonated}'
)
print(f' variants={titration.variants}')
print(f' refEnergies={titration.referenceEnergies}')
for i, (impl, expl) in enumerate(
zip(titration.implicitStates, titration.explicitStates, strict=True)
):
print(
f' state[{i}]: implicitH={impl.numHydrogens}, explicitH={expl.numHydrogens}, '
f'implAtoms={len(impl.atomIndices)}, explAtoms={len(expl.atomIndices)}'
)
[docs]
def validateStates(self):
"""Validate that titration states are properly set up."""
issues = []
for resIndex, titration in self.titrations.items():
# Check we have the right number of states
if len(titration.implicitStates) != len(titration.variants):
issues.append(
f'Res {resIndex}: {len(titration.implicitStates)} implicit states but {len(titration.variants)} variants'
)
if len(titration.explicitStates) != len(titration.variants):
issues.append(
f'Res {resIndex}: {len(titration.explicitStates)} explicit states but {len(titration.variants)} variants'
)
# Check numHydrogens differs between states
implicitHydrogens = [s.numHydrogens for s in titration.implicitStates]
explicitHydrogens = [s.numHydrogens for s in titration.explicitStates]
if len(set(implicitHydrogens)) == 1:
issues.append(
f'Res {resIndex}: all implicit states have same numHydrogens={implicitHydrogens[0]}'
)
if len(set(explicitHydrogens)) == 1:
issues.append(
f'Res {resIndex}: all explicit states have same numHydrogens={explicitHydrogens[0]}'
)
# Check protonated index makes sense
if titration.protonatedIndex >= len(titration.implicitStates):
issues.append(
f'Res {resIndex}: protonatedIndex={titration.protonatedIndex} out of range'
)
if issues:
print('State validation FAILED:')
for issue in issues:
print(f' - {issue}')
return False
else:
print('State validation PASSED')
return True
[docs]
def attemptMCStep(self, temperature, debug=False):
"""
Attempt Monte Carlo moves to change protonation states.
Parameters
----------
temperature : float or Quantity
Simulation temperature
debug : bool
If True, print debugging information about MC moves
"""
# Copy positions to implicit context
state = self.simulation.context.getState(getPositions=True, getParameters=True)
explicitPositions = state.getPositions(asNumpy=True).value_in_unit(nanometers)
# Map positions to implicit system
implicitPositions = explicitPositions[self.implicitAtomIndex]
self.implicitContext.setPositions(implicitPositions)
if debug:
if np.any(np.isnan(implicitPositions)):
print('WARNING: NaN in implicit positions!')
testEnergy = self.implicitContext.getState(
getEnergy=True
).getPotentialEnergy()
if np.isnan(testEnergy.value_in_unit(kilojoules_per_mole)):
print(
'WARNING: Implicit context energy is NaN after setting positions!'
)
print(f' Explicit pos has NaN: {np.any(np.isnan(explicitPositions))}')
print(f' Implicit pos has NaN: {np.any(np.isnan(implicitPositions))}')
print(
f' implicitAtomIndex range: {self.implicitAtomIndex.min()}-{self.implicitAtomIndex.max()}'
)
print(f' explicitPositions shape: {explicitPositions.shape}')
periodicDistance = compiled.periodicDistance(
state.getPeriodicBoxVectors().value_in_unit(nanometers)
)
# Attempt pH change if using simulated tempering
if len(self.pH) > 1:
self._attemptPHChange()
# Process residues in random order
anyChange = False
for resIndex in np.random.permutation(list(self.titrations)):
titrations = [self.titrations[resIndex]]
# Select new state
stateIndex = [self._selectNewState(titrations[0])]
# Occasionally attempt multi-site titration
if np.random.random() < 0.25:
neighbors = self._findNeighbors(
resIndex, explicitPositions, periodicDistance
)
if len(neighbors) > 0:
i = np.random.choice(neighbors)
titrations.append(self.titrations[i])
stateIndex.append(self._selectNewState(titrations[-1]))
# Compute implicit energy change
currentEnergy = self.implicitContext.getState(
getEnergy=True
).getPotentialEnergy()
for i, t in zip(stateIndex, titrations, strict=True):
self._applyStateToContext(
t.implicitStates[i],
self.implicitContext,
self.implicitExceptionIndex,
self.implicitInterResidue14,
self.implicit14Scale,
)
newEnergy = self.implicitContext.getState(
getEnergy=True
).getPotentialEnergy()
# Metropolis criterion
if not is_quantity(temperature):
temperature = temperature * kelvin
kT = MOLAR_GAS_CONSTANT_R * temperature
deltaRefEnergy = sum(
[
t.referenceEnergies[i] - t.referenceEnergies[t.currentIndex]
for i, t in zip(stateIndex, titrations, strict=True)
],
0.0 * kilojoules_per_mole,
)
deltaN = sum(
[
t.implicitStates[i].numHydrogens
- t.implicitStates[t.currentIndex].numHydrogens
for i, t in zip(stateIndex, titrations, strict=True)
]
)
w = (newEnergy - currentEnergy - deltaRefEnergy) / kT + deltaN * np.log(
10.0
) * self.pH[self.currentPHIndex]
if debug:
currE = currentEnergy.value_in_unit(kilojoules_per_mole)
newE = newEnergy.value_in_unit(kilojoules_per_mole)
dE = newE - currE
dRef = deltaRefEnergy.value_in_unit(kilojoules_per_mole)
if np.isnan(dE) or np.isnan(dRef):
print(
f' Residue {titrations[0].implicitStates[0].residueIndex}: '
f'NaN DETECTED! currE={currE}, newE={newE}, '
f'refs=[{titrations[0].referenceEnergies}]'
)
else:
print(
f' Residue {titrations[0].implicitStates[0].residueIndex}: '
f'state {titrations[0].currentIndex}->{stateIndex[0]}, '
f'deltaN={deltaN}, pH={self.pH[self.currentPHIndex]:.2f}, '
f'dE={dE:.2f} kJ/mol, '
f'dRef={dRef:.2f} kJ/mol, '
f'w={float(w):.3f}, '
f"accept={'yes' if w <= 0 else f'prob={np.exp(-float(w)):.4f}'}"
)
if w > 0.0 and np.exp(-w) < np.random.random():
# Reject: restore previous state
for t in titrations:
self._applyStateToContext(
t.implicitStates[t.currentIndex],
self.implicitContext,
self.implicitExceptionIndex,
self.implicitInterResidue14,
self.implicit14Scale,
)
continue
# Accept the move
anyChange = True
for i, t in zip(stateIndex, titrations, strict=True):
t.currentIndex = i
self._applyStateToContext(
t.explicitStates[i],
self.simulation.context,
self.explicitExceptionIndex,
self.explicitInterResidue14,
self.explicit14Scale,
)
self._applyStateToContext(
t.explicitStates[i],
self.relaxationContext,
self.explicitExceptionIndex,
self.explicitInterResidue14,
self.explicit14Scale,
)
# Relax solvent if any state changed
if anyChange:
self.relaxationContext.setPositions(explicitPositions)
self.relaxationContext.setPeriodicBoxVectors(*state.getPeriodicBoxVectors())
for param in self.relaxationContext.getParameters():
self.relaxationContext.setParameter(param, state.getParameters()[param])
self.relaxationContext.getIntegrator().step(self.relaxationSteps)
relaxedPositions = self.relaxationContext.getState(
getPositions=True
).getPositions(asNumpy=True)
self.simulation.context.setPositions(relaxedPositions)
[docs]
def setResidueState(self, residueIndex, stateIndex, relax=False):
"""Manually set a residue to a specific protonation state."""
titration = self.titrations[residueIndex]
self._applyStateToContext(
titration.explicitStates[stateIndex],
self.simulation.context,
self.explicitExceptionIndex,
self.explicitInterResidue14,
self.explicit14Scale,
)
self._applyStateToContext(
titration.explicitStates[stateIndex],
self.relaxationContext,
self.explicitExceptionIndex,
self.explicitInterResidue14,
self.explicit14Scale,
)
self._applyStateToContext(
titration.implicitStates[stateIndex],
self.implicitContext,
self.implicitExceptionIndex,
self.implicitInterResidue14,
self.implicit14Scale,
)
titration.currentIndex = stateIndex
if relax:
positions = self.simulation.context.getState(
getPositions=True
).getPositions(asNumpy=True)
self.relaxationContext.setPositions(positions)
self.relaxationContext.getIntegrator().step(self.relaxationSteps)
self.simulation.context.setPositions(
self.relaxationContext.getState(getPositions=True).getPositions(
asNumpy=True
)
)
def _get_zero_parameters(self, original_parameters, force):
"""Get per-particle parameter values with charge set to 0.
This is used for "ghost" hydrogens that exist in the fully protonated
state but not in deprotonated states.
"""
p = list(original_parameters)
if isinstance(force, (NonbondedForce, GBSAOBCForce)):
# First parameter is charge for both NonbondedForce and GBSAOBCForce
p[0] = 0.0
else:
# For custom forces, find the charge parameter by name
for i in range(force.getNumPerParticleParameters()):
if force.getPerParticleParameterName(i) == 'charge':
p[i] = 0.0
return tuple(p)
def _applyStateToContext(
self, state, context, exceptionIndex, interResidue14, coulomb14Scale
):
"""Update context parameters to match a protonation state.
This modifies Force parameters in the System and then calls
updateParametersInContext() to push changes to the GPU/CPU context.
"""
for forceIndex, params in state.particleParameters.items():
force = context.getSystem().getForce(forceIndex)
# Only NonbondedForce and GBSAOBCForce support setParticleParameters
if not isinstance(force, (NonbondedForce, GBSAOBCForce)):
continue
for atomName, atomParams in params.items():
if atomName not in state.atomIndices:
continue
atomIndex = state.atomIndices[atomName]
try:
force.setParticleParameters(atomIndex, atomParams)
except TypeError:
force.setParticleParameters(atomIndex, *atomParams)
if isinstance(force, NonbondedForce):
# Update intra-residue exceptions
for key, exceptionParams in state.exceptionParameters.get(
forceIndex, {}
).items():
if key in exceptionIndex:
p = force.getExceptionParameters(exceptionIndex[key])
force.setExceptionParameters(
exceptionIndex[key], p[0], p[1], *exceptionParams
)
# Update inter-residue 1-4 interactions
for index in interResidue14.get(state.residueIndex, []):
p1, p2, _, sigma, epsilon = force.getExceptionParameters(index)
q1, _, _ = force.getParticleParameters(p1)
q2, _, _ = force.getParticleParameters(p2)
q1_val = (
q1.value_in_unit(elementary_charge)
if hasattr(q1, 'value_in_unit')
else float(q1)
)
q2_val = (
q2.value_in_unit(elementary_charge)
if hasattr(q2, 'value_in_unit')
else float(q2)
)
chargeProd = coulomb14Scale * q1_val * q2_val * elementary_charge**2
force.setExceptionParameters(
index, p1, p2, chargeProd, sigma, epsilon
)
# Update parameters in context for this force
# Note: no need to call context.reinitialize() - updateParametersInContext is sufficient
force.updateParametersInContext(context)
def _selectNewState(self, titration):
"""Randomly select a new protonation state."""
numStates = len(titration.implicitStates)
if numStates == 2:
return 1 - titration.currentIndex
stateIndex = titration.currentIndex
while stateIndex == titration.currentIndex:
stateIndex = np.random.randint(numStates)
return stateIndex
def _findNeighbors(self, resIndex, explicitPositions, periodicDistance):
"""Find nearby titratable residues for multi-site moves."""
neighbors = []
titration1 = self.titrations[resIndex]
for resIndex2 in self.titrations:
if resIndex2 > resIndex:
titration2 = self.titrations[resIndex2]
isNeighbor = False
for i in titration1.explicitHydrogenIndices:
for j in titration2.explicitHydrogenIndices:
if (i < len(explicitPositions) and
j < len(explicitPositions) and
periodicDistance(explicitPositions[i], explicitPositions[j]) < 0.2):
isNeighbor = True
if isNeighbor:
neighbors.append(resIndex2)
return neighbors
def _attemptPHChange(self):
"""Attempt to change pH (simulated tempering)."""
hydrogens = sum(
t.explicitStates[t.currentIndex].numHydrogens
for t in self.titrations.values()
)
logProbability = [
self._weights[i] - hydrogens * np.log(10.0) * self.pH[i]
for i in range(len(self._weights))
]
maxLogProb = max(logProbability)
offset = maxLogProb + np.log(
sum(np.exp(x - maxLogProb) for x in logProbability)
)
probability = [np.exp(x - offset) for x in logProbability]
r = np.random.random_sample()
for j in range(len(probability)):
if r < probability[j]:
if j != self.currentPHIndex:
self._hasMadeTransition = True
self.currentPHIndex = j
if self._updateWeights:
self._weights[j] -= self._weightUpdateFactor
self._histogram[j] += 1
minCounts = min(self._histogram)
if minCounts > 20 and minCounts >= 0.2 * sum(self._histogram) / len(
self._histogram
):
self._weightUpdateFactor *= 0.5
self._histogram = [0] * len(self.pH)
self._weights = [x - self._weights[0] for x in self._weights]
elif (
not self._hasMadeTransition
and probability[self.currentPHIndex] > 0.99
and self._weightUpdateFactor < 1024.0
):
self._weightUpdateFactor *= 2.0
self._histogram = [0] * len(self.pH)
return
r -= probability[j]