Skip to content

API Reference

cross_correlation(map_a, map_b)

Compute the Pearson correlation coefficient between two 3D maps. This is a standard loss function for Cryo-EM fitting.

Parameters:

Name Type Description Default
map_a ndarray

(D, H, W) First density map.

required
map_b ndarray

(D, H, W) Second density map.

required

Returns:

Type Description
ndarray

Scalar correlation coefficient in [-1, 1].

Source code in diff_em/kernels.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def cross_correlation(
    map_a: jnp.ndarray,
    map_b: jnp.ndarray,
) -> jnp.ndarray:
    """
    Compute the Pearson correlation coefficient between two 3D maps.
    This is a standard loss function for Cryo-EM fitting.

    Args:
        map_a: (D, H, W) First density map.
        map_b: (D, H, W) Second density map.

    Returns:
        Scalar correlation coefficient in [-1, 1].
    """
    if map_a.shape != map_b.shape:
        raise ValueError("Maps must have the same shape")

    # Flatten maps
    a = map_a.flatten()
    b = map_b.flatten()

    # Center maps
    a_centered = a - jnp.mean(a)
    b_centered = b - jnp.mean(b)

    # Compute CC
    numerator = jnp.sum(a_centered * b_centered)
    denominator = jnp.sqrt(jnp.sum(a_centered**2) * jnp.sum(b_centered**2) + 1e-9)

    return numerator / denominator

simulate_density(coords, grid_coords, sigma=1.0)

Differentiable simulation of 3D density from atomic coordinates. Represent each atom as a 3D Gaussian.

Memory-efficient implementation using jax.lax.scan: atoms are processed one at a time and their contributions accumulated into a single (D, H, W) buffer, giving O(D·H·W) peak memory instead of the O(N·D·H·W) required by a naive broadcast. The scan is fully JIT-able and auto-differentiable.

Parameters:

Name Type Description Default
coords ndarray

(N, 3) atomic coordinates.

required
grid_coords ndarray

(D, H, W, 3) coordinates of the 3D grid points.

required
sigma float

Width (standard deviation) of the Gaussian blobs.

1.0

Returns:

Type Description
ndarray

3D density map (D, H, W).

Source code in diff_em/kernels.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def simulate_density(
    coords: jnp.ndarray,
    grid_coords: jnp.ndarray,
    sigma: float = 1.0,
) -> jnp.ndarray:
    """
    Differentiable simulation of 3D density from atomic coordinates.
    Represent each atom as a 3D Gaussian.

    Memory-efficient implementation using jax.lax.scan: atoms are processed
    one at a time and their contributions accumulated into a single (D, H, W)
    buffer, giving O(D·H·W) peak memory instead of the O(N·D·H·W) required
    by a naive broadcast.  The scan is fully JIT-able and auto-differentiable.

    Args:
        coords: (N, 3) atomic coordinates.
        grid_coords: (D, H, W, 3) coordinates of the 3D grid points.
        sigma: Width (standard deviation) of the Gaussian blobs.

    Returns:
        3D density map (D, H, W).
    """
    if sigma <= 0:
        raise ValueError("sigma must be positive")

    def add_atom(density: jnp.ndarray, atom_coord: jnp.ndarray) -> tuple[jnp.ndarray, None]:
        # grid_coords: (D, H, W, 3); atom_coord: (3,)
        # Broadcasting (D, H, W, 3) - (3,) → (D, H, W, 3)
        diff = grid_coords - atom_coord
        dist_sq = jnp.sum(diff**2, axis=-1)  # (D, H, W)
        # Gaussian kernel — normalization omitted; CC is scale-invariant
        return density + jnp.exp(-dist_sq / (2 * sigma**2)), None

    density_init = jnp.zeros(grid_coords.shape[:-1])  # (D, H, W)
    density, _ = jax.lax.scan(add_atom, density_init, coords)
    return density