Skip to content

⚛️ NMR API

The diff_biophys.nmr subpackage implements four categories of differentiable NMR observables. Each kernel accepts JAX arrays and returns JAX arrays, making them trivially composable into multi-observable loss functions.


Karplus J-coupling

diff_biophys.nmr.karplus implements the Karplus equation relating a three-bond scalar coupling constant \(^3J\) (Hz) to a backbone dihedral angle:

\[J(\theta) = A\cos^2(\theta) + B\cos(\theta) + C\]

For the backbone HN–Hα coupling, the dihedral \(\theta = \phi - 60°\) (Vuister & Bax 1993 offset convention). Default parameters: \(A = 6.98\), \(B = -1.38\), \(C = 1.72\) Hz.

Gradient meaning: \(\partial J / \partial \theta\) tells you how fast the coupling changes with the backbone angle — the key quantity for J-coupling-driven structure refinement.

from diff_biophys.nmr.karplus import calculate_karplus_j
import jax, jax.numpy as jnp

phi = jnp.array([-0.995])          # φ in radians (≈ −57° helix)
theta = phi - jnp.deg2rad(60.0)   # offset for HN-Hα

J = calculate_karplus_j(theta, A=6.98, B=-1.38, C=1.72)   # → [Hz]
dJ_dtheta = jax.grad(lambda t: jnp.sum(calculate_karplus_j(t, 6.98, -1.38, 1.72)))(theta)

calculate_karplus_j(theta, a, b, c)

Calculate 3J coupling constants using the Karplus equation.

J = a * cos^2(theta) + b * cos(theta) + c

.. important:: theta is not the same as the raw backbone dihedral phi. For ³J(H_N, H_α) couplings, the Karplus dihedral is offset from the IUPAC backbone angle by 60°::

    theta = phi - 60°   (i.e. subtract 60° from phi before calling)

Passing raw ``phi`` directly produces errors of ~3–4 Hz, invalidating
the α-helix / β-sheet discrimination.

The default parameters (A=6.51, B=−1.76, C=1.60) are for H_N–H_α
(Vuister & Bax 1993).  Different spin pairs require different offsets
and parameters.

Parameters:

Name Type Description Default
theta ndarray

(N,) Karplus dihedral angle in radians. For ³J(HN,HA): pass (phi − 60°), not phi.

required
a float

Cosine-squared coefficient (Hz).

required
b float

Cosine coefficient (Hz).

required
c float

Constant offset (Hz).

required

Returns:

Type Description
ndarray

jnp.ndarray: (N,) Calculated J-couplings.

Source code in diff_biophys/nmr/karplus.py
@jit
def calculate_karplus_j(theta: jnp.ndarray, a: float, b: float, c: float) -> jnp.ndarray:
    """
    Calculate 3J coupling constants using the Karplus equation.

    J = a * cos^2(theta) + b * cos(theta) + c

    .. important::
        ``theta`` is **not** the same as the raw backbone dihedral ``phi``.
        For ³J(H_N, H_α) couplings, the Karplus dihedral is offset from
        the IUPAC backbone angle by 60°::

            theta = phi - 60°   (i.e. subtract 60° from phi before calling)

        Passing raw ``phi`` directly produces errors of ~3–4 Hz, invalidating
        the α-helix / β-sheet discrimination.

        The default parameters (A=6.51, B=−1.76, C=1.60) are for H_N–H_α
        (Vuister & Bax 1993).  Different spin pairs require different offsets
        and parameters.

    Args:
        theta: (N,) Karplus dihedral angle in radians.
               For ³J(HN,HA): pass (phi − 60°), **not** phi.
        a: Cosine-squared coefficient (Hz).
        b: Cosine coefficient (Hz).
        c: Constant offset (Hz).

    Returns:
        jnp.ndarray: (N,) Calculated J-couplings.
    """
    cos_theta = jnp.cos(theta)
    return a * (cos_theta**2) + b * cos_theta + c

Cα Chemical Shifts

diff_biophys.nmr.chemical_shifts predicts Cα chemical shifts (ppm) from backbone torsion angles using a softmax-weighted Gaussian secondary-structure detector in (φ, ψ) space.

Secondary chemical shifts (relative to random-coil):

Structure Δδ(Cα)
α-helix +3.1 ppm
β-sheet −1.5 ppm
Random coil 0 ppm

The library provides RANDOM_COIL_CA — a dictionary of per-residue random-coil reference values.

from diff_biophys.nmr.chemical_shifts import predict_ca_shifts, RANDOM_COIL_CA
import jax, jax.numpy as jnp

