-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathiql.py
More file actions
371 lines (313 loc) · 11.9 KB
/
iql.py
File metadata and controls
371 lines (313 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
from collections import namedtuple
from dataclasses import dataclass, asdict
from datetime import datetime
from functools import partial
import os
import warnings
import distrax
import d4rl
import flax.linen as nn
from flax.training.train_state import TrainState
import gym
import jax
import jax.numpy as jnp
import numpy as onp
import optax
import tyro
import wandb
os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"
@dataclass
class Args:
# --- Experiment ---
seed: int = 0
dataset: str = "halfcheetah-medium-v2"
algorithm: str = "iql"
num_updates: int = 1_000_000
eval_interval: int = 2500
eval_workers: int = 8
eval_final_episodes: int = 1000
# --- Logging ---
log: bool = False
wandb_project: str = "unifloral"
wandb_team: str = "flair"
wandb_group: str = "debug"
# --- Generic optimization ---
lr: float = 3e-4
batch_size: int = 256
gamma: float = 0.99
polyak_step_size: float = 0.005
# --- IQL ---
beta: float = 3.0
iql_tau: float = 0.7
exp_adv_clip: float = 100.0
r"""
|\ __
\| /_/
\|
___|_____
\ /
\ /
\___/ Preliminaries
"""
AgentTrainState = namedtuple("AgentTrainState", "actor dual_q dual_q_target value")
Transition = namedtuple("Transition", "obs action reward next_obs done")
class SoftQNetwork(nn.Module):
obs_mean: jax.Array
obs_std: jax.Array
@nn.compact
def __call__(self, obs, action):
obs = (obs - self.obs_mean) / (self.obs_std + 1e-3)
x = jnp.concatenate([obs, action], axis=-1)
for _ in range(2):
x = nn.Dense(256)(x)
x = nn.relu(x)
q = nn.Dense(1)(x)
return q.squeeze(-1)
class DualQNetwork(nn.Module):
obs_mean: jax.Array
obs_std: jax.Array
@nn.compact
def __call__(self, obs, action):
vmap_critic = nn.vmap(
SoftQNetwork,
variable_axes={"params": 0}, # Parameters not shared between critics
split_rngs={"params": True, "dropout": True}, # Different initializations
in_axes=None,
out_axes=-1,
axis_size=2, # Two Q networks
)
q_values = vmap_critic(self.obs_mean, self.obs_std)(obs, action)
return q_values
class StateValueFunction(nn.Module):
obs_mean: jax.Array
obs_std: jax.Array
@nn.compact
def __call__(self, x):
x = (x - self.obs_mean) / (self.obs_std + 1e-3)
for _ in range(2):
x = nn.Dense(256)(x)
x = nn.relu(x)
v = nn.Dense(1)(x)
return v.squeeze(-1)
class TanhGaussianActor(nn.Module):
num_actions: int
obs_mean: jax.Array
obs_std: jax.Array
log_std_max: float = 2.0
log_std_min: float = -20.0
@nn.compact
def __call__(self, x, eval=False):
x = (x - self.obs_mean) / (self.obs_std + 1e-3)
for _ in range(2):
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(self.num_actions)(x)
x = nn.tanh(x)
if eval:
return distrax.Deterministic(x)
logstd = self.param(
"logstd",
init_fn=lambda key: jnp.zeros(self.num_actions, dtype=jnp.float32),
)
std = jnp.exp(jnp.clip(logstd, self.log_std_min, self.log_std_max))
return distrax.Normal(x, std)
def create_train_state(args, rng, network, dummy_input):
lr_schedule = optax.cosine_decay_schedule(args.lr, args.num_updates)
return TrainState.create(
apply_fn=network.apply,
params=network.init(rng, *dummy_input),
tx=optax.adam(lr_schedule, eps=1e-5),
)
def eval_agent(args, rng, env, agent_state):
# --- Reset environment ---
step = 0
returned = onp.zeros(args.eval_workers).astype(bool)
cum_reward = onp.zeros(args.eval_workers)
rng, rng_reset = jax.random.split(rng)
rng_reset = jax.random.split(rng_reset, args.eval_workers)
obs = env.reset()
# --- Rollout agent ---
@jax.jit
@jax.vmap
def _policy_step(rng, obs):
pi = agent_state.actor.apply_fn(agent_state.actor.params, obs, eval=True)
action = pi.sample(seed=rng)
return jnp.nan_to_num(action)
max_episode_steps = env.env_fns[0]().spec.max_episode_steps
while step < max_episode_steps and not returned.all():
# --- Take step in environment ---
step += 1
rng, rng_step = jax.random.split(rng)
rng_step = jax.random.split(rng_step, args.eval_workers)
action = _policy_step(rng_step, jnp.array(obs))
obs, reward, done, info = env.step(onp.array(action))
# --- Track cumulative reward ---
cum_reward += reward * ~returned
returned |= done
if step >= max_episode_steps and not returned.all():
warnings.warn("Maximum steps reached before all episodes terminated")
return cum_reward
r"""
__/)
.-(__(=:
|\ | \)
\ ||
\||
\|
___|_____
\ /
\ /
\___/ Agent
"""
def make_train_step(args, actor_apply_fn, q_apply_fn, value_apply_fn, dataset):
"""Make JIT-compatible agent train step."""
def _train_step(runner_state, _):
rng, agent_state = runner_state
# --- Sample batch ---
rng, rng_batch = jax.random.split(rng)
batch_indices = jax.random.randint(
rng_batch, (args.batch_size,), 0, len(dataset.obs)
)
batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset)
# --- Update Q target network ---
updated_q_target_params = optax.incremental_update(
agent_state.dual_q.params,
agent_state.dual_q_target.params,
args.polyak_step_size,
)
updated_q_target = agent_state.dual_q_target.replace(
step=agent_state.dual_q_target.step + 1, params=updated_q_target_params
)
agent_state = agent_state._replace(dual_q_target=updated_q_target)
# --- Compute targets ---
v_target = q_apply_fn(agent_state.dual_q_target.params, batch.obs, batch.action)
v_target = v_target.min(-1)
next_v_target = value_apply_fn(agent_state.value.params, batch.next_obs)
q_targets = batch.reward + args.gamma * (1 - batch.done) * next_v_target
# --- Update Q and value functions ---
def _q_loss_fn(params):
# Compute loss for both critics
q_pred = q_apply_fn(params, batch.obs, batch.action)
q_loss = jnp.square(q_pred - jnp.expand_dims(q_targets, axis=-1)).mean()
return q_loss
@partial(jax.value_and_grad, has_aux=True)
def _value_loss_fn(params):
adv = v_target - value_apply_fn(params, batch.obs)
# Asymmetric L2 loss
value_loss = jnp.abs(args.iql_tau - (adv < 0.0).astype(float)) * (adv**2)
return jnp.mean(value_loss), adv
q_loss, q_grad = jax.value_and_grad(_q_loss_fn)(agent_state.dual_q.params)
(value_loss, adv), value_grad = _value_loss_fn(agent_state.value.params)
agent_state = agent_state._replace(
dual_q=agent_state.dual_q.apply_gradients(grads=q_grad),
value=agent_state.value.apply_gradients(grads=value_grad),
)
# --- Update actor ---
exp_adv = jnp.exp(adv * args.beta).clip(max=args.exp_adv_clip)
@jax.value_and_grad
def _actor_loss_function(params):
def _compute_loss(transition, exp_adv):
pi = actor_apply_fn(params, transition.obs)
bc_loss = -pi.log_prob(transition.action)
return exp_adv * bc_loss.sum()
actor_loss = jax.vmap(_compute_loss)(batch, exp_adv)
return actor_loss.mean()
actor_loss, actor_grad = _actor_loss_function(agent_state.actor.params)
updated_actor = agent_state.actor.apply_gradients(grads=actor_grad)
agent_state = agent_state._replace(actor=updated_actor)
loss = {
"value_loss": value_loss,
"q_loss": q_loss,
"actor_loss": actor_loss,
}
return (rng, agent_state), loss
return _train_step
if __name__ == "__main__":
# --- Parse arguments ---
args = tyro.cli(Args)
rng = jax.random.PRNGKey(args.seed)
# --- Initialize logger ---
if args.log:
wandb.init(
config=args,
project=args.wandb_project,
entity=args.wandb_team,
group=args.wandb_group,
job_type="train_agent",
)
# --- Initialize environment and dataset ---
env = gym.vector.make(args.dataset, num_envs=args.eval_workers)
dataset = d4rl.qlearning_dataset(gym.make(args.dataset))
dataset = Transition(
obs=jnp.array(dataset["observations"]),
action=jnp.array(dataset["actions"]),
reward=jnp.array(dataset["rewards"]),
next_obs=jnp.array(dataset["next_observations"]),
done=jnp.array(dataset["terminals"]),
)
# --- Initialize agent and value networks ---
num_actions = env.single_action_space.shape[0]
obs_mean = dataset.obs.mean(axis=0)
obs_std = jnp.nan_to_num(dataset.obs.std(axis=0), nan=1.0)
dummy_obs = jnp.zeros(env.single_observation_space.shape)
dummy_action = jnp.zeros(num_actions)
actor_net = TanhGaussianActor(num_actions, obs_mean, obs_std)
q_net = DualQNetwork(obs_mean, obs_std)
value_net = StateValueFunction(obs_mean, obs_std)
# Target networks share seeds to match initialization
rng, rng_actor, rng_q, rng_value = jax.random.split(rng, 4)
agent_state = AgentTrainState(
actor=create_train_state(args, rng_actor, actor_net, [dummy_obs]),
dual_q=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
dual_q_target=create_train_state(args, rng_q, q_net, [dummy_obs, dummy_action]),
value=create_train_state(args, rng_value, value_net, [dummy_obs]),
)
# --- Make train step ---
_agent_train_step_fn = make_train_step(
args, actor_net.apply, q_net.apply, value_net.apply, dataset
)
num_evals = args.num_updates // args.eval_interval
for eval_idx in range(num_evals):
# --- Execute train loop ---
(rng, agent_state), loss = jax.lax.scan(
_agent_train_step_fn,
(rng, agent_state),
None,
args.eval_interval,
)
# --- Evaluate agent ---
rng, rng_eval = jax.random.split(rng)
returns = eval_agent(args, rng_eval, env, agent_state)
scores = d4rl.get_normalized_score(args.dataset, returns) * 100.0
# --- Log metrics ---
step = (eval_idx + 1) * args.eval_interval
print("Step:", step, f"\t Score: {scores.mean():.2f}")
if args.log:
log_dict = {
"return": returns.mean(),
"score": scores.mean(),
"score_std": scores.std(),
"num_updates": step,
**{k: loss[k][-1] for k in loss},
}
wandb.log(log_dict)
# --- Evaluate final agent ---
if args.eval_final_episodes > 0:
final_iters = int(onp.ceil(args.eval_final_episodes / args.eval_workers))
print(f"Evaluating final agent for {final_iters} iterations...")
_rng = jax.random.split(rng, final_iters)
rets = onp.array([eval_agent(args, _rng, env, agent_state) for _rng in _rng])
scores = d4rl.get_normalized_score(args.dataset, rets) * 100.0
agg_fn = lambda x, k: {k: x, f"{k}_mean": x.mean(), f"{k}_std": x.std()}
info = agg_fn(rets, "final_returns") | agg_fn(scores, "final_scores")
# --- Write final returns to file ---
os.makedirs("final_returns", exist_ok=True)
time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename = f"{args.algorithm}_{args.dataset}_{time_str}.npz"
with open(os.path.join("final_returns", filename), "wb") as f:
onp.savez_compressed(f, **info, args=asdict(args))
if args.log:
wandb.save(os.path.join("final_returns", filename))
env.close()
if args.log:
wandb.finish()