Skip to content

Losses

calculate_pseudo_torsions(ca_coords)

Calculates pseudo-torsion angles for consecutive Cα atoms.

A pseudo-torsion is the dihedral angle formed by four consecutive Cα atoms (i-1, i, i+1, i+2). In Cα-only models, these angles are the primary indicator of backbone conformation (analogous to Ramachandran angles for full-atom models).

Typical pseudo-torsion values

α-helix: ~ +50° β-strand: ~ ±180°

Parameters:

Name Type Description Default
ca_coords Array

(N, 3) array of Cα coordinates.

required

Returns:

Type Description
Array

(N-3,) array of pseudo-torsion angles in degrees, range [-180, 180].

Source code in resonance_flow/losses.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def calculate_pseudo_torsions(ca_coords: jax.Array) -> jax.Array:
    """
    Calculates pseudo-torsion angles for consecutive Cα atoms.

    A pseudo-torsion is the dihedral angle formed by four consecutive
    Cα atoms (i-1, i, i+1, i+2).  In Cα-only models, these angles are
    the primary indicator of backbone conformation (analogous to
    Ramachandran angles for full-atom models).

    Typical pseudo-torsion values:
        α-helix:  ~ +50°
        β-strand: ~ ±180°

    Args:
        ca_coords: (N, 3) array of Cα coordinates.

    Returns:
        (N-3,) array of pseudo-torsion angles in degrees, range [-180, 180].
    """
    # 1. Compute bond vectors
    b1 = ca_coords[1:-2] - ca_coords[:-3]
    b2 = ca_coords[2:-1] - ca_coords[1:-2]
    b3 = ca_coords[3:] - ca_coords[2:-1]

    # 2. Compute normal vectors to the planes
    n1 = jnp.cross(b1, b2)
    n2 = jnp.cross(b2, b3)

    # 3. Compute the angle using the atan2(y, x) robust formula.
    # n1 and n2 are normals to the two planes formed by (b1, b2) and (b2, b3).
    # The dihedral angle is the angle between these normals.
    n1_norm = n1 / (jnp.linalg.norm(n1, axis=-1, keepdims=True) + 1e-8)
    n2_norm = n2 / (jnp.linalg.norm(n2, axis=-1, keepdims=True) + 1e-8)
    b2_unit = b2 / (jnp.linalg.norm(b2, axis=-1, keepdims=True) + 1e-8)

    # y = [n1 x n2] . b2_unit
    # x = n1 . n2
    y = jnp.sum(jnp.cross(n1_norm, n2_norm) * b2_unit, axis=-1)
    x = jnp.sum(n1_norm * n2_norm, axis=-1)

    return cast(jax.Array, jnp.arctan2(y, x) * (180.0 / jnp.pi))

calculate_rdcs(predicted_vectors, saupe_tensor, d_max=21700.0)

Back-calculates RDCs for a set of vectors given a Saupe tensor.

Parameters:

Name Type Description Default
predicted_vectors Array

(N, 3) internuclear vectors.

required
saupe_tensor Array

(5,) array of tensor components.

required
d_max float

Maximum dipolar coupling constant (Hz).

21700.0

Returns:

Type Description
Array

(N,) predicted RDC values in Hz.

Source code in resonance_flow/losses.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def calculate_rdcs(
    predicted_vectors: jax.Array, saupe_tensor: jax.Array, d_max: float = 21700.0
) -> jax.Array:
    """
    Back-calculates RDCs for a set of vectors given a Saupe tensor.

    Args:
        predicted_vectors: (N, 3) internuclear vectors.
        saupe_tensor: (5,) array of tensor components.
        d_max: Maximum dipolar coupling constant (Hz).

    Returns:
        (N,) predicted RDC values in Hz.
    """
    norms = jnp.linalg.norm(predicted_vectors, axis=-1, keepdims=True)
    v = predicted_vectors / (norms + 1e-8)

    x, y, z = v[:, 0], v[:, 1], v[:, 2]
    A = d_max * jnp.stack([x**2 - z**2, y**2 - z**2, 2 * x * y, 2 * x * z, 2 * y * z], axis=1)
    return cast(jax.Array, A @ saupe_tensor)

