-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
337 lines (283 loc) · 13.4 KB
/
model.py
File metadata and controls
337 lines (283 loc) · 13.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
import torch
from torch import nn
import numpy as np
import tonic
import tonic.torch.models
# ==============================================================================
# 1. Biological Constraints & Initializers
# ==============================================================================
# These functions force the network to behave like biological neurons (e.g., only positive weights).
def excitatory(w, upper=None):
return w.clamp(min=0, max=upper)
def inhibitory(w, lower=None):
return w.clamp(min=lower, max=0)
def unsigned(w, lower=None, upper=None):
return w if lower is None and upper is None else w.clamp(min=lower, max=upper)
def graded(x):
return x.clamp(min=0, max=1)
def excitatory_uniform(shape=(1,), lower=0., upper=1.):
assert lower >= 0
return nn.init.uniform_(nn.Parameter(torch.empty(shape)), a=lower, b=upper)
def inhibitory_uniform(shape=(1,), lower=-1., upper=0.):
assert upper <= 0
return nn.init.uniform_(nn.Parameter(torch.empty(shape)), a=lower, b=upper)
def unsigned_uniform(shape=(1,), lower=-1., upper=1.):
return nn.init.uniform_(nn.Parameter(torch.empty(shape)), a=lower, b=upper)
def excitatory_constant(shape=(1,), value=1.):
return nn.Parameter(torch.full(shape, value))
def inhibitory_constant(shape=(1,), value=-1.):
return nn.Parameter(torch.full(shape, value))
def unsigned_constant(shape=(1,), lower=-1., upper=1., p=0.5):
with torch.no_grad():
weight = torch.empty(shape).uniform_(0, 1)
mask = weight < p
weight[mask] = upper
weight[~mask] = lower
return nn.Parameter(weight)
# ==============================================================================
# 2. Network Classes (NCAP)
# ==============================================================================
class SwimmerModule(nn.Module):
"""C.-elegans-inspired neural circuit architectural prior."""
def __init__(
self,
n_joints: int,
n_turn_joints: int = 1,
oscillator_period: int = 60,
use_weight_sharing: bool = True,
use_weight_constraints: bool = True,
use_weight_constant_init: bool = True,
include_proprioception: bool = True,
include_head_oscillators: bool = True,
include_speed_control: bool = False,
include_turn_control: bool = False,
):
super().__init__()
self.n_joints = n_joints
self.n_turn_joints = n_turn_joints
self.oscillator_period = oscillator_period
self.include_proprioception = include_proprioception
self.include_head_oscillators = include_head_oscillators
self.include_speed_control = include_speed_control
self.include_turn_control = include_turn_control
# Log activity
self.connections_log = []
# Timestep counter (for oscillations).
self.timestep = 0
# Weight sharing switch function.
self.ws = lambda nonshared, shared: shared if use_weight_sharing else nonshared
# Weight constraint and init functions.
if use_weight_constraints:
self.exc = excitatory
self.inh = inhibitory
if use_weight_constant_init:
exc_param = excitatory_constant
inh_param = inhibitory_constant
else:
exc_param = excitatory_uniform
inh_param = inhibitory_uniform
else:
self.exc = unsigned
self.inh = unsigned
if use_weight_constant_init:
exc_param = inh_param = unsigned_constant
else:
exc_param = inh_param = unsigned_uniform
# Learnable parameters.
self.params = nn.ParameterDict()
if use_weight_sharing:
if self.include_proprioception:
self.params['bneuron_prop'] = exc_param()
if self.include_speed_control:
self.params['bneuron_speed'] = inh_param()
if self.include_turn_control:
self.params['bneuron_turn'] = exc_param()
if self.include_head_oscillators:
self.params['bneuron_osc'] = exc_param()
self.params['muscle_ipsi'] = exc_param()
self.params['muscle_contra'] = inh_param()
else:
for i in range(self.n_joints):
if self.include_proprioception and i > 0:
self.params[f'bneuron_d_prop_{i}'] = exc_param()
self.params[f'bneuron_v_prop_{i}'] = exc_param()
if self.include_speed_control:
self.params[f'bneuron_d_speed_{i}'] = inh_param()
self.params[f'bneuron_v_speed_{i}'] = inh_param()
if self.include_turn_control and i < self.n_turn_joints:
self.params[f'bneuron_d_turn_{i}'] = exc_param()
self.params[f'bneuron_v_turn_{i}'] = exc_param()
if self.include_head_oscillators and i == 0:
self.params[f'bneuron_d_osc_{i}'] = exc_param()
self.params[f'bneuron_v_osc_{i}'] = exc_param()
self.params[f'muscle_d_d_{i}'] = exc_param()
self.params[f'muscle_d_v_{i}'] = inh_param()
self.params[f'muscle_v_v_{i}'] = exc_param()
self.params[f'muscle_v_d_{i}'] = inh_param()
def reset(self):
self.timestep = 0
def log_activity(self, activity_type, neuron):
"""Logs an active connection between neurons."""
self.connections_log.append((self.timestep, activity_type, neuron))
def forward(
self,
joint_pos,
right_control=None,
left_control=None,
speed_control=None,
timesteps=None,
log_activity=True,
log_file='log.txt'
):
"""Forward pass.
Args:
joint_pos (torch.Tensor): Joint positions in [-1, 1], shape (..., n_joints).
right_control (torch.Tensor): Right turn control in [0, 1], shape (..., 1).
left_control (torch.Tensor): Left turn control in [0, 1], shape (..., 1).
speed_control (torch.Tensor): Speed control in [0, 1], 0 stopped, 1 fastest, shape (..., 1).
timesteps (torch.Tensor): Timesteps in [0, max_env_steps], shape (..., 1).
Returns:
(torch.Tensor): Joint torques in [-1, 1], shape (..., n_joints).
"""
exc = self.exc
inh = self.inh
ws = self.ws
# Separate into dorsal and ventral sensor values in [0, 1], shape (..., n_joints).
joint_pos_d = joint_pos.clamp(min=0, max=1)
joint_pos_v = joint_pos.clamp(min=-1, max=0).neg()
# Convert speed signal from acceleration into brake.
if self.include_speed_control:
assert speed_control is not None
speed_control = 1 - speed_control.clamp(min=0, max=1)
joint_torques = [] # [shape (..., 1)]
for i in range(self.n_joints):
bneuron_d = bneuron_v = torch.zeros_like(joint_pos[..., 0, None]) # shape (..., 1)
# B-neurons recieve proprioceptive input from previous joint to propagate waves down the body.
if self.include_proprioception and i > 0:
bneuron_d = bneuron_d + joint_pos_d[
..., i - 1, None] * exc(self.params[ws(f'bneuron_d_prop_{i}', 'bneuron_prop')])
bneuron_v = bneuron_v + joint_pos_v[
..., i - 1, None] * exc(self.params[ws(f'bneuron_v_prop_{i}', 'bneuron_prop')])
self.log_activity('exc', f'bneuron_d_prop_{i}')
self.log_activity('exc', f'bneuron_v_prop_{i}')
# Speed control unit modulates all B-neurons.
if self.include_speed_control:
bneuron_d = bneuron_d + speed_control * inh(
self.params[ws(f'bneuron_d_speed_{i}', 'bneuron_speed')]
)
bneuron_v = bneuron_v + speed_control * inh(
self.params[ws(f'bneuron_v_speed_{i}', 'bneuron_speed')]
)
self.log_activity('inh', f'bneuron_d_speed_{i}')
self.log_activity('inh', f'bneuron_v_speed_{i}')
# Turn control units modulate head B-neurons.
if self.include_turn_control and i < self.n_turn_joints:
assert right_control is not None
assert left_control is not None
turn_control_d = right_control.clamp(min=0, max=1) # shape (..., 1)
turn_control_v = left_control.clamp(min=0, max=1)
bneuron_d = bneuron_d + turn_control_d * exc(
self.params[ws(f'bneuron_d_turn_{i}', 'bneuron_turn')]
)
bneuron_v = bneuron_v + turn_control_v * exc(
self.params[ws(f'bneuron_v_turn_{i}', 'bneuron_turn')]
)
self.log_activity('exc', f'bneuron_d_turn_{i}')
self.log_activity('exc', f'bneuron_v_turn_{i}')
# Oscillator units modulate first B-neurons.
if self.include_head_oscillators and i == 0:
if timesteps is not None:
phase = timesteps.round().remainder(self.oscillator_period)
mask = phase < self.oscillator_period // 2
oscillator_d = torch.zeros_like(timesteps) # shape (..., 1)
oscillator_v = torch.zeros_like(timesteps) # shape (..., 1)
oscillator_d[mask] = 1.
oscillator_v[~mask] = 1.
else:
phase = self.timestep % self.oscillator_period # in [0, oscillator_period)
if phase < self.oscillator_period // 2:
oscillator_d, oscillator_v = 1.0, 0.0
else:
oscillator_d, oscillator_v = 0.0, 1.0
bneuron_d = bneuron_d + oscillator_d * exc(
self.params[ws(f'bneuron_d_osc_{i}', 'bneuron_osc')]
)
bneuron_v = bneuron_v + oscillator_v * exc(
self.params[ws(f'bneuron_v_osc_{i}', 'bneuron_osc')]
)
self.log_activity('exc', f'bneuron_d_osc_{i}')
self.log_activity('exc', f'bneuron_v_osc_{i}')
# B-neuron activation.
bneuron_d = graded(bneuron_d)
bneuron_v = graded(bneuron_v)
# Muscles receive excitatory ipsilateral and inhibitory contralateral input.
muscle_d = graded(
bneuron_d * exc(self.params[ws(f'muscle_d_d_{i}', 'muscle_ipsi')]) +
bneuron_v * inh(self.params[ws(f'muscle_d_v_{i}', 'muscle_contra')])
)
muscle_v = graded(
bneuron_v * exc(self.params[ws(f'muscle_v_v_{i}', 'muscle_ipsi')]) +
bneuron_d * inh(self.params[ws(f'muscle_v_d_{i}', 'muscle_contra')])
)
# Joint torque from antagonistic contraction of dorsal and ventral muscles.
joint_torque = muscle_d - muscle_v
joint_torques.append(joint_torque)
self.timestep += 1
out = torch.cat(joint_torques, -1) # shape (..., n_joints)
return out
class SwimmerActor(nn.Module):
def __init__(
self,
swimmer,
controller=None,
distribution=None,
timestep_transform=(-1, 1, 0, 1000),
):
super().__init__()
self.swimmer = swimmer
self.controller = controller
# Use Tonic's default Gaussian head if none is provided
if distribution is None:
self.distribution = tonic.torch.models.DetachedScaleGaussianPolicyHead()
else:
self.distribution = distribution
self.timestep_transform = timestep_transform
def initialize(
self,
observation_space,
action_space,
observation_normalizer=None,
):
# 1. Get the integer size of the action vector (e.g., 5 joints)
self.action_size = action_space.shape[0]
# 2. Initialize the distribution head (FIXED)
# We pass action_size TWICE:
# - First as 'input_size' (because your SwimmerModule outputs 5 values)
# - Second as 'action_size' (because the environment expects 5 values)
self.distribution.initialize(self.action_size, self.action_size)
def forward(self, observations):
# Extract joint positions and time
joint_pos = observations[..., :self.action_size]
timesteps = observations[..., -1, None]
# Normalize joint positions
joint_limit = 2 * np.pi / (self.action_size + 1)
joint_pos = torch.clamp(joint_pos / joint_limit, min=-1, max=1)
# Transform timestep
if self.timestep_transform:
low_in, high_in, low_out, high_out = self.timestep_transform
timesteps = (timesteps - low_in) / (high_in - low_in) * (high_out - low_out) + low_out
# High-level control
if self.controller:
right, left, speed = self.controller(observations)
else:
right, left, speed = None, None, None
# Low-level action (Biological Mean)
action_mean = self.swimmer(
joint_pos,
timesteps=timesteps,
right_control=right,
left_control=left,
speed_control=speed,
)
# Pass through distribution (returns a probability distribution object)
return self.distribution(action_mean)