diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..c251744 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +__pycache__/ +.ipynb_checkpoints +build/* +dist/* +model_tools.egg-info/* diff --git a/model_tools/brain_transformation/__init__.py b/model_tools/brain_transformation/__init__.py index 0a15c82..233ff00 100644 --- a/model_tools/brain_transformation/__init__.py +++ b/model_tools/brain_transformation/__init__.py @@ -5,6 +5,7 @@ from model_tools.brain_transformation.temporal import TemporalIgnore from .behavior import BehaviorArbiter, LogitsBehavior, ProbabilitiesMapping from .neural import LayerMappedModel, LayerSelection, LayerScores +from .search import VisualSearchObjArray, VisualSearch class ModelCommitment(BrainModel): @@ -20,9 +21,8 @@ class ModelCommitment(BrainModel): 'IT': LazyLoad(MajajITPublicBenchmark), } - def __init__(self, identifier, - activations_model, layers, behavioral_readout_layer=None, region_benchmarks=None, - visual_degrees=8): + def __init__(self, identifier, activations_model, layers, behavioral_readout_layer=None, region_benchmarks=None, + search_target_model_param=None, search_stimuli_model_param=None, visual_degrees=8): self.layers = layers self.activations_model = activations_model self.region_benchmarks = {**self.standard_region_benchmarks, **(region_benchmarks or {})} @@ -32,8 +32,14 @@ def __init__(self, identifier, behavioral_readout_layer = behavioral_readout_layer or layers[-1] probabilities_behavior = ProbabilitiesMapping(identifier=identifier, activations_model=activations_model, layer=behavioral_readout_layer) + search_obj_model = VisualSearchObjArray(identifier=identifier, target_model_param=search_target_model_param, + stimuli_model_param=search_stimuli_model_param) + search_model = VisualSearch(identifier=identifier, target_model_param=search_target_model_param, + stimuli_model_param=search_stimuli_model_param) self.behavior_model = BehaviorArbiter({BrainModel.Task.label: logits_behavior, - BrainModel.Task.probabilities: probabilities_behavior}) + BrainModel.Task.probabilities: probabilities_behavior, + BrainModel.Task.object_search: search_obj_model, + BrainModel.Task.visual_search: search_model}) self.do_behavior = False self._visual_degrees = visual_degrees diff --git a/model_tools/brain_transformation/search.py b/model_tools/brain_transformation/search.py new file mode 100755 index 0000000..1fe0848 --- /dev/null +++ b/model_tools/brain_transformation/search.py @@ -0,0 +1,286 @@ +import cv2 +import logging +import numpy as np +from tqdm import tqdm + +from brainscore.model_interface import BrainModel +from brainscore.utils import fullname + +class VisualSearchObjArray(BrainModel): + def __init__(self, identifier, target_model_param, stimuli_model_param): + self.current_task = None + self.identifier = identifier + self.target_model = target_model_param['target_model'] + self.stimuli_model = stimuli_model_param['stimuli_model'] + self.target_layer = target_model_param['target_layer'] + self.stimuli_layer = stimuli_model_param['stimuli_layer'] + self.search_image_size = stimuli_model_param['search_image_size'] + self._logger = logging.getLogger(fullname(self)) + + def start_task(self, task: BrainModel.Task, **kwargs): + self.fix = kwargs['fix'] # fixation map + self.max_fix = kwargs['max_fix'] # maximum allowed fixation excluding the very first fixation + self.data_len = kwargs['data_len'] # Number of stimuli + self.current_task = task + + def look_at(self, stimuli_set): + self.gt_array = [] + gt = stimuli_set[stimuli_set['image_label'] == 'mask'] + gt_paths = list(gt.image_paths.values())[int(gt.index.values[0]):int(gt.index.values[-1] + 1)] + + for i in range(6): + imagename_gt = gt_paths[i] + + gt = cv2.imread(imagename_gt, 0) + gt = cv2.resize(gt, (self.search_image_size, self.search_image_size), interpolation=cv2.INTER_AREA) + retval, gt = cv2.threshold(gt, 125, 255, cv2.THRESH_BINARY) + temp_stim = np.uint8(np.zeros((3 * self.search_image_size, 3 * self.search_image_size))) + temp_stim[self.search_image_size:2 * self.search_image_size, + self.search_image_size:2 * self.search_image_size] = np.copy(gt) + gt = np.copy(temp_stim) + gt = gt / 255 + + self.gt_array.append(gt) + + self.gt_total = np.copy(self.gt_array[0]) + for i in range(1, 6): + self.gt_total += self.gt_array[i] + + self.score = np.zeros((self.data_len, self.max_fix + 1)) + self.data = np.zeros((self.data_len, self.max_fix + 2, 2), dtype=int) + S_data = np.zeros((300, 7, 2), dtype=int) + I_data = np.zeros((300, 1), dtype=int) + + data_cnt = 0 + + target = stimuli_set[stimuli_set['image_label'] == 'target'] + target_features = self.target_model(target, layers=[self.target_layer], stimuli_identifier=False) + if target_features.shape[0] == target_features['neuroid_num'].shape[0]: + target_features = target_features.T + + stimuli = stimuli_set[stimuli_set['image_label'] == 'stimuli'] + stimuli_features = self.stimuli_model(stimuli, layers=[self.stimuli_layer], stimuli_identifier=False) + if stimuli_features.shape[0] == stimuli_features['neuroid_num'].shape[0]: + stimuli_features = stimuli_features.T + + import torch + + for i in tqdm(range(self.data_len), desc="visual search stimuli: "): + op_target = self.unflat(target_features[i:i + 1]) + MMconv = torch.nn.Conv2d(op_target.shape[1], 1, kernel_size=(op_target.shape[2], op_target.shape[3]), + stride=1, bias=False) + MMconv.weight = torch.nn.Parameter(torch.Tensor(op_target)) + + gt_idx = target_features.tar_obj_pos.values[i] + gt = self.gt_array[gt_idx] + + op_stimuli = self.unflat(stimuli_features[i:i + 1]) + out = MMconv(torch.Tensor(op_stimuli)).detach().numpy() + out = out.reshape(out.shape[2:]) + + out = out - np.min(out) + out = out / np.max(out) + out *= 255 + out = np.uint8(out) + out = cv2.resize(out, (self.search_image_size, self.search_image_size), interpolation=cv2.INTER_AREA) + out = cv2.GaussianBlur(out, (7, 7), 3) + + temp_stim = np.uint8(np.zeros((3 * self.search_image_size, 3 * self.search_image_size))) + temp_stim[self.search_image_size:2 * self.search_image_size, + self.search_image_size:2 * self.search_image_size] = np.copy(out) + attn = np.copy(temp_stim * self.gt_total) + + saccade = [] + (x, y) = int(attn.shape[0] / 2), int(attn.shape[1] / 2) + saccade.append((x, y)) + + for k in range(self.max_fix): + (x, y) = np.unravel_index(np.argmax(attn), attn.shape) + + fxn_x, fxn_y = x, y + + fxn_x, fxn_y = max(fxn_x, self.search_image_size), max(fxn_y, self.search_image_size) + fxn_x, fxn_y = min(fxn_x, (attn.shape[0] - self.search_image_size)), min(fxn_y, ( + attn.shape[1] - self.search_image_size)) + + saccade.append((fxn_x, fxn_y)) + + attn, t = self.remove_attn(attn, saccade[-1][0], saccade[-1][1]) + + if (t == gt_idx): + self.score[data_cnt, k + 1] = 1 + data_cnt += 1 + break + + saccade = np.asarray(saccade) + j = saccade.shape[0] + + for k in range(j): + tar_id = self.get_pos(saccade[k, 0], saccade[k, 1], 0) + saccade[k, 0] = self.fix[tar_id][0] + saccade[k, 1] = self.fix[tar_id][1] + + I_data[i, 0] = min(7, j) + S_data[i, :j, 0] = saccade[:, 0].reshape((-1,))[:7] + S_data[i, :j, 1] = saccade[:, 1].reshape((-1,))[:7] + + self.data[:, :7, :] = S_data + self.data[:, 7, :] = I_data + + return (self.score, self.data) + + def remove_attn(self, img, x, y): + t = -1 + for i in range(5, -1, -1): + fxt_place = self.gt_array[i][x, y] + if (fxt_place > 0): + t = i + break + + if (t > -1): + img[self.gt_array[t] == 1] = 0 + + return img, t + + def get_pos(self, x, y, t): + for i in range(5, -1, -1): + fxt_place = self.gt_array[i][int(x), int(y)] + if (fxt_place > 0): + t = i + 1 + break + return t + + def unflat(self, X): + channel_names = ['channel', 'channel_x', 'channel_y'] + assert all(hasattr(X, coord) for coord in channel_names) + shapes = [len(set(X[channel].values)) for channel in channel_names] + X = np.reshape(X.values, [X.shape[0]] + shapes) + X = np.transpose(X, axes=[0, 3, 1, 2]) + return X + + +class VisualSearch(BrainModel): + def __init__(self, identifier, target_model_param, stimuli_model_param): + self.current_task = None + self.identifier = identifier + self.target_model = target_model_param['target_model'] + self.stimuli_model = stimuli_model_param['stimuli_model'] + self.target_layer = target_model_param['target_layer'] + self.stimuli_layer = stimuli_model_param['stimuli_layer'] + self.search_image_size = stimuli_model_param['search_image_size'] + self._logger = logging.getLogger(fullname(self)) + + def start_task(self, task: BrainModel.Task, **kwargs): + self.max_fix = kwargs['max_fix'] # maximum allowed fixation excluding the very first fixation + self.data_len = kwargs['data_len'] # Number of stimuli + self.current_task = task + self.ior_size = kwargs['ior_size'] + + def look_at(self, stimuli_set): + self.score = np.zeros((self.data_len, self.max_fix + 1)) + self.data = np.zeros((self.data_len, self.max_fix + 2, 2), dtype=int) + S_data = np.zeros((self.data_len, self.max_fix + 1, 2), dtype=int) + I_data = np.zeros((self.data_len, 1), dtype=int) + + data_cnt = 0 + + target = stimuli_set[stimuli_set['image_label'] == 'target'] + target_features = self.target_model(target, layers=[self.target_layer], stimuli_identifier=False) + if target_features.shape[0] == target_features['neuroid_num'].shape[0]: + target_features = target_features.T + + stimuli = stimuli_set[stimuli_set['image_label'] == 'stimuli'] + stimuli_features = self.stimuli_model(stimuli, layers=[self.stimuli_layer], stimuli_identifier=False) + if stimuli_features.shape[0] == stimuli_features['neuroid_num'].shape[0]: + stimuli_features = stimuli_features.T + + gt = stimuli_set[stimuli_set['image_label'] == 'gt'] + gt_paths = list(gt.image_paths.values())[int(gt.index.values[0]):int(gt.index.values[-1] + 1)] + + import torch + + for i in tqdm(range(self.data_len), desc="visual search stimuli: "): + op_target = self.unflat(target_features[i:i + 1]) + MMconv = torch.nn.Conv2d(op_target.shape[1], 1, kernel_size=(op_target.shape[2], op_target.shape[3]), + stride=1, bias=False) + MMconv.weight = torch.nn.Parameter(torch.Tensor(op_target)) + + imagename_gt = gt_paths[i] + gt = cv2.imread(imagename_gt, 0) + gt = cv2.resize(gt, (self.search_image_size, self.search_image_size), interpolation=cv2.INTER_AREA) + retval, gt = cv2.threshold(gt, 125, 255, cv2.THRESH_BINARY) + temp_stim = np.uint8(np.zeros((3 * self.search_image_size, 3 * self.search_image_size))) + temp_stim[self.search_image_size:2 * self.search_image_size, + self.search_image_size:2 * self.search_image_size] = np.copy(gt) + gt = np.copy(temp_stim) + gt = gt / 255 + + op_stimuli = self.unflat(stimuli_features[i:i + 1]) + out = MMconv(torch.Tensor(op_stimuli)).detach().numpy() + out = out.reshape(out.shape[2:]) + + out = out - np.min(out) + out = out / np.max(out) + out *= 255 + out = np.uint8(out) + out = cv2.resize(out, (self.search_image_size, self.search_image_size), interpolation=cv2.INTER_AREA) + out = cv2.GaussianBlur(out, (7, 7), 3) + + temp_stim = np.uint8(np.zeros((3 * self.search_image_size, 3 * self.search_image_size))) + temp_stim[self.search_image_size:2 * self.search_image_size, + self.search_image_size:2 * self.search_image_size] = np.copy(out) + attn = np.copy(temp_stim) + + saccade = [] + (x, y) = int(attn.shape[0] / 2), int(attn.shape[1] / 2) + saccade.append((x, y)) + + for k in range(self.max_fix): + (x, y) = np.unravel_index(np.argmax(attn), attn.shape) + + fxn_x, fxn_y = x, y + + fxn_x, fxn_y = max(fxn_x, self.search_image_size), max(fxn_y, self.search_image_size) + fxn_x, fxn_y = min(fxn_x, (attn.shape[0] - self.search_image_size)), min(fxn_y, ( + attn.shape[1] - self.search_image_size)) + + saccade.append((fxn_x, fxn_y)) + + attn, t = self.remove_attn(attn, saccade[-1][0], saccade[-1][1], gt) + + if t: + self.score[data_cnt, k + 1] = 1 + data_cnt += 1 + break + + saccade = np.asarray(saccade) + j = saccade.shape[0] + + I_data[i, 0] = min(self.max_fix+1, j) + S_data[i, :j, 0] = saccade[:, 0].reshape((-1,))[:self.max_fix+1] + S_data[i, :j, 1] = saccade[:, 1].reshape((-1,))[:self.max_fix+1] + + self.data[:, :self.max_fix+1, :] = S_data + self.data[:, self.max_fix+1, :] = I_data + + return (self.score, self.data) + + def remove_attn(self, img, x, y, gt): + img[(x - int(self.ior_size/2)):(x + int(self.ior_size/2)), (y - int(self.ior_size/2)):(y + int(self.ior_size/2))] = 0 + + fxt_xtop = x-int(self.ior_size/2) + fxt_ytop = y-int(self.ior_size/2) + fxt_place = gt[fxt_xtop:(fxt_xtop+self.ior_size), fxt_ytop:(fxt_ytop+self.ior_size)] + + if (np.sum(fxt_place)>0): + return img, True + else: + return img, False + + def unflat(self, X): + channel_names = ['channel', 'channel_x', 'channel_y'] + assert all(hasattr(X, coord) for coord in channel_names) + shapes = [len(set(X[channel].values)) for channel in channel_names] + X = np.reshape(X.values, [X.shape[0]] + shapes) + X = np.transpose(X, axes=[0, 3, 1, 2]) + return X diff --git a/setup.py b/setup.py index bb521ac..3dca275 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ "scikit-learn", "pandas==0.25.3", "result_caching @ git+https://github.com/mschrimpf/result_caching", + "opencv-contrib-python", ] setup( diff --git a/tests/brain_transformation/test_search.py b/tests/brain_transformation/test_search.py new file mode 100644 index 0000000..1c4c169 --- /dev/null +++ b/tests/brain_transformation/test_search.py @@ -0,0 +1,63 @@ +import functools +import os + +import numpy as np +import pandas as pd +import pytest +from pytest import approx + +from model_tools.brain_transformation import ModelCommitment +from model_tools.activations import PytorchWrapper +import brainscore +import brainio_collection +from brainscore.model_interface import BrainModel + +def pytorch_custom(image_size): + import torch + from torch import nn + from model_tools.activations.pytorch import load_preprocess_images + + class MyModel(nn.Module): + def __init__(self): + super(MyModel, self).__init__() + self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=2, kernel_size=3, bias=False) + self.relu1 = torch.nn.ReLU() + + def forward(self, x): + x = self.conv1(x) + x = self.relu1(x) + return x + + preprocessing = functools.partial(load_preprocess_images, image_size=image_size) + return PytorchWrapper(model=MyModel(), preprocessing=preprocessing) + +class TestObjectSearch: + def test_model(self): + target_model_pool = pytorch_custom(28) + stimuli_model_pool = pytorch_custom(224) + search_target_model_param = {} + search_stimuli_model_param = {} + search_target_model_param['target_model'] = target_model_pool + search_stimuli_model_param['stimuli_model'] = stimuli_model_pool + search_target_model_param['target_layer'] = 'relu1' + search_stimuli_model_param['stimuli_layer'] = 'relu1' + search_target_model_param['target_img_size'] = 28 + search_stimuli_model_param['search_image_size'] = 224 + + model = ModelCommitment(identifier=stimuli_model_pool.identifier, activations_model=None, layers=['relu1'], search_target_model_param=search_target_model_param, search_stimuli_model_param=search_stimuli_model_param) + assemblies = brainscore.get_assembly('klab.Zhang2018search_obj_array') + stimuli = assemblies.stimulus_set + fix = [[640, 512], + [365, 988], + [90, 512], + [365, 36], + [915, 36], + [1190, 512], + [915, 988]] + max_fix = 6 + data_len = 300 + model.start_task(BrainModel.Task.visual_search_obj_arr, fix=fix, max_fix=max_fix, data_len=data_len) + cumm_perf, saccades = model.look_at(stimuli) + + assert saccades.shape == (300, 8, 2) + assert cumm_perf.shape == (7, 2)