Cryo-EM Density Fitting¶
Cryogenic Electron Microscopy (Cryo-EM) produces 3D Coulomb potential maps (often called density maps) of macromolecules. A common task is fitting an atomic model into this density map.
The diff_em module allows us to simulate a 3D density map from atomic coordinates using Gaussian blobs, and then compute the Pearson cross-correlation between the simulated map and a target experimental map. Both operations are fully differentiable, meaning we can optimize atomic coordinates directly against the 3D density.
In this tutorial, we will:
- Generate a synthetic target 3D density map.
- Distort the atomic coordinates.
- Optimize the distorted coordinates using JAX gradients to maximize the cross-correlation (CC) with the target density.
import sys
if "google.colab" in sys.modules:
!pip install -q diff-em jax jaxlib optax matplotlib biotite
else:
sys.path.append("../../")
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
1. Simulating 3D Density from Coordinates¶
We'll define a simple 3D grid and a few atoms, and simulate the 3D density using simulate_density. This uses a memory-efficient jax.lax.scan over the atoms.
from diff_em.kernels import cross_correlation, simulate_density
# Define a 3D grid: 32x32x32 voxels from -15 to 15 Angstroms
grid_1d = jnp.linspace(-15.0, 15.0, 32)
Z, Y, X = jnp.meshgrid(grid_1d, grid_1d, grid_1d, indexing="ij")
grid_coords = jnp.stack([X, Y, Z], axis=-1) # (32, 32, 32, 3)
# Synthetic atomic coordinates for a simple "U" shaped molecule
true_coords = jnp.array(
[[-5.0, 5.0, 0.0], [-5.0, 0.0, 0.0], [0.0, -5.0, 0.0], [5.0, 0.0, 0.0], [5.0, 5.0, 0.0]]
)
# Simulate the target density map
target_density = simulate_density(true_coords, grid_coords, sigma=2.0)
# Visualize a 2D slice at Z=0
z_slice_idx = 16 # middle of 32
plt.figure(figsize=(6, 5))
plt.imshow(
np.array(target_density[z_slice_idx, :, :]),
extent=[-15, 15, -15, 15],
origin="lower",
cmap="magma",
)
plt.colorbar(label="Simulated Density")
plt.scatter(true_coords[:, 0], true_coords[:, 1], color="cyan", label="True Atoms", marker="o")
plt.xlabel("X (Å)")
plt.ylabel("Y (Å)")
plt.title("Target Cryo-EM Density (Z=0 slice)")
plt.legend()
plt.tight_layout()
plt.show()
2. Gradient-Based Fitting¶
Now, we'll start with a distorted version of our molecule and use gradient descent (via optax) to maximize the cross-correlation (CC) with the target density.
# Distort the true coordinates
key = jax.random.PRNGKey(42)
initial_coords = true_coords + 3.0 * jax.random.normal(key, true_coords.shape)
# Define the loss function to minimize: 1 - CC
@jax.jit
def loss_fn(coords):
simulated_density = simulate_density(coords, grid_coords, sigma=2.0)
cc = cross_correlation(simulated_density, target_density)
return 1.0 - cc
# Set up optimizer
optimizer = optax.adam(learning_rate=0.5)
opt_state = optimizer.init(initial_coords)
@jax.jit
def step(coords, opt_state):
loss, grads = jax.value_and_grad(loss_fn)(coords)
updates, opt_state = optimizer.update(grads, opt_state)
new_coords = optax.apply_updates(coords, updates)
return new_coords, opt_state, loss
# Run optimization loop
coords = initial_coords
losses = []
for _ in range(100):
coords, opt_state, loss = step(coords, opt_state)
losses.append(loss)
final_coords = coords
# Plot the loss trajectory
plt.figure(figsize=(8, 4))
plt.plot(np.array(losses), linewidth=2, color="red")
plt.xlabel("Iteration")
plt.ylabel("Loss (1 - CC)")
plt.title("Cryo-EM Density Fitting Optimization")
plt.tight_layout()
plt.show()
Visualizing the Result¶
Let's overlay the Initial, Final, and True coordinates on the density slice.
plt.figure(figsize=(8, 7))
plt.imshow(
np.array(target_density[z_slice_idx, :, :]),
extent=[-15, 15, -15, 15],
origin="lower",
cmap="magma",
)
plt.colorbar(label="Target Density")
plt.scatter(
true_coords[:, 0], true_coords[:, 1], color="cyan", s=100, label="True Atoms", marker="o"
)
plt.scatter(
initial_coords[:, 0],
initial_coords[:, 1],
color="white",
s=50,
label="Initial (Distorted)",
marker="x",
)
plt.scatter(
final_coords[:, 0],
final_coords[:, 1],
color="lightgreen",
s=100,
label="Final (Optimized)",
marker="*",
)
# Draw arrows showing movement
for i in range(len(true_coords)):
plt.arrow(
initial_coords[i, 0],
initial_coords[i, 1],
final_coords[i, 0] - initial_coords[i, 0],
final_coords[i, 1] - initial_coords[i, 1],
color="white",
alpha=0.5,
head_width=0.5,
length_includes_head=True,
)
plt.xlabel("X (Å)")
plt.ylabel("Y (Å)")
plt.title("Fitting Trajectory Overlay")
plt.legend()
plt.tight_layout()
plt.show()
final_cc = 1.0 - losses[-1]
print(f"Final Cross-Correlation: {final_cc:.4f}")