📐 Geometry API
The diff_biophys.geometry subpackage provides differentiable structural primitives —
the mathematical backbone for converting between internal coordinates (bond lengths,
bond angles, dihedral angles) and Cartesian 3D positions, plus tools for alignment
and macroscopic property calculation.
All functions are compiled with jax.jit and are fully compatible with jax.grad,
jax.vmap, and jax.pmap.
NeRF — Forward Kinematics
diff_biophys.geometry.nerf implements the Natural Extension Reference Frame
algorithm. Given three reference atoms and a set of internal coordinates, it places
a new atom in 3D space. The recurrence can be chained to build an entire polymer
backbone from scratch — entirely differentiably.
When to use:
- Convert torsion-angle parameters (φ, ψ, ω) to Cartesian coordinates
- Build synthetic structures for gradient-descent refinement
- Provide a differentiable mapping from parameter space to observable space
from diff_biophys.geometry.nerf import position_atom_3d, chain_nerf
import jax.numpy as jnp
# Place one atom given three reference atoms + internal coordinates
p1 = jnp.array([0.0, 0.0, 0.0])
p2 = jnp.array([1.52, 0.0, 0.0])
p3 = jnp.array([1.52 + 1.52 * jnp.cos(jnp.pi - 1.94), 1.52 * jnp.sin(jnp.pi - 1.94), 0.0])
p4 = position_atom_3d(
p1, p2, p3,
bond_length=jnp.array(1.33), # Å (Cα–C)
bond_angle=jnp.array(1.94), # rad
dihedral=jnp.array(-0.994), # rad (helix ψ)
)
print(p4) # → (3,) array with Cartesian coordinates
chain_nerf(init_coords, bond_lengths, bond_angles, dihedrals)
Build a chain of atoms using the NeRF algorithm.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
init_coords
|
ndarray
|
(3, 3) initial coordinates for the first 3 atoms |
required |
bond_lengths
|
ndarray
|
(N,) bond lengths for atoms 4 to N+3 |
required |
bond_angles
|
ndarray
|
(N,) bond angles (in radians) for atoms 4 to N+3 |
required |
dihedrals
|
ndarray
|
(N,) dihedral angles (in radians) for atoms 4 to N+3 |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
jnp.ndarray: (N+3, 3) coordinates for the entire chain |
Source code in diff_biophys/geometry/nerf.py
position_atom_3d(p1, p2, p3, bond_length, bond_angle_rad, dihedral_angle_rad)
Differentiable NeRF implementation in JAX for a single atom.
Places atom p4 given three reference atoms (p1, p2, p3) and the internal coordinates (bond length, bond angle, dihedral angle) that define its position relative to p3.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
p1
|
ndarray
|
(3,) first reference atom coordinate. |
required |
p2
|
ndarray
|
(3,) second reference atom coordinate. |
required |
p3
|
ndarray
|
(3,) third reference atom coordinate (parent of p4). |
required |
bond_length
|
ndarray
|
Scalar distance p3→p4 in Ångströms. |
required |
bond_angle_rad
|
ndarray
|
Scalar bond angle ∠(p2, p3, p4) in radians. |
required |
dihedral_angle_rad
|
ndarray
|
Scalar dihedral angle ∠(p1, p2, p3, p4) in radians. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
jnp.ndarray: (3,) Cartesian coordinates of the new atom p4. |
Source code in diff_biophys/geometry/nerf.py
Torsions — Internal Coordinate Extraction
diff_biophys.geometry.torsions extracts bond lengths, bond angles, and dihedral
(torsion) angles from a Cartesian coordinate array. Together with NeRF this forms a
round-trip: internal → Cartesian → internal.
from diff_biophys.geometry.torsions import compute_dihedrals, compute_bond_lengths, compute_bond_angles
# coords: (N, 3) backbone atom positions
dihedrals = compute_dihedrals(coords) # (N-3,)
bond_lengths = compute_bond_lengths(coords) # (N-1,)
bond_angles = compute_bond_angles(coords) # (N-2,)
compute_bond_angles(coords)
Compute bond angles (in radians) between three adjacent atoms.
Source code in diff_biophys/geometry/torsions.py
compute_bond_lengths(coords)
Compute bond lengths between adjacent atoms.
compute_dihedrals(coords)
Compute dihedral angles (in radians) for four adjacent atoms. Follows the IUPAC convention and matches synth-pdb. Uses the robust Praxeolitic formula.
Source code in diff_biophys/geometry/torsions.py
Superposition — Kabsch Alignment
diff_biophys.geometry.superposition implements the Kabsch algorithm for
optimal RMSD superposition of two structures via SVD. The rotation matrix and
translation vector are returned, allowing the aligned RMSD to be used as a
differentiable loss term.
from diff_biophys.geometry.superposition import kabsch_alignment
import jax, jax.numpy as jnp
# P: mobile structure (N, 3), Q: reference (N, 3)
R, t = kabsch_alignment(P, Q)
P_aligned = P @ R.T + t
rmsd = jnp.sqrt(jnp.mean(jnp.sum((P_aligned - Q) ** 2, axis=-1)))
# Gradient of RMSD w.r.t. mobile coordinates
grad = jax.grad(lambda p: jnp.sqrt(jnp.mean(jnp.sum(
(kabsch_alignment(p, Q)[0] @ p.T).T - Q, axis=-1) ** 2)))(P)
kabsch_alignment(P, Q)
Optimal superposition of P onto Q using Kabsch algorithm in JAX.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
P
|
ndarray
|
(N, 3) mobile coordinates |
required |
Q
|
ndarray
|
(N, 3) reference coordinates |
required |
Returns:
| Type | Description |
|---|---|
tuple[ndarray, ndarray]
|
tuple[jnp.ndarray, jnp.ndarray]: (3x3 rotation matrix, 3-element translation vector) |
Source code in diff_biophys/geometry/superposition.py
Macroscopic Properties
diff_biophys.geometry.macroscopic computes bulk structural descriptors.
Radius of Gyration
Useful as a compaction restraint: minimise \((R_g - R_g^{target})^2\) to drive a structure toward a target size extracted from the Guinier region of a SAXS profile.
from diff_biophys.geometry.macroscopic import compute_rg
import jax, jax.numpy as jnp
coords = jnp.array(...) # (N, 3)
masses = jnp.ones(len(coords)) # uniform masses
rg = compute_rg(coords, masses)
# Gradient: which atoms, if moved, most change Rg?
grad_rg = jax.grad(lambda c: compute_rg(c, masses))(coords)
compute_rg(coords, masses=None)
Compute the Radius of Gyration (Rg) for a set of coordinates.
Rg is a macroscopic measure of compactness, commonly used in SAXS and polymer physics. This implementation is fully differentiable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
coords
|
ndarray
|
(N, 3) array of coordinates. |
required |
masses
|
ndarray | None
|
Optional (N,) array of weights (e.g., atomic masses or electron counts). If None, all points are weighted equally. |
None
|
Returns:
| Type | Description |
|---|---|
ndarray
|
A scalar jnp.ndarray representing the Radius of Gyration. |