Skip to content

⚛️ NMR API

The diff_biophys.nmr subpackage implements four categories of differentiable NMR observables. Each kernel accepts JAX arrays and returns JAX arrays, making them trivially composable into multi-observable loss functions.


Karplus J-coupling

diff_biophys.nmr.karplus implements the Karplus equation relating a three-bond scalar coupling constant \(^3J\) (Hz) to a backbone dihedral angle:

\[J(\theta) = A\cos^2(\theta) + B\cos(\theta) + C\]

For the backbone HN–Hα coupling, the dihedral \(\theta = \phi - 60°\) (Vuister & Bax 1993 offset convention). Default parameters: \(A = 6.98\), \(B = -1.38\), \(C = 1.72\) Hz.

Gradient meaning: \(\partial J / \partial \theta\) tells you how fast the coupling changes with the backbone angle — the key quantity for J-coupling-driven structure refinement.

from diff_biophys.nmr.karplus import calculate_karplus_j
import jax, jax.numpy as jnp

phi = jnp.array([-0.995])          # φ in radians (≈ −57° helix)
theta = phi - jnp.deg2rad(60.0)   # offset for HN-Hα

J = calculate_karplus_j(theta, A=6.98, B=-1.38, C=1.72)   # → [Hz]
dJ_dtheta = jax.grad(lambda t: jnp.sum(calculate_karplus_j(t, 6.98, -1.38, 1.72)))(theta)

calculate_karplus_j(theta, a, b, c)

Calculate 3J coupling constants using the Karplus equation.

J = a * cos^2(theta) + b * cos(theta) + c

.. important:: theta is not the same as the raw backbone dihedral phi. For ³J(H_N, H_α) couplings, the Karplus dihedral is offset from the IUPAC backbone angle by 60°::

    theta = phi - 60°   (i.e. subtract 60° from phi before calling)

Passing raw ``phi`` directly produces errors of ~3–4 Hz, invalidating
the α-helix / β-sheet discrimination.

The default parameters (A=6.51, B=−1.76, C=1.60) are for H_N–H_α
(Vuister & Bax 1993).  Different spin pairs require different offsets
and parameters.

Parameters:

Name Type Description Default
theta ndarray

(N,) Karplus dihedral angle in radians. For ³J(HN,HA): pass (phi − 60°), not phi.

required
a float

Cosine-squared coefficient (Hz).

required
b float

Cosine coefficient (Hz).

required
c float

Constant offset (Hz).

required

Returns:

Type Description
ndarray

jnp.ndarray: (N,) Calculated J-couplings.

Source code in diff_biophys/nmr/karplus.py
@jit
def calculate_karplus_j(theta: jnp.ndarray, a: float, b: float, c: float) -> jnp.ndarray:
    """
    Calculate 3J coupling constants using the Karplus equation.

    J = a * cos^2(theta) + b * cos(theta) + c

    .. important::
        ``theta`` is **not** the same as the raw backbone dihedral ``phi``.
        For ³J(H_N, H_α) couplings, the Karplus dihedral is offset from
        the IUPAC backbone angle by 60°::

            theta = phi - 60°   (i.e. subtract 60° from phi before calling)

        Passing raw ``phi`` directly produces errors of ~3–4 Hz, invalidating
        the α-helix / β-sheet discrimination.

        The default parameters (A=6.51, B=−1.76, C=1.60) are for H_N–H_α
        (Vuister & Bax 1993).  Different spin pairs require different offsets
        and parameters.

    Args:
        theta: (N,) Karplus dihedral angle in radians.
               For ³J(HN,HA): pass (phi − 60°), **not** phi.
        a: Cosine-squared coefficient (Hz).
        b: Cosine coefficient (Hz).
        c: Constant offset (Hz).

    Returns:
        jnp.ndarray: (N,) Calculated J-couplings.
    """
    cos_theta = jnp.cos(theta)
    return a * (cos_theta**2) + b * cos_theta + c

Cα Chemical Shifts

diff_biophys.nmr.chemical_shifts predicts Cα chemical shifts (ppm) from backbone torsion angles using a softmax-weighted Gaussian secondary-structure detector in (φ, ψ) space.

Secondary chemical shifts (relative to random-coil):

Structure Δδ(Cα)
α-helix +3.1 ppm
β-sheet −1.5 ppm
Random coil 0 ppm

The library provides RANDOM_COIL_CA — a dictionary of per-residue random-coil reference values.

from diff_biophys.nmr.chemical_shifts import predict_ca_shifts, RANDOM_COIL_CA
import jax, jax.numpy as jnp

n_res = 10
phi = jnp.full((n_res,), jnp.deg2rad(-57.0))   # α-helix
psi = jnp.full((n_res,), jnp.deg2rad(-47.0))
rc  = jnp.full((n_res,), RANDOM_COIL_CA["ALA"])

shifts = predict_ca_shifts(phi, psi, rc)   # (n_res,) in ppm

# Refine φ to match target shifts
target = shifts + 0.5   # perturb target
loss = lambda p: jnp.mean((predict_ca_shifts(p, psi, rc) - target) ** 2)
grad_phi = jax.grad(loss)(phi)

make_ca_shift_loss(exp_res_ids, exp_shifts, struct_res_ids, struct_res_names)

Build a differentiable Cα chemical shift RMSD loss.

Matches experimental residues to the structure by residue ID and builds a JAX-differentiable closure that predicts Cα shifts from backbone torsions and returns the RMSD against the matched experimental values.

Parameters:

Name Type Description Default
exp_res_ids ndarray

(M,) residue IDs from the experimental dataset.

required
exp_shifts ndarray

(M,) experimental Cα chemical shifts in ppm.

required
struct_res_ids ndarray

