-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscript v3.py
More file actions
60 lines (53 loc) · 1.95 KB
/
script v3.py
File metadata and controls
60 lines (53 loc) · 1.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import json
import numpy as np
from scipy.stats import linregress
import pandas as pd
from sklearn.linear_model import LinearRegression
with open('data.json', 'r') as f:
data = json.load(f)
# data normalization
normdata = {}
for refcode, entry in data.items():
normdata[refcode] = {}
for key, val in entry.items():
normdata[refcode][key] = (np.array(val) - np.mean(val)) / np.std(val)
normdata[refcode][key] = [i for i in normdata[refcode][key]]
# save the normalized data
with open('normdata.json', 'w') as f:
json.dump(normdata, f)
# single-variant linear regression fit
descriptor = 'dn2'
predictors = ['dm2', 'cn2_x', 'cn2_y', 'cn2_z', 'm2n2_angle']
for refcode, entry in normdata.items():
print('Now processing', refcode)
y = entry['dn2']
# print('single-variate linear regression results:')
inputs = ['dm2', 'cn2_x', 'cn2_y', 'cn2_z', 'm2n2_angle']
removed = False
slopeList = []
for p in predictors:
x = entry[p]
slope, intercept, r_value, p_value, std_err = linregress(x, y)
cutoff = 0.95
if cutoff >= abs(slope):
print('Removed', p, slope)
inputs.remove(p)
removed = True
if abs(slope) > cutoff - 0.05:
print("*******************************************")
print(inputs)
# check for orthonogality
df = pd.DataFrame(normdata[refcode], columns=inputs)
corr = df.corr()
print('correlation matrix:')
print(corr)
if removed and len(inputs) > 0:
# multi-variant linear regression
# note that i didn't make any selection of the predictors
X = [entry[p] for p in inputs]
X = np.transpose(X)
reg = LinearRegression()
fit = reg.fit(X, y)
print('correlation coefficients: ', fit.coef_) # print the correlation coefficients
print('R^2: ', reg.score(X, y),) # print R^2
print()