Skip to content

📐 Geometry API

The diff_biophys.geometry subpackage provides differentiable structural primitives — the mathematical backbone for converting between internal coordinates (bond lengths, bond angles, dihedral angles) and Cartesian 3D positions, plus tools for alignment and macroscopic property calculation.

All functions are compiled with jax.jit and are fully compatible with jax.grad, jax.vmap, and jax.pmap.


NeRF — Forward Kinematics

diff_biophys.geometry.nerf implements the Natural Extension Reference Frame algorithm. Given three reference atoms and a set of internal coordinates, it places a new atom in 3D space. The recurrence can be chained to build an entire polymer backbone from scratch — entirely differentiably.

When to use:

  • Convert torsion-angle parameters (φ, ψ, ω) to Cartesian coordinates
  • Build synthetic structures for gradient-descent refinement
  • Provide a differentiable mapping from parameter space to observable space
from diff_biophys.geometry.nerf import position_atom_3d, chain_nerf
import jax.numpy as jnp

# Place one atom given three reference atoms + internal coordinates
p1 = jnp.array([0.0, 0.0, 0.0])
p2 = jnp.array([1.52, 0.0, 0.0])
p3 = jnp.array([1.52 + 1.52 * jnp.cos(jnp.pi - 1.94), 1.52 * jnp.sin(jnp.pi - 1.94), 0.0])

p4 = position_atom_3d(
    p1, p2, p3,
    bond_length=jnp.array(1.33),      # Å  (Cα–C)
    bond_angle=jnp.array(1.94),       # rad
    dihedral=jnp.array(-0.994),       # rad  (helix ψ)
)
print(p4)  # → (3,) array with Cartesian coordinates

chain_nerf(init_coords, bond_lengths, bond_angles, dihedrals)

Build a chain of atoms using the NeRF algorithm.

Parameters:

Name Type Description Default
init_coords ndarray

(3, 3) initial coordinates for the first 3 atoms

required
bond_lengths ndarray

(N,) bond lengths for atoms 4 to N+3

required
bond_angles ndarray

(N,) bond angles (in radians) for atoms 4 to N+3

required
dihedrals ndarray

(N,) dihedral angles (in radians) for atoms 4 to N+3

required

Returns:

Type Description
ndarray

jnp.ndarray: (N+3, 3) coordinates for the entire chain

Source code in diff_biophys/geometry/nerf.py
@jit
def chain_nerf(
    init_coords: jnp.ndarray,
    bond_lengths: jnp.ndarray,
    bond_angles: jnp.ndarray,
    dihedrals: jnp.ndarray,
) -> jnp.ndarray:
    """
    Build a chain of atoms using the NeRF algorithm.

    Args:
        init_coords: (3, 3) initial coordinates for the first 3 atoms
        bond_lengths: (N,) bond lengths for atoms 4 to N+3
        bond_angles: (N,) bond angles (in radians) for atoms 4 to N+3
        dihedrals: (N,) dihedral angles (in radians) for atoms 4 to N+3

    Returns:
        jnp.ndarray: (N+3, 3) coordinates for the entire chain
    """

    def body_fun(
        carry: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], i: Any
    ) -> tuple[tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any]:
        p1, p2, p3 = carry

        p4 = position_atom_3d(p1, p2, p3, bond_lengths[i], bond_angles[i], dihedrals[i])
        return (p2, p3, p4), p4

    # Use .shape[0] instead of len() so this works correctly under vmap
    # and with dynamically-shaped arrays during JAX tracing.
    indices = jnp.arange(bond_lengths.shape[0])
    init_carry = (init_coords[0], init_coords[1], init_coords[2])
    _, final_coords = lax.scan(body_fun, init_carry, indices)

    return cast(jnp.ndarray, jnp.concatenate([init_coords, final_coords], axis=0))

position_atom_3d(p1, p2, p3, bond_length, bond_angle_rad, dihedral_angle_rad)

Differentiable NeRF implementation in JAX for a single atom.

Places atom p4 given three reference atoms (p1, p2, p3) and the internal coordinates (bond length, bond angle, dihedral angle) that define its position relative to p3.

Parameters:

Name Type Description Default
p1 ndarray

(3,) first reference atom coordinate.

required
p2 ndarray

(3,) second reference atom coordinate.

required
p3 ndarray

(3,) third reference atom coordinate (parent of p4).

required
bond_length ndarray

Scalar distance p3→p4 in Ångströms.

required
bond_angle_rad ndarray

Scalar bond angle ∠(p2, p3, p4) in radians.

required
dihedral_angle_rad ndarray

Scalar dihedral angle ∠(p1, p2, p3, p4) in radians.

required

Returns:

Type Description
ndarray

jnp.ndarray: (3,) Cartesian coordinates of the new atom p4.