(N,) residue IDs present in the structure.

required
struct_res_names list[str]

List of N three-letter residue codes, aligned with struct_res_ids.

required

Returns:

Type Description
Callable[[ndarray, ndarray], ndarray]

Tuple (loss_fn, n_matched) where:

int
  • loss_fn (phi, psi) → scalar RMSD (ppm) — differentiable with respect to both torsion arrays.
tuple[Callable[[ndarray, ndarray], ndarray], int]
  • n_matched — number of residues found in both datasets.

Raises:

Type Description
ValueError

If no residues overlap between the experimental and structure datasets.

Source code in diff_biophys/nmr/chemical_shifts.py
def make_ca_shift_loss(
    exp_res_ids: np.ndarray,
    exp_shifts: np.ndarray,
    struct_res_ids: np.ndarray,
    struct_res_names: list[str],
) -> tuple[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], int]:
    """Build a differentiable Cα chemical shift RMSD loss.

    Matches experimental residues to the structure by residue ID and builds a
    JAX-differentiable closure that predicts Cα shifts from backbone torsions
    and returns the RMSD against the matched experimental values.

    Args:
        exp_res_ids: ``(M,)`` residue IDs from the experimental dataset.
        exp_shifts: ``(M,)`` experimental Cα chemical shifts in ppm.
        struct_res_ids: ``(N,)`` residue IDs present in the structure.
        struct_res_names: List of N three-letter residue codes, aligned with
            ``struct_res_ids``.

    Returns:
        Tuple ``(loss_fn, n_matched)`` where:

        * **loss_fn** ``(phi, psi) → scalar RMSD (ppm)`` — differentiable with
          respect to both torsion arrays.
        * **n_matched** — number of residues found in both datasets.

    Raises:
        ValueError: If no residues overlap between the experimental and
            structure datasets.
    """
    res_id_to_idx = {int(rid): i for i, rid in enumerate(struct_res_ids)}

    matched_struct_idx: list[int] = []
    matched_exp: list[float] = []
    matched_names: list[str] = []

    for rid, shift in zip(exp_res_ids, exp_shifts, strict=False):
        if int(rid) in res_id_to_idx:
            si = res_id_to_idx[int(rid)]
            matched_struct_idx.append(si)
            matched_exp.append(float(shift))
            matched_names.append(struct_res_names[si])

    if not matched_struct_idx:
        raise ValueError(
            "No residues overlap between exp_res_ids and struct_res_ids. "
            "Check that residue numbering conventions match."
        )

    rc = np.array([RANDOM_COIL_CA.get(name, 55.0) for name in matched_names], dtype=np.float32)
    rc_jax = jnp.array(rc)
    exp_jax = jnp.array(matched_exp, dtype=jnp.float32)
    idx_jax = jnp.array(matched_struct_idx, dtype=jnp.int32)
    n_matched = len(matched_struct_idx)

    def loss_fn(phi: jnp.ndarray, psi: jnp.ndarray) -> jnp.ndarray:
        """Cα shift RMSD loss.

        Args:
            phi: ``(N,)`` backbone φ angles in radians.
            psi: ``(N,)`` backbone ψ angles in radians.

        Returns:
            jnp.ndarray: Scalar RMSD in ppm.
        """
        pred = predict_ca_shifts(phi[idx_jax], psi[idx_jax], rc_jax)
        return jnp.sqrt(jnp.mean((pred - exp_jax) ** 2))

    return loss_fn, n_matched

predict_ca_shifts(phi, psi, rc_shifts)

Differentiable Cα chemical shift prediction based on backbone torsions.

Uses Gaussian "soft detectors" in (Φ, Ψ) space to classify secondary structure and applies SPARTA-like offsets. The detectors are normalised via a softmax so that helix, sheet, and coil contributions always sum to 1.0, preventing unphysical double-counting.

Reference centres (radians): * α-helix: Φ = −1.05 rad (−60°), Ψ = −0.78 rad (−45°) * β-sheet: Φ = −2.09 rad (−120°), Ψ = +2.35 rad (+135°) * Random coil: treated as the baseline (weight = 1 − w_helix − w_sheet)

Parameters:

Name Type Description Default
phi ndarray

(N,) backbone Φ angles in radians.

required
psi ndarray

(N,) backbone Ψ angles in radians.

required
rc_shifts ndarray

(N,) baseline random-coil Cα shifts (ppm).

required

Returns:

Type Description
ndarray

jnp.ndarray: (N,) predicted Cα chemical shifts (ppm).

