⚛️ NMR API
The diff_biophys.nmr subpackage implements four categories of differentiable NMR
observables. Each kernel accepts JAX arrays and returns JAX arrays, making them
trivially composable into multi-observable loss functions.
Karplus J-coupling
diff_biophys.nmr.karplus implements the Karplus equation relating a three-bond
scalar coupling constant \(^3J\) (Hz) to a backbone dihedral angle:
For the backbone HN–Hα coupling, the dihedral \(\theta = \phi - 60°\) (Vuister & Bax 1993 offset convention). Default parameters: \(A = 6.98\), \(B = -1.38\), \(C = 1.72\) Hz.
Gradient meaning: \(\partial J / \partial \theta\) tells you how fast the coupling changes with the backbone angle — the key quantity for J-coupling-driven structure refinement.
from diff_biophys.nmr.karplus import calculate_karplus_j
import jax, jax.numpy as jnp
phi = jnp.array([-0.995]) # φ in radians (≈ −57° helix)
theta = phi - jnp.deg2rad(60.0) # offset for HN-Hα
J = calculate_karplus_j(theta, A=6.98, B=-1.38, C=1.72) # → [Hz]
dJ_dtheta = jax.grad(lambda t: jnp.sum(calculate_karplus_j(t, 6.98, -1.38, 1.72)))(theta)
calculate_karplus_j(theta, a, b, c)
Calculate 3J coupling constants using the Karplus equation.
J = a * cos^2(theta) + b * cos(theta) + c
.. important::
theta is not the same as the raw backbone dihedral phi.
For ³J(H_N, H_α) couplings, the Karplus dihedral is offset from
the IUPAC backbone angle by 60°::
theta = phi - 60° (i.e. subtract 60° from phi before calling)
Passing raw ``phi`` directly produces errors of ~3–4 Hz, invalidating
the α-helix / β-sheet discrimination.
The default parameters (A=6.51, B=−1.76, C=1.60) are for H_N–H_α
(Vuister & Bax 1993). Different spin pairs require different offsets
and parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
theta
|
ndarray
|
(N,) Karplus dihedral angle in radians. For ³J(HN,HA): pass (phi − 60°), not phi. |
required |
a
|
float
|
Cosine-squared coefficient (Hz). |
required |
b
|
float
|
Cosine coefficient (Hz). |
required |
c
|
float
|
Constant offset (Hz). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
jnp.ndarray: (N,) Calculated J-couplings. |
Source code in diff_biophys/nmr/karplus.py
Cα Chemical Shifts
diff_biophys.nmr.chemical_shifts predicts Cα chemical shifts (ppm) from
backbone torsion angles using a softmax-weighted Gaussian secondary-structure
detector in (φ, ψ) space.
Secondary chemical shifts (relative to random-coil):
| Structure | Δδ(Cα) |
|---|---|
| α-helix | +3.1 ppm |
| β-sheet | −1.5 ppm |
| Random coil | 0 ppm |
The library provides RANDOM_COIL_CA — a dictionary of per-residue random-coil
reference values.
from diff_biophys.nmr.chemical_shifts import predict_ca_shifts, RANDOM_COIL_CA
import jax, jax.numpy as jnp
n_res = 10
phi = jnp.full((n_res,), jnp.deg2rad(-57.0)) # α-helix
psi = jnp.full((n_res,), jnp.deg2rad(-47.0))
rc = jnp.full((n_res,), RANDOM_COIL_CA["ALA"])
shifts = predict_ca_shifts(phi, psi, rc) # (n_res,) in ppm
# Refine φ to match target shifts
target = shifts + 0.5 # perturb target
loss = lambda p: jnp.mean((predict_ca_shifts(p, psi, rc) - target) ** 2)
grad_phi = jax.grad(loss)(phi)
make_ca_shift_loss(exp_res_ids, exp_shifts, struct_res_ids, struct_res_names)
Build a differentiable Cα chemical shift RMSD loss.
Matches experimental residues to the structure by residue ID and builds a JAX-differentiable closure that predicts Cα shifts from backbone torsions and returns the RMSD against the matched experimental values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
exp_res_ids
|
ndarray
|
|
required |
exp_shifts
|
ndarray
|
|
required |
struct_res_ids
|
ndarray
|
|
required |
struct_res_names
|
list[str]
|
List of N three-letter residue codes, aligned with
|
required |
Returns:
| Type | Description |
|---|---|
Callable[[ndarray, ndarray], ndarray]
|
Tuple |
int
|
|
tuple[Callable[[ndarray, ndarray], ndarray], int]
|
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If no residues overlap between the experimental and structure datasets. |
Source code in diff_biophys/nmr/chemical_shifts.py
predict_ca_shifts(phi, psi, rc_shifts)
Differentiable Cα chemical shift prediction based on backbone torsions.
Uses Gaussian "soft detectors" in (Φ, Ψ) space to classify secondary structure and applies SPARTA-like offsets. The detectors are normalised via a softmax so that helix, sheet, and coil contributions always sum to 1.0, preventing unphysical double-counting.
Reference centres (radians): * α-helix: Φ = −1.05 rad (−60°), Ψ = −0.78 rad (−45°) * β-sheet: Φ = −2.09 rad (−120°), Ψ = +2.35 rad (+135°) * Random coil: treated as the baseline (weight = 1 − w_helix − w_sheet)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
phi
|
ndarray
|
(N,) backbone Φ angles in radians. |
required |
psi
|
ndarray
|
(N,) backbone Ψ angles in radians. |
required |
rc_shifts
|
ndarray
|
(N,) baseline random-coil Cα shifts (ppm). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
jnp.ndarray: (N,) predicted Cα chemical shifts (ppm). |
Source code in diff_biophys/nmr/chemical_shifts.py
Residual Dipolar Couplings (RDCs)
diff_biophys.nmr.rdc computes Residual Dipolar Couplings (Hz) from bond
vectors and a Saupe alignment tensor:
where \(\mathbf{v}\) is the unit bond vector and \(\mathbf{S}\) is the Saupe tensor (3×3 symmetric traceless matrix). At the magic angle (\(\theta \approx 54.74°\) from the alignment axis), \(D = 0\).
Also provides:
fit_saupe_tensor— SVD-based least-squares fitting of S from observed RDCscalculate_rdc— axially-symmetric simplified form (\(D \propto 3\cos^2\theta - 1\))calculate_q_factor— R-factor for RDC quality assessment
from diff_biophys.nmr.rdc import calculate_rdc_from_tensor, fit_saupe_tensor
import jax.numpy as jnp
# N–H bond unit vectors (n_res, 3)
bond_vecs = jnp.array(...)
# Saupe alignment tensor (3, 3) — axially symmetric
S = jnp.diag(jnp.array([-0.05, -0.05, 0.10]))
rdcs = calculate_rdc_from_tensor(bond_vecs, S, d_max=21585.0) # Hz
# Fit tensor from experimental RDCs
S_fit = fit_saupe_tensor(bond_vecs, rdcs_experimental)
calculate_q_factor(calculated_rdcs, experimental_rdcs)
Calculate the RDC Q-factor (Cornilescu et al., 1998). Q = sqrt( sum((D_calc - D_exp)^2) / sum(D_exp^2) )
Returns 0.0 when all experimental RDCs are zero (perfect trivial match).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
calculated_rdcs
|
ndarray
|
(N,) calculated couplings. |
required |
experimental_rdcs
|
ndarray
|
(N,) measured couplings. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
jnp.ndarray: Scalar Q-factor. |
Source code in diff_biophys/nmr/rdc.py
calculate_rdc(bond_vectors, da, r)
Differentiable RDC calculation in the principal axis frame (PAF).
.. important::
bond_vectors must be expressed in the principal axis frame
of the alignment tensor, i.e. the frame where the Saupe tensor is
diagonal. Passing lab-frame vectors will give incorrect results
without any error.
The formula used is the standard Clore/Bax convention (Clore et al. 1998, J. Magn. Reson. 133, 216–221)::
D = Da · [(3 cos²θ − 1) + (3/2) R sin²θ cos 2φ]
= Da · [(3z² − 1) + (3/2) R (x² − y²)]
where Da is the axial component and R = (Axx − Ayy) / Azz is
the rhombicity (0 ≤ R ≤ 2/3).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bond_vectors
|
ndarray
|
(N, 3) unit vectors in the tensor's principal axis frame. |
required |
da
|
float
|
Axial component Da in Hz. |
required |
r
|
float
|
Rhombicity R (0 ≤ R ≤ 2/3). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
jnp.ndarray: (N,) Calculated RDCs in Hz. |
Source code in diff_biophys/nmr/rdc.py
calculate_rdc_from_tensor(bond_vectors, saupe_tensor, d_max=1.0)
Calculate RDCs from a full 3x3 Saupe alignment tensor. D = d_max * sum_ij (v_i * S_ij * v_j)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bond_vectors
|
ndarray
|
(N, 3) unit vectors |
required |
saupe_tensor
|
ndarray
|
(3, 3) symmetric traceless Saupe tensor |
required |
d_max
|
float
|
Maximum dipolar coupling constant (Hz) |
1.0
|
Returns:
| Type | Description |
|---|---|
ndarray
|
jnp.ndarray: Calculated RDCs (N,) |
Source code in diff_biophys/nmr/rdc.py
fit_saupe_tensor(bond_vectors, experimental_rdcs, d_max=1.0)
Fit a Saupe alignment tensor to experimental RDCs using SVD (least squares).
The RDC formula can be rewritten as D = A * s where s = [Sxx, Syy, Sxy, Sxz, Syz] (5 independent components)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bond_vectors
|
ndarray
|
(N, 3) unit vectors |
required |
experimental_rdcs
|
ndarray
|
(N,) measured RDCs in Hz |
required |
d_max
|
float
|
Maximum dipolar coupling constant (Hz) |
1.0
|
Returns:
| Type | Description |
|---|---|
ndarray
|
jnp.ndarray: (3, 3) Fitted Saupe tensor |
Source code in diff_biophys/nmr/rdc.py
make_rdc_refinement_fns(exp_res_ids, exp_rdcs, struct_res_ids, d_max=21.7)
Build three callables for fixed-tensor RDC-based structure refinement.
The Saupe alignment tensor has 5 free parameters. If it is fitted inside the gradient computation the optimiser trivially drives Q→0 by finding backbone orientations that any tensor can fit — the system is severely underdetermined. The standard solution (X-PLOR/CNS/PALES) is to hold the tensor fixed during each gradient cycle and re-fit it periodically.
This factory returns:
-
loss_fn
(coords, fixed_tensor) → scalar MSEGradient flows throughcoordsonly;fixed_tensoris wrapped in :func:jax.lax.stop_gradient. Use this insidejax.grad. -
q_eval_fn
(coords) → scalar Q-factorRe-fits the tensor from scratch and returns the best-achievable Q. Never call inside a gradient — use this for monitoring only. -
make_tensor_fn
(coords) → (3, 3) Saupe tensorFits and returns the alignment tensor for periodic updates outside the gradient.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
exp_res_ids
|
ndarray
|
|
required |
exp_rdcs
|
ndarray
|
|
required |
struct_res_ids
|
ndarray
|
|
required |
d_max
|
float
|
Maximum dipolar coupling constant in Hz (default 21.7 Hz for ¹⁵N–¹H). |
21.7
|
Returns:
| Type | Description |
|---|---|
Callable[[ndarray, ndarray], ndarray]
|
Tuple |
Callable[[ndarray], ndarray]
|
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If no residues overlap between |
Example::
loss_fn, q_eval_fn, make_tensor_fn, n = make_rdc_refinement_fns(
rdc_data["res_id"], rdc_data["rdc"], res_ids
)
tensor = make_tensor_fn(initial_coords)
def total_loss(params, tensor):
phi, psi = params
coords = build(phi, psi)
return loss_fn(coords, tensor)
# Optimization loop
for i in range(n_steps):
if i % update_interval == 0:
tensor = make_tensor_fn(build(*params))
params = adam_step(total_loss, params, tensor)
Source code in diff_biophys/nmr/rdc.py
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 | |
nh_bond_vectors(coords)
Reconstruct amide N–H unit vectors from N–CA–C backbone coordinates.
The amide H lies in the peptide plane defined by C(i−1), N(i), CA(i). Its direction is approximated as anti-parallel to the bisector of the N→CA and N→C(i−1) unit vectors, placing H at ~119° from each bond — consistent with standard peptide-plane geometry.
Coordinate layout::
coords = [N₀, CA₀, C₀, N₁, CA₁, C₁, …, Nₙ, CAₙ, Cₙ]
N(i) = coords[3i]
CA(i) = coords[3i+1]
C(i-1) = coords[3i-1] (for i ≥ 1)
Residue 0 has no preceding C; its NH vector falls back to −(N→CA).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coords
|
ndarray
|
|
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
jnp.ndarray: |
Source code in diff_biophys/nmr/rdc.py
Ring Current Shifts
diff_biophys.nmr.ring_currents implements the Johnson-Bovey model of
aromatic ring current shielding. Protons directly above the ring plane are
shielded (upfield shift, negative Δδ); protons in the plane are
deshielded (downfield shift, positive Δδ).
The shift falls off as \(\sim 1/r^3\) with distance from the ring centre.
from diff_biophys.nmr.ring_currents import calculate_ring_current_shift
import jax, jax.numpy as jnp
# Proton position relative to ring centre
proton_pos = jnp.array([0.0, 0.0, 3.5]) # 3.5 Å above ring
ring_pos = jnp.array([0.0, 0.0, 0.0])
ring_normal = jnp.array([0.0, 0.0, 1.0]) # ring in xy-plane
delta = calculate_ring_current_shift(proton_pos, ring_pos, ring_normal)
# delta < 0 (shielded above the ring)
grad = jax.grad(calculate_ring_current_shift)(proton_pos, ring_pos, ring_normal)
# tells you: move the proton in which direction to maximise shielding
calculate_ring_current_shift(coords, ring_center, ring_normal, intensity)
Calculate chemical shift changes due to aromatic ring currents using the Johnson-Bovey dipolar approximation.
delta = intensity * (1 - 3*cos^2(theta)) / r^3
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coords
|
ndarray
|
(N, 3) coordinates of the nuclei being shielded. |
required |
ring_center
|
ndarray
|
(3,) coordinates of the aromatic ring center. |
required |
ring_normal
|
ndarray
|
(3,) unit vector normal to the ring plane. |
required |
intensity
|
float
|
Scaling factor (proportional to ring area and current). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
jnp.ndarray: (N,) shielding values in ppm. |
Source code in diff_biophys/nmr/ring_currents.py
Fixed-Tensor RDC Loss
FixedTensorRDCLoss
diff_integrator.terms.nmr.FixedTensorRDCLoss
A LossTerm that computes the RDC loss while keeping the Saupe alignment tensor frozen during backpropagation via jax.lax.stop_gradient. The tensor is re-fitted from current coordinates every update_interval epochs, preventing the degeneracy exploit where gradient descent drives Q→0 unphysically by distorting the tensor rather than the structure.
Constructor
FixedTensorRDCLoss(
loss_fn,
tensor_fn,
update_interval: int = 50,
n_rdcs: int | None = None,
val_q_eval_fn: Callable | None = None,
)
| Parameter | Type | Description |
|---|---|---|
loss_fn |
Callable |
Training-set RDC loss function (bond_vecs, S) → scalar returned by make_rdc_cv_refinement_fns |
tensor_fn |
Callable |
Tensor-fitting function (bond_vecs, rdcs) → S returned by make_rdc_cv_refinement_fns |
update_interval |
int |
Re-fit the Saupe tensor every this many epochs (default 50) |
n_rdcs |
int \| None |
Number of training RDCs; used by suggested_weight() |
val_q_eval_fn |
Callable \| None |
Held-out Q-factor evaluator returned by make_rdc_cv_refinement_fns; enables evaluate_validation_q() |
Methods
maybe_update_tensor(coords, epoch)
Re-fits the Saupe tensor from the current Cartesian coordinates if epoch % update_interval == 0. Should be called at the top of each training step before the gradient computation.
suggested_weight(base_weight=1.0)
Returns a weight scaled by the overdetermination ratio relative to an ideal 10× ratio:
Use this to auto-scale the RDC term so that systems with fewer RDCs are not over-penalised.
evaluate_validation_q(coords)
Returns the Q-factor on the held-out cross-validation split, or None if val_q_eval_fn was not provided.
make_rdc_cv_refinement_fns
diff_integrator.terms.nmr.make_rdc_cv_refinement_fns
Factory function that partitions experimental RDCs into a training set and a held-out cross-validation set, then returns pre-built closures ready for FixedTensorRDCLoss.
Signature
make_rdc_cv_refinement_fns(
rdc_res_ids: ArrayLike,
exp_rdcs: ArrayLike,
struct_res_ids: ArrayLike,
cv_fraction: float = 0.2,
) -> tuple[Callable, Callable, Callable, Callable, int, int]
| Parameter | Type | Description |
|---|---|---|
rdc_res_ids |
ArrayLike |
Residue IDs corresponding to each experimental RDC value |
exp_rdcs |
ArrayLike |
Experimental RDC values (Hz) |
struct_res_ids |
ArrayLike |
Residue IDs present in the structure (used for index mapping) |
cv_fraction |
float |
Fraction of RDCs to hold out for cross-validation (default 0.2) |
Returns (loss_fn, q_eval_fn, tensor_fn, val_q_fn, n_train, n_val)
| Return value | Description |
|---|---|
loss_fn |
Training-set loss closure (bond_vecs, S) → scalar |
q_eval_fn |
Training-set Q-factor evaluator (bond_vecs, S) → float |
tensor_fn |
Tensor-fitting closure (bond_vecs) → S |
val_q_fn |
Validation Q-factor evaluator (bond_vecs, S) → float |
n_train |
Number of training RDCs |
n_val |
Number of held-out validation RDCs |
EarlyStopping
EarlyStopping
diff_integrator.optimizer.EarlyStopping
A dataclass passed as early_stopping= to IntegrativeRefiner.run() that halts training when a monitored loss term stops improving.
Fields
| Field | Type | Default | Description |
|---|---|---|---|
term_index |
int |
— | Index into JointLoss.terms of the term to monitor (unweighted value) |
patience |
int |
— | Number of epochs with no improvement before stopping |
min_delta |
float |
1e-5 |
Minimum change in monitored value to count as an improvement |
mode |
str |
"min" |
"min" for loss-type metrics (lower is better); "max" for score-type metrics (higher is better) |
Behaviour
- The unweighted value of the monitored term (not the contribution to total loss) is tracked.
- When no improvement exceeding
min_deltaoccurs forpatienceconsecutive epochs, training stops and the best-checkpoint parameters are returned. - The stopping event is recorded in
RefinementResult.stopped_earlyandRefinementResult.early_stopping_triggered_by.
Example
from diff_integrator.optimizer import EarlyStopping, IntegrativeRefiner
refiner = IntegrativeRefiner(loss_fn=joint_loss)
result = refiner.run(
init_params=starting_coords,
epochs=2000,
learning_rate=0.005,
early_stopping=EarlyStopping(
term_index=1, # monitor RDC term (index 1 in JointLoss)
patience=50,
min_delta=1e-4,
mode="min",
),
)
print(f"Stopped at epoch {result.best_epoch} / 2000")
print(f"Triggered by term {result.early_stopping_triggered_by}")
NOE Distance Restraints
NOELoss
diff_integrator.terms.noe.NOELoss
A LossTerm implementing the standard flat-bottomed harmonic NOE (Nuclear Overhauser
Effect) distance restraint used in XPLOR, CNS, and ARIA. The energy is zero when a
distance is within bounds and grows quadratically once a bound is violated:
where \(M\) is the number of restraints and \(k\) is force_const. A mean (not sum) is
used so the weight has a consistent interpretation across datasets of different sizes.
Constructor
NOELoss(
atom_pairs: jnp.ndarray, # (M, 2) integer atom index pairs
d_upper: jnp.ndarray, # (M,) upper-bound distances in Å
d_lower: jnp.ndarray | None = None, # (M,) optional lower bounds
force_const: float = 1.0,
)
| Parameter | Type | Description |
|---|---|---|
atom_pairs |
(M, 2) int |
Each row [i, j] defines one restraint between atoms i and j |
d_upper |
(M,) float |
Upper-bound distances in Å — penalty fires when d > d_upper |
d_lower |
(M,) float \| None |
Lower-bound distances in Å — penalty fires when d < d_lower. None (default) means upper-bound only, which is the standard NMR convention |
force_const |
float |
Harmonic force constant. Default 1.0. Typical values: 5–50 |
name attribute: "noe"
Methods
count_violations(coords) → dict[str, int]
Returns {"upper": n, "lower": n, "total": n} — count of atoms violating each bound.
Pure diagnostic; not used in the gradient.
rms_violation(coords) → float
Root-mean-square distance violation across all restraints, in Å.
Property
n_restraints → int
Number of distance restraints.
Example
from diff_integrator.terms.noe import NOELoss
noe_loss = NOELoss(
atom_pairs = jnp.array([[4, 31], [7, 52]]), # Cα pairs
d_upper = jnp.array([6.0, 4.5]), # Å
d_lower = jnp.array([1.8, 1.8]), # Å (optional)
force_const = 10.0,
)
make_noe_restraints
diff_integrator.terms.noe.make_noe_restraints
Factory that maps (res_id, atom_name) observations to flat atom indices using the
structure's residue ordering, then returns a ready-to-use NOELoss.
Signature
make_noe_restraints(
noe_list: list[dict],
res_ids: np.ndarray,
atom_names: list[str] | None = None, # default ["N", "CA", "C"]
force_const: float = 1.0,
) -> NOELoss
Each dict in noe_list must contain:
| Key | Type | Description |
|---|---|---|
"res_i" |
int |
Residue number of atom i |
"atom_i" |
str |
Atom name of atom i (e.g. "CA") |
"res_j" |
int |
Residue number of atom j |
"atom_j" |
str |
Atom name of atom j |
"d_upper" |
float |
Upper-bound distance in Å |
"d_lower" |
float |
(optional) Lower-bound distance in Å |
Example
from diff_integrator.terms.noe import make_noe_restraints
import numpy as np
noe_observations = [
{"res_i": 5, "atom_i": "CA", "res_j": 20, "atom_j": "CA",
"d_upper": 6.0, "d_lower": 1.8},
{"res_i": 12, "atom_i": "N", "res_j": 45, "atom_j": "CA",
"d_upper": 5.5},
]
noe_term = make_noe_restraints(noe_observations, struct_res_ids)
violations = noe_term.count_violations(coords)
print(f"{violations['total']} NOE violations at current coordinates")