Source code in diff_biophys/geometry/nerf.py
@jit
def position_atom_3d(
    p1: jnp.ndarray,
    p2: jnp.ndarray,
    p3: jnp.ndarray,
    bond_length: jnp.ndarray,
    bond_angle_rad: jnp.ndarray,
    dihedral_angle_rad: jnp.ndarray,
) -> jnp.ndarray:
    """
    Differentiable NeRF implementation in JAX for a single atom.

    Places atom p4 given three reference atoms (p1, p2, p3) and the internal
    coordinates (bond length, bond angle, dihedral angle) that define its
    position relative to p3.

    Args:
        p1: (3,) first reference atom coordinate.
        p2: (3,) second reference atom coordinate.
        p3: (3,) third reference atom coordinate (parent of p4).
        bond_length: Scalar distance p3→p4 in Ångströms.
        bond_angle_rad: Scalar bond angle ∠(p2, p3, p4) in radians.
        dihedral_angle_rad: Scalar dihedral angle ∠(p1, p2, p3, p4) in radians.

    Returns:
        jnp.ndarray: (3,) Cartesian coordinates of the new atom p4.
    """
    v1 = p1 - p2
    v2 = p3 - p2

    u2 = v2 / (jnp.linalg.norm(v2) + 1e-10)

    n = jnp.cross(v1, u2)
    n /= jnp.linalg.norm(n) + 1e-10

    m = jnp.cross(n, u2)

    p4 = p3 + bond_length * (
        -jnp.cos(bond_angle_rad) * u2
        - jnp.sin(bond_angle_rad) * jnp.cos(dihedral_angle_rad) * m
        - jnp.sin(bond_angle_rad) * jnp.sin(dihedral_angle_rad) * n
    )
    return cast(jnp.ndarray, p4)

Torsions — Internal Coordinate Extraction

diff_biophys.geometry.torsions extracts bond lengths, bond angles, and dihedral (torsion) angles from a Cartesian coordinate array. Together with NeRF this forms a round-trip: internal → Cartesian → internal.

from diff_biophys.geometry.torsions import compute_dihedrals, compute_bond_lengths, compute_bond_angles

# coords: (N, 3) backbone atom positions
dihedrals   = compute_dihedrals(coords)    # (N-3,)
bond_lengths = compute_bond_lengths(coords) # (N-1,)
bond_angles  = compute_bond_angles(coords)  # (N-2,)

compute_bond_angles(coords)

Compute bond angles (in radians) between three adjacent atoms.

Source code in diff_biophys/geometry/torsions.py
@jit
def compute_bond_angles(coords: jnp.ndarray) -> jnp.ndarray:
    """
    Compute bond angles (in radians) between three adjacent atoms.
    """
    v1 = coords[:-2] - coords[1:-1]
    v2 = coords[2:] - coords[1:-1]

    v1_norm = v1 / (jnp.linalg.norm(v1, axis=-1, keepdims=True) + 1e-10)
    v2_norm = v2 / (jnp.linalg.norm(v2, axis=-1, keepdims=True) + 1e-10)

    cos_angle = jnp.sum(v1_norm * v2_norm, axis=-1)
    return cast(jnp.ndarray, jnp.arccos(jnp.clip(cos_angle, -1.0 + 1e-7, 1.0 - 1e-7)))

compute_bond_lengths(coords)

Compute bond lengths between adjacent atoms.

Source code in diff_biophys/geometry/torsions.py
@jit
def compute_bond_lengths(coords: jnp.ndarray) -> jnp.ndarray:
    """
    Compute bond lengths between adjacent atoms.
    """
    vectors = coords[1:] - coords[:-1]
    return cast(jnp.ndarray, jnp.linalg.norm(vectors, axis=-1))

compute_dihedrals(coords)

Compute dihedral angles (in radians) for four adjacent atoms. Follows the IUPAC convention and matches synth-pdb. Uses the robust Praxeolitic formula.

Source code in diff_biophys/geometry/torsions.py
@jit
def compute_dihedrals(coords: jnp.ndarray) -> jnp.ndarray:
    """
    Compute dihedral angles (in radians) for four adjacent atoms.
    Follows the IUPAC convention and matches synth-pdb.
    Uses the robust Praxeolitic formula.
    """
    # Vectors: p1-p2, p3-p2, p4-p3
    b0 = coords[:-3] - coords[1:-2]
    b1 = coords[2:-1] - coords[1:-2]
    b2 = coords[3:] - coords[2:-1]

    # Normalize b1
    b1_norm = jnp.linalg.norm(b1, axis=-1, keepdims=True)
    u1 = b1 / (b1_norm + 1e-10)

    # v = orthogonal component of b0 with respect to b1
    v = b0 - jnp.sum(b0 * u1, axis=-1, keepdims=True) * u1
    # w = orthogonal component of b2 with respect to b1
    w = b2 - jnp.sum(b2 * u1, axis=-1, keepdims=True) * u1

    # x = dot product of v and w
    x = jnp.sum(v * w, axis=-1)
    # y = dot product of cross(u1, v) and w
    y = jnp.sum(jnp.cross(u1, v) * w, axis=-1)

    # Use a small epsilon to ensure x and y are not both zero, avoiding NaN gradients
    return cast(jnp.ndarray, jnp.arctan2(y, x + 1e-10))

Superposition — Kabsch Alignment

