# ruff: noqa: N999
"""
Improved constant pH analysis with UWHAM reweighting.
This implementation adds multistate analysis capabilities to the basic
curve fitting approach. Uses log-space arithmetic for numerical stability.
"""
from __future__ import annotations
import ast
import re
import warnings
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import numpy.typing as npt
import polars as pl
from scipy.optimize import curve_fit
from scipy.special import logsumexp
if TYPE_CHECKING:
import matplotlib.pyplot as plt
[docs]
class UWHAMSolver:
"""
Unbinned Weighted Histogram Analysis Method (UWHAM) solver.
NOTE: This class is NOT currently used because UWHAM/MBAR is designed
for umbrella sampling and replica exchange, not independent constant pH
simulations. For standard constant pH MD where each pH is an independent
equilibrium simulation, simple curve fitting is the correct approach.
This class is retained for potential use with replica exchange constant
pH (REX-cpH) simulations where samples ARE correlated across pH values.
Uses log-space arithmetic throughout for numerical stability with
large systems (100+ titratable residues).
"""
[docs]
def __init__(self, tol: float = 1e-7, maxiter: int = 10000):
self.tol = tol
self.maxiter = maxiter
self.f = None # Log of normalization constants (will be solved)
self.log10 = np.log(10)
[docs]
def load_data(self, df: pl.DataFrame, resid_cols: list[str]):
"""
Load data from polars DataFrame into UWHAM-compatible format.
Parameters
----------
df : pl.DataFrame
DataFrame with columns: rankid, current_pH, and residue columns
Residue columns should contain numeric protonation states (0 or 1)
resid_cols : List[str]
List of column names corresponding to residue IDs
"""
# Get unique pH values and count samples
pH_groups = df.group_by('current_pH').agg(pl.len().alias('count'))
self.pH_values = pH_groups['current_pH'].to_numpy()
self.nsamples = pH_groups['count'].to_numpy().astype(int)
self.nstates = len(self.pH_values)
# Sort by pH for consistency
sort_idx = np.argsort(self.pH_values)
self.pH_values = self.pH_values[sort_idx]
self.nsamples = self.nsamples[sort_idx]
# Store state data for each pH simulation
self.states = {} # resid -> list of arrays (one per pH)
self.nprotons_total = [] # Total protons for each pH simulation
for resid_col in resid_cols:
self.states[resid_col] = []
# Extract data for each pH
for pH in self.pH_values:
pH_data = df.filter(pl.col('current_pH') == pH)
# Compute total protons for this pH's samples
total_protons = np.zeros(len(pH_data))
# For each residue, store states
for resid_col in resid_cols:
states = pH_data[resid_col].to_numpy().astype(float)
self.states[resid_col].append(states)
total_protons += states
self.nprotons_total.append(total_protons)
# Precompute reduced potentials for all state pairs
# u_kl[k] is shape (nstates, n_k) - reduced potential of samples from k evaluated at all states
self.u_kl = []
for k in range(self.nstates):
n_k = self.nsamples[k]
u_k = np.zeros((self.nstates, n_k))
for l in range(self.nstates): # noqa: E741
u_k[l, :] = self.log10 * self.pH_values[l] * self.nprotons_total[k]
self.u_kl.append(u_k)
[docs]
def solve(self, verbose: bool = False):
"""
Solve UWHAM self-consistent equations iteratively.
Uses the MBAR equation:
f_k = -log(Σ_n exp(-u_k(x_n)) / Σ_l N_l exp(f_l - u_l(x_n)))
where the sum over n includes ALL samples from ALL states.
Returns
-------
f : np.ndarray
Free energy offsets for each pH simulation
"""
# Initialize free energies
f = np.zeros(self.nstates)
log_N = np.log(self.nsamples.astype(float))
total_samples = sum(self.nsamples)
# Precompute reduced potentials for all samples at all target states
# u_all[target_k, sample_idx] = reduced potential at state k for sample idx
# Also store source state for each sample
u_all = np.zeros((self.nstates, total_samples))
sample_source = np.zeros(
total_samples, dtype=int
) # which state each sample came from
idx = 0
for source_i in range(self.nstates):
n_i = self.nsamples[source_i]
for n in range(n_i):
sample_source[idx] = source_i
for target_k in range(self.nstates):
# u_k(x_n) = log10 * pH_k * nprotons(x_n)
u_all[target_k, idx] = (
self.log10
* self.pH_values[target_k]
* self.nprotons_total[source_i][n]
)
idx += 1
for iteration in range(self.maxiter):
f_old = f.copy()
# Compute denominator for each sample: c_n = Σ_l N_l exp(f_l - u_l(x_n))
# log(c_n) = logsumexp(log_N + f - u_l(x_n))
log_c = np.zeros(total_samples)
for n in range(total_samples):
source_i = sample_source[n]
# u_l(x_n) for all states l - this is stored in u_kl[source_i]
log_c[n] = logsumexp(
log_N + f - self.u_kl[source_i][:, n % self.nsamples[source_i]]
)
# Wait, that indexing is wrong. Let me redo this.
# Actually I need to recompute using proper indexing
log_c = np.zeros(total_samples)
idx = 0
for source_i in range(self.nstates):
n_i = self.nsamples[source_i]
for local_n in range(n_i):
# u_l(x_n) for all states l
log_c[idx] = logsumexp(log_N + f - self.u_kl[source_i][:, local_n])
idx += 1
# Update each free energy
for target_k in range(self.nstates):
# f_k = -log(Σ_n exp(-u_k(x_n)) / c_n)
# = -log(Σ_n exp(-u_k(x_n) - log(c_n)))
# = -logsumexp(-u_all[target_k, :] - log_c)
log_weights = -u_all[target_k, :] - log_c
f[target_k] = -logsumexp(log_weights)
# Normalize so f[0] = 0
f = f - f[0]
# Check convergence
delta = np.abs(f - f_old).max()
if verbose and iteration % 100 == 0:
print(f' Iteration {iteration}: max|Δf| = {delta:.2e}')
if delta < self.tol:
if verbose:
print(f' Converged after {iteration + 1} iterations')
break
else:
warnings.warn(
f'UWHAM did not converge after {self.maxiter} iterations '
f'(final delta = {delta:.2e})',
stacklevel=2
)
self.f = f
self.log_c = log_c # Store for weight computation
self.u_all = u_all # Store for weight computation
self.sample_source = sample_source
self.total_samples = total_samples
return f
[docs]
def compute_log_weights(self, target_pH: float) -> tuple[np.ndarray, float]:
"""
Compute log weights for reweighting to target pH.
Uses MBAR formula:
w_n ∝ exp(-u_target(x_n)) / Σ_l N_l exp(f_l - u_l(x_n))
Returns
-------
log_weights : np.ndarray
Log weights for all samples (flattened)
log_norm : float
Log of the normalization constant
"""
if self.f is None:
raise RuntimeError('Must call solve() before computing weights')
# Compute reduced potential at target pH for all samples
u_target = np.zeros(self.total_samples)
idx = 0
for source_i in range(self.nstates):
n_i = self.nsamples[source_i]
for local_n in range(n_i):
u_target[idx] = (
self.log10 * target_pH * self.nprotons_total[source_i][local_n]
)
idx += 1
# log(w_n) = -u_target(x_n) - log(c_n)
# where log(c_n) was precomputed in solve()
log_weights = -u_target - self.log_c
# Normalize
log_norm = logsumexp(log_weights)
return log_weights, log_norm
[docs]
def compute_expectation_at_pH(
self, observable_by_state: list[np.ndarray], target_pH: float
) -> float:
"""
Compute expectation value of observable at arbitrary pH.
Parameters
----------
observable_by_state : List[np.ndarray]
Observable values for each sample, organized by state index
target_pH : float
pH value at which to compute the expectation
Returns
-------
expectation : float
Reweighted expectation value at target_pH
"""
log_weights, log_norm = self.compute_log_weights(target_pH)
# Flatten observable to match log_weights ordering
obs_flat = np.concatenate(observable_by_state)
# Compute weighted sum
# <A> = Σ_n A_n * w_n / Σ_n w_n
# = Σ_n A_n * exp(log_w_n - log_norm)
weights = np.exp(log_weights - log_norm)
return np.sum(obs_flat * weights)
[docs]
def get_occupancy_for_resid(self, resid: str) -> list[np.ndarray]:
"""Get occupancy arrays for a specific residue across all pH values."""
return self.states[resid]
[docs]
class TitrationCurve:
"""
Analyze constant pH simulations with multiple fitting methods.
Available methods:
- curvefit: Simple least squares fit of Hill equation to per-pH averages
- weighted: Weighted least squares (weight by 1/variance)
- bootstrap: Curve fitting with bootstrap confidence intervals
Note: For independent constant pH simulations (not replica exchange),
simple curve fitting is the statistically correct approach. UWHAM/MBAR
is only appropriate for replica exchange constant pH where samples
are correlated across pH values.
"""
[docs]
def __init__(
self,
log_file: Path | list[Path],
make_plots: bool = True,
out: Path = Path('.'),
method: str = 'uwham', # 'curvefit' or 'uwham'
):
if isinstance(log_file, list):
dfs = []
resids = None
for log in log_file:
df, r = self.parse_log(log)
dfs.append(df)
if resids is None:
resids = r
self.df = pl.concat(dfs, how='vertical')
else:
self.df, resids = self.parse_log(log_file)
# Store residue IDs (converted to strings to match column names)
assert resids is not None, "No residue IDs found in any log file"
self.resid_cols = [str(r) for r in resids]
self.make_plots = make_plots
self.out = out
self.method = method
[docs]
@staticmethod
def parse_log(log: Path) -> tuple[pl.DataFrame, list[int]]:
"""Parse OpenMM constant pH log file.
Returns
-------
df : pl.DataFrame
DataFrame with columns: rankid, current_pH, and one column per residue
resids : List[int]
List of residue IDs in order
"""
lines = log.read_text().splitlines()
resids = None
# Header format: "cpH: resids 20 76 83 92 ..."
header_re = re.compile(r'cpH:\s+resids\s+(.+)$')
# Find header with residue IDs
for line in lines:
m = header_re.search(line)
if m:
# Residue IDs are separated by whitespace (possibly multiple spaces)
resids = [int(x) for x in m.group(1).split()]
break
if resids is None:
raise RuntimeError(
'Could not find cpH residue ID header line in log. '
'Expected line containing "cpH: resids ..."'
)
# Parse state lines
state_re = re.compile(r'rank=(\d+).*cpH:\s+pH\s+([0-9.]+):\s+(\[.*\])')
rows = []
for line in lines:
m = state_re.search(line)
if not m:
continue
rank = int(m.group(1))
current_pH = float(m.group(2))
states_list = ast.literal_eval(m.group(3))
if len(states_list) != len(resids):
raise ValueError(
f'Mismatch between number of residues ({len(resids)}) '
f'and number of states ({len(states_list)})'
)
# Build row dictionary
row = {
'rankid': rank,
'current_pH': current_pH,
}
row.update(
{str(resid): state
for resid, state in zip(resids, states_list, strict=True)}
)
rows.append(row)
return pl.DataFrame(rows), resids
[docs]
def prepare(self) -> None:
"""Prepare data for analysis."""
# Melt to long format for curve fitting method
self.df_long = self.df.unpivot(
index=['rankid', 'current_pH'],
on=self.resid_cols,
variable_name='resid',
value_name='state',
)
# Determine canonical resname for each residue ID
# Look at the first state observed for each residue
self.resid_to_resname = {}
for resid_col in self.resid_cols:
# Get the first non-null state for this residue
first_state = self.df[resid_col].drop_nulls().head(1).to_list()
if first_state:
state = first_state[0]
self.resid_to_resname[resid_col] = self.canonical_resname.get(
state, state
)
else:
self.resid_to_resname[resid_col] = 'UNK'
# Map states to protonation (1 or 0)
self.df_long = self.df_long.with_columns(
pl.col('state')
.map_elements(
lambda x: self.protonation_mapping.get(x), return_dtype=pl.Int64
)
.alias('prot')
).drop_nulls('prot')
# Compute per-pH statistics for curve fitting
self.titrations = (
self.df_long.group_by(['resid', 'current_pH'])
.agg(
[
pl.col('prot').mean().alias('fraction_protonated'),
pl.col('prot').count().alias('n_samples'),
]
)
.sort(['resid', 'current_pH'])
)
[docs]
def compute_titrations_curvefit(self) -> pl.DataFrame:
"""
Compute pKa and Hill coefficient using scipy curve_fit.
This is the simple approach that treats each pH independently.
"""
fit_rows = []
for resid, subdf in self.titrations.group_by('resid', maintain_order=True):
resid = resid[0] # Unpack tuple
resname = self.resid_to_resname.get(resid, 'UNK')
x = subdf['current_pH'].to_numpy().astype(float)
y = subdf['fraction_protonated'].to_numpy().astype(float)
if x.size < 3:
# Not enough data points
fit_rows.append(
{
'resid': resid,
'resname': resname,
'pKa': np.nan,
'Hill_n': np.nan,
'pKa_err': np.nan,
'Hill_n_err': np.nan,
'n_points': int(x.size),
'method': 'curvefit',
}
)
continue
# Initial guess: pKa where fraction ~ 0.5
idx_mid = np.argmin(np.abs(y - 0.5))
pKa0 = x[idx_mid]
n0 = 1.0
try:
popt, pcov = curve_fit(
self.hill_equation,
x,
y,
p0=[pKa0, n0],
bounds=([0.0, 0.1], [14.0, 10.0]),
maxfev=5000,
)
pKa, n = popt
pKa_err = np.sqrt(np.diag(pcov))[0] if pcov is not None else np.nan
n_err = np.sqrt(np.diag(pcov))[1] if pcov is not None else np.nan
except Exception:
pKa, n = np.nan, np.nan
pKa_err, n_err = np.nan, np.nan
fit_rows.append(
{
'resid': resid,
'resname': resname,
'pKa': float(pKa),
'Hill_n': float(n),
'pKa_err': float(pKa_err),
'Hill_n_err': float(n_err),
'n_points': int(x.size),
'method': 'curvefit',
}
)
return pl.DataFrame(fit_rows)
[docs]
def compute_titrations_weighted(self, verbose: bool = False) -> pl.DataFrame:
"""
Compute pKa and Hill coefficient using weighted least squares.
Weights each pH point by 1/variance, giving more influence to
points with more samples and intermediate protonation fractions.
This is more statistically rigorous than unweighted curve fitting
when sample sizes vary across pH values.
"""
fit_rows = []
for resid, subdf in self.titrations.group_by('resid', maintain_order=True):
resid = resid[0]
resname = self.resid_to_resname.get(resid, 'UNK')
x = subdf['current_pH'].to_numpy().astype(float)
y = subdf['fraction_protonated'].to_numpy().astype(float)
n = subdf['n_samples'].to_numpy().astype(float)
if x.size < 3:
fit_rows.append(
{
'resid': resid,
'resname': resname,
'pKa': np.nan,
'Hill_n': np.nan,
'pKa_err': np.nan,
'Hill_n_err': np.nan,
'n_points': int(x.size),
'method': 'weighted',
}
)
continue
# Compute weights: 1/variance for binomial
# Var(p) = p(1-p)/n, but avoid division by zero
# Add small epsilon to avoid infinite weights at p=0 or p=1
eps = 0.01
y_clipped = np.clip(y, eps, 1 - eps)
variance = y_clipped * (1 - y_clipped) / n
weights = 1.0 / variance
# Normalize weights
weights = weights / weights.sum()
# Initial guess
idx_mid = np.argmin(np.abs(y - 0.5))
pKa0 = x[idx_mid]
n0 = 1.0
try:
# Weighted curve fit using sigma = 1/sqrt(weight)
sigma = 1.0 / np.sqrt(weights * len(weights))
popt, pcov = curve_fit(
self.hill_equation,
x,
y,
p0=[pKa0, n0],
sigma=sigma,
absolute_sigma=False,
bounds=([0.0, 0.1], [14.0, 10.0]),
maxfev=5000,
)
pKa, hill_n = popt
pKa_err = np.sqrt(np.diag(pcov))[0] if pcov is not None else np.nan
n_err = np.sqrt(np.diag(pcov))[1] if pcov is not None else np.nan
except Exception:
pKa, hill_n = np.nan, np.nan
pKa_err, n_err = np.nan, np.nan
fit_rows.append(
{
'resid': resid,
'resname': resname,
'pKa': float(pKa),
'Hill_n': float(hill_n),
'pKa_err': float(pKa_err),
'Hill_n_err': float(n_err),
'n_points': int(x.size),
'method': 'weighted',
}
)
return pl.DataFrame(fit_rows)
[docs]
def compute_titrations_bootstrap(
self, n_bootstrap: int = 1000, verbose: bool = False
) -> pl.DataFrame:
"""
Compute pKa and Hill coefficient with bootstrap confidence intervals.
Resamples the data at each pH to estimate uncertainty in fitted
parameters. This gives robust error estimates even when the
Hill equation doesn't perfectly fit the data.
Parameters
----------
n_bootstrap : int
Number of bootstrap iterations (default 1000)
verbose : bool
Print progress
Returns
-------
DataFrame with pKa, Hill_n, and 95% confidence intervals
"""
fit_rows = []
if verbose:
print(f'Running bootstrap with {n_bootstrap} iterations...')
for i, (resid, subdf) in enumerate(
self.titrations.group_by('resid', maintain_order=True)
):
resid = resid[0]
resname = self.resid_to_resname.get(resid, 'UNK')
x = subdf['current_pH'].to_numpy().astype(float)
y = subdf['fraction_protonated'].to_numpy().astype(float)
n_samples = subdf['n_samples'].to_numpy().astype(int)
if verbose and (i + 1) % 20 == 0:
print(f' {i + 1}/{len(self.resid_cols)} residues...')
if x.size < 3:
fit_rows.append(
{
'resid': resid,
'resname': resname,
'pKa': np.nan,
'pKa_lo': np.nan,
'pKa_hi': np.nan,
'Hill_n': np.nan,
'Hill_n_lo': np.nan,
'Hill_n_hi': np.nan,
'n_points': int(x.size),
'method': 'bootstrap',
}
)
continue
# First fit to get point estimate
idx_mid = np.argmin(np.abs(y - 0.5))
pKa0 = x[idx_mid]
try:
popt, _ = curve_fit(
self.hill_equation,
x,
y,
p0=[pKa0, 1.0],
bounds=([0.0, 0.1], [14.0, 10.0]),
maxfev=5000,
)
pKa_point, hill_n_point = popt
except Exception:
pKa_point, hill_n_point = np.nan, np.nan
# Bootstrap resampling
pKa_boots = []
hill_n_boots = []
for _ in range(n_bootstrap):
# Resample: for each pH, draw n_samples from Binomial(n, p)
y_boot = np.zeros(len(x))
for j in range(len(x)):
# Number of protonated in bootstrap sample
n_prot = np.random.binomial(n_samples[j], y[j])
y_boot[j] = n_prot / n_samples[j]
try:
popt_boot, _ = curve_fit(
self.hill_equation,
x,
y_boot,
p0=[pKa0, 1.0],
bounds=([0.0, 0.1], [14.0, 10.0]),
maxfev=2000,
)
pKa_boots.append(popt_boot[0])
hill_n_boots.append(popt_boot[1])
except Exception:
pass
# Compute confidence intervals
if len(pKa_boots) > 10:
pKa_lo, pKa_hi = np.percentile(pKa_boots, [2.5, 97.5])
hill_n_lo, hill_n_hi = np.percentile(hill_n_boots, [2.5, 97.5])
else:
pKa_lo, pKa_hi = np.nan, np.nan
hill_n_lo, hill_n_hi = np.nan, np.nan
fit_rows.append(
{
'resid': resid,
'resname': resname,
'pKa': float(pKa_point),
'pKa_lo': float(pKa_lo),
'pKa_hi': float(pKa_hi),
'Hill_n': float(hill_n_point),
'Hill_n_lo': float(hill_n_lo),
'Hill_n_hi': float(hill_n_hi),
'n_points': int(x.size),
'method': 'bootstrap',
}
)
return pl.DataFrame(fit_rows)
[docs]
def compute_titrations(
self, verbose: bool = False, n_bootstrap: int = 1000
) -> None:
"""Compute titrations using selected method."""
if self.method == 'curvefit':
self.fits = self.compute_titrations_curvefit()
elif self.method == 'weighted':
self.fits = self.compute_titrations_weighted(verbose=verbose)
elif self.method == 'bootstrap':
self.fits = self.compute_titrations_bootstrap(
n_bootstrap=n_bootstrap, verbose=verbose
)
else:
raise ValueError(
f"Unknown method: {self.method}. Use 'curvefit', 'weighted', or 'bootstrap'"
)
[docs]
def postprocess(self) -> None:
"""Generate fitted curves for plotting."""
if self.fits is None:
raise RuntimeError('Must call compute_titrations() first')
pH_min_val = self.df['current_pH'].min()
pH_max_val = self.df['current_pH'].max()
assert isinstance(pH_min_val, (int, float)) and isinstance(pH_max_val, (int, float)), "No pH data found"
pH_grid = np.linspace(float(pH_min_val), float(pH_max_val), 200)
curves = []
for row in self.fits.iter_rows(named=True):
resid = row['resid']
pKa = row['pKa']
n = row['Hill_n']
if np.isnan(pKa) or np.isnan(n):
continue
y_fit = self.hill_equation(pH_grid, pKa, n)
curves.append(
pl.DataFrame(
{
'resid': [resid] * len(pH_grid),
'pH': pH_grid,
'fraction_protonated_fit': y_fit,
}
)
)
self.curves = pl.concat(curves) if curves else None
if self.make_plots:
self.plot()
[docs]
def plot(self) -> None:
"""Generate plots (to be implemented)."""
pass
[docs]
def diagnose_residue(self, resid: str, verbose: bool = True) -> dict:
"""
Diagnose why a residue might have failed pKa determination.
Parameters
----------
resid : str
Residue ID to diagnose
verbose : bool
Print diagnostic information
Returns
-------
dict with diagnostic info including titration curve data
"""
# Get per-pH fraction protonated from simple averaging
resid_data = self.titrations.filter(pl.col('resid') == resid)
pH_vals = resid_data['current_pH'].to_numpy()
frac_prot = resid_data['fraction_protonated'].to_numpy()
n_samples = resid_data['n_samples'].to_numpy()
# Get state distribution
resid_states = self.df_long.filter(pl.col('resid') == resid)
state_counts = resid_states.group_by('state').agg(pl.len().alias('count'))
resname = self.resid_to_resname.get(resid, 'UNK')
result = {
'resid': resid,
'resname': resname,
'pH': pH_vals,
'fraction_protonated': frac_prot,
'n_samples': n_samples,
'state_distribution': state_counts.to_dict(),
'frac_min': frac_prot.min() if len(frac_prot) > 0 else np.nan,
'frac_max': frac_prot.max() if len(frac_prot) > 0 else np.nan,
}
if verbose:
print(f'\nDiagnostics for residue {resid} ({resname}):')
print(' State distribution:')
for row in state_counts.iter_rows(named=True):
print(f" {row['state']}: {row['count']}")
print('\n Titration curve (simple average):')
print(f" {'pH':>6s} {'frac':>6s} {'n':>5s}")
for pH, f, n in zip(pH_vals, frac_prot, n_samples, strict=True):
print(f' {pH:6.2f} {f:6.3f} {n:5d}')
print(
f"\n Fraction range: {result['frac_min']:.3f} - {result['frac_max']:.3f}"
)
if result['frac_min'] > 0.5:
print(
f' → Always >50% protonated - pKa likely ABOVE pH {pH_vals.max():.1f}'
)
elif result['frac_max'] < 0.5:
print(
f' → Always <50% protonated - pKa likely BELOW pH {pH_vals.min():.1f}'
)
elif result['frac_max'] - result['frac_min'] < 0.1:
print(
' → Very little titration observed - may not titrate in this pH range'
)
return result
[docs]
@staticmethod
def hill_equation(
pH: float | npt.NDArray[np.floating], pKa: float, n: float
) -> float | npt.NDArray[np.floating]:
"""
Hill equation for acid-base equilibrium.
Returns fraction protonated as function of pH.
"""
return 1.0 / (1.0 + 10.0 ** (n * (pH - pKa)))
@property
def protonation_mapping(self) -> dict[str, int]:
"""Map state names to protonation numbers (1 = protonated, 0 = not)."""
return {
'ASH': 1,
'ASP': 0,
'GLH': 1,
'GLU': 0,
'LYS': 1,
'LYN': 0,
'CYS': 1,
'CYX': 0,
'HIP': 1,
'HIE': 0,
'HID': 0,
}
@property
def canonical_resname(self) -> dict[str, str]:
"""Map any state name to canonical residue name."""
return {
'ASH': 'ASP',
'ASP': 'ASP',
'GLH': 'GLU',
'GLU': 'GLU',
'LYS': 'LYS',
'LYN': 'LYS',
'CYS': 'CYS',
'CYX': 'CYS',
'HIP': 'HIS',
'HIE': 'HIS',
'HID': 'HIS',
}
[docs]
def compare_methods(self, resids: list[str] | None = None) -> pl.DataFrame:
"""
Compare curve fit vs UWHAM results for specified residues.
Parameters
----------
resids : List[str], optional
Residues to compare. If None, compares all.
Returns
-------
DataFrame with both methods' results side by side
"""
# Run both methods
fits_cf = self.compute_titrations_curvefit()
fits_wt = self.compute_titrations_weighted(verbose=False)
# Join on resid
comparison = fits_cf.join(
fits_wt.select(['resid', 'pKa', 'Hill_n']),
on='resid',
suffix='_weighted',
)
# Add difference columns
comparison = comparison.with_columns(
[
(pl.col('pKa') - pl.col('pKa_weighted')).alias('pKa_diff'),
(pl.col('Hill_n') - pl.col('Hill_n_weighted')).alias('Hill_n_diff'),
]
)
if resids is not None:
comparison = comparison.filter(pl.col('resid').is_in(resids))
return comparison
[docs]
class TitrationAnalyzer:
"""
High-level analyzer for constant pH simulations.
Provides a streamlined API that runs both curve fitting and UWHAM analysis,
generates comparisons, and creates publication-quality plots.
Example usage
-------------
>>> analyzer = TitrationAnalyzer(["cpH.log"])
>>> analyzer.run()
>>> analyzer.summary()
>>> analyzer.plot_residue("145")
>>> analyzer.plot_all(output_dir="plots/")
>>> analyzer.save_results("results/")
"""
[docs]
def __init__(
self,
log_files: Path | list[Path] | str | list[str],
output_dir: Path | str | None = None,
):
"""
Initialize the analyzer.
Parameters
----------
log_files : Path, str, or list thereof
Path(s) to constant pH log file(s)
output_dir : Path or str, optional
Directory for output files. If None, uses current directory.
"""
log_file_list: list[Path | str] = [log_files] if isinstance(log_files, (str, Path)) else list(log_files)
self.log_files = [Path(f) for f in log_file_list]
self.output_dir = Path(output_dir) if output_dir else Path('.')
self.output_dir.mkdir(parents=True, exist_ok=True)
# Results storage
self.fits_curvefit: pl.DataFrame | None = None
self.fits_weighted: pl.DataFrame | None = None
self.fits_bootstrap: pl.DataFrame | None = None
self.comparison: pl.DataFrame | None = None
self.titration_data: pl.DataFrame | None = None
# Internal objects
self._tc: TitrationCurve | None = None
# Metadata
self.resid_to_resname: dict[str, str] = {}
self.resid_cols: list[str] = []
self._analyzed = False
[docs]
def run(
self,
methods: list[str] = ['curvefit', 'weighted'], # noqa: B006
verbose: bool = True,
n_bootstrap: int = 1000,
) -> TitrationAnalyzer:
"""
Run the analysis with specified methods.
Parameters
----------
methods : list of str
Methods to run: 'curvefit', 'weighted', 'bootstrap'
- curvefit: Simple least squares fit of Hill equation
- weighted: Weighted least squares (by 1/variance)
- bootstrap: Curve fit with bootstrap confidence intervals
verbose : bool
Print progress information
n_bootstrap : int
Number of bootstrap iterations (only used if 'bootstrap' in methods)
Returns
-------
self : for method chaining
"""
if verbose:
print('=' * 60)
print('Constant pH Titration Analysis')
print('=' * 60)
print(f'Log files: {[str(f) for f in self.log_files]}')
# Initialize and prepare
self._tc = TitrationCurve(self.log_files, make_plots=False)
self._tc.prepare()
# Store data for plotting
self.titration_data = self._tc.titrations.clone()
self.resid_to_resname = self._tc.resid_to_resname.copy()
self.resid_cols = self._tc.resid_cols.copy()
if verbose:
n_residues = len(self._tc.resid_cols)
pH_vals = self._tc.df['current_pH'].unique().sort()
print(f'Residues: {n_residues}')
print(f'pH values: {pH_vals.to_list()}')
print(f'Total samples: {len(self._tc.df)}')
# Curve fitting
if 'curvefit' in methods:
if verbose:
print('\n' + '-' * 40)
print('Running curve fitting...')
self.fits_curvefit = self._tc.compute_titrations_curvefit()
if verbose:
n_success = self.fits_curvefit.filter(pl.col('pKa').is_not_nan()).height
print(f' Success: {n_success}/{len(self.fits_curvefit)} residues')
# Weighted fitting
if 'weighted' in methods:
if verbose:
print('\n' + '-' * 40)
print('Running weighted curve fitting...')
self.fits_weighted = self._tc.compute_titrations_weighted(verbose=verbose)
if verbose:
n_success = self.fits_weighted.filter(pl.col('pKa').is_not_nan()).height
print(f' Success: {n_success}/{len(self.fits_weighted)} residues')
# Bootstrap
if 'bootstrap' in methods:
if verbose:
print('\n' + '-' * 40)
print(f'Running bootstrap ({n_bootstrap} iterations)...')
self.fits_bootstrap = self._tc.compute_titrations_bootstrap(
n_bootstrap=n_bootstrap, verbose=verbose
)
if verbose:
n_success = self.fits_bootstrap.filter(
pl.col('pKa').is_not_nan()
).height
print(f' Success: {n_success}/{len(self.fits_bootstrap)} residues')
# Generate comparison if multiple methods ran
if self.fits_curvefit is not None and self.fits_weighted is not None:
self._generate_comparison()
self._analyzed = True
if verbose:
print('\n' + '=' * 60)
print('Analysis complete!')
print('=' * 60)
return self
def _generate_comparison(self) -> None:
"""Generate comparison DataFrame between curvefit and weighted methods."""
assert self.fits_curvefit is not None and self.fits_weighted is not None
self.comparison = self.fits_curvefit.join(
self.fits_weighted.select(['resid', 'pKa', 'Hill_n']),
on='resid',
suffix='_weighted',
).with_columns(
[
(pl.col('pKa') - pl.col('pKa_weighted')).alias('pKa_diff'),
(pl.col('Hill_n') - pl.col('Hill_n_weighted')).alias('Hill_n_diff'),
]
)
[docs]
def summary(self, show_all: bool = False) -> pl.DataFrame | None:
"""
Print and return summary of results.
Parameters
----------
show_all : bool
If True, show all residues. Otherwise show first 20.
Returns
-------
DataFrame with comparison results
"""
if not self._analyzed:
raise RuntimeError('Must call run() before summary()')
if self.comparison is not None:
successful = self.comparison.filter(
pl.col('pKa').is_not_nan() & pl.col('pKa_weighted').is_not_nan()
)
print(
f'\nComparison Summary ({len(successful)} residues with both methods successful):'
)
print('-' * 60)
if len(successful) > 0:
delta = successful['pKa_diff'].to_numpy()
print('ΔpKa (curvefit - weighted):')
print(f' Mean: {np.mean(delta):+.3f}')
print(f' Std: {np.std(delta):.3f}')
print(f' Median: {np.median(delta):+.3f}')
print(f' Range: [{np.min(delta):.3f}, {np.max(delta):.3f}]')
display_df = successful.select(
[
'resid',
'resname',
'pKa',
'pKa_weighted',
'pKa_diff',
'Hill_n',
'Hill_n_weighted',
]
)
if not show_all and len(display_df) > 20:
print(
f'\nShowing first 20 of {len(display_df)} residues (use show_all=True for all):'
)
print(display_df.head(20))
else:
print(display_df)
return self.comparison
elif self.fits_curvefit is not None:
print('\nCurve Fitting Results:')
print(self.fits_curvefit if show_all else self.fits_curvefit.head(20))
return self.fits_curvefit
elif self.fits_weighted is not None:
print('\nWeighted Fitting Results:')
print(self.fits_weighted if show_all else self.fits_weighted.head(20))
return self.fits_weighted
elif self.fits_bootstrap is not None:
print('\nBootstrap Results:')
print(self.fits_bootstrap if show_all else self.fits_bootstrap.head(20))
return self.fits_bootstrap
return None
[docs]
def get_results(self, method: str = 'curvefit') -> pl.DataFrame | None:
"""
Get results DataFrame for specified method.
Parameters
----------
method : str
'curvefit', 'weighted', 'bootstrap', or 'comparison'
"""
if method == 'curvefit':
return self.fits_curvefit
elif method == 'weighted':
return self.fits_weighted
elif method == 'bootstrap':
return self.fits_bootstrap
elif method == 'comparison':
return self.comparison
else:
raise ValueError(f'Unknown method: {method}')
[docs]
def plot_residue(
self,
resid: str,
ax: plt.Axes | None = None,
show_curvefit: bool = True,
show_weighted: bool = True,
show_data: bool = True,
figsize: tuple[float, float] = (8, 6),
save: str | Path | None = None,
) -> plt.Figure:
"""
Plot titration curve for a single residue.
Parameters
----------
resid : str
Residue ID to plot
ax : matplotlib Axes, optional
Axes to plot on. If None, creates new figure.
show_curvefit : bool
Show curve fitting result
show_weighted : bool
Show weighted fit result
show_data : bool
Show raw data points
figsize : tuple
Figure size if creating new figure
save : str or Path, optional
Path to save figure
Returns
-------
matplotlib Figure
"""
import matplotlib.pyplot as plt
if not self._analyzed:
raise RuntimeError('Must call run() before plotting')
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
_fig = ax.get_figure()
assert _fig is not None, "Axes has no associated Figure"
fig = _fig
resname = self.resid_to_resname.get(resid, 'UNK')
# Raw data
assert self.titration_data is not None, "No titration data available"
resid_data = self.titration_data.filter(pl.col('resid') == resid)
pH_data = resid_data['current_pH'].to_numpy()
frac_data = resid_data['fraction_protonated'].to_numpy()
n_samples = resid_data['n_samples'].to_numpy()
# Standard error for binomial
se = np.sqrt(frac_data * (1 - frac_data) / np.maximum(n_samples, 1))
# Plot data points
if show_data:
ax.errorbar(
pH_data,
frac_data,
yerr=se,
fmt='o',
color='black',
markersize=8,
capsize=3,
capthick=1,
elinewidth=1,
label='Data',
zorder=10,
)
# pH grid for curves
pH_grid = np.linspace(min(pH_data) - 0.5, max(pH_data) + 0.5, 200)
# Curve fit line (unweighted)
if show_curvefit and self.fits_curvefit is not None:
cf_row = self.fits_curvefit.filter(pl.col('resid') == resid)
if len(cf_row) > 0:
pKa_cf = cf_row['pKa'][0]
n_cf = cf_row['Hill_n'][0]
if not np.isnan(pKa_cf) and not np.isnan(n_cf):
y_cf = TitrationCurve.hill_equation(pH_grid, pKa_cf, n_cf)
ax.plot(
pH_grid,
y_cf,
'-',
color='blue',
linewidth=2,
label=f'Curve fit (pKa={pKa_cf:.2f}, n={n_cf:.2f})',
)
ax.axvline(pKa_cf, color='blue', linestyle=':', alpha=0.5)
# Weighted fit line
if show_weighted and self.fits_weighted is not None:
wt_row = self.fits_weighted.filter(pl.col('resid') == resid)
if len(wt_row) > 0:
pKa_wt = wt_row['pKa'][0]
n_wt = wt_row['Hill_n'][0]
if not np.isnan(pKa_wt) and not np.isnan(n_wt):
y_wt = TitrationCurve.hill_equation(pH_grid, pKa_wt, n_wt)
ax.plot(
pH_grid,
y_wt,
'--',
color='red',
linewidth=2,
label=f'Weighted (pKa={pKa_wt:.2f}, n={n_wt:.2f})',
)
ax.axvline(pKa_wt, color='red', linestyle=':', alpha=0.5)
# Formatting
ax.set_xlabel('pH', fontsize=12)
ax.set_ylabel('Fraction Protonated', fontsize=12)
ax.set_title(f'Residue {resid} ({resname})', fontsize=14)
ax.set_ylim(-0.05, 1.05)
ax.axhline(0.5, color='gray', linestyle='--', alpha=0.3)
ax.legend(loc='best', fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
if save:
fig.savefig(save, dpi=150, bbox_inches='tight')
return fig
[docs]
def plot_all(
self,
output_dir: str | Path | None = None,
format: str = 'png',
show_curvefit: bool = True,
show_weighted: bool = True,
residues: list[str] | None = None,
verbose: bool = True,
) -> None:
"""
Generate plots for all (or selected) residues.
Parameters
----------
output_dir : str or Path, optional
Directory for plots. Uses self.output_dir / 'plots' if None.
format : str
Image format ('png', 'pdf', 'svg')
show_curvefit : bool
Include curve fitting results
show_weighted : bool
Include weighted fit results
residues : list of str, optional
Specific residues to plot. If None, plots all.
verbose : bool
Print progress
"""
import matplotlib.pyplot as plt
if not self._analyzed:
raise RuntimeError('Must call run() before plotting')
plot_dir = Path(output_dir) if output_dir else self.output_dir / 'plots'
plot_dir.mkdir(parents=True, exist_ok=True)
if residues is None:
residues = self.resid_cols
if verbose:
print(f'Generating {len(residues)} plots in {plot_dir}/')
for i, resid in enumerate(residues):
resname = self.resid_to_resname.get(resid, 'UNK')
filename = plot_dir / f'{resname}_{resid}.{format}'
fig = self.plot_residue(
resid,
show_curvefit=show_curvefit,
show_weighted=show_weighted,
save=filename,
)
plt.close(fig)
if verbose and (i + 1) % 20 == 0:
print(f' {i + 1}/{len(residues)} plots generated...')
if verbose:
print(f' All {len(residues)} plots saved to {plot_dir}/')
[docs]
def plot_summary(
self,
figsize: tuple[float, float] = (12, 5),
save: str | Path | None = None,
) -> plt.Figure:
"""
Generate summary plot comparing methods.
Creates a 2-panel figure:
- Left: pKa comparison scatter plot
- Right: Distribution of pKa differences
"""
import matplotlib.pyplot as plt
if self.comparison is None:
raise RuntimeError(
'Need both curvefit and weighted methods for summary plot'
)
successful = self.comparison.filter(
pl.col('pKa').is_not_nan() & pl.col('pKa_weighted').is_not_nan()
)
if len(successful) == 0:
raise ValueError('No residues with both methods successful')
fig, axes = plt.subplots(1, 2, figsize=figsize)
pKa_cf = successful['pKa'].to_numpy()
pKa_wt = successful['pKa_weighted'].to_numpy()
diff = successful['pKa_diff'].to_numpy()
# Scatter plot
ax = axes[0]
ax.scatter(pKa_cf, pKa_wt, alpha=0.6, edgecolor='black', linewidth=0.5)
lims = [
min(min(pKa_cf), min(pKa_wt)) - 0.5,
max(max(pKa_cf), max(pKa_wt)) + 0.5,
]
ax.plot(lims, lims, 'k--', alpha=0.5)
ax.set_xlim(lims)
ax.set_ylim(lims)
ax.set_xlabel('pKa (Curve Fit)', fontsize=12)
ax.set_ylabel('pKa (Weighted)', fontsize=12)
ax.set_title('Method Comparison', fontsize=14)
ax.set_aspect('equal')
ax.grid(True, alpha=0.3)
corr = np.corrcoef(pKa_cf, pKa_wt)[0, 1]
ax.text(
0.05,
0.95,
f'r = {corr:.3f}',
transform=ax.transAxes,
fontsize=11,
verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
)
# Histogram
ax = axes[1]
ax.hist(diff, bins=20, edgecolor='black', alpha=0.7)
ax.axvline(0, color='red', linestyle='--', linewidth=2)
ax.axvline(
np.mean(diff),
color='blue',
linestyle='-',
linewidth=2,
label=f'Mean = {np.mean(diff):.3f}',
)
ax.set_xlabel('ΔpKa (Curve Fit - Weighted)', fontsize=12)
ax.set_ylabel('Count', fontsize=12)
ax.set_title('pKa Difference Distribution', fontsize=14)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
if save:
fig.savefig(save, dpi=150, bbox_inches='tight')
return fig
[docs]
def save_results(
self,
output_dir: str | Path | None = None,
prefix: str = '',
formats: list[str] = ['csv'], # noqa: B006
) -> None:
"""
Save all results to files.
Parameters
----------
output_dir : str or Path, optional
Output directory. Uses self.output_dir if None.
prefix : str
Prefix for filenames
formats : list of str
Output formats: 'csv', 'parquet', 'json'
"""
out_dir = Path(output_dir) if output_dir else self.output_dir
out_dir.mkdir(parents=True, exist_ok=True)
prefix = f'{prefix}_' if prefix else ''
def save_df(df: pl.DataFrame, name: str):
for fmt in formats:
filepath = out_dir / f'{prefix}{name}.{fmt}'
if fmt == 'csv':
df.write_csv(filepath)
elif fmt == 'parquet':
df.write_parquet(filepath)
elif fmt == 'json':
df.write_json(filepath)
print(f' Saved {filepath}')
print(f'Saving results to {out_dir}/')
if self.fits_curvefit is not None:
save_df(self.fits_curvefit, 'pKa_curvefit')
if self.fits_weighted is not None:
save_df(self.fits_weighted, 'pKa_weighted')
if self.fits_bootstrap is not None:
save_df(self.fits_bootstrap, 'pKa_bootstrap')
if self.comparison is not None:
save_df(self.comparison, 'pKa_comparison')
if self.titration_data is not None:
save_df(self.titration_data, 'titration_data')
[docs]
def diagnose(self, resid: str) -> dict:
"""Get diagnostic information for a residue."""
if self._tc is None:
raise RuntimeError('Must call run() first')
return self._tc.diagnose_residue(resid, verbose=True)
[docs]
def recommend_protonation(
self,
target_pH: float,
confidence_threshold: float = 0.7,
use_bootstrap: bool = False,
verbose: bool = True,
) -> pl.DataFrame:
"""
Recommend protonation states for a target pH.
Uses the fitted titration curves to predict which residues are
protonated vs deprotonated at the specified pH, with confidence
estimates based on distance from pKa.
Parameters
----------
target_pH : float
pH value to make predictions for (e.g., 3.0, 7.4)
confidence_threshold : float
Probability threshold for "confident" predictions (default 0.7)
Residues with P(protonated) between (1-threshold) and threshold
are marked as "uncertain"
use_bootstrap : bool
If True and bootstrap results available, use bootstrap CI for
uncertainty estimation
verbose : bool
Print summary of recommendations
Returns
-------
DataFrame with columns:
- resid: residue ID
- resname: canonical residue name (ASP, GLU, HIS, LYS, CYS)
- pKa: fitted pKa value
- prob_protonated: probability of being protonated at target pH
- recommendation: 'protonated', 'deprotonated', or 'uncertain'
- confidence: 'high', 'medium', or 'low'
- state_name: recommended state name (e.g., 'ASH' or 'ASP')
"""
if not self._analyzed:
raise RuntimeError('Must call run() before recommend_protonation()')
# Use curvefit results (or weighted if available)
fits = self.fits_curvefit
if fits is None:
fits = self.fits_weighted
if fits is None:
raise RuntimeError('No fitting results available')
# State name mappings
protonated_state = {
'ASP': 'ASH',
'GLU': 'GLH',
'HIS': 'HIP',
'LYS': 'LYS',
'CYS': 'CYS',
}
deprotonated_state = {
'ASP': 'ASP',
'GLU': 'GLU',
'HIS': 'HIE',
'LYS': 'LYN',
'CYS': 'CYX',
}
# Reference pKa values for sanity checking
reference_pKa = {'ASP': 3.9, 'GLU': 4.3, 'HIS': 6.0, 'LYS': 10.5, 'CYS': 8.3}
recommendations = []
for row in fits.iter_rows(named=True):
resid = row['resid']
resname = row['resname']
pKa = row['pKa']
hill_n = row['Hill_n']
# Compute probability of being protonated at target pH
if np.isnan(pKa) or np.isnan(hill_n):
# No fit available - use reference pKa
ref_pKa = reference_pKa.get(resname, 7.0)
prob_prot = 1.0 / (1.0 + 10 ** (target_pH - ref_pKa))
pKa_used = ref_pKa
fit_quality = 'reference'
else:
# Use fitted Hill equation
prob_prot = TitrationCurve.hill_equation(target_pH, pKa, hill_n)
pKa_used = pKa
fit_quality = 'fitted'
# Determine recommendation
if prob_prot >= confidence_threshold:
recommendation = 'protonated'
state_name = protonated_state.get(resname, resname)
elif prob_prot <= (1 - confidence_threshold):
recommendation = 'deprotonated'
state_name = deprotonated_state.get(resname, resname)
else:
recommendation = 'uncertain'
# For uncertain cases, go with majority
if prob_prot >= 0.5:
state_name = protonated_state.get(resname, resname)
else:
state_name = deprotonated_state.get(resname, resname)
# Confidence based on distance from 0.5
prob_distance = abs(prob_prot - 0.5)
if prob_distance > 0.4: # >90% or <10%
confidence = 'high'
elif prob_distance > 0.2: # >70% or <30%
confidence = 'medium'
else:
confidence = 'low'
recommendations.append(
{
'resid': resid,
'resname': resname,
'pKa': pKa_used,
'pKa_source': fit_quality,
'prob_protonated': prob_prot,
'recommendation': recommendation,
'confidence': confidence,
'state_name': state_name,
}
)
result = pl.DataFrame(recommendations)
if verbose:
print(f"\n{'=' * 60}")
print(f'Protonation Recommendations at pH {target_pH}')
print(f"{'=' * 60}")
# Summary counts
n_prot = result.filter(pl.col('recommendation') == 'protonated').height
n_deprot = result.filter(pl.col('recommendation') == 'deprotonated').height
n_uncertain = result.filter(pl.col('recommendation') == 'uncertain').height
print('\nSummary:')
print(f' Protonated: {n_prot:3d} residues')
print(f' Deprotonated: {n_deprot:3d} residues')
print(f' Uncertain: {n_uncertain:3d} residues')
# Group by residue type
print('\nBy residue type:')
for restype in ['ASP', 'GLU', 'HIS', 'LYS', 'CYS']:
subset = result.filter(pl.col('resname') == restype)
if len(subset) > 0:
n_p = subset.filter(pl.col('recommendation') == 'protonated').height
n_d = subset.filter(
pl.col('recommendation') == 'deprotonated'
).height
n_u = subset.filter(pl.col('recommendation') == 'uncertain').height
ref = reference_pKa.get(restype, '?')
print(
f' {restype} (ref pKa={ref}): {n_p} prot, {n_d} deprot, {n_u} uncertain'
)
# Show uncertain residues (most important to check)
uncertain = result.filter(pl.col('recommendation') == 'uncertain')
if len(uncertain) > 0:
print(
f'\n⚠️ Uncertain residues (prob between {1 - confidence_threshold:.0%}-{confidence_threshold:.0%}):'
)
for row in uncertain.sort('prob_protonated', descending=True).iter_rows(
named=True
):
print(
f" {row['resname']} {row['resid']}: "
f"P(prot)={row['prob_protonated']:.1%}, "
f"pKa={row['pKa']:.1f} → {row['state_name']}"
)
# Show residues with pKa near target pH
near_pKa = result.filter(
(pl.col('pKa') > target_pH - 1.5)
& (pl.col('pKa') < target_pH + 1.5)
& (pl.col('pKa_source') == 'fitted')
)
if len(near_pKa) > 0:
print(f'\n📍 Residues with pKa near pH {target_pH} (±1.5 units):')
for row in near_pKa.sort('pKa').iter_rows(named=True):
print(
f" {row['resname']} {row['resid']}: "
f"pKa={row['pKa']:.2f}, "
f"P(prot)={row['prob_protonated']:.1%} → {row['state_name']}"
)
return result
[docs]
def get_protonation_string(
self,
target_pH: float,
confidence_threshold: float = 0.7,
) -> str:
"""
Get a simple string of recommended protonation states.
Useful for setting up simulations.
Parameters
----------
target_pH : float
pH value to make predictions for
confidence_threshold : float
Probability threshold for confident predictions
Returns
-------
String with format: "resid:state,resid:state,..."
"""
recs = self.recommend_protonation(
target_pH, confidence_threshold=confidence_threshold, verbose=False
)
parts = []
for row in recs.iter_rows(named=True):
parts.append(f"{row['resid']}:{row['state_name']}")
return ','.join(parts)
[docs]
def export_protonation_states(
self,
target_pH: float,
output_file: str | Path | None = None,
format: str = 'csv',
confidence_threshold: float = 0.7,
) -> pl.DataFrame:
"""
Export protonation state recommendations to file.
Parameters
----------
target_pH : float
pH value to make predictions for
output_file : str or Path, optional
Output file path. If None, uses output_dir/protonation_pH{pH}.{format}
format : str
Output format: 'csv', 'json', or 'txt'
confidence_threshold : float
Probability threshold for confident predictions
Returns
-------
DataFrame with recommendations
"""
recs = self.recommend_protonation(
target_pH, confidence_threshold=confidence_threshold, verbose=False
)
if output_file is None:
output_file = self.output_dir / f'protonation_pH{target_pH:.1f}.{format}'
else:
output_file = Path(output_file)
if format == 'csv':
recs.write_csv(output_file)
elif format == 'json':
recs.write_json(output_file)
elif format == 'txt':
# Simple text format for easy reading
with open(output_file, 'w') as f:
f.write(f'# Protonation states at pH {target_pH}\n')
f.write(f'# confidence_threshold = {confidence_threshold}\n')
f.write('#\n')
f.write('# resid resname state prob_prot confidence\n')
for row in recs.iter_rows(named=True):
f.write(
f"{row['resid']:>6s} {row['resname']:>7s} "
f"{row['state_name']:>5s} {row['prob_protonated']:>9.3f} "
f"{row['confidence']}\n"
)
print(f'Saved protonation recommendations to {output_file}')
return recs
[docs]
def plot_protonation_summary(
self,
target_pH: float,
figsize: tuple[float, float] = (12, 6),
save: str | Path | None = None,
) -> plt.Figure:
"""
Visualize protonation probabilities at target pH.
Creates a bar plot showing P(protonated) for each residue,
colored by residue type.
Parameters
----------
target_pH : float
pH value to visualize
figsize : tuple
Figure size
save : str or Path, optional
Path to save figure
Returns
-------
matplotlib Figure
"""
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
recs = self.recommend_protonation(target_pH, verbose=False)
# Sort by probability
recs_sorted = recs.sort('prob_protonated', descending=True)
fig, ax = plt.subplots(figsize=figsize)
# Colors for each residue type
colors = {
'ASP': '#e41a1c', # red
'GLU': '#ff7f00', # orange
'HIS': '#4daf4a', # green
'LYS': '#377eb8', # blue
'CYS': '#984ea3', # purple
}
x = np.arange(len(recs_sorted))
probs = recs_sorted['prob_protonated'].to_numpy()
resnames = recs_sorted['resname'].to_list()
resids = recs_sorted['resid'].to_list()
bar_colors = [colors.get(rn, 'gray') for rn in resnames]
_bars = ax.bar(x, probs, color=bar_colors, edgecolor='black', linewidth=0.5)
# Add 0.5 line
ax.axhline(0.5, color='black', linestyle='--', linewidth=2, alpha=0.7)
ax.axhline(0.7, color='gray', linestyle=':', linewidth=1, alpha=0.5)
ax.axhline(0.3, color='gray', linestyle=':', linewidth=1, alpha=0.5)
# Labels
ax.set_xlabel('Residue', fontsize=12)
ax.set_ylabel('P(protonated)', fontsize=12)
ax.set_title(f'Protonation Probabilities at pH {target_pH}', fontsize=14)
ax.set_ylim(0, 1.05)
# X-axis labels (show every Nth label if too many)
n_residues = len(x)
if n_residues > 50:
# Show fewer labels
step = n_residues // 20
ax.set_xticks(x[::step])
labels = [f'{resnames[i]}{resids[i]}' for i in range(0, n_residues, step)]
ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=8)
else:
ax.set_xticks(x)
labels = [f'{rn}{ri}' for rn, ri in zip(resnames, resids, strict=True)]
ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=8)
# Legend
legend_elements = [
Patch(facecolor=c, edgecolor='black', label=n)
for n, c in colors.items()
if n in resnames
]
ax.legend(handles=legend_elements, loc='upper right', fontsize=10)
# Add text annotations for counts
n_prot = sum(1 for p in probs if p >= 0.5)
n_deprot = sum(1 for p in probs if p < 0.5)
ax.text(
0.02,
0.98,
f'Protonated: {n_prot}\nDeprotonated: {n_deprot}',
transform=ax.transAxes,
fontsize=10,
verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
)
plt.tight_layout()
if save:
fig.savefig(save, dpi=150, bbox_inches='tight')
return fig
def __repr__(self) -> str:
status = 'analyzed' if self._analyzed else 'not analyzed'
return f'TitrationAnalyzer({len(self.log_files)} log files, {status})'
[docs]
def analyze_cph(
log_files: Path | list[Path] | str | list[str],
output_dir: str | Path | None = None,
methods: list[str] = ['curvefit', 'weighted'], # noqa: B006
plot: bool = True,
verbose: bool = True,
) -> TitrationAnalyzer:
"""
Convenience function to run complete constant pH analysis.
Parameters
----------
log_files : path(s) to log files
output_dir : output directory
methods : list of methods to run ('curvefit', 'weighted', 'bootstrap')
plot : whether to generate plots
verbose : print progress
Returns
-------
TitrationAnalyzer with results
Example
-------
>>> results = analyze_cph("cpH.log", output_dir="analysis/")
>>> results.summary()
>>> results.plot_residue("145")
"""
analyzer = TitrationAnalyzer(log_files, output_dir=output_dir)
analyzer.run(methods=methods, verbose=verbose)
if plot:
try:
analyzer.plot_all(verbose=verbose)
analyzer.plot_summary(save=analyzer.output_dir / 'summary.png')
except ImportError:
if verbose:
print('matplotlib not available, skipping plots')
except RuntimeError:
# plot_summary requires both methods
pass
analyzer.save_results()
return analyzer
if __name__ == '__main__':
import sys
# Get log files from command line or use default
log_paths = [Path('cpH.log')]
if len(sys.argv) > 1:
log_paths = [Path(p) for p in sys.argv[1:]]
# =========================================================================
# STREAMLINED API - TitrationAnalyzer
# =========================================================================
#
# Available methods:
# - curvefit: Simple least squares fit (default)
# - weighted: Weighted least squares (by 1/variance)
# - bootstrap: Curve fit with bootstrap confidence intervals
#
# Basic usage:
# analyzer = TitrationAnalyzer(log_paths)
# analyzer.run()
# analyzer.summary()
#
# Protonation recommendations:
# recs = analyzer.recommend_protonation(target_pH=3.0)
# analyzer.plot_protonation_summary(target_pH=3.0)
#
# =========================================================================
# Create analyzer
analyzer = TitrationAnalyzer(log_paths, output_dir='cph_analysis')
# Run curve fitting and weighted fitting
analyzer.run(methods=['curvefit', 'weighted'], verbose=True)
# Print summary
analyzer.summary()
# Generate all plots (if matplotlib available)
try:
analyzer.plot_all(verbose=True)
analyzer.plot_summary(save='cph_analysis/summary.png')
print('\nPlots saved to cph_analysis/plots/')
except ImportError:
print('\nSkipping plots (matplotlib not installed)')
# Save results
analyzer.save_results()
# =========================================================================
# PROTONATION RECOMMENDATIONS
# =========================================================================
# Get recommendations for pH 3.0
print('\n')
recs = analyzer.recommend_protonation(target_pH=3.0)
# Export to file
analyzer.export_protonation_states(target_pH=3.0, format='csv')
# Visualize
import contextlib
with contextlib.suppress(ImportError):
analyzer.plot_protonation_summary(
target_pH=3.0, save='cph_analysis/protonation_pH3.0.png'
)
# Can also get recommendations for physiological pH
# analyzer.recommend_protonation(target_pH=7.4)