Source code in diff_biophys/nmr/chemical_shifts.py
@jit
def predict_ca_shifts(phi: jnp.ndarray, psi: jnp.ndarray, rc_shifts: jnp.ndarray) -> jnp.ndarray:
    """
    Differentiable Cα chemical shift prediction based on backbone torsions.

    Uses Gaussian "soft detectors" in (Φ, Ψ) space to classify secondary
    structure and applies SPARTA-like offsets.  The detectors are normalised
    via a softmax so that helix, sheet, and coil contributions always sum to
    1.0, preventing unphysical double-counting.

    Reference centres (radians):
        * α-helix: Φ = −1.05 rad (−60°), Ψ = −0.78 rad (−45°)
        * β-sheet:  Φ = −2.09 rad (−120°), Ψ = +2.35 rad (+135°)
        * Random coil: treated as the baseline (weight = 1 − w_helix − w_sheet)

    Args:
        phi: (N,) backbone Φ angles in radians.
        psi: (N,) backbone Ψ angles in radians.
        rc_shifts: (N,) baseline random-coil Cα shifts (ppm).

    Returns:
        jnp.ndarray: (N,) predicted Cα chemical shifts (ppm).
    """
    # --- Unnormalised Gaussian affinities ---
    # Alpha-helix centre: Φ ~ −60°, Ψ ~ −45°
    helix_dist_sq = (phi + 1.05) ** 2 + (psi + 0.78) ** 2
    w_helix_raw = jnp.exp(-helix_dist_sq / _SS_SIGMA_SQ)

    # Beta-sheet centre: Φ ~ −120°, Ψ ~ +135°
    sheet_dist_sq = (phi + 2.09) ** 2 + (psi - 2.35) ** 2
    w_sheet_raw = jnp.exp(-sheet_dist_sq / _SS_SIGMA_SQ)

    # Coil baseline: all residues start with weight 1 (i.e. the neutral state)
    w_coil_raw = jnp.ones_like(phi)

    # --- Softmax normalisation ---
    # Ensures w_helix + w_sheet + w_coil = 1 for every residue,
    # preventing simultaneous helix + sheet double-counting.
    total = w_helix_raw + w_sheet_raw + w_coil_raw
    w_helix = w_helix_raw / total
    w_sheet = w_sheet_raw / total
    # w_coil = w_coil_raw / total  (implicit; contributes zero offset)

    # NOTE — effective offset cap (Issue 3):
    # Because w_coil_raw is fixed at 1.0, at the *exact* helix or sheet centre
    # (where the Gaussian peak = 1 and the opposing class ≈ 0) the denominator
    # is 1 + 0 + 1 = 2, so w_helix_norm ≈ 0.5.  This means the maximum Cα
    # shift a perfectly helical residue receives is:
    #
    #   0.5 × OFFSET_HELIX  ≈  +1.55 ppm  (not the full +3.1 ppm)
    #
    # This is a deliberate approximation: the coil baseline acts as a Bayesian
    # prior that prevents runaway shifts.  It also means the predictor underestimates
    # pure-helix / pure-sheet shifts by ~50% relative to SPARTA+.  Users who need
    # quantitative SPARTA+ parity should either:
    #   (a) reduce _SS_SIGMA_SQ (sharper Gaussians → w_helix closer to 1), or
    #   (b) double OFFSET_HELIX / OFFSET_SHEET to compensate.

    # --- Weighted offset ---
    # Coil weight contributes 0 offset (it is the RC baseline).
    return rc_shifts + (w_helix * OFFSET_HELIX) + (w_sheet * OFFSET_SHEET)

Residual Dipolar Couplings (RDCs)

diff_biophys.nmr.rdc computes Residual Dipolar Couplings (Hz) from bond vectors and a Saupe alignment tensor:

\[D_{NH} = D_{\max} \sum_{i,j} v_i S_{ij} v_j\]

where \(\mathbf{v}\) is the unit bond vector and \(\mathbf{S}\) is the Saupe tensor (3×3 symmetric traceless matrix). At the magic angle (\(\theta \approx 54.74°\) from the alignment axis), \(D = 0\).

Also provides:

  • fit_saupe_tensor — SVD-based least-squares fitting of S from observed RDCs
  • calculate_rdc — axially-symmetric simplified form (\(D \propto 3\cos^2\theta - 1\))
  • calculate_q_factor — R-factor for RDC quality assessment
from diff_biophys.nmr.rdc import calculate_rdc_from_tensor, fit_saupe_tensor
import jax.numpy as jnp

# N–H bond unit vectors  (n_res, 3)
bond_vecs = jnp.array(...)

# Saupe alignment tensor (3, 3) — axially symmetric
S = jnp.diag(jnp.array([-0.05, -0.05, 0.10]))

rdcs = calculate_rdc_from_tensor(bond_vecs, S, d_max=21585.0)   # Hz

# Fit tensor from experimental RDCs
S_fit = fit_saupe_tensor(bond_vecs, rdcs_experimental)

calculate_q_factor(calculated_rdcs, experimental_rdcs)

Calculate the RDC Q-factor (Cornilescu et al., 1998). Q = sqrt( sum((D_calc - D_exp)^2) / sum(D_exp^2) )

Returns 0.0 when all experimental RDCs are zero (perfect trivial match).

Parameters:

Name Type Description Default
calculated_rdcs ndarray

(N,) calculated couplings.

required
experimental_rdcs ndarray

(N,) measured couplings.

required

Returns:

Type Description
ndarray

jnp.ndarray: Scalar Q-factor.

Source code in diff_biophys/nmr/rdc.py
@jit
def calculate_q_factor(calculated_rdcs: jnp.ndarray, experimental_rdcs: jnp.ndarray) -> jnp.ndarray:
    """
    Calculate the RDC Q-factor (Cornilescu et al., 1998).
    Q = sqrt( sum((D_calc - D_exp)^2) / sum(D_exp^2) )

    Returns 0.0 when all experimental RDCs are zero (perfect trivial match).

    Args:
        calculated_rdcs: (N,) calculated couplings.
        experimental_rdcs: (N,) measured couplings.

    Returns:
        jnp.ndarray: Scalar Q-factor.
    """
    diff_sq = jnp.sum((calculated_rdcs - experimental_rdcs) ** 2)
    exp_sq = jnp.sum(experimental_rdcs**2)

    # Robust Q-factor calculation to avoid NaN gradients at zero experimental RDCs.
    # We use a safe denominator for the division and then mask the result.
    q = jnp.sqrt(diff_sq / jnp.maximum(exp_sq, 1e-10))
    return jnp.where(exp_sq > 0.0, q, 0.0)

calculate_rdc(bond_vectors, da, r)

Differentiable RDC calculation in the principal axis frame (PAF).

.. important:: bond_vectors must be expressed in the principal axis frame of the alignment tensor, i.e. the frame where the Saupe tensor is diagonal. Passing lab-frame vectors will give incorrect results without any error.