n_res = 10
phi = jnp.full((n_res,), jnp.deg2rad(-57.0))   # α-helix
psi = jnp.full((n_res,), jnp.deg2rad(-47.0))
rc  = jnp.full((n_res,), RANDOM_COIL_CA["ALA"])

shifts = predict_ca_shifts(phi, psi, rc)   # (n_res,) in ppm

# Refine φ to match target shifts
target = shifts + 0.5   # perturb target
loss = lambda p: jnp.mean((predict_ca_shifts(p, psi, rc) - target) ** 2)
grad_phi = jax.grad(loss)(phi)

predict_ca_shifts(phi, psi, rc_shifts)

Differentiable Cα chemical shift prediction based on backbone torsions.

Uses Gaussian "soft detectors" in (Φ, Ψ) space to classify secondary structure and applies SPARTA-like offsets. The detectors are normalised via a softmax so that helix, sheet, and coil contributions always sum to 1.0, preventing unphysical double-counting.

Reference centres (radians): * α-helix: Φ = −1.05 rad (−60°), Ψ = −0.78 rad (−45°) * β-sheet: Φ = −2.09 rad (−120°), Ψ = +2.35 rad (+135°) * Random coil: treated as the baseline (weight = 1 − w_helix − w_sheet)

Parameters:

Name Type Description Default
phi ndarray

(N,) backbone Φ angles in radians.

required
psi ndarray

(N,) backbone Ψ angles in radians.

required
rc_shifts ndarray

(N,) baseline random-coil Cα shifts (ppm).

required

Returns:

Type Description
ndarray

jnp.ndarray: (N,) predicted Cα chemical shifts (ppm).

Source code in diff_biophys/nmr/chemical_shifts.py
@jit
def predict_ca_shifts(phi: jnp.ndarray, psi: jnp.ndarray, rc_shifts: jnp.ndarray) -> jnp.ndarray:
    """
    Differentiable Cα chemical shift prediction based on backbone torsions.

    Uses Gaussian "soft detectors" in (Φ, Ψ) space to classify secondary
    structure and applies SPARTA-like offsets.  The detectors are normalised
    via a softmax so that helix, sheet, and coil contributions always sum to
    1.0, preventing unphysical double-counting.

    Reference centres (radians):
        * α-helix: Φ = −1.05 rad (−60°), Ψ = −0.78 rad (−45°)
        * β-sheet:  Φ = −2.09 rad (−120°), Ψ = +2.35 rad (+135°)
        * Random coil: treated as the baseline (weight = 1 − w_helix − w_sheet)

    Args:
        phi: (N,) backbone Φ angles in radians.
        psi: (N,) backbone Ψ angles in radians.
        rc_shifts: (N,) baseline random-coil Cα shifts (ppm).

    Returns:
        jnp.ndarray: (N,) predicted Cα chemical shifts (ppm).
    """
    # --- Unnormalised Gaussian affinities ---
    # Alpha-helix centre: Φ ~ −60°, Ψ ~ −45°
    helix_dist_sq = (phi + 1.05) ** 2 + (psi + 0.78) ** 2
    w_helix_raw = jnp.exp(-helix_dist_sq / _SS_SIGMA_SQ)

    # Beta-sheet centre: Φ ~ −120°, Ψ ~ +135°
    sheet_dist_sq = (phi + 2.09) ** 2 + (psi - 2.35) ** 2
    w_sheet_raw = jnp.exp(-sheet_dist_sq / _SS_SIGMA_SQ)

    # Coil baseline: all residues start with weight 1 (i.e. the neutral state)
    w_coil_raw = jnp.ones_like(phi)

    # --- Softmax normalisation ---
    # Ensures w_helix + w_sheet + w_coil = 1 for every residue,
    # preventing simultaneous helix + sheet double-counting.
    total = w_helix_raw + w_sheet_raw + w_coil_raw
    w_helix = w_helix_raw / total
    w_sheet = w_sheet_raw / total
    # w_coil = w_coil_raw / total  (implicit; contributes zero offset)

    # NOTE — effective offset cap (Issue 3):
    # Because w_coil_raw is fixed at 1.0, at the *exact* helix or sheet centre
    # (where the Gaussian peak = 1 and the opposing class ≈ 0) the denominator
    # is 1 + 0 + 1 = 2, so w_helix_norm ≈ 0.5.  This means the maximum Cα
    # shift a perfectly helical residue receives is:
    #
    #   0.5 × OFFSET_HELIX  ≈  +1.55 ppm  (not the full +3.1 ppm)
    #
    # This is a deliberate approximation: the coil baseline acts as a Bayesian
    # prior that prevents runaway shifts.  It also means the predictor underestimates
    # pure-helix / pure-sheet shifts by ~50% relative to SPARTA+.  Users who need
    # quantitative SPARTA+ parity should either:
    #   (a) reduce _SS_SIGMA_SQ (sharper Gaussians → w_helix closer to 1), or
    #   (b) double OFFSET_HELIX / OFFSET_SHEET to compensate.

    # --- Weighted offset ---
    # Coil weight contributes 0 offset (it is the RC baseline).
    return rc_shifts + (w_helix * OFFSET_HELIX) + (w_sheet * OFFSET_SHEET)

