Skip to content

Commit 61fb93b

Browse files
committed
Added functionality to show only raw value levels and utilities, but no log-scale versions of these subplots. Added functionality to normalise all subplots to same scale.
1 parent 70ab54f commit 61fb93b

File tree

1 file changed

+43
-27
lines changed

1 file changed

+43
-27
lines changed

ValueGraphBalancing3_3values_2humans.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,12 @@ def tiebreaking_argmax(arr):
141141
return result
142142

143143

144-
def plot_agent_history(subplots, plots_column, agent_name, values_history, utilities_history):
144+
def plot_agent_history(subplots, plots_row, plot_columns, agent_name, values_history, utilities_history):
145145

146146
linewidth = 0.75 # TODO: config
147147

148148

149-
subplot = subplots[plots_column, 0]
149+
subplot = subplots[plots_row, plot_columns[0]] if add_logscale_plots else subplots[plot_columns[0]]
150150
for index, value_name in enumerate(value_names):
151151
subplot.plot(
152152
values_history[:, index],
@@ -159,20 +159,21 @@ def plot_agent_history(subplots, plots_column, agent_name, values_history, utili
159159
subplot.legend()
160160

161161

162-
subplot = subplots[plots_column, 1]
163-
for index, value_name in enumerate(value_names):
164-
subplot.plot(
165-
custom_sigmoid10(values_history[:, index]),
166-
label=value_name,
167-
linewidth=linewidth,
168-
)
162+
if plot_columns[1] != -1:
163+
subplot = subplots[plots_row, plot_columns[1]]
164+
for index, value_name in enumerate(value_names):
165+
subplot.plot(
166+
custom_sigmoid10(values_history[:, index]),
167+
label=value_name,
168+
linewidth=linewidth,
169+
)
169170

170-
subplot.set_title(f"{agent_name} - Sigmoid10 of Value level")
171-
subplot.set(xlabel="step", ylabel="custom_sigmoid10(raw value level)")
172-
subplot.legend()
171+
subplot.set_title(f"{agent_name} - Sigmoid10 of Value level")
172+
subplot.set(xlabel="step", ylabel="custom_sigmoid10(raw value level)")
173+
subplot.legend()
173174

174175

175-
subplot = subplots[plots_column, 2]
176+
subplot = subplots[plots_row, plot_columns[2]] if add_logscale_plots else subplots[plot_columns[2]]
176177
for index, value_name in enumerate(value_names):
177178
subplot.plot(
178179
utilities_history[:, index],
@@ -185,27 +186,37 @@ def plot_agent_history(subplots, plots_column, agent_name, values_history, utili
185186
subplot.legend()
186187

187188

188-
subplot = subplots[plots_column, 3]
189-
for index, value_name in enumerate(value_names):
190-
subplot.plot(
191-
custom_sigmoid10(utilities_history[:, index]),
192-
label=value_name,
193-
linewidth=linewidth,
194-
)
189+
if plot_columns[3] != -1:
190+
subplot = subplots[plots_row, plot_columns[3]]
191+
for index, value_name in enumerate(value_names):
192+
subplot.plot(
193+
custom_sigmoid10(utilities_history[:, index]),
194+
label=value_name,
195+
linewidth=linewidth,
196+
)
195197

196-
subplot.set_title(f"{agent_name} - Sigmoid10 of Utilities")
197-
subplot.set(xlabel="step", ylabel="custom_sigmoid10(utility level)")
198-
subplot.legend()
198+
subplot.set_title(f"{agent_name} - Sigmoid10 of Utilities")
199+
subplot.set(xlabel="step", ylabel="custom_sigmoid10(utility level)")
200+
subplot.legend()
199201

200202

201203
# TODO: std or gini index over values per timestep plot
202204

203-
#/ def plot_agent_history(values_history, utilities_history, utility_function_mode, rebalancing_mode):
205+
#/ def plot_agent_history(subplots, plots_row, plot_columns, agent_name, values_history, utilities_history):
204206

205207

206208
def plot_history(values_history_dict, utilities_history_dict, utility_function_mode, rebalancing_mode):
207209

208-
fig, subplots = plt.subplots(2, 4)
210+
if add_logscale_plots:
211+
fig, subplots = plt.subplots(2, 4) # top row - alice, bottom row - bob
212+
else:
213+
fig, subplots = plt.subplots(1, 4) # 2 left-side plots - alice, 2 right-side plots - bob
214+
215+
216+
if use_same_axis_limits_for_all_subplots:
217+
axis_min = min([x.min() for x in values_history_dict.values()])
218+
axis_max = max([x.max() for x in values_history_dict.values()])
219+
plt.setp(subplots, ylim=(axis_min, axis_max)) # setting the values for all axes.
209220

210221

211222
fig.suptitle(f"Value graph balancing - utility function: {utility_function_mode} - rebalancing: {rebalancing_mode}")
@@ -214,7 +225,8 @@ def plot_history(values_history_dict, utilities_history_dict, utility_function_m
214225
agent_name = agent_names[0]
215226
plot_agent_history(
216227
subplots,
217-
0, # plots_column
228+
0, # plots_row
229+
[0, 1, 2, 3] if add_logscale_plots else [0, -1, 1, -1], # plot_columns
218230
agent_name.upper(),
219231
values_history_dict[agent_name],
220232
utilities_history_dict[agent_name],
@@ -223,7 +235,8 @@ def plot_history(values_history_dict, utilities_history_dict, utility_function_m
223235
agent_name = agent_names[1]
224236
plot_agent_history(
225237
subplots,
226-
1, # plots_column
238+
1 if add_logscale_plots else 0, # plots_row
239+
[0, 1, 2, 3] if add_logscale_plots else [2, -1, 3, -1], # plot_columns
227240
agent_name.upper(),
228241
values_history_dict[agent_name],
229242
utilities_history_dict[agent_name],
@@ -1014,6 +1027,9 @@ def main(utility_function_mode, rebalancing_mode):
10141027
max_tokens = get_max_tokens_for_model(model_name)
10151028
num_trials = 1 # 10 # how many simulations to run (how many resets?)
10161029

1030+
add_logscale_plots = False
1031+
use_same_axis_limits_for_all_subplots = True
1032+
10171033

10181034
# utility function mode and rebalancing mode
10191035

0 commit comments

Comments
 (0)