aboutsummaryrefslogtreecommitdiff
path: root/logistic_regression/binary.py
blob: 1c3f60886f1e00d1ef930343708c8efc3da08c7f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D


# Load the data
csv="../data/binary_logistic.csv"
data=pd.read_csv(csv)
x_1=np.array(data[data.columns[0]])
x_2=np.array(data[data.columns[1]])
y=np.array(data[data.columns[2]])

w1=w2=w3=-8

# Define our model
def h(x_1,x_2):
    global w1,w2,w3
    model=w1+w2*x_1+w3*x_2
    return(1/(1+np.exp(-model)))


def dw1():
    global x_1,x_2,y
    return(1/len(x_1)*(sum(h(x_1,x_2)-y)))
def dw2():
    global x_1,x_2,y
    return(1/len(x_1)*sum(x_1*(h(x_1,x_2)-y)))
def dw3():
    global x_1,x_2,y
    return(1/len(x_1)*sum(x_2*(h(x_1,x_2)-y)))


# Perform the gradient decent
#fig, ax = plt.subplots(dpi=300)
alpha=0.01 # Proportion of the gradient to take into account
accuracy=0.0001 # Accuracy of the decent
done=False
def decent():
    global w1,w2,w3,x,y
    skip_frame=0 # Current frame (plot animation)
    while True: 
        w1_old=w1
        w1_new=w1-alpha*dw1()
        w2_old=w2
        w2_new=w2-alpha*dw2()
        w3_old=w3
        w3_new=w3-alpha*dw3()
        w1=w1_new
        w2=w2_new
        w3=w3_new

        if abs(w1_new-w1_old) <= accuracy and abs(w2_new-w2_old) <= accuracy and abs(w2_new-w2_old) <= accuracy:
            break
        skip_frame+=1







decent()
fig=plt.figure()

#print(np.round(h(x_1,x_2)))
#pred=np.round(h(x_1,x_2))

# Plot data
ax = fig.add_subplot(2,2,1)
ax.set_title("Original Data")
ax.set_xlabel("X")
ax.set_ylabel("Y")
scatter=plt.scatter(x_1,x_2,c=y,marker="o")
handles, labels = scatter.legend_elements(prop="colors", alpha=0.6)
legend = ax.legend(handles, ["Class A","Class B"], loc="upper right", title="Legend")


# Plot model
ax = fig.add_subplot(2,2,2,projection='3d')
ax.set_title("Model")
X,Y= np.meshgrid(np.sort(x_1), np.sort(x_2))
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Probability")
surf = ax.plot_wireframe(X,Y, h(X,Y),rstride=10,cstride=10)

# Plot prediction
ax = fig.add_subplot(2,1,2)
ax.set_title("Predictions")
ax.set_xlabel("X")
ax.set_ylabel("Y")
scatter=plt.scatter(x_1,x_2,c=np.round(h(x_1,x_2)),marker="o")
handles, labels = scatter.legend_elements(prop="colors", alpha=0.6)
legend = ax.legend(handles, ["Class A","Class B"], loc="upper right", title="Legend")

x=np.arange(0,10,0.2)
plt.plot([1,2],[2,2])

# Save
plt.tight_layout()
#plt.savefig("binary.png",dpi=300)
plt.show()