Residual Dipolar Couplings (RDCs)

diff_biophys.nmr.rdc computes Residual Dipolar Couplings (Hz) from bond vectors and a Saupe alignment tensor:

\[D_{NH} = D_{\max} \sum_{i,j} v_i S_{ij} v_j\]

where \(\mathbf{v}\) is the unit bond vector and \(\mathbf{S}\) is the Saupe tensor (3×3 symmetric traceless matrix). At the magic angle (\(\theta \approx 54.74°\) from the alignment axis), \(D = 0\).

Also provides:

  • fit_saupe_tensor — SVD-based least-squares fitting of S from observed RDCs
  • calculate_rdc — axially-symmetric simplified form (\(D \propto 3\cos^2\theta - 1\))
  • calculate_q_factor — R-factor for RDC quality assessment
from diff_biophys.nmr.rdc import calculate_rdc_from_tensor, fit_saupe_tensor
import jax.numpy as jnp

# N–H bond unit vectors  (n_res, 3)
bond_vecs = jnp.array(...)

# Saupe alignment tensor (3, 3) — axially symmetric
S = jnp.diag(jnp.array([-0.05, -0.05, 0.10]))

rdcs = calculate_rdc_from_tensor(bond_vecs, S, d_max=21585.0)   # Hz

# Fit tensor from experimental RDCs
S_fit = fit_saupe_tensor(bond_vecs, rdcs_experimental)

calculate_q_factor(calculated_rdcs, experimental_rdcs)

Calculate the RDC Q-factor (Cornilescu et al., 1998). Q = sqrt( sum((D_calc - D_exp)^2) / sum(D_exp^2) )

Returns 0.0 when all experimental RDCs are zero (perfect trivial match).

Parameters:

Name Type Description Default
calculated_rdcs ndarray

(N,) calculated couplings.

required
experimental_rdcs ndarray

(N,) measured couplings.

required

Returns:

Type Description
ndarray

jnp.ndarray: Scalar Q-factor.

Source code in diff_biophys/nmr/rdc.py
@jit
def calculate_q_factor(calculated_rdcs: jnp.ndarray, experimental_rdcs: jnp.ndarray) -> jnp.ndarray:
    """
    Calculate the RDC Q-factor (Cornilescu et al., 1998).
    Q = sqrt( sum((D_calc - D_exp)^2) / sum(D_exp^2) )

    Returns 0.0 when all experimental RDCs are zero (perfect trivial match).

    Args:
        calculated_rdcs: (N,) calculated couplings.
        experimental_rdcs: (N,) measured couplings.

    Returns:
        jnp.ndarray: Scalar Q-factor.
    """
    diff_sq = jnp.sum((calculated_rdcs - experimental_rdcs) ** 2)
    exp_sq = jnp.sum(experimental_rdcs**2)

    # Robust Q-factor calculation to avoid NaN gradients at zero experimental RDCs.
    # We use a safe denominator for the division and then mask the result.
    q = jnp.sqrt(diff_sq / jnp.maximum(exp_sq, 1e-10))
    return jnp.where(exp_sq > 0.0, q, 0.0)

calculate_rdc(bond_vectors, da, r)

Differentiable RDC calculation in the principal axis frame (PAF).

.. important:: bond_vectors must be expressed in the principal axis frame of the alignment tensor, i.e. the frame where the Saupe tensor is diagonal. Passing lab-frame vectors will give incorrect results without any error.

The formula used is the standard Clore/Bax convention (Clore et al. 1998, J. Magn. Reson. 133, 216–221)::

D = Da · [(3 cos²θ − 1) + (3/2) R sin²θ cos 2φ]
  = Da · [(3z² − 1) + (3/2) R (x² − y²)]

where Da is the axial component and R = (Axx − Ayy) / Azz is the rhombicity (0 ≤ R ≤ 2/3).

