-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_maml.py
More file actions
67 lines (42 loc) · 2.18 KB
/
eval_maml.py
File metadata and controls
67 lines (42 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import numpy as np
import warnings
import random
from utils.args import parse_args
from utils.datasets import get_meta_dataset
from torchmeta.utils.data import BatchMetaDataLoader
from method.test import Meta_Test
warnings.filterwarnings(action="ignore", category=UserWarning)
def eval_maml(args, dataloader, model, num_test=1):
acc_list = []
loss_list = []
test = Meta_Test(model)
for _ in range(num_test):
RANDOM_SEED = random.randint(0, 1000)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
test_acc, test_loss = test.maml_test(args, dataloader)
acc_list.append(test_acc)
loss_list.append(test_loss)
avg_acc = np.array(acc_list).mean(axis=0)
avg_loss = np.array(loss_list).mean(axis=0)
acc_stds = np.std(np.array(acc_list), 0)
acc_ci95 = 1.96*acc_stds/np.sqrt(len(acc_list))
loss_stds = np.std(np.array(loss_list), 0)
loss_ci95 = 1.96*loss_stds/np.sqrt(len(loss_list))
print("{} with ci95: Average Accuracy -> {:.2f} +- {:.2f} Average Loss -> {:.2f} +- {:.2f}".format(args.datasets, avg_acc, acc_ci95, avg_loss, loss_ci95))
print("{}: {}-ways {}-shots".format(args.datasets, args.num_ways, args.num_shots))
print("Average Accuracy with ci95 -> {:.2f} +- {:.2f}".format(avg_acc, acc_ci95))
print("Average Loss with ci95 -> {:.2f} +- {:.2f}".format(avg_loss, loss_ci95))
return avg_acc, avg_loss
if __name__ == "__main__":
args = parse_args()
torch.cuda.set_device(args.gpu_id)
SAVE_PATH = "/data2/jjlee_datasets/model_ckpt/single/" #write your own path
checkpoint = torch.load(SAVE_PATH + "MAML_5-{}_{}_{}_{}_{}_best.pt".format(str(args.num_shots), args.datasets,args.update, args.model, args.version))
kwargs = {'shuffle': True, 'pin_memory': True, 'num_workers': 8}
test_sets = get_meta_dataset(args, dataset=args.datasets, only_test=True)
test_loader = BatchMetaDataLoader(test_sets, batch_size=1, **kwargs)
eval_maml(args, test_loader, checkpoint['model'], num_test=1)