The formula used is the standard Clore/Bax convention (Clore et al. 1998, J. Magn. Reson. 133, 216–221)::

D = Da · [(3 cos²θ − 1) + (3/2) R sin²θ cos 2φ]
  = Da · [(3z² − 1) + (3/2) R (x² − y²)]

where Da is the axial component and R = (Axx − Ayy) / Azz is the rhombicity (0 ≤ R ≤ 2/3).

Parameters:

Name Type Description Default
bond_vectors ndarray

(N, 3) unit vectors in the tensor's principal axis frame.

required
da float

Axial component Da in Hz.

required
r float

Rhombicity R (0 ≤ R ≤ 2/3).

required

Returns:

Type Description
ndarray

jnp.ndarray: (N,) Calculated RDCs in Hz.

Source code in diff_biophys/nmr/rdc.py
@jit
def calculate_rdc(bond_vectors: jnp.ndarray, da: float, r: float) -> jnp.ndarray:
    """
    Differentiable RDC calculation in the principal axis frame (PAF).

    .. important::
        ``bond_vectors`` **must be expressed in the principal axis frame**
        of the alignment tensor, i.e. the frame where the Saupe tensor is
        diagonal.  Passing lab-frame vectors will give incorrect results
        without any error.

    The formula used is the standard Clore/Bax convention
    (Clore et al. 1998, *J. Magn. Reson.* **133**, 216–221)::

        D = Da · [(3 cos²θ − 1) + (3/2) R sin²θ cos 2φ]
          = Da · [(3z² − 1) + (3/2) R (x² − y²)]

    where ``Da`` is the axial component and ``R = (Axx − Ayy) / Azz`` is
    the rhombicity (0 ≤ R ≤ 2/3).

    Args:
        bond_vectors: (N, 3) unit vectors in the tensor's principal axis frame.
        da: Axial component Da in Hz.
        r: Rhombicity R (0 ≤ R ≤ 2/3).

    Returns:
        jnp.ndarray: (N,) Calculated RDCs in Hz.
    """
    x, y, z = bond_vectors[:, 0], bond_vectors[:, 1], bond_vectors[:, 2]

    axial = 3.0 * z**2 - 1.0
    rhombic = 1.5 * r * (x**2 - y**2)

    return da * (axial + rhombic)

calculate_rdc_from_tensor(bond_vectors, saupe_tensor, d_max=1.0)

Calculate RDCs from a full 3x3 Saupe alignment tensor. D = d_max * sum_ij (v_i * S_ij * v_j)

Parameters:

Name Type Description Default
bond_vectors ndarray

(N, 3) unit vectors

required
saupe_tensor ndarray

(3, 3) symmetric traceless Saupe tensor

required
d_max float

Maximum dipolar coupling constant (Hz)

1.0

Returns:

Type Description
ndarray

jnp.ndarray: Calculated RDCs (N,)

Source code in diff_biophys/nmr/rdc.py
@jit
def calculate_rdc_from_tensor(
    bond_vectors: jnp.ndarray, saupe_tensor: jnp.ndarray, d_max: float = 1.0
) -> jnp.ndarray:
    """
    Calculate RDCs from a full 3x3 Saupe alignment tensor.
    D = d_max * sum_ij (v_i * S_ij * v_j)

    Args:
        bond_vectors: (N, 3) unit vectors
        saupe_tensor: (3, 3) symmetric traceless Saupe tensor
        d_max: Maximum dipolar coupling constant (Hz)

    Returns:
        jnp.ndarray: Calculated RDCs (N,)
    """
    # Vectorized computation of v^T S v
    return d_max * jnp.einsum("ni,ij,nj->n", bond_vectors, saupe_tensor, bond_vectors)

fit_saupe_tensor(bond_vectors, experimental_rdcs, d_max=1.0)

Fit a Saupe alignment tensor to experimental RDCs using SVD (least squares).

The RDC formula can be rewritten as D = A * s where s = [Sxx, Syy, Sxy, Sxz, Syz] (5 independent components)

Parameters:

Name Type Description Default
bond_vectors ndarray

(N, 3) unit vectors

required
experimental_rdcs ndarray

(N,) measured RDCs in Hz

required
d_max float

Maximum dipolar coupling constant (Hz)

1.0

Returns:

Type Description
ndarray

jnp.ndarray: (3, 3) Fitted Saupe tensor

Source code in diff_biophys/nmr/rdc.py
@jit
def fit_saupe_tensor(
    bond_vectors: jnp.ndarray, experimental_rdcs: jnp.ndarray, d_max: float = 1.0
) -> jnp.ndarray:
    """
    Fit a Saupe alignment tensor to experimental RDCs using SVD (least squares).

    The RDC formula can be rewritten as D = A * s
    where s = [Sxx, Syy, Sxy, Sxz, Syz] (5 independent components)

    Args:
        bond_vectors: (N, 3) unit vectors
        experimental_rdcs: (N,) measured RDCs in Hz
        d_max: Maximum dipolar coupling constant (Hz)

    Returns:
        jnp.ndarray: (3, 3) Fitted Saupe tensor
    """
    x = bond_vectors[:, 0]
    y = bond_vectors[:, 1]
    z = bond_vectors[:, 2]

    # Basis functions for the 5 independent components
    # Using the identity Szz = -Sxx - Syy
    # D = d_max * [ Sxx*x^2 + Syy*y^2 + Szz*z^2 + 2Sxy*xy + 2Sxz*xz + 2Syz*yz ]
    # D = d_max * [ Sxx(x^2 - z^2) + Syy(y^2 - z^2) + 2Sxy*xy + 2Sxz*xz + 2Syz*yz ]

    A = d_max * jnp.stack([x**2 - z**2, y**2 - z**2, 2 * x * y, 2 * x * z, 2 * y * z], axis=1)

    # Solve A * s = experimental_rdcs
    s, _, _, _ = jnp.linalg.lstsq(A, experimental_rdcs)

    sxx, syy, sxy, sxz, syz = s
    szz = -(sxx + syy)

    tensor = jnp.array([[sxx, sxy, sxz], [sxy, syy, syz], [sxz, syz, szz]])

    return tensor

