Skip to content

Commit c13e705

Browse files
Merge pull request #32 from Emerge-Lab/hr_rl
Update main
2 parents ae2a064 + a3134f9 commit c13e705

20 files changed

Lines changed: 13805 additions & 6291 deletions

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ temp/
7373
# Waymo dataset
7474
data/train_no_tl/
7575
data/valid_no_tl/
76+
data/test_no_tl/
7677

7778
# Logging output
7879
logs/
@@ -87,4 +88,4 @@ paper/
8788

8889
# Videos and scene info dicts
8990
videos/
90-
scene_info/
91+
scene_info/

configs/bc_config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
save_model: true # Save model after training
33
model_name: human_policy # Name of saved model
44
save_model_path: ./models/il # Path to save model
5+
save_data_path: ./data_il/train_no_tl # Path to save training data
56

67
# Train
7-
total_samples: 80_000 # Number of obs-act-next_obs-done pairs to generate
8-
n_epochs: 20 # Training epochs
9-
net_arch: [256, 128] # Network architecture
8+
total_samples: 2_000_000 # Number of obs-act-next_obs-done pairs to generate
9+
net_arch: [128, 64] # Network architecture

configs/env_config.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dt: 0.1
1515
sims_per_step: 10
1616
discretize_actions: true
1717
include_head_angle: false # Whether to include the head tilt/angle as part of a vehicle's action
18-
accel_discretization: 15
18+
accel_discretization: 21
1919
accel_lower_bound: -4.0 # decelerate
2020
accel_upper_bound: 4.0 # accelerate
2121
steering_lower_bound: -0.3 # steer right
@@ -88,18 +88,18 @@ normalize_state: true
8888
# Ego feature names + max values in each category
8989
ego_state_feat_min: -30
9090
ego_state_feat_max:
91-
veh_len: 16
92-
veh_width: 4
91+
veh_len: 25
92+
veh_width: 5
9393
speed: 100
94-
target_dist: 300
94+
target_dist: 350
9595
target_azimuth: 3.14
9696
target_heading: 3.14
9797
rel_target_speed_dist: 40
9898
curr_accel: 5 # Vehicle acceleration
9999
curr_steering: 3
100100
curr_head_angle: 0.00001 # Not used at the moment
101101

102-
vis_obs_max: 100 # The maximum value across visible state elements
102+
vis_obs_max: 110 # The maximum value across visible state elements
103103
vis_obs_min: -10 # The minimum value across visible state elements
104104

105105
# # # # Agent settings # # # #

configs/exp_config.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
project: new_env
2-
group: fixed_experts
2+
group: base_S100_A1
33
env_id: Nocturne
44
seed: 42
55
track_wandb: true
6-
wandb_init_videos: ['expert']
6+
wandb_init_videos: []
77
where_am_i: headless_machine # Change to "headless_machine" when you're on a cluster or desktop
88
exp_name: Nocturne
99
verbose: 0
@@ -18,7 +18,7 @@ ma_callback:
1818
save_video: true
1919
model_save_freq: 250 # In iterations (one iter ~ (num_agents x n_steps))
2020
record_n_scenes: 10 # Number of different scenes to render
21-
video_save_freq: 20 # Make a video every k iterations (100 iters ~ 1M steps)
21+
video_save_freq: 50_000 # Make a video every k iterations (100 iters ~ 1M steps)
2222
video_deterministic: true
2323
eval_freq: 100 # Evaluate full RL task in deterministic mode (turn off intermediate goals)
2424

@@ -35,7 +35,8 @@ learn:
3535

3636
# human-regularized RL
3737
reg_weight: 0.0
38-
human_policy_path: models/il/human_policy_D403_S500_02_08_21_30.pt
38+
human_policy_path: models/il/human_policy_D651_S500_02_18_20_05_AV_ONLY.pt
39+
reg_weight_decay_schedule: None
3940

4041
# Model arch
4142
model_config:

