Skip to content

💡 CD & Cryo-EM API


Circular Dichroism — CD Matrix Method

diff_biophys.cd.kernels implements the DeVoe coupled-oscillator (matrix) model for simulating circular dichroism spectra from atomic positions.

Each amide bond in the backbone acts as a chromophore with a transition dipole moment \(\boldsymbol{\mu}_i\). When two dipoles interact, their coupling splits the transition into symmetric and antisymmetric combinations with different energies. The rotational strength of each coupled transition determines the sign and magnitude of the CD signal at each wavelength.

The kernel computes:

\[\Delta\varepsilon(\lambda) = \varepsilon_L(\lambda) - \varepsilon_R(\lambda)\]

as a function of chromophore positions and transition dipole orientations.

Gradient meaning: \(\partial [\theta](\lambda) / \partial \mathbf{r}_i\) identifies which chromophore, if moved, most changes the CD signal at wavelength \(\lambda\). At 222 nm (the canonical helix marker), moving chromophores that have large coupling with their helical neighbours will have the largest gradient.

from diff_biophys.cd.kernels import simulate_cd_matrix
import jax, jax.numpy as jnp

# chromophore_positions:    (N, 3)  amide nitrogen positions (Å)
# dipole_orientations:      (N, 3)  unit transition-dipole vectors
# wavelengths:              (M,)    wavelengths in nm

wavelengths = jnp.linspace(180.0, 260.0, 81)

cd_spectrum = simulate_cd_matrix(
    chromophore_positions,
    dipole_orientations,
    wavelengths,
    f_osc=0.2,       # oscillator strength
    gamma=10.0,      # linewidth (nm)
    lambda_0=190.0,  # transition wavelength (nm)
)
# cd_spectrum: (M,) molar ellipticity [deg cm² dmol⁻¹]

# Gradient of [θ] at 222 nm w.r.t. chromophore positions
idx_222 = jnp.argmin(jnp.abs(wavelengths - 222.0))
grad_222 = jax.grad(
    lambda pos: simulate_cd_matrix(pos, dipole_orientations, wavelengths)[idx_222]
)(chromophore_positions)
# grad_222: (N, 3) — largest magnitude → most influential chromophore

Typical α-helix signature

Wavelength Sign Assignment
222 nm negative \(n \to \pi^*\) parallel
208 nm negative \(\pi \to \pi^*\) perpendicular
193 nm positive \(\pi \to \pi^*\) parallel

Building chromophore coordinates

See Notebook 03 · CD Spectroscopy for a complete example of building a helix from scratch and computing its CD spectrum.

simulate_cd_matrix(peptide_positions, dipole_orientations, wavelengths, f_osc=0.2, gamma=10.0, lambda_0=190.0)

Matrix-Method CD Simulation (DeVoe Theory).

Implements the coupled-oscillator model for transition dipole coupling. Calculates the interaction matrix and solves for the complex polarizability response to determine molar ellipticity.

Parameters:

Name Type Description Default
peptide_positions ndarray

(N, 3) positions of amide chromophores in Angstroms.

required
dipole_orientations ndarray

(N, 3) unit vectors for transition dipoles.

required
wavelengths ndarray

(M,) wavelengths in nm to simulate.

required
f_osc float

Oscillator strength of the transition (default 0.2 for pi->pi*).

0.2
gamma float

Linewidth parameter in nm (default 10.0).

10.0
lambda_0 float

Resonance wavelength in nm (default 190.0).

190.0

Returns:

Type Description
ndarray

Molar ellipticity [θ] in deg cm^2 / dmol (M,).

