-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathdynamic_simple_model.py
More file actions
29 lines (23 loc) · 883 Bytes
/
dynamic_simple_model.py
File metadata and controls
29 lines (23 loc) · 883 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import sys
import torch
import random
from execution import runner
def optim_func(params) :
return torch.optim.SGD(params, lr=0.01)
def input_func(steps, dtype, device) :
max_seq_length = 128
min_seq_length = 2
seq_lengths = [random.randint(min_seq_length, max_seq_length) for _ in range(steps)]
return [[torch.randn(128, seql, 1024, dtype=dtype, device=device)] for seql in seq_lengths]
class TestModule(torch.nn.Module) :
def __init__(self) :
super(TestModule, self).__init__()
self.linear = torch.nn.Linear(1024, 1024)
self.act = torch.nn.ReLU()
def forward(self, inputs) :
out1 = self.linear(inputs)
out2 = self.act(out1)
out3 = out2 + inputs
return (out3.sum(),)
if __name__ == "__main__" :
runner.run(sys.argv, 'Dynamic-Simple-Model', TestModule(), optim_func, input_func, None)