@@ -141,12 +141,12 @@ def tiebreaking_argmax(arr):
141141 return result
142142
143143
144- def plot_agent_history (subplots , plots_column , agent_name , values_history , utilities_history ):
144+ def plot_agent_history (subplots , plots_row , plot_columns , agent_name , values_history , utilities_history ):
145145
146146 linewidth = 0.75 # TODO: config
147147
148148
149- subplot = subplots [plots_column , 0 ]
149+ subplot = subplots [plots_row , plot_columns [ 0 ]] if add_logscale_plots else subplots [ plot_columns [ 0 ] ]
150150 for index , value_name in enumerate (value_names ):
151151 subplot .plot (
152152 values_history [:, index ],
@@ -159,20 +159,21 @@ def plot_agent_history(subplots, plots_column, agent_name, values_history, utili
159159 subplot .legend ()
160160
161161
162- subplot = subplots [plots_column , 1 ]
163- for index , value_name in enumerate (value_names ):
164- subplot .plot (
165- custom_sigmoid10 (values_history [:, index ]),
166- label = value_name ,
167- linewidth = linewidth ,
168- )
162+ if plot_columns [1 ] != - 1 :
163+ subplot = subplots [plots_row , plot_columns [1 ]]
164+ for index , value_name in enumerate (value_names ):
165+ subplot .plot (
166+ custom_sigmoid10 (values_history [:, index ]),
167+ label = value_name ,
168+ linewidth = linewidth ,
169+ )
169170
170- subplot .set_title (f"{ agent_name } - Sigmoid10 of Value level" )
171- subplot .set (xlabel = "step" , ylabel = "custom_sigmoid10(raw value level)" )
172- subplot .legend ()
171+ subplot .set_title (f"{ agent_name } - Sigmoid10 of Value level" )
172+ subplot .set (xlabel = "step" , ylabel = "custom_sigmoid10(raw value level)" )
173+ subplot .legend ()
173174
174175
175- subplot = subplots [plots_column , 2 ]
176+ subplot = subplots [plots_row , plot_columns [ 2 ]] if add_logscale_plots else subplots [ plot_columns [ 2 ] ]
176177 for index , value_name in enumerate (value_names ):
177178 subplot .plot (
178179 utilities_history [:, index ],
@@ -185,27 +186,37 @@ def plot_agent_history(subplots, plots_column, agent_name, values_history, utili
185186 subplot .legend ()
186187
187188
188- subplot = subplots [plots_column , 3 ]
189- for index , value_name in enumerate (value_names ):
190- subplot .plot (
191- custom_sigmoid10 (utilities_history [:, index ]),
192- label = value_name ,
193- linewidth = linewidth ,
194- )
189+ if plot_columns [3 ] != - 1 :
190+ subplot = subplots [plots_row , plot_columns [3 ]]
191+ for index , value_name in enumerate (value_names ):
192+ subplot .plot (
193+ custom_sigmoid10 (utilities_history [:, index ]),
194+ label = value_name ,
195+ linewidth = linewidth ,
196+ )
195197
196- subplot .set_title (f"{ agent_name } - Sigmoid10 of Utilities" )
197- subplot .set (xlabel = "step" , ylabel = "custom_sigmoid10(utility level)" )
198- subplot .legend ()
198+ subplot .set_title (f"{ agent_name } - Sigmoid10 of Utilities" )
199+ subplot .set (xlabel = "step" , ylabel = "custom_sigmoid10(utility level)" )
200+ subplot .legend ()
199201
200202
201203 # TODO: std or gini index over values per timestep plot
202204
203- #/ def plot_agent_history(values_history, utilities_history, utility_function_mode, rebalancing_mode ):
205+ #/ def plot_agent_history(subplots, plots_row, plot_columns, agent_name, values_history, utilities_history ):
204206
205207
206208def plot_history (values_history_dict , utilities_history_dict , utility_function_mode , rebalancing_mode ):
207209
208- fig , subplots = plt .subplots (2 , 4 )
210+ if add_logscale_plots :
211+ fig , subplots = plt .subplots (2 , 4 ) # top row - alice, bottom row - bob
212+ else :
213+ fig , subplots = plt .subplots (1 , 4 ) # 2 left-side plots - alice, 2 right-side plots - bob
214+
215+
216+ if use_same_axis_limits_for_all_subplots :
217+ axis_min = min ([x .min () for x in values_history_dict .values ()])
218+ axis_max = max ([x .max () for x in values_history_dict .values ()])
219+ plt .setp (subplots , ylim = (axis_min , axis_max )) # setting the values for all axes.
209220
210221
211222 fig .suptitle (f"Value graph balancing - utility function: { utility_function_mode } - rebalancing: { rebalancing_mode } " )
@@ -214,7 +225,8 @@ def plot_history(values_history_dict, utilities_history_dict, utility_function_m
214225 agent_name = agent_names [0 ]
215226 plot_agent_history (
216227 subplots ,
217- 0 , # plots_column
228+ 0 , # plots_row
229+ [0 , 1 , 2 , 3 ] if add_logscale_plots else [0 , - 1 , 1 , - 1 ], # plot_columns
218230 agent_name .upper (),
219231 values_history_dict [agent_name ],
220232 utilities_history_dict [agent_name ],
@@ -223,7 +235,8 @@ def plot_history(values_history_dict, utilities_history_dict, utility_function_m
223235 agent_name = agent_names [1 ]
224236 plot_agent_history (
225237 subplots ,
226- 1 , # plots_column
238+ 1 if add_logscale_plots else 0 , # plots_row
239+ [0 , 1 , 2 , 3 ] if add_logscale_plots else [2 , - 1 , 3 , - 1 ], # plot_columns
227240 agent_name .upper (),
228241 values_history_dict [agent_name ],
229242 utilities_history_dict [agent_name ],
@@ -1014,6 +1027,9 @@ def main(utility_function_mode, rebalancing_mode):
10141027 max_tokens = get_max_tokens_for_model (model_name )
10151028 num_trials = 1 # 10 # how many simulations to run (how many resets?)
10161029
1030+ add_logscale_plots = False
1031+ use_same_axis_limits_for_all_subplots = True
1032+
10171033
10181034 # utility function mode and rebalancing mode
10191035
0 commit comments