-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodel_dql.py
More file actions
289 lines (267 loc) · 10.1 KB
/
model_dql.py
File metadata and controls
289 lines (267 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import jax,flax
import jax.numpy as jnp
import optax
from flax import linen as nn
from flax.training import train_state
from typing import Any, Callable, Optional
from functools import partial
from util_dql import (
extract,
linear_beta_schedule,
cosine_beta_schedule,
vp_beta_schedule,
SinusoidalPosEmb,
mish,
)
# -------------------- Diffusion 模型 -------------------- #
class Diffusion(nn.Module):
state_dim: int
action_dim: int
max_action: float
beta_schedule: str = "vp" # 可选 'linear', 'cosine', 'vp'
n_timesteps: int = 100
loss_type: str = "l2" # 目前支持 'l2' 和 'l1'
clip_denoised: bool = True
predict_epsilon: bool = True # True 时 model 输出 noise,否则输出 x0
def setup(self):
# 定义noise模型
self.model = ActionPredictorMLP(self.state_dim, self.action_dim)
# 定义损失函数
if self.loss_type == "l2":
self.loss_fn = lambda pred, target, weights: jnp.mean(
((pred - target) ** 2) * weights
)
elif self.loss_type == "l1":
self.loss_fn = lambda pred, target, weights: jnp.mean(
jnp.abs(pred - target) * weights
)
else:
raise NotImplementedError(f"loss type {self.loss_type} not implemented")
# 选择 beta 安排
if self.beta_schedule == "linear":
betas = linear_beta_schedule(self.n_timesteps)
elif self.beta_schedule == "cosine":
betas = cosine_beta_schedule(self.n_timesteps)
elif self.beta_schedule == "vp":
betas = vp_beta_schedule(self.n_timesteps)
else:
raise ValueError(f"unkown beta_schedule: {self.beta_schedule}")
# 计算 alpha 与相关量
alphas = 1.0 - betas
alphas_cumprod = jnp.cumprod(alphas, axis=0)
alphas_cumprod_prev = jnp.concatenate(
[jnp.ones((1,), dtype=alphas_cumprod.dtype), alphas_cumprod[:-1]], axis=0
)
alphas_cumprod = alphas_cumprod
alphas_cumprod_prev = alphas_cumprod_prev
sqrt_alphas_cumprod = jnp.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = jnp.sqrt(1.0 - alphas_cumprod)
log_one_minus_alphas_cumprod = jnp.log(1.0 - alphas_cumprod)
sqrt_recip_alphas_cumprod = jnp.sqrt(1.0 / alphas_cumprod)
sqrt_recipm1_alphas_cumprod = jnp.sqrt(1.0 / alphas_cumprod - 1.0)
# 计算后验 q(x_{t-1} | x_t, x0) 相关量
posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
posterior_log_variance_clipped = jnp.log(
jnp.clip(posterior_variance, a_min=1e-20)
)
posterior_mean_coef1 = (
betas * jnp.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
posterior_mean_coef2 = (
(1.0 - alphas_cumprod_prev) * jnp.sqrt(alphas) / (1.0 - alphas_cumprod)
)
# 临时保存相关量
self.betas = betas
self.alphas = alphas
self.alphas_cumprod = alphas_cumprod
self.alphas_cumprod_prev = alphas_cumprod_prev
self.sqrt_alphas_cumprod = sqrt_alphas_cumprod
self.sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod
self.log_one_minus_alphas_cumprod = log_one_minus_alphas_cumprod
self.sqrt_recip_alphas_cumprod = sqrt_recip_alphas_cumprod
self.sqrt_recipm1_alphas_cumprod = sqrt_recipm1_alphas_cumprod
self.posterior_variance = posterior_variance
self.posterior_log_variance_clipped = posterior_log_variance_clipped
self.posterior_mean_coef1 = posterior_mean_coef1
self.posterior_mean_coef2 = posterior_mean_coef2
# ---------- 采样(推理)部分 ---------- #
def predict_start_from_noise(
self, x_t: jnp.ndarray, t: jnp.ndarray, noise: jnp.ndarray
) -> jnp.ndarray:
"""
根据 x_t 与预测 noise 还原出 x0。
"""
if self.predict_epsilon:
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
else:
return noise
def q_posterior(self, x_start: jnp.ndarray, x_t: jnp.ndarray, t: jnp.ndarray):
"""
计算后验分布 q(x_{t-1} | x_t, x0) 的均值与对数方差。
"""
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(
self.posterior_log_variance_clipped, t, x_t.shape
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x: jnp.ndarray, t: jnp.ndarray, state: jnp.ndarray):
"""
根据当前 x_t 计算模型预测的均值与对数方差。
"""
noise_pred = self.model(x, t, state)
x_recon = self.predict_start_from_noise(x, t, noise_pred)
if self.clip_denoised:
x_recon = jnp.clip(x_recon, -self.max_action, self.max_action)
model_mean, model_var, model_log_variance = self.q_posterior(x_recon, x, t)
return model_mean, model_var, model_log_variance
def p_sample(
self, x: jnp.ndarray, t: jnp.ndarray, state: jnp.ndarray, rng: jnp.ndarray
) -> jnp.ndarray:
"""
在当前时间步 t 下采样 x_{t-1}。
"""
model_mean, _, model_log_variance = self.p_mean_variance(x, t, state)
rng, subkey = jax.random.split(rng)
noise = jax.random.normal(subkey, shape=x.shape)
nonzero_mask = (t != 0).astype(x.dtype)
nonzero_mask = nonzero_mask.reshape(
(nonzero_mask.shape[0],) + (1,) * (len(x.shape) - 1)
)
return model_mean + nonzero_mask * jnp.exp(0.5 * model_log_variance) * noise
# TODO: 可能需要修改为 lax control_flow, 有点难度
def p_sample_loop(
self,
state: jnp.ndarray,
shape: tuple,
rng: jnp.ndarray,
verbose: bool = False,
return_diffusion: bool = False,
) -> Any:
"""
从纯噪声开始,反向采样直到得到样本。
"""
rng, subkey = jax.random.split(rng)
x = jax.random.normal(subkey, shape)
diffusion = [x] if return_diffusion else None
for i in reversed(range(self.n_timesteps)):
t = jnp.full((shape[0],), i, dtype=jnp.int32)
rng, subkey = jax.random.split(rng)
x = self.p_sample(x, t, state, subkey)
if verbose:
print(f"t = {i}")
if return_diffusion:
diffusion.append(x)
if return_diffusion:
diffusion = jnp.stack(diffusion, axis=1)
return x, diffusion
else:
return x
def sample(self, state: jnp.ndarray, rng: jnp.ndarray, **kwargs) -> jnp.ndarray:
"""
给定条件 state 采样 action。
"""
batch_size = state.shape[0]
shape = (batch_size, self.action_dim)
rng, subkey = jax.random.split(rng)
action = self.p_sample_loop(state, shape, subkey, **kwargs)
return jnp.clip(action, -self.max_action, self.max_action)
# ---------- 训练部分 ---------- #
def q_sample(
self,
x_start: jnp.ndarray,
t: jnp.ndarray,
rng: jnp.ndarray,
noise: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""
根据 x0 采样 x_t。
"""
if noise is None:
noise = jax.random.normal(rng, shape=x_start.shape)
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(
self,
x_start: jnp.ndarray,
state: jnp.ndarray,
t: jnp.ndarray,
rng: jnp.ndarray,
weights: float = 1.0,
) -> jnp.ndarray:
"""
计算训练损失。
"""
rng, subkey = jax.random.split(rng)
noise = jax.random.normal(subkey, shape=x_start.shape)
rng, subkey = jax.random.split(rng)
x_noisy = self.q_sample(x_start, t, subkey, noise)
x_recon = self.model(x_noisy, t, state)
if self.predict_epsilon:
loss = self.loss_fn(x_recon, noise, weights)
else:
loss = self.loss_fn(x_recon, x_start, weights)
return loss
def loss(
self, x: jnp.ndarray, state: jnp.ndarray, rng: jnp.ndarray, weights: float = 1.0
) -> jnp.ndarray:
"""
随机采样时间步 t 并计算损失。
"""
batch_size = x.shape[0]
rng, subkey = jax.random.split(rng)
t = jax.random.randint(
subkey, shape=(batch_size,), minval=0, maxval=self.n_timesteps
)
return self.p_losses(x, state, t, rng, weights)
def __call__(self, state: jnp.ndarray, rng: jnp.ndarray, **kwargs) -> jnp.ndarray:
"""
默认调用时返回采样结果。
"""
return self.sample(state, rng, **kwargs)
# -------------------- MLP 模型 -------------------- #
class ActionPredictorMLP(nn.Module):
"""
mlp map action,timestep,state to action.
f(a,t,s) -> a
t need embedding.
"""
state_dim: int
action_dim: int
t_dim: int = 16
def setup(self):
self.time_emb = nn.Sequential(
[
SinusoidalPosEmb(self.t_dim),
nn.Dense(self.t_dim * 2),
mish,
nn.Dense(self.t_dim),
]
)
self.net = nn.Sequential(
[
nn.Dense(256),
mish,
nn.Dense(256),
mish,
nn.Dense(256),
mish,
nn.Dense(self.action_dim),
]
)
def __call__(
self, x: jnp.ndarray, t: jnp.ndarray, state: jnp.ndarray
) -> jnp.ndarray:
t_emb = self.time_emb(t)
x = jnp.concatenate([x, t_emb, state], axis=-1)
return self.net(x)