-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplotSaliency.py
More file actions
50 lines (40 loc) · 1.47 KB
/
plotSaliency.py
File metadata and controls
50 lines (40 loc) · 1.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gi
gi.require_version('Gtk', '2.0')
import numpy as np
import matplotlib.pylab as plt
import matplotlib.colors as colors
class MidpointNormalize(colors.Normalize):
def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
self.midpoint = midpoint
colors.Normalize.__init__(self, vmin, vmax, clip)
def __call__(self, value, clip=None):
x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
return np.ma.masked_array(np.interp(value, x, y))
def plotHeatMapExampleWise(input,
title,
saveLocation,
greyScale=False,
flip=False,
x_axis=None,
y_axis=None,
show=True):
if(flip):
input=np.transpose(input)
fig, ax = plt.subplots()
if(greyScale):
cmap='gray'
else:
cmap='seismic'
# plt.axis('off')
cax = ax.imshow(input, interpolation='nearest', cmap=cmap, norm=MidpointNormalize(midpoint=0))
if(x_axis !=None):
fig.text(0.5, 0.01, x_axis, ha='center' , fontsize=14)
if(y_axis !=None):
fig.text(0.05, 0.5, y_axis, va='center', rotation='vertical', fontsize=14)
fig.tight_layout()
# ax.set_title(title)
ax.set_xlabel("Time")
ax.set_ylabel("ResNeT Features")
plt.savefig(saveLocation+ str(title) + '.png' )
if(show):
plt.show()