estimate_nh_proxy_vectors(ca_coords)

Estimates backbone N-H proxy vectors from Cα coordinates.

Uses the anti-parallel virtual-bond approximation: for each interior residue i the proxy N-H direction is taken as the unit vector from Cα(i+1) to Cα(i-1), which is roughly anti-parallel to the local backbone tangent and correlates with the amide N-H orientation in both α-helices and β-strands. This is a standard Cα-only coarse- graining strategy for alignment tensor calculations (see Zweckstetter & Bax, J. Am. Chem. Soc. 2000, for the geometric relationship between Cα positions and alignment-frame vectors).

Note: for full-atom models, real N–H internuclear vectors should be supplied directly to rdc_loss instead of using this approximation.

Parameters:

Name Type Description Default
ca_coords Array

(N, 3) array of Cα coordinates.

required

Returns:

Type Description
Array

(N-2, 3) unit proxy vectors for residues 1 … N-2.

Source code in resonance_flow/losses.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def estimate_nh_proxy_vectors(ca_coords: jax.Array) -> jax.Array:
    """
    Estimates backbone N-H proxy vectors from Cα coordinates.

    Uses the anti-parallel virtual-bond approximation: for each interior
    residue i the proxy N-H direction is taken as the unit vector from
    Cα(i+1) to Cα(i-1), which is roughly anti-parallel to the local
    backbone tangent and correlates with the amide N-H orientation in
    both α-helices and β-strands.  This is a standard Cα-only coarse-
    graining strategy for alignment tensor calculations (see Zweckstetter &
    Bax, J. Am. Chem. Soc. 2000, for the geometric relationship between
    Cα positions and alignment-frame vectors).

    Note: for full-atom models, real N–H internuclear vectors should be
    supplied directly to rdc_loss instead of using this approximation.

    Args:
        ca_coords: (N, 3) array of Cα coordinates.

    Returns:
        (N-2, 3) unit proxy vectors for residues 1 … N-2.
    """
    # Anti-parallel virtual bond: Cα(i-1) − Cα(i+1), normalised.
    raw = ca_coords[:-2] - ca_coords[2:]  # shape (N-2, 3)
    norms = jnp.linalg.norm(raw, axis=-1, keepdims=True)
    return cast(jax.Array, raw / (norms + 1e-8))

fit_saupe_tensor(predicted_vectors, measured_rdcs, d_max=21700.0)

Fits the Saupe alignment tensor (5 components) to vectors and RDCs.

Parameters:

Name Type Description Default
predicted_vectors Array

(N, 3) internuclear vectors.

required
measured_rdcs Array

(N,) experimental RDC values in Hz.

required
d_max float

Maximum dipolar coupling constant (Hz).

21700.0

Returns:

Type Description
Array

(5,) array containing the independent components of the Saupe tensor

Array

[Sxx, Syy, Sxy, Sxz, Syz].

Source code in resonance_flow/losses.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def fit_saupe_tensor(
    predicted_vectors: jax.Array, measured_rdcs: jax.Array, d_max: float = 21700.0
) -> jax.Array:
    """
    Fits the Saupe alignment tensor (5 components) to vectors and RDCs.

    Args:
        predicted_vectors: (N, 3) internuclear vectors.
        measured_rdcs: (N,) experimental RDC values in Hz.
        d_max: Maximum dipolar coupling constant (Hz).

    Returns:
        (5,) array containing the independent components of the Saupe tensor
        [Sxx, Syy, Sxy, Sxz, Syz].
    """
    norms = jnp.linalg.norm(predicted_vectors, axis=-1, keepdims=True)
    v = predicted_vectors / (norms + 1e-8)

    x, y, z = v[:, 0], v[:, 1], v[:, 2]
    A = d_max * jnp.stack([x**2 - z**2, y**2 - z**2, 2 * x * y, 2 * x * z, 2 * y * z], axis=1)
    s, _, _, _ = jnp.linalg.lstsq(A, measured_rdcs, rcond=1e-5)
    return cast(jax.Array, s)