configs/model_config.yaml

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,60 @@
11
# Base paths
22
bc_models_dir: models/il
3-
hr_ppo_models_dir_self_play: models/hr_rl/self_play
3+
hr_ppo_models_dir_self_play: models/hr_rl/self_play_0221
44

55
# The human reference policy used for HR-PPO
66
used_human_policy:
7-
- name: human_policy_D403_S500_02_08_21_30
7+
- name: human_policy_D651_S500_02_18_20_05_AV_ONLY
88
agent: BC
99
train_agent: '-'
1010
wandb_run: '-'
1111

1212
# HR-PPO models
1313
best_overall_models:
14-
- name: policy_L0.0_S100_I1750.zip
14+
- name: policy_L0.0_S100_I2500.zip
1515
agent: PPO
1616
reg_weight: 0.0
1717
train_agent: Self-play
1818
wandb_run:
1919

20-
- name: policy_L0.01_S100_I2071.zip
20+
- name: policy_L0.01_S100_I2579.zip
2121
reg_weight: 0.01
2222
agent: HR-PPO
2323
train_agent: Self-play
2424
wandb_run:
2525

26+
- name: policy_L0.02_S100_I2585.zip
27+
reg_weight: 0.02
28+
agent: HR-PPO
29+
train_agent: Self-play
30+
wandb_run:
31+
32+
- name: policy_L0.03_S100_I2611.zip
33+
reg_weight: 0.03
34+
agent: HR-PPO
35+
train_agent: Self-play
36+
wandb_run:
37+
38+
- name: policy_L0.04_S100_I2500.zip
39+
reg_weight: 0.04
40+
agent: HR-PPO
41+
train_agent: Self-play
42+
wandb_run:
43+
2644
- name: policy_L0.05_S100_I2000.zip
2745
reg_weight: 0.05
2846
agent: HR-PPO
2947
train_agent: Self-play
3048
wandb_run:
3149

32-
- name: policy_L0.005_S100_I2059.zip
33-
reg_weight: 0.005
50+
- name: policy_L0.1_S100_I2000.zip
51+
reg_weight: 0.1
3452
agent: HR-PPO
3553
train_agent: Self-play
3654
wandb_run:
3755

38-
- name: policy_L0.025_S100_I2109.zip
39-
reg_weight: 0.025
56+
- name: policy_L0.2_S100_I2000.zip
57+
reg_weight: 0.2
4058
agent: HR-PPO
4159
train_agent: Self-play
42-
wandb_run:
60+
wandb_run:

configs/model_quality.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Base paths
2+
bc_models_dir:
3+
hr_ppo_models_dir_self_play:
4+
5+
# The human reference policy used for HR-PPO
6+
human_policies:
7+
- name:
8+
agent: BC
9+
train_agent: '-'
10+
wandb_run: '-'
11+
12+
# HR-PPO models
13+
hr_ppo_models:
14+
- name:
15+
agent: PPO
16+
reg_weight: 0.0
17+
train_agent: Self-play
18+
wandb_run:
19+
20+

evaluation/gen_res_x_intersection_df.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def gen_and_save_res_df(
5858
eval_dataset = data_sets[dataset]
5959
scene_to_paths_dict = intersection_dicts[dataset] if intersection_dicts is not None else None
6060

61-
if num_controlled_agents >= 50:
62-
eval_episodes = num_scenes_to_select_from
61+
if num_controlled_agents > 1:
62+
eval_episodes = num_scenes_to_select_from
6363
else:
6464
eval_episodes = num_eval_episodes
6565

@@ -154,7 +154,7 @@ def gen_and_save_res_df(
154154
# Generate dataframe
155155
gen_and_save_res_df(
156156
num_scenes_to_select_from=100,
157-
num_eval_episodes=1000,
157+
num_eval_episodes=4000,
158158
env_config=env_config,
159159
intersection_dicts={'Train': train_scene_to_paths_dict},
160160
model_config=models_config,

0 commit comments

Comments
 (0)