Source code in diff_biophys/cd/kernels.py
def simulate_cd_matrix(
    peptide_positions: jnp.ndarray,
    dipole_orientations: jnp.ndarray,
    wavelengths: jnp.ndarray,
    f_osc: float = 0.2,
    gamma: float = 10.0,
    lambda_0: float = 190.0,
) -> jnp.ndarray:
    """
    Matrix-Method CD Simulation (DeVoe Theory).

    Implements the coupled-oscillator model for transition dipole coupling.
    Calculates the interaction matrix and solves for the complex polarizability
    response to determine molar ellipticity.

    Args:
        peptide_positions: (N, 3) positions of amide chromophores in Angstroms.
        dipole_orientations: (N, 3) unit vectors for transition dipoles.
        wavelengths: (M,) wavelengths in nm to simulate.
        f_osc: Oscillator strength of the transition (default 0.2 for pi->pi*).
        gamma: Linewidth parameter in nm (default 10.0).
        lambda_0: Resonance wavelength in nm (default 190.0).

    Returns:
        Molar ellipticity [θ] in deg cm^2 / dmol (M,).
    """
    n_chromophores = peptide_positions.shape[0]

    # 1. Compute dipole-dipole interaction matrix V_ij
    # V_ij = (1/r^3) * [ mu_i . mu_j - 3(mu_i . r_ij)(mu_j . r_ij) ]
    diff = peptide_positions[:, None, :] - peptide_positions[None, :, :]
    dist_sq = jnp.sum(diff**2, axis=-1)

    # Safe distance for gradients (avoid sqrt(0) and 1/0)
    # 1e-9 is a safe epsilon for float32
    mask = dist_sq > 0
    safe_dist_sq = jnp.where(mask, dist_sq, 1.0)
    r_ij = jnp.sqrt(safe_dist_sq)
    r_ij_inv3 = jnp.where(mask, 1.0 / r_ij**3, 0.0)

    # Unit vectors between chromophores
    r_hat = diff * jnp.where(mask[:, :, None], 1.0 / r_ij[:, :, None], 0.0)

    # Dot products
    mu_i_mu_j = jnp.sum(dipole_orientations[:, None, :] * dipole_orientations[None, :, :], axis=-1)
    mu_i_r = jnp.sum(dipole_orientations[:, None, :] * r_hat, axis=-1)
    mu_j_r = jnp.sum(dipole_orientations[None, :, :] * r_hat, axis=-1)

    # Interaction energy V (N, N)
    V = r_ij_inv3 * (mu_i_mu_j - 3 * mu_i_r * mu_j_r)

    # 2. Frequency-dependent response
    def compute_at_wavelength(lmbda: jnp.ndarray) -> jnp.ndarray:
        # Complex polarizability alpha(lambda)
        # Lorentzian-like response
        denom = (1.0 / lmbda**2 - 1.0 / lambda_0**2) + 1j * (gamma / (lmbda * lambda_0**2))
        alpha = f_osc / denom

        # Interaction matrix (I - alpha * V)
        # alpha is scalar for all identical chromophores here
        M = jnp.eye(n_chromophores) - alpha * V

        # We'll use the matrix inverse to find the coupled response
        # Note: jnp.linalg.inv is differentiable but can be sensitive
        inv_M = jnp.linalg.inv(M)

        # Geometric factor for CD (Scalar triple product mu_i x mu_j . r_ij)
        # This represents the chiral arrangement.
        cross_mu = jnp.cross(dipole_orientations[:, None, :], dipole_orientations[None, :, :])
        R_ij = jnp.sum(cross_mu * diff, axis=-1)

        # Total CD response at this wavelength
        coupled_V = inv_M @ (alpha * V)
        cd_val = jnp.imag(jnp.sum(coupled_V * R_ij))

        return cd_val

    # Vectorize over wavelengths
    cd_spectrum = jax.vmap(compute_at_wavelength)(wavelengths)

    # Scale to molar ellipticity (arbitrary units for this kernel,
    # should be calibrated to exp data)
    return cd_spectrum * 1e5

Cryo-EM — Fourier Shell Correlation

diff_biophys.cryo_em implements the Fourier Shell Correlation (FSC), the standard figure-of-merit for cryo-EM reconstruction quality.

The FSC measures the normalised cross-correlation between two independently reconstructed half-maps as a function of spatial frequency:

\[\text{FSC}(\nu) = \frac{\sum_{\mathbf{k} \in \text{shell}} F_1(\mathbf{k})\, F_2^*(\mathbf{k})}{\sqrt{\sum |F_1|^2 \cdot \sum |F_2|^2}}\]

The gold-standard 0.143 threshold gives the resolution at which the two half-maps are no longer correlated — i.e., the spatial frequency up to which the reconstruction is reliable.

Gradient meaning: \(\partial \text{FSC}(\nu) / \partial \text{map}_1\) shows which voxels, if improved, would most increase the correlation at frequency \(\nu\). This can drive iterative map sharpening or density modification.

from diff_biophys.cryo_em import compute_fsc
import jax, jax.numpy as jnp

# map1, map2: (D, H, W) float32 real-space density maps
# voxel_size: (dz, dy, dx) in Å

