-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
55 lines (47 loc) · 1.45 KB
/
main.py
File metadata and controls
55 lines (47 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from env import SchedulingEnv
import gym
from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import A2C, DDPG, DQN, TD3
import argparse
def parse_arguments():
ap = argparse.ArgumentParser()
ap.add_argument("-n", "--numQueues", default=3,
help="number of queues", type=int)
ap.add_argument("-b", "--boost", default=0,
help="", type=int)
ap.add_argument("-a", "--agent", default="A2C", help="Agent", type=str)
ap.add_argument("-f", "--file", default="rl_model", help="File to store model", type=str)
return ap
arg_parser = parse_arguments()
args = vars(arg_parser.parse_args())
env = SchedulingEnv(args["boost"], args["numQueues"], False)
# ac = env.quantum_list
# while True:
# st, r, done, _ = env.step(ac)
# if done:
# obs = env.reset()
# break
# check_env(env, warn=True)
agent_dict = {"A2C": A2C}
print("Agent: ", args["agent"])
model = agent_dict[args["agent"]]('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=100000)
# bound on iterations = 100000
obs = env.reset()
r = 0
for i in range(100000):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
r += reward
if done:
obs = env.reset()
break
# env.render()
print ("reward = ", r)
model.save(args["file"])
# st = env.reset()
# while True:
# st, r, done, _ = env.step(st)
# if done:
# break
# env.render()