Skip to content

🔬 SAXS API

The diff_biophys.saxs subpackage implements differentiable small-angle X-ray scattering kernels. The core function debye_saxs computes the full \(O(N^2)\) pairwise Debye sum, GPU-accelerated via JAX vmap, with an optional excluded-volume hydration shell correction.


Debye Scattering

The Debye formula computes solution-state X-ray scattering intensity from atomic coordinates and form factors:

\[I(q) = \sum_i \sum_j f_i(q)\, f_j(q)\, \frac{\sin(q r_{ij})}{q r_{ij}}\]

where \(q\) is the momentum transfer (Å⁻¹), \(r_{ij}\) is the pairwise inter-atomic distance, and \(f_i(q)\) are atomic form factors.

Hydration shell correction (Fraser et al. 1978): a solvent layer surrounding the protein contributes excess scattering density. The correction subtracts a bulk-solvent term scaled by the excluded volume of each atom.

Gradient meaning: \(\partial I(q) / \partial \mathbf{r}_i\) reveals which atom, if displaced, would most change the scattering intensity at momentum transfer \(q\). At low \(q\), this is dominated by the overall shape (Rg); at high \(q\), by inter-atomic distances.

from diff_biophys.saxs.kernels import debye_saxs
import jax, jax.numpy as jnp

# coords: (N, 3) atomic positions in Å
# q_vals: (M,) momentum transfer grid in Å⁻¹
# form_factors: (N, M) or (N,) atomic form factors

coords       = jnp.array(...)          # (N, 3)
q_vals       = jnp.linspace(0.01, 0.5, 100)
form_factors = jnp.ones(len(coords))   # uniform (simplified)

I_q = debye_saxs(coords, q_vals, form_factors)    # (M,) in a.u.

# Chi-squared loss vs experimental profile
def saxs_chi2(c, I_exp, sigma=1.0):
    I_calc = debye_saxs(c, q_vals, form_factors)
    return jnp.mean(((I_calc - I_exp) / sigma) ** 2)

grad_coords = jax.grad(saxs_chi2)(coords, I_exp)

# JIT for speed
saxs_jit = jax.jit(debye_saxs)

Multi-structure ensemble averaging

Use jax.vmap to evaluate the Debye sum over an ensemble of conformers and optimise population weights:

import jax

# ensemble: (K, N, 3) — K conformers
ensemble    = jnp.array(...)
weights     = jax.nn.softmax(jnp.zeros(K))   # uniform start

batch_saxs  = jax.vmap(lambda c: debye_saxs(c, q_vals, form_factors))
I_ensemble  = jnp.einsum("k,km->m", weights, batch_saxs(ensemble))

def ensemble_loss(w):
    I = jnp.einsum("k,km->m", jax.nn.softmax(w), batch_saxs(ensemble))
    return jnp.mean((I - I_exp) ** 2)

grad_w = jax.grad(ensemble_loss)(jnp.zeros(K))

Guinier analysis (Rg from low-q slope)

At low \(q\), \(\ln I(q) \approx \ln I_0 - q^2 R_g^2 / 3\). Fit a line to \(\ln I\) vs \(q^2\) to extract \(R_g\):

from diff_biophys.geometry.macroscopic import compute_rg

# Direct differentiable Rg (faster than Guinier fitting)
rg = compute_rg(coords)

# Rg restraint loss
rg_target = 15.0  # Å from experiment
rg_loss = (compute_rg(coords) - rg_target) ** 2

debye_saxs(coords, q_values, form_factors, volumes=None, solvent_density=0.334)

Differentiable Debye Formula in JAX with optional solvent subtraction.

Note: This function is NOT decorated with @jit because the volumes argument may be None (a Python sentinel that is resolved at trace time, not at runtime). JIT-compile the call site instead, e.g.::

jitted_debye = jax.jit(lambda c: debye_saxs(c, q, ff, volumes=vols))

Parameters:

