Skip to content

Model

TransformerCoordinatePredictor

Bases: Module

A simple Transformer that takes an amino acid sequence (as token IDs) and predicts 3D coordinates for each token.

Source code in resonance_flow/model.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class TransformerCoordinatePredictor(nn.Module):
    """
    A simple Transformer that takes an amino acid sequence (as token IDs)
    and predicts 3D coordinates for each token.
    """

    vocab_size: int = 21
    d_model: int = 128
    num_heads: int = 4
    num_layers: int = 4
    max_len: int = 512
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
        """
        Predicts 3D coordinates.

        Args:
            x: (batch_size, seq_len) token IDs.
            deterministic: True for eval, False for train.
        """
        batch_size, seq_len = x.shape

        x = nn.Embed(num_embeddings=self.vocab_size, features=self.d_model)(x)

        pos_emb = self.param(
            "pos_embedding",
            nn.initializers.normal(stddev=0.02),
            (1, self.max_len, self.d_model),
        )
        x = x + pos_emb[:, :seq_len, :]
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)

        for _ in range(self.num_layers):
            y = nn.LayerNorm()(x)
            y = nn.MultiHeadDotProductAttention(
                num_heads=self.num_heads,
                qkv_features=self.d_model,
                out_features=self.d_model,
            )(y)
            x = x + nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic)

            y = nn.LayerNorm()(x)
            y = nn.Dense(features=self.d_model * 4)(y)
            y = nn.gelu(y)
            y = nn.Dense(features=self.d_model)(y)
            x = x + nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic)

        x = nn.LayerNorm()(x)
        coords = nn.Dense(features=3, kernel_init=nn.initializers.normal(stddev=1e-3))(x)
        return coords

__call__(x, deterministic=True)

Predicts 3D coordinates.

Parameters:

Name Type Description Default
x Array

(batch_size, seq_len) token IDs.

required
deterministic bool

True for eval, False for train.

True
Source code in resonance_flow/model.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
@nn.compact
def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
    """
    Predicts 3D coordinates.

    Args:
        x: (batch_size, seq_len) token IDs.
        deterministic: True for eval, False for train.
    """
    batch_size, seq_len = x.shape

    x = nn.Embed(num_embeddings=self.vocab_size, features=self.d_model)(x)

    pos_emb = self.param(
        "pos_embedding",
        nn.initializers.normal(stddev=0.02),
        (1, self.max_len, self.d_model),
    )
    x = x + pos_emb[:, :seq_len, :]
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)

    for _ in range(self.num_layers):
        y = nn.LayerNorm()(x)
        y = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            qkv_features=self.d_model,
            out_features=self.d_model,
        )(y)
        x = x + nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic)

        y = nn.LayerNorm()(x)
        y = nn.Dense(features=self.d_model * 4)(y)
        y = nn.gelu(y)
        y = nn.Dense(features=self.d_model)(y)
        x = x + nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic)

    x = nn.LayerNorm()(x)
    coords = nn.Dense(features=3, kernel_init=nn.initializers.normal(stddev=1e-3))(x)
    return coords