-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgaussian_analysis.py
More file actions
64 lines (53 loc) · 2.25 KB
/
gaussian_analysis.py
File metadata and controls
64 lines (53 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import Dict
from os import makedirs
import matplotlib.pyplot as plt
# from ogc import dimensionality_reduction as dr
from ogc.classifiers import mvg
from ogc import utilities
import numpy.typing as npt
import numpy as np
from project import TRAINING_DATA, ROOT_PATH
import logging
from pprint import pprint
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
OUTPUT_PATH = ROOT_PATH + "../images/gaussian_analysis/"
TABLES_OUTPUT_PATH = ROOT_PATH + "../tables/gaussian_analysis/"
makedirs(OUTPUT_PATH, exist_ok=True)
makedirs(TABLES_OUTPUT_PATH, exist_ok=True)
# Try all the possible combinations of MVG, TiedMVG, NaiveMVG and LDA, PCA and LDA+PCA
def mvg_callback(prior, mvg_params: dict, dimred, dataset_type):
from ogc.utilities import Kfold
DTR, LTR = TRAINING_DATA()
model = mvg.MVG(prior_probability=[prior, 1 - prior], **mvg_params)
if dataset_type == "Z-Norm":
from ogc.utilities import ZNormalization as znorm
DTR = znorm(DTR)[0]
if dimred != None:
from ogc import dimensionality_reduction as dr
DTR = dr.PCA(DTR, dimred)[0]
kfold = Kfold(DTR, LTR, model, 5, prior=prior)
return kfold
if __name__ == "__main__":
import time
start = time.time()
# results = gaussian_analysis()
fast_test = False
if not fast_test:
priors = [("0.5", 0.5), ("0.1", 0.1), ("0.9",0.9)]
mvg_params = [("Standard MVG", {}), ("Naive MVG", {"naive": True}), ("Tied MVG", {"tied": True}), ("Tied Naive MVG", {"naive": True, "tied": True})]
dataset_types = [("RAW", None), ("Z-Norm", "Z-Norm")]
dimred = [("No PCA", None), ("PCA 5", 5)]
else:
priors = [("0.5", 0.5)]
dataset_types = [("RAW", None)]
mvg_params = [("Standard MVG", {})]
dimred = [("No PCA", None)]
use_csv = True
if use_csv:
table = utilities.load_from_csv(TABLES_OUTPUT_PATH + "mvg_results.csv")
else:
_, table = utilities.grid_search(mvg_callback, priors, mvg_params, dimred, dataset_types)
np.savetxt(TABLES_OUTPUT_PATH + "mvg_results.csv", table, delimiter=";", fmt="%s", header=";".join(["Prior", "MVG", "PCA", "Dataset", "MinDCF"]))
print(f"Time elapsed: {time.time() - start} seconds")