-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathBatch_Gradient_Descent.py
More file actions
79 lines (66 loc) · 1.59 KB
/
Batch_Gradient_Descent.py
File metadata and controls
79 lines (66 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from mpl_toolkits.mplot3d import Axes3D
data = pd.read_excel('data.xlsx',header = None)
print(data.shape)
data = np.array(data)
# Data Preprocesssing (Feature Scaling)
x1 = data[:,[0]]
x2 = data[:,[1]]
y = data[:,2]
x1_mean = np.mean(x1)
x2_mean = np.mean(x2)
y_mean = np.mean(y)
x1_std = np.std(x1)
x2_std = np.std(x2)
y_std = np.std(y)
x1 = (x1 - x1_mean)/x1_std
x2 = (x2 - x2_mean)/x2_std
y = (y - y_mean)/y_std
bias = np.expand_dims(np.ones([len(x1)]),axis = 1)
X = np.append(bias,x1,axis = 1)
X = np.append(X,x2,axis = 1)
# Linear Regression
w0=0
w1=0
w2=0
d_prec = 0.00001
no_of_itr = []
w1_ans = []
w2_ans = []
cost_func = []
a = 0
rep = 0
alpha = 0.01
a_prec = 1
while (a_prec > d_prec) and (rep<=12):
for i in range(len(y)):
r0 = ((w0 + w1*X[i][1] + w2*X[i][2]) - y[i]) * X[i][0]
r1 = ((w0 + w1*X[i][1] + w2*X[i][2]) - y[i]) * X[i][1]
r2 = ((w0 + w1*X[i][1] + w2*X[i][2]) - y[i]) * X[i][2]
w0 = w0 - alpha*r0
w1 = w1 - alpha*r1
w2 = w2 - alpha*r2
no_of_itr.append(rep);
w1_ans.append(w1)
w2_ans.append(w2)
j=0
for i in range(len(y)):
j = j + (1/(2*len(y)) * (w0 + w1*X[i][1] + w2*X[i][2] - y[i]) ** 2)
rep = rep + 1
cost_func.append(j)
a_prec = abs(j-a)
a = j
plt.plot(no_of_itr,cost_func);
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot(w1_ans, w2_ans, cost_func)
ax.set_title('Cost function vs Weight Values')
ax.set_xlabel('W1')
ax.set_ylabel('W2')
ax.set_zlabel('Cost function')
plt.show()
print(w0,w1,w2)