get_bond_length_loss(target_distance=3.8)

Penalises deviations from the ideal Cα–Cα virtual bond length.

The canonical Cα–Cα distance in a peptide chain is 3.80 ± 0.02 Å (Engh & Huber, Acta Crystallogr. A, 1991). This is the virtual bond between sequential alpha-carbons across the full peptide unit; it is NOT the C–C covalent bond length (1.52 Å).

Parameters:

Name Type Description Default
target_distance float

Ideal Cα–Cα virtual bond length in Angstroms. Default 3.8 Å (Engh & Huber 1991).

3.8
Source code in resonance_flow/losses.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def get_bond_length_loss(
    target_distance: float = 3.8,
) -> Callable[[jax.Array], jax.Array]:
    """
    Penalises deviations from the ideal Cα–Cα virtual bond length.

    The canonical Cα–Cα distance in a peptide chain is 3.80 ± 0.02 Å
    (Engh & Huber, Acta Crystallogr. A, 1991).  This is the virtual bond
    between sequential alpha-carbons across the full peptide unit; it is
    NOT the C–C covalent bond length (1.52 Å).

    Args:
        target_distance: Ideal Cα–Cα virtual bond length in Angstroms.
                         Default 3.8 Å (Engh & Huber 1991).
    """

    def bond_length_loss(positions: jax.Array) -> jax.Array:
        # Compute distances between consecutive Cα atoms.
        diffs = positions[1:] - positions[:-1]
        distances = jnp.linalg.norm(diffs, axis=-1)
        return cast(jax.Array, jnp.mean((distances - target_distance) ** 2))

    return bond_length_loss

get_steric_clash_loss(box_size=None, exclude_bonded_range=0)

Returns a function to compute the steric clash (atom overlap) penalty.

Parameters:

Name Type Description Default
box_size float | None

Optional. If provided, uses periodic boundary conditions. Otherwise, assumes free space.

None
exclude_bonded_range int

Exclude atom pairs whose sequential index separation is <= this value. Default 0 excludes only self-interactions (original behaviour). Set to 1 to also exclude directly bonded 1-2 neighbours, or 2 for 1-2 and 1-3 pairs (standard AMBER / CHARMM convention).

0
Source code in resonance_flow/losses.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def get_steric_clash_loss(
    box_size: float | None = None, exclude_bonded_range: int = 0
) -> Callable[[jax.Array, jax.Array], jax.Array]:
    """
    Returns a function to compute the steric clash (atom overlap) penalty.

    Args:
        box_size: Optional. If provided, uses periodic boundary conditions.
                  Otherwise, assumes free space.
        exclude_bonded_range: Exclude atom pairs whose sequential index
                              separation is <= this value.  Default 0 excludes
                              only self-interactions (original behaviour).
                              Set to 1 to also exclude directly bonded 1-2
                              neighbours, or 2 for 1-2 and 1-3 pairs
                              (standard AMBER / CHARMM convention).
    """

    def steric_clash_loss(positions: jax.Array, atom_radii: jax.Array) -> jax.Array:
        """
        Computes the penalty for overlapping atoms.

        Args:
            positions: (N, 3) array of atomic coordinates.
            atom_radii: (N,) array of atomic van der Waals radii.

        Returns:
            A scalar loss representing the total steric clash penalty.
        """
        n = positions.shape[0]

        diff = positions[:, None, :] - positions[None, :, :]
        if box_size is not None:
            diff = diff - box_size * jnp.round(diff / box_size)

        dr2 = jnp.sum(diff**2, axis=-1)
        # Avoid NaN gradients at zero distance (e.g. self-interactions)
        dr = jnp.sqrt(jnp.where(dr2 == 0.0, 1e-10, dr2))

        radii_sum = atom_radii[:, None] + atom_radii[None, :]
        overlap = jnp.maximum(radii_sum - dr, 0.0)

        # Build mask: exclude pairs with index separation <= exclude_bonded_range.
        # exclude_bonded_range=0  →  only self excluded  (original behaviour)
        # exclude_bonded_range=1  →  self + 1-2 bonded excluded
        # exclude_bonded_range=2  →  self + 1-2 + 1-3 excluded
        indices = jnp.arange(n)
        pair_sep = jnp.abs(indices[:, None] - indices[None, :])
        mask = (pair_sep > exclude_bonded_range).astype(jnp.float32)
        overlap = overlap * mask

        loss = jnp.sum(overlap**2) / 2.0
        return cast(jax.Array, loss)

    return steric_clash_loss

