-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathpreprocessandtrain.py
More file actions
47 lines (41 loc) · 1.67 KB
/
preprocessandtrain.py
File metadata and controls
47 lines (41 loc) · 1.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import cv2
from model import generate_model
TRAINING_DATA_DIR = "training_data"
IMG_SHAPE = (225, 225)
# here write code to pre process the images as required
# but make sure to have a copy of the original training data
# or that you write your code in such a way that you pre-process on the fly
# pre-processing on the fly is possible because this is small project and
# not a lot of images will be there as well.
def preprocess(img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, IMG_SHAPE)
return img
# NOTE: This method will only work if you have followed the same folder strucutre as I mentioned
# if not update this code below
def get_dataset():
# load images from the training data directory
dataset = []
for label_dir in os.listdir(TRAINING_DATA_DIR):
# iterating each item in the training data directory
path = os.path.join(TRAINING_DATA_DIR, label_dir)
if not os.path.isdir(path):
continue
# iterating threw each file in the sub directory
for image_file in os.listdir(path):
# loading each image
img = cv2.imread(os.path.join(path, image_file))
img = preprocess(img)
# adding them to the dataset
dataset.append([img, label_dir])
return zip(*dataset)
def main():
# this line will get the data pre-process the image
# as per preprocess() method and store the results in X, y
# which will be used in model generation
X, y = get_dataset()
# this line will generate a tensorflow model and save it as rock-paper-scissors-model.h5 file
generate_model(X, y)
if __name__ == "__main__":
main()