make_rdc_refinement_fns(exp_res_ids, exp_rdcs, struct_res_ids, d_max=21.7)

Build three callables for fixed-tensor RDC-based structure refinement.

The Saupe alignment tensor has 5 free parameters. If it is fitted inside the gradient computation the optimiser trivially drives Q→0 by finding backbone orientations that any tensor can fit — the system is severely underdetermined. The standard solution (X-PLOR/CNS/PALES) is to hold the tensor fixed during each gradient cycle and re-fit it periodically.

This factory returns:

  • loss_fn (coords, fixed_tensor) → scalar MSE Gradient flows through coords only; fixed_tensor is wrapped in :func:jax.lax.stop_gradient. Use this inside jax.grad.

  • q_eval_fn (coords) → scalar Q-factor Re-fits the tensor from scratch and returns the best-achievable Q. Never call inside a gradient — use this for monitoring only.

  • make_tensor_fn (coords) → (3, 3) Saupe tensor Fits and returns the alignment tensor for periodic updates outside the gradient.

Parameters:

Name Type Description Default
exp_res_ids ndarray

(M,) residue IDs where RDCs were measured.

required
exp_rdcs ndarray

(M,) experimental RDC values in Hz.

required
struct_res_ids ndarray

(N,) residue IDs present in the structure.

required
d_max float

Maximum dipolar coupling constant in Hz (default 21.7 Hz for ¹⁵N–¹H).

21.7

Returns:

Type Description
Callable[[ndarray, ndarray], ndarray]

Tuple (loss_fn, q_eval_fn, make_tensor_fn, n_matched) where

Callable[[ndarray], ndarray]

n_matched is the number of residues found in both datasets.

Raises:

Type Description
ValueError

If no residues overlap between exp_res_ids and struct_res_ids.

Example::

loss_fn, q_eval_fn, make_tensor_fn, n = make_rdc_refinement_fns(
    rdc_data["res_id"], rdc_data["rdc"], res_ids
)
tensor = make_tensor_fn(initial_coords)

def total_loss(params, tensor):
    phi, psi = params
    coords = build(phi, psi)
    return loss_fn(coords, tensor)

# Optimization loop
for i in range(n_steps):
    if i % update_interval == 0:
        tensor = make_tensor_fn(build(*params))
    params = adam_step(total_loss, params, tensor)
Source code in diff_biophys/nmr/rdc.py
def make_rdc_refinement_fns(
    exp_res_ids: np.ndarray,
    exp_rdcs: np.ndarray,
    struct_res_ids: np.ndarray,
    d_max: float = 21.7,
) -> tuple[
    Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
    Callable[[jnp.ndarray], jnp.ndarray],
    Callable[[jnp.ndarray], jnp.ndarray],
    int,
]:
    """Build three callables for fixed-tensor RDC-based structure refinement.

    The Saupe alignment tensor has 5 free parameters.  If it is fitted inside
    the gradient computation the optimiser trivially drives Q→0 by finding
    backbone orientations that any tensor can fit — the system is severely
    underdetermined.  The standard solution (X-PLOR/CNS/PALES) is to hold the
    tensor fixed during each gradient cycle and re-fit it periodically.

    This factory returns:

    * **loss_fn** ``(coords, fixed_tensor) → scalar MSE``
      Gradient flows through ``coords`` only; ``fixed_tensor`` is wrapped in
      :func:`jax.lax.stop_gradient`.  Use this inside ``jax.grad``.

    * **q_eval_fn** ``(coords) → scalar Q-factor``
      Re-fits the tensor from scratch and returns the best-achievable Q.
      Never call inside a gradient — use this for monitoring only.

    * **make_tensor_fn** ``(coords) → (3, 3) Saupe tensor``
      Fits and returns the alignment tensor for periodic updates outside the
      gradient.

    Args:
        exp_res_ids: ``(M,)`` residue IDs where RDCs were measured.
        exp_rdcs: ``(M,)`` experimental RDC values in Hz.
        struct_res_ids: ``(N,)`` residue IDs present in the structure.
        d_max: Maximum dipolar coupling constant in Hz (default 21.7 Hz for
            ¹⁵N–¹H).

    Returns:
        Tuple ``(loss_fn, q_eval_fn, make_tensor_fn, n_matched)`` where
        ``n_matched`` is the number of residues found in both datasets.

    Raises:
        ValueError: If no residues overlap between ``exp_res_ids`` and
            ``struct_res_ids``.

    Example::

        loss_fn, q_eval_fn, make_tensor_fn, n = make_rdc_refinement_fns(
            rdc_data["res_id"], rdc_data["rdc"], res_ids
        )
        tensor = make_tensor_fn(initial_coords)

        def total_loss(params, tensor):
            phi, psi = params
            coords = build(phi, psi)
            return loss_fn(coords, tensor)

        # Optimization loop
        for i in range(n_steps):
            if i % update_interval == 0:
                tensor = make_tensor_fn(build(*params))
            params = adam_step(total_loss, params, tensor)
    """
    res_id_to_idx = {int(rid): i for i, rid in enumerate(struct_res_ids)}

    matched_struct_idx: list[int] = []
    matched_rdcs: list[float] = []
    for rid, rdc in zip(exp_res_ids, exp_rdcs, strict=False):
        if int(rid) in res_id_to_idx:
            matched_struct_idx.append(res_id_to_idx[int(rid)])
            matched_rdcs.append(float(rdc))

    if not matched_struct_idx:
        raise ValueError(
            "No residues overlap between exp_res_ids and struct_res_ids. "
            "Check that residue numbering conventions match."
        )

    n_matched = len(matched_struct_idx)
    exp_jax = jnp.array(matched_rdcs, dtype=jnp.float32)
    idx_jax = jnp.array(matched_struct_idx, dtype=jnp.int32)

    def _matched_nh(coords: jnp.ndarray) -> jnp.ndarray:
        return cast(jnp.ndarray, nh_bond_vectors(coords)[idx_jax])  # (n_matched, 3)

    def loss_fn(coords: jnp.ndarray, fixed_tensor: jnp.ndarray) -> jnp.ndarray:
        """Fixed-tensor MSE loss; gradient flows through coords only."""
        tensor = jax.lax.stop_gradient(fixed_tensor)
        nh = _matched_nh(coords)
        calc = calculate_rdc_from_tensor(nh, tensor, d_max=d_max)
        return cast(jnp.ndarray, jnp.mean((calc - exp_jax) ** 2))

    def q_eval_fn(coords: jnp.ndarray) -> jnp.ndarray:
        """Re-fit tensor and return honest Q-factor (monitoring only)."""
        nh = _matched_nh(coords)
        tensor = fit_saupe_tensor(nh, exp_jax, d_max=d_max)
        calc = calculate_rdc_from_tensor(nh, tensor, d_max=d_max)
        return cast(jnp.ndarray, calculate_q_factor(calc, exp_jax))

    def make_tensor_fn(coords: jnp.ndarray) -> jnp.ndarray:
        """Fit and return the Saupe tensor for periodic updates."""
        nh = _matched_nh(coords)
        return cast(jnp.ndarray, fit_saupe_tensor(nh, exp_jax, d_max=d_max))

    return loss_fn, q_eval_fn, make_tensor_fn, n_matched

