-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_setup.py
More file actions
81 lines (65 loc) · 2.99 KB
/
data_setup.py
File metadata and controls
81 lines (65 loc) · 2.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import zipfile
from pathlib import Path
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
def unzip_file_dir(source: str, destination: str):
'''
This function takes a zipped file path and a destinatination path that the file will be unzipped to.
parameters
-----------
source: [str] -> The directory path the zipped folder is located in
destination: [str] -> The directory the zipped folder should be unzipped in
'''
with zipfile.ZipFile(source, 'r') as zip_ref:
zip_ref.extractall(destination)
prompt= f'[INFO]: your file has been unzipped to {destination} directory'
return prompt
def walkthrough_dir(dir_path: str):
'''
walks through a dir path in order to return the summary of its content
parameter
---------
dir_path: The directory to be walked through
'''
for dirpath, dirname,filenames in os.walk(dir_path):
print (f'there are {len(dirname)} directories and {len(filenames)} images in {dirpath}')
def dataloader_function(train_dir: Path,
test_dir: Path,
val_dir: Path,
transform: transforms.Compose,
batchsize,
num_worker,
test_val_transform= transforms.Compose([
transforms.Resize(size= (224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])):
'''
This function takes in an image path and return the transformed image data in batches
parameters
-----------
train_dir: The train data directory
test_dir: The test data directory
val_dir: The val data directory
transform: Transforms for the data
batchsize: The batchsize the data should be trained on
num_worker: No of workers
'''
train_data= datasets.ImageFolder(train_dir, transform)
test_data= datasets.ImageFolder(test_dir, transform= test_val_transform)
val_data= datasets.ImageFolder(val_dir, test_val_transform)
class_names= train_data.classes
train_dataloader= DataLoader(dataset= train_data,
batch_size= batchsize,
shuffle= True,
num_workers=num_worker)
test_dataloader= DataLoader(dataset= test_data,
batch_size= batchsize,
shuffle= False,
num_workers=num_worker)
val_dataloader= DataLoader(dataset= val_data,
batch_size= batchsize,
shuffle= False,
num_workers=num_worker)
return train_dataloader, test_dataloader, val_dataloader, class_names