Bases: Module
A simple Transformer that takes an amino acid sequence (as token IDs)
and predicts 3D coordinates for each token.
Source code in resonance_flow/model.py
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56 | class TransformerCoordinatePredictor(nn.Module):
"""
A simple Transformer that takes an amino acid sequence (as token IDs)
and predicts 3D coordinates for each token.
"""
vocab_size: int = 21
d_model: int = 128
num_heads: int = 4
num_layers: int = 4
max_len: int = 512
dropout_rate: float = 0.1
@nn.compact
def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
"""
Predicts 3D coordinates.
Args:
x: (batch_size, seq_len) token IDs.
deterministic: True for eval, False for train.
"""
batch_size, seq_len = x.shape
x = nn.Embed(num_embeddings=self.vocab_size, features=self.d_model)(x)
pos_emb = self.param(
"pos_embedding",
nn.initializers.normal(stddev=0.02),
(1, self.max_len, self.d_model),
)
x = x + pos_emb[:, :seq_len, :]
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
for _ in range(self.num_layers):
y = nn.LayerNorm()(x)
y = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
qkv_features=self.d_model,
out_features=self.d_model,
)(y)
x = x + nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic)
y = nn.LayerNorm()(x)
y = nn.Dense(features=self.d_model * 4)(y)
y = nn.gelu(y)
y = nn.Dense(features=self.d_model)(y)
x = x + nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic)
x = nn.LayerNorm()(x)
coords = nn.Dense(features=3, kernel_init=nn.initializers.normal(stddev=1e-3))(x)
return coords
|
Predicts 3D coordinates.
Parameters:
| Name |
Type |
Description |
Default |
x
|
Array
|
(batch_size, seq_len) token IDs.
|
required
|
deterministic
|
bool
|
True for eval, False for train.
|
True
|
Source code in resonance_flow/model.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56 | @nn.compact
def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
"""
Predicts 3D coordinates.
Args:
x: (batch_size, seq_len) token IDs.
deterministic: True for eval, False for train.
"""
batch_size, seq_len = x.shape
x = nn.Embed(num_embeddings=self.vocab_size, features=self.d_model)(x)
pos_emb = self.param(
"pos_embedding",
nn.initializers.normal(stddev=0.02),
(1, self.max_len, self.d_model),
)
x = x + pos_emb[:, :seq_len, :]
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
for _ in range(self.num_layers):
y = nn.LayerNorm()(x)
y = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
qkv_features=self.d_model,
out_features=self.d_model,
)(y)
x = x + nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic)
y = nn.LayerNorm()(x)
y = nn.Dense(features=self.d_model * 4)(y)
y = nn.gelu(y)
y = nn.Dense(features=self.d_model)(y)
x = x + nn.Dropout(rate=self.dropout_rate)(y, deterministic=deterministic)
x = nn.LayerNorm()(x)
coords = nn.Dense(features=3, kernel_init=nn.initializers.normal(stddev=1e-3))(x)
return coords
|