nh_bond_vectors(coords)

Reconstruct amide N–H unit vectors from N–CA–C backbone coordinates.

The amide H lies in the peptide plane defined by C(i−1), N(i), CA(i). Its direction is approximated as anti-parallel to the bisector of the N→CA and N→C(i−1) unit vectors, placing H at ~119° from each bond — consistent with standard peptide-plane geometry.

Coordinate layout::

coords = [N₀, CA₀, C₀,  N₁, CA₁, C₁,  …,  Nₙ, CAₙ, Cₙ]
N(i)   = coords[3i]
CA(i)  = coords[3i+1]
C(i-1) = coords[3i-1]  (for i ≥ 1)

Residue 0 has no preceding C; its NH vector falls back to −(N→CA).

Parameters:

Name Type Description Default
coords ndarray

(3N, 3) backbone atom coordinates.

required

Returns:

Type Description
ndarray

jnp.ndarray: (N, 3) unit vectors pointing N→H.

Source code in diff_biophys/nmr/rdc.py
@jit
def nh_bond_vectors(coords: jnp.ndarray) -> jnp.ndarray:
    """Reconstruct amide N–H unit vectors from N–CA–C backbone coordinates.

    The amide H lies in the peptide plane defined by C(i−1), N(i), CA(i).
    Its direction is approximated as anti-parallel to the bisector of the
    N→CA and N→C(i−1) unit vectors, placing H at ~119° from each bond —
    consistent with standard peptide-plane geometry.

    Coordinate layout::

        coords = [N₀, CA₀, C₀,  N₁, CA₁, C₁,  …,  Nₙ, CAₙ, Cₙ]
        N(i)   = coords[3i]
        CA(i)  = coords[3i+1]
        C(i-1) = coords[3i-1]  (for i ≥ 1)

    Residue 0 has no preceding C; its NH vector falls back to −(N→CA).

    Args:
        coords: ``(3N, 3)`` backbone atom coordinates.

    Returns:
        jnp.ndarray: ``(N, 3)`` unit vectors pointing N→H.
    """
    n_atoms = coords[0::3]  # (N_res, 3)
    ca_atoms = coords[1::3]  # (N_res, 3)
    c_atoms = coords[2::3]  # (N_res, 3)

    # Unit vector N→CA
    n_to_ca = ca_atoms - n_atoms
    n_to_ca = n_to_ca / jnp.maximum(jnp.linalg.norm(n_to_ca, axis=-1, keepdims=True), 1e-8)

    # Unit vector N→C(i-1):  use C(i-1) = c_atoms[i-1], with dummy for i=0
    c_prev = jnp.concatenate([c_atoms[:1], c_atoms[:-1]], axis=0)
    n_to_cprev = c_prev - n_atoms
    n_to_cprev = n_to_cprev / jnp.maximum(jnp.linalg.norm(n_to_cprev, axis=-1, keepdims=True), 1e-8)

    # Bisector of the two bonds emanating from N; H is anti-parallel
    bisector = n_to_ca + n_to_cprev
    bisector = bisector / jnp.maximum(jnp.linalg.norm(bisector, axis=-1, keepdims=True), 1e-8)
    nh = -bisector

    # Residue 0 fallback: no C(i-1) available, use −(N→CA)
    nh = jnp.concatenate([-n_to_ca[:1], nh[1:]], axis=0)
    return nh  # (N_res, 3), already unit vectors

Ring Current Shifts

diff_biophys.nmr.ring_currents implements the Johnson-Bovey model of aromatic ring current shielding. Protons directly above the ring plane are shielded (upfield shift, negative Δδ); protons in the plane are deshielded (downfield shift, positive Δδ).