diff_biophys.geometry.superposition implements the Kabsch algorithm for optimal RMSD superposition of two structures via SVD. The rotation matrix and translation vector are returned, allowing the aligned RMSD to be used as a differentiable loss term.

from diff_biophys.geometry.superposition import kabsch_alignment
import jax, jax.numpy as jnp

# P: mobile structure (N, 3),  Q: reference (N, 3)
R, t = kabsch_alignment(P, Q)
P_aligned = P @ R.T + t
rmsd = jnp.sqrt(jnp.mean(jnp.sum((P_aligned - Q) ** 2, axis=-1)))

# Gradient of RMSD w.r.t. mobile coordinates
grad = jax.grad(lambda p: jnp.sqrt(jnp.mean(jnp.sum(
    (kabsch_alignment(p, Q)[0] @ p.T).T - Q, axis=-1) ** 2)))(P)

kabsch_alignment(P, Q)

Optimal superposition of P onto Q using Kabsch algorithm in JAX.

Parameters:

Name Type Description Default
P ndarray

(N, 3) mobile coordinates

required
Q ndarray

(N, 3) reference coordinates

required

Returns:

Type Description
tuple[ndarray, ndarray]

tuple[jnp.ndarray, jnp.ndarray]: (3x3 rotation matrix, 3-element translation vector)

Source code in diff_biophys/geometry/superposition.py
@jit
def kabsch_alignment(P: jnp.ndarray, Q: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Optimal superposition of P onto Q using Kabsch algorithm in JAX.

    Args:
        P: (N, 3) mobile coordinates
        Q: (N, 3) reference coordinates

    Returns:
        tuple[jnp.ndarray, jnp.ndarray]: (3x3 rotation matrix, 3-element translation vector)
    """
    p_center = jnp.mean(P, axis=0)
    q_center = jnp.mean(Q, axis=0)

    P_c = P - p_center
    Q_c = Q - q_center

    H = jnp.dot(P_c.T, Q_c)

    U, S, Vt = jnp.linalg.svd(H)

    d = jnp.linalg.det(jnp.dot(Vt.T, U.T))
    step = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, jnp.where(d > 0, 1.0, -1.0)]])

    R = jnp.dot(Vt.T, jnp.dot(step, U.T))
    t = q_center - jnp.dot(R, p_center)

    return R, t

Macroscopic Properties

diff_biophys.geometry.macroscopic computes bulk structural descriptors.

Radius of Gyration

\[R_g^2 = \frac{\sum_i m_i \|\mathbf{r}_i - \mathbf{r}_{cm}\|^2}{\sum_i m_i}\]

Useful as a compaction restraint: minimise \((R_g - R_g^{target})^2\) to drive a structure toward a target size extracted from the Guinier region of a SAXS profile.

from diff_biophys.geometry.macroscopic import compute_rg
import jax, jax.numpy as jnp

coords = jnp.array(...)         # (N, 3)
masses = jnp.ones(len(coords))  # uniform masses

rg = compute_rg(coords, masses)

# Gradient: which atoms, if moved, most change Rg?
grad_rg = jax.grad(lambda c: compute_rg(c, masses))(coords)

compute_rg(coords, masses=None)

Compute the Radius of Gyration (Rg) for a set of coordinates.

Rg is a macroscopic measure of compactness, commonly used in SAXS and polymer physics. This implementation is fully differentiable.

Parameters:

Name Type Description Default
coords ndarray

(N, 3) array of coordinates.

required
masses ndarray | None

Optional (N,) array of weights (e.g., atomic masses or electron counts). If None, all points are weighted equally.

None

Returns:

Type Description
ndarray

A scalar jnp.ndarray representing the Radius of Gyration.

Source code in diff_biophys/geometry/macroscopic.py
@jit
def compute_rg(coords: jnp.ndarray, masses: jnp.ndarray | None = None) -> jnp.ndarray:
    """
    Compute the Radius of Gyration (Rg) for a set of coordinates.

    Rg is a macroscopic measure of compactness, commonly used in SAXS and polymer physics.
    This implementation is fully differentiable.

    Args:
        coords: (N, 3) array of coordinates.
        masses: Optional (N,) array of weights (e.g., atomic masses or electron counts).
                If None, all points are weighted equally.

    Returns:
        A scalar jnp.ndarray representing the Radius of Gyration.
    """
    if masses is None:
        # Unweighted center of mass
        com = jnp.mean(coords, axis=0)
        # Mean squared distance from COM
        sq_dist = jnp.sum((coords - com) ** 2, axis=-1)
        rg_sq = jnp.mean(sq_dist)
    else:
        # Weighted center of mass
        total_mass = jnp.sum(masses)
        com = jnp.sum(coords * masses[:, None], axis=0) / total_mass
        # Weighted mean squared distance
        sq_dist = jnp.sum((coords - com) ** 2, axis=-1)
        rg_sq = jnp.sum(sq_dist * masses) / total_mass

    return cast(
        jnp.ndarray, jnp.sqrt(rg_sq + 1e-10)
    )  # Add epsilon for numerical stability of sqrt near 0