-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathutils.py
More file actions
28 lines (21 loc) · 714 Bytes
/
utils.py
File metadata and controls
28 lines (21 loc) · 714 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import torch.nn as nn
class Logger():
def __init__(self, log_path):
self.log_path = log_path
def log(self, str_to_log):
print(str_to_log)
if not self.log_path is None:
with open(self.log_path, 'a') as f:
f.write(str_to_log + '\n')
f.flush()
class SingleChannelModel():
""" reshapes images to rgb before classification
i.e. [N, 1, H, W x 3] -> [N, 3, H, W]
"""
def __init__(self, model):
if isinstance(model, nn.Module):
assert not model.training
self.model = model
def __call__(self, x):
return self.model(x.view(x.shape[0], 3, x.shape[2], x.shape[3] // 3))