The shift falls off as \(\sim 1/r^3\) with distance from the ring centre.

from diff_biophys.nmr.ring_currents import calculate_ring_current_shift
import jax, jax.numpy as jnp

# Proton position relative to ring centre
proton_pos = jnp.array([0.0, 0.0, 3.5])   # 3.5 Å above ring
ring_pos   = jnp.array([0.0, 0.0, 0.0])
ring_normal = jnp.array([0.0, 0.0, 1.0])   # ring in xy-plane

delta = calculate_ring_current_shift(proton_pos, ring_pos, ring_normal)
# delta < 0  (shielded above the ring)

grad = jax.grad(calculate_ring_current_shift)(proton_pos, ring_pos, ring_normal)
# tells you: move the proton in which direction to maximise shielding

calculate_ring_current_shift(coords, ring_center, ring_normal, intensity)

Calculate chemical shift changes due to aromatic ring currents using the Johnson-Bovey dipolar approximation.

delta = intensity * (1 - 3*cos^2(theta)) / r^3

Parameters:

Name Type Description Default
coords ndarray

(N, 3) coordinates of the nuclei being shielded.

required
ring_center ndarray

(3,) coordinates of the aromatic ring center.

required
ring_normal ndarray

(3,) unit vector normal to the ring plane.

required
intensity float

Scaling factor (proportional to ring area and current).

required

Returns:

Type Description
ndarray

jnp.ndarray: (N,) shielding values in ppm.

Source code in diff_biophys/nmr/ring_currents.py
@jit
def calculate_ring_current_shift(
    coords: jnp.ndarray, ring_center: jnp.ndarray, ring_normal: jnp.ndarray, intensity: float
) -> jnp.ndarray:
    """
    Calculate chemical shift changes due to aromatic ring currents using
    the Johnson-Bovey dipolar approximation.

    delta = intensity * (1 - 3*cos^2(theta)) / r^3

    Args:
        coords: (N, 3) coordinates of the nuclei being shielded.
        ring_center: (3,) coordinates of the aromatic ring center.
        ring_normal: (3,) unit vector normal to the ring plane.
        intensity: Scaling factor (proportional to ring area and current).

    Returns:
        jnp.ndarray: (N,) shielding values in ppm.
    """
    # 1. Displacement vectors from ring center
    r_vec = coords - ring_center

    # 2. Distances
    r = jnp.linalg.norm(r_vec, axis=-1)

    # 3. cos(theta) where theta is the angle between r_vec and the ring normal
    # cos(theta) = (r_vec . normal) / (|r_vec| * |normal|)
    # Assume ring_normal is already a unit vector
    cos_theta = jnp.sum(r_vec * ring_normal, axis=-1) / (r + 1e-10)

    # 4. Johnson-Bovey geometric term
    # delta = intensity * (1 - 3 * cos^2(theta)) / r^3
    return cast(jnp.ndarray, intensity * (1.0 - 3.0 * cos_theta**2) / (r**3 + 1e-10))

Fixed-Tensor RDC Loss

FixedTensorRDCLoss

diff_integrator.terms.nmr.FixedTensorRDCLoss

A LossTerm that computes the RDC loss while keeping the Saupe alignment tensor frozen during backpropagation via jax.lax.stop_gradient. The tensor is re-fitted from current coordinates every update_interval epochs, preventing the degeneracy exploit where gradient descent drives Q→0 unphysically by distorting the tensor rather than the structure.

Constructor

FixedTensorRDCLoss(
    loss_fn,
    tensor_fn,
    update_interval: int = 50,
    n_rdcs: int | None = None,
    val_q_eval_fn: Callable | None = None,
)
Parameter Type Description
loss_fn Callable Training-set RDC loss function (bond_vecs, S) → scalar returned by make_rdc_cv_refinement_fns
tensor_fn Callable Tensor-fitting function (bond_vecs, rdcs) → S returned by make_rdc_cv_refinement_fns
update_interval int Re-fit the Saupe tensor every this many epochs (default 50)
n_rdcs int \| None Number of training RDCs; used by suggested_weight()
val_q_eval_fn Callable \| None Held-out Q-factor evaluator returned by make_rdc_cv_refinement_fns; enables evaluate_validation_q()

Methods

maybe_update_tensor(coords, epoch)

Re-fits the Saupe tensor from the current Cartesian coordinates if epoch % update_interval == 0. Should be called at the top of each training step before the gradient computation.

rdc_term.maybe_update_tensor(coords, epoch=epoch)

suggested_weight(base_weight=1.0)

Returns a weight scaled by the overdetermination ratio relative to an ideal 10× ratio:

\[w = \text{base\_weight} \times \frac{n\_rdcs / 10}{5}\]

Use this to auto-scale the RDC term so that systems with fewer RDCs are not over-penalised.

rdc_weight = rdc_term.suggested_weight(base_weight=1.0)

evaluate_validation_q(coords)

Returns the Q-factor on the held-out cross-validation split, or None if val_q_eval_fn was not provided.

val_q = rdc_term.evaluate_validation_q(coords)

make_rdc_cv_refinement_fns

diff_integrator.terms.nmr.make_rdc_cv_refinement_fns

Factory function that partitions experimental RDCs into a training set and a held-out cross-validation set, then returns pre-built closures ready for FixedTensorRDCLoss.

Signature

make_rdc_cv_refinement_fns(
    rdc_res_ids: ArrayLike,
    exp_rdcs: ArrayLike,
    struct_res_ids: ArrayLike,
    cv_fraction: float = 0.2,
) -> tuple[Callable, Callable, Callable, Callable, int, int]
Parameter Type Description
rdc_res_ids ArrayLike Residue IDs corresponding to each experimental RDC value
exp_rdcs ArrayLike Experimental RDC values (Hz)
struct_res_ids ArrayLike Residue IDs present in the structure (used for index mapping)
cv_fraction float Fraction of RDCs to hold out for cross-validation (default 0.2)