frequencies, fsc_curve = compute_fsc(
    map1, map2,
    voxel_size=(1.0, 1.0, 1.0)
)
# frequencies: (n_shells,) in Å⁻¹
# fsc_curve:   (n_shells,) values in [−1, 1]

# Resolution at 0.143 threshold
resolution_mask = fsc_curve > 0.143
resolution_Å = 1.0 / float(frequencies[resolution_mask][-1])

# Gradient w.r.t. first half-map
def fsc_sum(m1):
    _, fsc = compute_fsc(m1, map2, voxel_size=(1.0, 1.0, 1.0))
    return jnp.sum(fsc)

grad_map1 = jax.grad(fsc_sum)(map1)

compute_fsc(data1, data2, voxel_size)

Compute the fully differentiable Fourier Shell Correlation (FSC) between two 3D maps using JAX. Returns frequencies and correlation values.

This function matches the implementation in synth-core, but uses jax.numpy so that gradients can flow through the FSC calculation to the input maps.

Source code in diff_biophys/cryo_em.py
@jax.jit
def compute_fsc(
    data1: jax.Array, data2: jax.Array, voxel_size: tuple[float, float, float]
) -> tuple[jax.Array, jax.Array]:
    """
    Compute the fully differentiable Fourier Shell Correlation (FSC) between two 3D maps using JAX.
    Returns frequencies and correlation values.

    This function matches the implementation in synth-core, but uses jax.numpy
    so that gradients can flow through the FSC calculation to the input maps.
    """
    # Fourier transforms in JAX
    f1 = jnp.fft.rfftn(data1)
    f2 = jnp.fft.rfftn(data2)

    # Cross-spectral density and power spectra are computed using float arithmetic
    # We avoid complex multiplication to parallel the numpy memory stability fix
    cross = f1.real * f2.real + f1.imag * f2.imag
    p1 = f1.real**2 + f1.imag**2
    p2 = f2.real**2 + f2.imag**2

    # Calculate radial bins
    nz, ny, nx = data1.shape
    kz = jnp.fft.fftfreq(nz, d=voxel_size[0])
    ky = jnp.fft.fftfreq(ny, d=voxel_size[1])
    kx = jnp.fft.rfftfreq(nx, d=voxel_size[2])

    # Create 3D grid of frequencies
    kz_grid, ky_grid, kx_grid = jnp.meshgrid(kz, ky, kx, indexing="ij")

    # Calculate magnitude of frequency vector for each voxel
    k = jnp.sqrt(kz_grid**2 + ky_grid**2 + kx_grid**2)

    # Flatten everything
    k = k.ravel()
    cross = cross.ravel()
    p1 = p1.ravel()
    p2 = p2.ravel()

    # Sort by frequency
    idx = jnp.argsort(k)
    k_sorted = k[idx]
    cross_sorted = cross[idx]
    p1_sorted = p1[idx]
    p2_sorted = p2[idx]

    n_bins = min(nx, ny, nz) // 2
    k_max = k_sorted[-1]
    k_eps = k_max / (10 * n_bins)
    bins = jnp.linspace(k_eps, k_max, n_bins + 1)

    # We use vmap to compute the bin sums to keep the function differentiable and JIT-compatible
    # We avoid python loops with dynamic shapes.

    def compute_bin(i: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:
        bin_start = bins[i]
        bin_end = bins[i + 1]
        mask = (k_sorted >= bin_start) & (k_sorted < bin_end)

        sum_cross = jnp.sum(jnp.where(mask, cross_sorted, 0.0))
        sum_p1 = jnp.sum(jnp.where(mask, p1_sorted, 0.0))
        sum_p2 = jnp.sum(jnp.where(mask, p2_sorted, 0.0))

        num = sum_cross
        den = jnp.sqrt(sum_p1 * sum_p2)

        # Avoid division by zero
        val = jnp.where(den > 0, num / den, 0.0)
        # Clamp to [-1, 1]
        val = jnp.clip(val, -1.0, 1.0)
        freq = (bin_start + bin_end) / 2.0

        # We need to return valid mask too, because some bins might be empty
        is_valid = jnp.any(mask)
        return freq, val, is_valid

    indices = jnp.arange(n_bins)
    freqs, vals, is_valid = jax.vmap(compute_bin)(indices)

    # Note: jnp.where with dynamic sizes breaks JIT if we don't pad.
    # For a fully differentiable metric, we typically pad with 0s or NaNs, or return the full array.
    # We will return the full array but mask out invalid frequencies with NaN or 0.

    return freqs, vals