Name Type Description Default
coords ndarray

(N, 3) atomic coordinates in Ångströms.

required
q_values ndarray

(M,) scattering vector magnitudes (Å⁻¹).

required
form_factors ndarray

(N, M) q-dependent vacuum atomic form factors.

required
volumes ndarray | None

(N,) atomic volumes (ų) for excluded-volume correction. Pass None (default) to skip solvent subtraction.

None
solvent_density float

Bulk solvent electron density (e/ų). Default 0.334 e/ų for water.

0.334

Returns:

Type Description
ndarray

jnp.ndarray: Scattering intensities I(q), shape (M,).

Source code in diff_biophys/saxs/kernels.py
def debye_saxs(
    coords: jnp.ndarray,
    q_values: jnp.ndarray,
    form_factors: jnp.ndarray,
    volumes: jnp.ndarray | None = None,
    solvent_density: float = 0.334,
) -> jnp.ndarray:
    """
    Differentiable Debye Formula in JAX with optional solvent subtraction.

    Note: This function is NOT decorated with ``@jit`` because the
    ``volumes`` argument may be ``None`` (a Python sentinel that is
    resolved at trace time, not at runtime).  JIT-compile the *call site*
    instead, e.g.::

        jitted_debye = jax.jit(lambda c: debye_saxs(c, q, ff, volumes=vols))

    Args:
        coords: (N, 3) atomic coordinates in Ångströms.
        q_values: (M,) scattering vector magnitudes (Å⁻¹).
        form_factors: (N, M) q-dependent vacuum atomic form factors.
        volumes: (N,) atomic volumes (ų) for excluded-volume correction.
            Pass ``None`` (default) to skip solvent subtraction.
        solvent_density: Bulk solvent electron density (e/ų).
            Default 0.334 e/ų for water.

    Returns:
        jnp.ndarray: Scattering intensities I(q), shape (M,).
    """
    # 1. Pairwise distances (N, N)
    sq_norms = jnp.sum(coords**2, axis=-1)
    dist_sq = sq_norms[:, None] + sq_norms[None, :] - 2 * jnp.dot(coords, coords.T)
    dist = jnp.sqrt(jnp.maximum(dist_sq, 0.0) + 1e-12)

    # 2. Effective form factors with optional solvent correction.
    # When volumes=None we use a zero-volume array so the code path is
    # identical (JIT-safe) and the correction term vanishes.
    if volumes is None:
        f_eff = form_factors
    else:
        # Effective radius for excluded volume: V = (4/3) π R³  →  R = (3V/4π)^(1/3)
        r_eff = (3.0 * volumes / (4.0 * jnp.pi)) ** (1.0 / 3.0)

        # Gaussian decay for the excluded-volume envelope (Fraser et al. 1978)
        # f_eff(q) = f_vac(q) - ρ_sol · V · exp(−(q·r_eff)² / (4π))
        decay = jnp.exp(-((q_values[None, :] * r_eff[:, None]) ** 2) / (4.0 * jnp.pi))
        f_eff = form_factors - (solvent_density * volumes[:, None] * decay)

    # 3. Debye sum: I(q) = Σ_i Σ_j f_i(q) f_j(q) sinc(q r_ij)
    def compute_intensity(q_idx: Any) -> Any:
        q = q_values[q_idx]

        f_q = f_eff[:, q_idx]
        f_prod = f_q[:, None] * f_q[None, :]
        qr = q * dist
        # Taylor expansion for qr→0; standard formula elsewhere.
        # The epsilon is in the denominator only, *not* inside sin(), to
        # avoid introducing a phase error at large qr.
        sinc_qr = jnp.where(
            qr < 1e-4,
            1.0 - (qr**2) / 6.0,
            jnp.sin(qr) / (qr + 1e-10),
        )
        return jnp.sum(f_prod * sinc_qr)

    return cast(jnp.ndarray, vmap(compute_intensity)(jnp.arange(len(q_values))))