Skip to content

Commit 220e938

Browse files
committed
initial commit
0 parents  commit 220e938

File tree

3 files changed

+72
-0
lines changed

3 files changed

+72
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
/venv/
2+
/.idea/

learnable_fourier_pos_encoding.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
5+
6+
class LearnableFourierPositionalEncoding(nn.Module):
7+
def __init__(self, G: int, M: int, F_dim: int, H_dim: int, D: int, gamma: float):
8+
"""
9+
Learnable Fourier Features from https://arxiv.org/pdf/2106.02795.pdf (Algorithm 1)
10+
Implementation of Algorithm 1: Compute the Fourier feature positional encoding of a multi-dimensional position
11+
Computes the positional encoding of a tensor of shape [N, G, M]
12+
:param G: positional groups (positions in different groups are independent)
13+
:param M: each point has a M-dimensional positional values
14+
:param F_dim: depth of the Fourier feature dimension
15+
:param H_dim: hidden layer dimension
16+
:param D: positional encoding dimension
17+
:param gamma: parameter to initialize Wr
18+
"""
19+
super().__init__()
20+
self.G = G
21+
self.M = M
22+
self.F_dim = F_dim
23+
self.H_dim = H_dim
24+
self.D = D
25+
self.gamma = gamma
26+
27+
# Projection matrix on learned lines (used in eq. 2)
28+
self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False)
29+
# MLP (GeLU(F @ W1 + B1) @ W2 + B2 (eq. 6)
30+
self.mlp = nn.Sequential(
31+
nn.Linear(self.F_dim, self.H_dim, bias=True),
32+
nn.GELU(),
33+
nn.Linear(self.H_dim, self.D // self.G)
34+
)
35+
36+
self.init_weights()
37+
38+
def init_weights(self):
39+
nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
40+
41+
def forward(self, x):
42+
"""
43+
Produce positional encodings from x
44+
:param x: tensor of shape [N, G, M] that represents N positions where each position is in the shape of [G, M],
45+
where G is the positional group and each group has M-dimensional positional values.
46+
Positions in different positional groups are independent
47+
:return: positional encoding for X
48+
"""
49+
N, G, M = x.shape
50+
# Step 1. Compute Fourier features (eq. 2)
51+
projected = self.Wr(x)
52+
cosines = torch.cos(projected)
53+
sines = torch.sin(projected)
54+
F = 1 / np.sqrt(self.F_dim) * torch.cat([cosines, sines], dim=-1)
55+
# Step 2. Compute projected Fourier features (eq. 6)
56+
Y = self.mlp(F)
57+
# Step 3. Reshape to x's shape
58+
PEx = Y.reshape((N, self.D))
59+
return PEx
60+
61+
62+
if __name__ == '__main__':
63+
G = 3
64+
M = 17
65+
x = torch.randn((97, G, M))
66+
enc = LearnableFourierPositionalEncoding(G, M, 768, 32, 768, 10)
67+
pex = enc(x)
68+
print(pex.shape)

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
numpy
2+
torch

0 commit comments

Comments
 (0)