Parameters:

Name Type Description Default
bond_vectors ndarray

(N, 3) unit vectors in the tensor's principal axis frame.

required
da float

Axial component Da in Hz.

required
r float

Rhombicity R (0 ≤ R ≤ 2/3).

required

Returns:

Type Description
ndarray

jnp.ndarray: (N,) Calculated RDCs in Hz.

Source code in diff_biophys/nmr/rdc.py
@jit
def calculate_rdc(bond_vectors: jnp.ndarray, da: float, r: float) -> jnp.ndarray:
    """
    Differentiable RDC calculation in the principal axis frame (PAF).

    .. important::
        ``bond_vectors`` **must be expressed in the principal axis frame**
        of the alignment tensor, i.e. the frame where the Saupe tensor is
        diagonal.  Passing lab-frame vectors will give incorrect results
        without any error.

    The formula used is the standard Clore/Bax convention
    (Clore et al. 1998, *J. Magn. Reson.* **133**, 216–221)::

        D = Da · [(3 cos²θ − 1) + (3/2) R sin²θ cos 2φ]
          = Da · [(3z² − 1) + (3/2) R (x² − y²)]

    where ``Da`` is the axial component and ``R = (Axx − Ayy) / Azz`` is
    the rhombicity (0 ≤ R ≤ 2/3).

    Args:
        bond_vectors: (N, 3) unit vectors in the tensor's principal axis frame.
        da: Axial component Da in Hz.
        r: Rhombicity R (0 ≤ R ≤ 2/3).

    Returns:
        jnp.ndarray: (N,) Calculated RDCs in Hz.
    """
    x, y, z = bond_vectors[:, 0], bond_vectors[:, 1], bond_vectors[:, 2]

    axial = 3.0 * z**2 - 1.0
    rhombic = 1.5 * r * (x**2 - y**2)

    return da * (axial + rhombic)

calculate_rdc_from_tensor(bond_vectors, saupe_tensor, d_max=1.0)

Calculate RDCs from a full 3x3 Saupe alignment tensor. D = d_max * sum_ij (v_i * S_ij * v_j)

Parameters:

Name Type Description Default
bond_vectors ndarray

(N, 3) unit vectors

required
saupe_tensor ndarray

(3, 3) symmetric traceless Saupe tensor

required
d_max float

Maximum dipolar coupling constant (Hz)

1.0

Returns:

Type Description
ndarray

jnp.ndarray: Calculated RDCs (N,)

Source code in diff_biophys/nmr/rdc.py
@jit
def calculate_rdc_from_tensor(
    bond_vectors: jnp.ndarray, saupe_tensor: jnp.ndarray, d_max: float = 1.0
) -> jnp.ndarray:
    """
    Calculate RDCs from a full 3x3 Saupe alignment tensor.
    D = d_max * sum_ij (v_i * S_ij * v_j)

    Args:
        bond_vectors: (N, 3) unit vectors
        saupe_tensor: (3, 3) symmetric traceless Saupe tensor
        d_max: Maximum dipolar coupling constant (Hz)

    Returns:
        jnp.ndarray: Calculated RDCs (N,)
    """
    # Vectorized computation of v^T S v
    return d_max * jnp.einsum("ni,ij,nj->n", bond_vectors, saupe_tensor, bond_vectors)

fit_saupe_tensor(bond_vectors, experimental_rdcs, d_max=1.0)

Fit a Saupe alignment tensor to experimental RDCs using SVD (least squares).

The RDC formula can be rewritten as D = A * s where s = [Sxx, Syy, Sxy, Sxz, Syz] (5 independent components)

Parameters:

Name Type Description Default
bond_vectors ndarray

(N, 3) unit vectors

required
experimental_rdcs ndarray

(N,) measured RDCs in Hz

required
d_max float

Maximum dipolar coupling constant (Hz)

1.0

Returns:

Type Description
ndarray

jnp.ndarray: (3, 3) Fitted Saupe tensor

