-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtrain.py
More file actions
60 lines (51 loc) · 1.92 KB
/
train.py
File metadata and controls
60 lines (51 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import cv2
from data_downloader import Div2k
from glob import glob
from models import edsr, srgan_discriminator
from model_trainer import SrganTrainer
from pathlib import Path
from preprocessing import data_preprocessing
import sys
def main():
if len(sys.argv) == 1:
# Download DIV2K dataset:
PATH = r"./datasets"
div2k = Div2k(PATH, ["train", "valid"])
div2k.download()
# Preprocess the dataset:
hr_images = sorted(glob(r"datasets/div2k/DIV2K_train_HR/*.png"))
lr_images = sorted(glob(r"datasets/div2k/DIV2K_train_LR_bicubic/X4/*.png"))
Path('datasets/preprocessed_data/HR').mkdir(parents=True)
Path('datasets/preprocessed_data/LR').mkdir(parents=True)
write_hr = r"datasets/preprocessed_data/HR/"
write_lr = r"datasets/preprocessed_data/LR/"
degree = 0
for i in range(len(hr_images)):
print(f"image #{i}:", end="")
if i%100 == 0: degree += 10
hr_img = cv2.imread(hr_images[i])
lr_img = cv2.imread(lr_images[i])
d = data_preprocessing(hr_img, lr_img,
write_hr, write_lr,
i, i,
(256, 256), (64, 64),
degree=degree, threshold=0.6)
d.generate_images()
print("done")
data_path = r"datasets/preprocessed_data/"
else:
data_path = sys.argv[1]
# Start Training:
generator = edsr()
discriminator = srgan_discriminator()
gan = SrganTrainer(generator,
discriminator,
data_path=data_path,
load_all_data=False)
weights_path = gan.trainGenerator(epochs=150,
batch_size=32)
gan.train_gan(weights_path,
steps=2e5,
batch_size=16)
if __name__=="__main__":
main()