noe_upper_bound_loss(positions, noe_pairs, upper_bounds)

Penalises violations of NOE-derived inter-proton distance upper bounds.

NOE distance restraints are the primary source of 3D structural information in protein NMR, providing upper bounds on inter-proton distances typically in the range 1.8–6.0 Å (Wüthrich, NMR of Proteins and Nucleic Acids, 1986; Güntert et al., J. Mol. Biol., 1997).

A flat-bottomed harmonic penalty is applied only to upper-bound violations (no lower-bound penalty, since NOE cross-peaks are only observed when protons are close):

L_NOE = mean( max(0, d_ij − d_upper)² )

Parameters:

Name Type Description Default
positions Array

(N, 3) atomic coordinates in Angstroms.

required
noe_pairs Array

(M, 2) integer array of atom-index pairs.

required
upper_bounds Array

(M,) upper distance bounds in Angstroms.

required

Returns:

Type Description
Array

Scalar NOE violation loss.

Source code in resonance_flow/losses.py
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def noe_upper_bound_loss(
    positions: jax.Array, noe_pairs: jax.Array, upper_bounds: jax.Array
) -> jax.Array:
    """
    Penalises violations of NOE-derived inter-proton distance upper bounds.

    NOE distance restraints are the primary source of 3D structural
    information in protein NMR, providing upper bounds on inter-proton
    distances typically in the range 1.8–6.0 Å (Wüthrich, *NMR of Proteins
    and Nucleic Acids*, 1986; Güntert et al., J. Mol. Biol., 1997).

    A flat-bottomed harmonic penalty is applied only to upper-bound
    violations (no lower-bound penalty, since NOE cross-peaks are only
    observed when protons are close):

        L_NOE = mean( max(0, d_ij − d_upper)² )

    Args:
        positions: (N, 3) atomic coordinates in Angstroms.
        noe_pairs: (M, 2) integer array of atom-index pairs.
        upper_bounds: (M,) upper distance bounds in Angstroms.

    Returns:
        Scalar NOE violation loss.
    """
    ri = positions[noe_pairs[:, 0]]
    rj = positions[noe_pairs[:, 1]]
    dists = jnp.linalg.norm(ri - rj, axis=-1)
    violations = jnp.maximum(dists - upper_bounds, 0.0)
    return jnp.mean(violations**2)

rdc_loss(predicted_vectors, measured_rdcs, d_max=21700.0)

Scientifically correct RDC loss using Saupe tensor fitting. Fits the alignment tensor to the structure, then calculates the residual.

References

Bax & Tjandra, J. Biomol. NMR, 1997. Cornilescu, Marquardt, Ottiger & Bax, J. Am. Chem. Soc., 1998.