Returns (loss_fn, q_eval_fn, tensor_fn, val_q_fn, n_train, n_val)

Return value Description
loss_fn Training-set loss closure (bond_vecs, S) → scalar
q_eval_fn Training-set Q-factor evaluator (bond_vecs, S) → float
tensor_fn Tensor-fitting closure (bond_vecs) → S
val_q_fn Validation Q-factor evaluator (bond_vecs, S) → float
n_train Number of training RDCs
n_val Number of held-out validation RDCs

EarlyStopping

EarlyStopping

diff_integrator.optimizer.EarlyStopping

A dataclass passed as early_stopping= to IntegrativeRefiner.run() that halts training when a monitored loss term stops improving.

Fields

Field Type Default Description
term_index int Index into JointLoss.terms of the term to monitor (unweighted value)
patience int Number of epochs with no improvement before stopping
min_delta float 1e-5 Minimum change in monitored value to count as an improvement
mode str "min" "min" for loss-type metrics (lower is better); "max" for score-type metrics (higher is better)

Behaviour

  • The unweighted value of the monitored term (not the contribution to total loss) is tracked.
  • When no improvement exceeding min_delta occurs for patience consecutive epochs, training stops and the best-checkpoint parameters are returned.
  • The stopping event is recorded in RefinementResult.stopped_early and RefinementResult.early_stopping_triggered_by.

Example

from diff_integrator.optimizer import EarlyStopping, IntegrativeRefiner

refiner = IntegrativeRefiner(loss_fn=joint_loss)
result = refiner.run(
    init_params=starting_coords,
    epochs=2000,
    learning_rate=0.005,
    early_stopping=EarlyStopping(
        term_index=1,      # monitor RDC term (index 1 in JointLoss)
        patience=50,
        min_delta=1e-4,
        mode="min",
    ),
)
print(f"Stopped at epoch {result.best_epoch} / 2000")
print(f"Triggered by term {result.early_stopping_triggered_by}")

NOE Distance Restraints

NOELoss

diff_integrator.terms.noe.NOELoss

A LossTerm implementing the standard flat-bottomed harmonic NOE (Nuclear Overhauser Effect) distance restraint used in XPLOR, CNS, and ARIA. The energy is zero when a distance is within bounds and grows quadratically once a bound is violated:

\[E(d) = \frac{k}{M} \sum_{m=1}^{M} \left[ \max(0,\, d_m - d_m^{\text{upper}})^2 + \max(0,\, d_m^{\text{lower}} - d_m)^2 \right]\]

where \(M\) is the number of restraints and \(k\) is force_const. A mean (not sum) is used so the weight has a consistent interpretation across datasets of different sizes.

Constructor

NOELoss(
    atom_pairs:  jnp.ndarray,             # (M, 2) integer atom index pairs
    d_upper:     jnp.ndarray,             # (M,)  upper-bound distances in Å
    d_lower:     jnp.ndarray | None = None,  # (M,) optional lower bounds
    force_const: float = 1.0,
)
Parameter Type Description
atom_pairs (M, 2) int Each row [i, j] defines one restraint between atoms i and j
d_upper (M,) float Upper-bound distances in Å — penalty fires when d > d_upper
d_lower (M,) float \| None Lower-bound distances in Å — penalty fires when d < d_lower. None (default) means upper-bound only, which is the standard NMR convention
force_const float Harmonic force constant. Default 1.0. Typical values: 550

name attribute: "noe"

Methods

count_violations(coords) → dict[str, int]

Returns {"upper": n, "lower": n, "total": n} — count of atoms violating each bound. Pure diagnostic; not used in the gradient.

rms_violation(coords) → float

Root-mean-square distance violation across all restraints, in Å.

Property

n_restraints → int

Number of distance restraints.

Example

from diff_integrator.terms.noe import NOELoss

noe_loss = NOELoss(
    atom_pairs  = jnp.array([[4, 31], [7, 52]]),   # Cα pairs
    d_upper     = jnp.array([6.0, 4.5]),            # Å
    d_lower     = jnp.array([1.8, 1.8]),            # Å (optional)
    force_const = 10.0,
)

make_noe_restraints

diff_integrator.terms.noe.make_noe_restraints

Factory that maps (res_id, atom_name) observations to flat atom indices using the structure's residue ordering, then returns a ready-to-use NOELoss.

Signature

make_noe_restraints(
    noe_list:   list[dict],
    res_ids:    np.ndarray,
    atom_names: list[str] | None = None,   # default ["N", "CA", "C"]
    force_const: float = 1.0,
) -> NOELoss

Each dict in noe_list must contain:

Key Type Description
"res_i" int Residue number of atom i
"atom_i" str Atom name of atom i (e.g. "CA")
"res_j" int Residue number of atom j
"atom_j" str Atom name of atom j
"d_upper" float Upper-bound distance in Å
"d_lower" float (optional) Lower-bound distance in Å

Example

from diff_integrator.terms.noe import make_noe_restraints
import numpy as np

noe_observations = [
    {"res_i":  5, "atom_i": "CA", "res_j": 20, "atom_j": "CA",
     "d_upper": 6.0, "d_lower": 1.8},
    {"res_i": 12, "atom_i": "N",  "res_j": 45, "atom_j": "CA",
     "d_upper": 5.5},
]

noe_term = make_noe_restraints(noe_observations, struct_res_ids)
violations = noe_term.count_violations(coords)
print(f"{violations['total']} NOE violations at current coordinates")