Source code in diff_biophys/nmr/rdc.py
@jit
def fit_saupe_tensor(
    bond_vectors: jnp.ndarray, experimental_rdcs: jnp.ndarray, d_max: float = 1.0
) -> jnp.ndarray:
    """
    Fit a Saupe alignment tensor to experimental RDCs using SVD (least squares).

    The RDC formula can be rewritten as D = A * s
    where s = [Sxx, Syy, Sxy, Sxz, Syz] (5 independent components)

    Args:
        bond_vectors: (N, 3) unit vectors
        experimental_rdcs: (N,) measured RDCs in Hz
        d_max: Maximum dipolar coupling constant (Hz)

    Returns:
        jnp.ndarray: (3, 3) Fitted Saupe tensor
    """
    x = bond_vectors[:, 0]
    y = bond_vectors[:, 1]
    z = bond_vectors[:, 2]

    # Basis functions for the 5 independent components
    # Using the identity Szz = -Sxx - Syy
    # D = d_max * [ Sxx*x^2 + Syy*y^2 + Szz*z^2 + 2Sxy*xy + 2Sxz*xz + 2Syz*yz ]
    # D = d_max * [ Sxx(x^2 - z^2) + Syy(y^2 - z^2) + 2Sxy*xy + 2Sxz*xz + 2Syz*yz ]

    A = d_max * jnp.stack([x**2 - z**2, y**2 - z**2, 2 * x * y, 2 * x * z, 2 * y * z], axis=1)

    # Solve A * s = experimental_rdcs
    s, _, _, _ = jnp.linalg.lstsq(A, experimental_rdcs)

    sxx, syy, sxy, sxz, syz = s
    szz = -(sxx + syy)

    tensor = jnp.array([[sxx, sxy, sxz], [sxy, syy, syz], [sxz, syz, szz]])

    return tensor

Ring Current Shifts

diff_biophys.nmr.ring_currents implements the Johnson-Bovey model of aromatic ring current shielding. Protons directly above the ring plane are shielded (upfield shift, negative Δδ); protons in the plane are deshielded (downfield shift, positive Δδ).

The shift falls off as \(\sim 1/r^3\) with distance from the ring centre.

from diff_biophys.nmr.ring_currents import calculate_ring_current_shift
import jax, jax.numpy as jnp

# Proton position relative to ring centre
proton_pos = jnp.array([0.0, 0.0, 3.5])   # 3.5 Å above ring
ring_pos   = jnp.array([0.0, 0.0, 0.0])
ring_normal = jnp.array([0.0, 0.0, 1.0])   # ring in xy-plane

delta = calculate_ring_current_shift(proton_pos, ring_pos, ring_normal)
# delta < 0  (shielded above the ring)

grad = jax.grad(calculate_ring_current_shift)(proton_pos, ring_pos, ring_normal)
# tells you: move the proton in which direction to maximise shielding

calculate_ring_current_shift(coords, ring_center, ring_normal, intensity)

Calculate chemical shift changes due to aromatic ring currents using the Johnson-Bovey dipolar approximation.

delta = intensity * (1 - 3*cos^2(theta)) / r^3

Parameters:

Name Type Description Default
coords ndarray

(N, 3) coordinates of the nuclei being shielded.

required
ring_center ndarray

(3,) coordinates of the aromatic ring center.

required
ring_normal ndarray

(3,) unit vector normal to the ring plane.

required
intensity float

Scaling factor (proportional to ring area and current).

required

Returns:

Type Description
ndarray

jnp.ndarray: (N,) shielding values in ppm.

Source code in diff_biophys/nmr/ring_currents.py
@jit
def calculate_ring_current_shift(
    coords: jnp.ndarray, ring_center: jnp.ndarray, ring_normal: jnp.ndarray, intensity: float
) -> jnp.ndarray:
    """
    Calculate chemical shift changes due to aromatic ring currents using
    the Johnson-Bovey dipolar approximation.

    delta = intensity * (1 - 3*cos^2(theta)) / r^3

    Args:
        coords: (N, 3) coordinates of the nuclei being shielded.
        ring_center: (3,) coordinates of the aromatic ring center.
        ring_normal: (3,) unit vector normal to the ring plane.
        intensity: Scaling factor (proportional to ring area and current).

    Returns:
        jnp.ndarray: (N,) shielding values in ppm.
    """
    # 1. Displacement vectors from ring center
    r_vec = coords - ring_center

    # 2. Distances
    r = jnp.linalg.norm(r_vec, axis=-1)

    # 3. cos(theta) where theta is the angle between r_vec and the ring normal
    # cos(theta) = (r_vec . normal) / (|r_vec| * |normal|)
    # Assume ring_normal is already a unit vector
    cos_theta = jnp.sum(r_vec * ring_normal, axis=-1) / (r + 1e-10)

    # 4. Johnson-Bovey geometric term
    # delta = intensity * (1 - 3 * cos^2(theta)) / r^3
    return cast(jnp.ndarray, intensity * (1.0 - 3.0 * cos_theta**2) / (r**3 + 1e-10))