Source code in resonance_flow/losses.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def rdc_loss(
    predicted_vectors: jax.Array, measured_rdcs: jax.Array, d_max: float = 21700.0
) -> jax.Array:
    """
    Scientifically correct RDC loss using Saupe tensor fitting.
    Fits the alignment tensor to the structure, then calculates the residual.

    References:
        Bax & Tjandra, J. Biomol. NMR, 1997.
        Cornilescu, Marquardt, Ottiger & Bax, J. Am. Chem. Soc., 1998.
    """
    s = fit_saupe_tensor(predicted_vectors, measured_rdcs, d_max)
    predicted_rdcs = calculate_rdcs(predicted_vectors, s, d_max)
    return jnp.mean((predicted_rdcs - measured_rdcs) ** 2)

rdc_q_factor(predicted_vectors, measured_rdcs, d_max=21700.0)

Computes the RDC Q-factor (Cornilescu, Marquardt, Ottiger & Bax, JACS 1998).

The Q-factor is the NMR analogue of the crystallographic R-factor

Q = RMSD(D_calc − D_obs) / RMS(D_obs)

Source code in resonance_flow/losses.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def rdc_q_factor(
    predicted_vectors: jax.Array, measured_rdcs: jax.Array, d_max: float = 21700.0
) -> jax.Array:
    """
    Computes the RDC Q-factor (Cornilescu, Marquardt, Ottiger & Bax, JACS 1998).

    The Q-factor is the NMR analogue of the crystallographic R-factor:
        Q = RMSD(D_calc − D_obs) / RMS(D_obs)
    """
    s = fit_saupe_tensor(predicted_vectors, measured_rdcs, d_max)
    predicted_rdcs = calculate_rdcs(predicted_vectors, s, d_max)

    rmsd = jnp.sqrt(jnp.mean((predicted_rdcs - measured_rdcs) ** 2))
    rms_obs = jnp.sqrt(jnp.mean(measured_rdcs**2))
    return rmsd / (rms_obs + 1e-10)

rdc_q_free(predicted_vectors, measured_rdcs, train_mask, d_max=21700.0)

Computes the Q_free cross-validation metric (Clore & Garrett, JACS 1999).

Fits the Saupe tensor using only data where train_mask is True, then calculates the Q-factor on the held-out data (where train_mask is False). This is the gold standard for detecting overfitting to RDCs.

Parameters:

Name Type Description Default
predicted_vectors Array

(N, 3) internuclear vectors.

required
measured_rdcs Array

(N,) experimental RDC values.

required
train_mask Array

(N,) boolean mask (True = use for fitting, False = use for Q_free).

required
d_max float

Maximum dipolar coupling constant.

21700.0

Returns:

Type Description
Array

Q_free (dimensionless).

Source code in resonance_flow/losses.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def rdc_q_free(
    predicted_vectors: jax.Array,
    measured_rdcs: jax.Array,
    train_mask: jax.Array,
    d_max: float = 21700.0,
) -> jax.Array:
    """
    Computes the Q_free cross-validation metric (Clore & Garrett, JACS 1999).

    Fits the Saupe tensor using only data where train_mask is True, then
    calculates the Q-factor on the held-out data (where train_mask is False).
    This is the gold standard for detecting overfitting to RDCs.

    Args:
        predicted_vectors: (N, 3) internuclear vectors.
        measured_rdcs: (N,) experimental RDC values.
        train_mask: (N,) boolean mask (True = use for fitting, False = use for Q_free).
        d_max: Maximum dipolar coupling constant.

    Returns:
        Q_free (dimensionless).
    """
    # 1. Fit tensor on the training subset.
    s = fit_saupe_tensor(predicted_vectors[train_mask], measured_rdcs[train_mask], d_max)

    # 2. Evaluate on the test subset (the 'free' set).
    test_mask = ~train_mask
    v_test = predicted_vectors[test_mask]
    d_test = measured_rdcs[test_mask]

    predicted_test = calculate_rdcs(v_test, s, d_max)

    rmsd = jnp.sqrt(jnp.mean((predicted_test - d_test) ** 2))
    rms_obs = jnp.sqrt(jnp.mean(d_test**2))
    return rmsd / (rms_obs + 1e-10)