-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexplainer.py
More file actions
181 lines (143 loc) · 6 KB
/
explainer.py
File metadata and controls
181 lines (143 loc) · 6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import torch
import torch.nn as nn
import CLIP.clip as clip
from typing import List
import numpy as np
class Hook:
"""Attaches to a module and records its activations and gradients."""
def __init__(self, module: nn.Module):
self.data = None
self.hook = module.register_forward_hook(self.save_grad)
def save_grad(self, module, input, output):
self.data = output
output.requires_grad_(True)
output.retain_grad()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self.hook.remove()
@property
def activation(self) -> torch.Tensor:
return self.data
@property
def gradient(self) -> torch.Tensor:
return self.data.grad
#@title Control context expansion (number of attention layers to consider)
#@title Number of layers for image Transformer
start_layer = -1#@param {type:"number"}
#@title Number of layers for text Transformer
start_layer_text = -1#@param {type:"number"}
def interpret(image, texts, model, device, start_layer=start_layer, start_layer_text=start_layer_text):
batch_size = texts.shape[0]
images = image.repeat(batch_size, 1, 1, 1)
logits_per_image, logits_per_text = model(images, texts)
probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
index = [i for i in range(batch_size)]
one_hot = np.zeros((logits_per_image.shape[0], logits_per_image.shape[1]), dtype=np.float32)
one_hot[torch.arange(logits_per_image.shape[0]), index] = 1
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
one_hot = torch.sum(one_hot.cuda() * logits_per_image)
model.zero_grad()
image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
if start_layer == -1:
# calculate index of last layer
start_layer = len(image_attn_blocks) - 1
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
for i, blk in enumerate(image_attn_blocks):
if i < start_layer:
continue
grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
cam = blk.attn_probs.detach()
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
cam = grad * cam
cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
cam = cam.clamp(min=0).mean(dim=1)
R = R + torch.bmm(cam, R)
image_relevance = R[:, 0, 1:]
return image_relevance
def batch_gradCAM(
model: nn.Module,
inputs: torch.Tensor,
targets_groups: List[torch.Tensor],
layer: nn.Module
) -> List[torch.Tensor]:
batch_size = inputs.size(0)
# Extend the inputs and targets to align with each other
inputs_list = []
targets_list = []
for i in range(batch_size):
num_texts = targets_groups[i].size(0)
repeated_input = inputs[i:i+1].repeat(num_texts, 1, 1, 1)
inputs_list.append(repeated_input)
targets_list.append(targets_groups[i])
final_inputs = torch.cat(inputs_list, dim=0)
final_targets = torch.cat(targets_list, dim=0)
# Zero out any gradients at the input
if final_inputs.grad is not None:
final_inputs.grad.data.zero_()
# Save the requires_grad setting and then set requires_grad to False
requires_grad = {}
for name, param in model.named_parameters():
requires_grad[name] = param.requires_grad
param.requires_grad_(False)
# Forward and backward passes
with Hook(layer if not isinstance(model, nn.DataParallel) else model.module.layer) as hook:
output = model(final_inputs)
loss = (output * final_targets).sum(dim=1).mean()
loss.backward()
grad = hook.gradient.float()
act = hook.activation.float()
alpha = grad.mean(dim=(2, 3), keepdim=True)
gradcam = torch.sum(act * alpha, dim=1, keepdim=True)
gradcam = torch.clamp(gradcam, min=0)
# Restore the requires_grad setting
for name, param in model.named_parameters():
param.requires_grad_(requires_grad[name])
final_gradcam = gradcam.squeeze()
# Splitting the concatenated heatmaps into batches
heatmaps = []
start_idx = 0
for i in range(batch_size):
num_texts = targets_groups[i].size(0)
heatmaps.append(final_gradcam[start_idx: start_idx+num_texts])
start_idx += num_texts
return heatmaps
# Reference: https://arxiv.org/abs/1610.02391
def gradCAM(
model: nn.Module,
input: torch.Tensor,
target: torch.Tensor,
layer: nn.Module
) -> torch.Tensor:
# Zero out any gradients at the input.
if input.grad is not None:
input.grad.data.zero_()
# Disable gradient settings.
requires_grad = {}
for name, param in model.named_parameters():
requires_grad[name] = param.requires_grad
param.requires_grad_(False)
# Attach a hook to the model at the desired layer.
assert isinstance(layer, nn.Module)
with Hook(layer) as hook:
# Do a forward and backward pass.
output = model(input)
output.backward(target)
grad = hook.gradient.float()
act = hook.activation.float()
# Global average pool gradient across spatial dimension
# to obtain importance weights.
alpha = grad.mean(dim=(2, 3), keepdim=True)
# Weighted combination of activation maps over channel
# dimension.
gradcam = torch.sum(act * alpha, dim=1, keepdim=True)
# We only want neurons with positive influence so we
# clamp any negative ones.
gradcam = torch.clamp(gradcam, min=0)
# Restore gradient settings.
for name, param in model.named_parameters():
param.requires_grad_(requires_grad[name])
return gradcam.squeeze()