-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathmain.py
More file actions
100 lines (91 loc) · 3.66 KB
/
main.py
File metadata and controls
100 lines (91 loc) · 3.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python
#from comet_ml import Experiment
import h5py
import matplotlib.pyplot as plt
import numpy as np
import argparse
import importlib
import random
import os
from FLAlgorithms.servers.serverADMM import ADMM
from FLAlgorithms.servers.serverGrassmann import Grassmann
from FLAlgorithms.servers.centralisedPCA import Centralised
from utils.model_utils import read_data
from FLAlgorithms.trainmodel.models import *
from utils.plot_utils import *
import torch
torch.manual_seed(0)
from utils.options import args_parser
# import comet_ml at the top of your file
#
# Create an experiment with your api key:
def main(experiment, dataset, algorithm, batch_size, learning_rate, ro, num_glob_iters,
local_epochs, numusers,dim, times, gpu):
# Get device status: Check GPU or CPU
device = torch.device("cuda:{}".format(gpu) if torch.cuda.is_available() and gpu != -1 else "cpu")
data = read_data(dataset) , dataset
if(algorithm == "FAPL"):
server = ADMM(experiment, device, data, learning_rate, ro, num_glob_iters, local_epochs, numusers, dim, times)
if(algorithm =="FGPL"):
server = Grassmann(experiment, device, data, learning_rate, ro, num_glob_iters, local_epochs, numusers, dim, times)
if(algorithm == "Centralised"):
learning_rate = 0
ro = 0
num_glob_iters = 0
local_epochs = 0
server = Centralised(experiment, device, data, learning_rate, ro, num_glob_iters, local_epochs, numusers, dim, times)
server.train()
# test performance of concensus principle subspace
global_epoch = 50
server.test("dnn", dataset, global_epoch)
if __name__ == "__main__":
args = args_parser()
print("=" * 80)
print("Summary of training process:")
print("Algorithm: {}".format(args.algorithm))
print("Batch size: {}".format(args.batch_size))
print("Learing rate : {}".format(args.learning_rate))
print("Average Moving : {}".format(args.ro))
print("Subset of users : {}".format(args.subusers))
print("Number of global rounds : {}".format(args.num_global_iters))
print("Number of local rounds : {}".format(args.local_epochs))
print("Dataset : {}".format(args.dataset))
print("=" * 80)
if(args.commet):
# Create an experiment with your api key:
experiment = Experiment(
api_key="VtHmmkcG2ngy1isOwjkm5sHhP",
project_name="multitask-for-test",
workspace="federated-learning-exp",
)
hyper_params = {
"dataset":args.dataset,
"algorithm" : args.algorithm,
"batch_size":args.batch_size,
"learning_rate":args.learning_rate,
"ro":args.ro,
"dim" : args.dim,
"num_glob_iters":args.num_global_iters,
"local_epochs":args.local_epochs,
"numusers": args.subusers,
"times" : args.times,
"gpu": args.gpu,
"cut-off": args.cutoff
}
experiment.log_parameters(hyper_params)
else:
experiment = 0
main(
experiment= experiment,
dataset=args.dataset,
algorithm = args.algorithm,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
ro = args.ro,
num_glob_iters=args.num_global_iters,
local_epochs=args.local_epochs,
numusers = args.subusers,
dim = args.dim,
times = args.times,
gpu=args.gpu
)