Add binary logistic regression
This commit is contained in:
parent
ad57158ad2
commit
8e79c80354
3 changed files with 202 additions and 0 deletions
BIN
logistic_regression/binary.png
Normal file
BIN
logistic_regression/binary.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 317 KiB |
102
logistic_regression/binary.py
Executable file
102
logistic_regression/binary.py
Executable file
|
@ -0,0 +1,102 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.animation import FuncAnimation
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
|
||||
|
||||
# Load the data
|
||||
csv="../data/binary_logistic.csv"
|
||||
data=pd.read_csv(csv)
|
||||
x_1=np.array(data[data.columns[0]])
|
||||
x_2=np.array(data[data.columns[1]])
|
||||
y=np.array(data[data.columns[2]])
|
||||
|
||||
w1=w2=w3=-8
|
||||
|
||||
# Define our model
|
||||
def h(x_1,x_2):
|
||||
global w1,w2,w3
|
||||
model=w1+w2*x_1+w3*x_2
|
||||
return(1/(1+np.exp(-model)))
|
||||
|
||||
|
||||
def dw1():
|
||||
global x_1,x_2,y
|
||||
return(1/len(x_1)*(sum(h(x_1,x_2)-y)))
|
||||
def dw2():
|
||||
global x_1,x_2,y
|
||||
return(1/len(x_1)*sum(x_1*(h(x_1,x_2)-y)))
|
||||
def dw3():
|
||||
global x_1,x_2,y
|
||||
return(1/len(x_1)*sum(x_2*(h(x_1,x_2)-y)))
|
||||
|
||||
|
||||
# Perform the gradient decent
|
||||
#fig, ax = plt.subplots(dpi=300)
|
||||
alpha=0.01 # Proportion of the gradient to take into account
|
||||
accuracy=0.0001 # Accuracy of the decent
|
||||
done=False
|
||||
def decent():
|
||||
global w1,w2,w3,x,y
|
||||
skip_frame=0 # Current frame (plot animation)
|
||||
while True:
|
||||
w1_old=w1
|
||||
w1_new=w1-alpha*dw1()
|
||||
w2_old=w2
|
||||
w2_new=w2-alpha*dw2()
|
||||
w3_old=w3
|
||||
w3_new=w3-alpha*dw3()
|
||||
w1=w1_new
|
||||
w2=w2_new
|
||||
w3=w3_new
|
||||
|
||||
if abs(w1_new-w1_old) <= accuracy and abs(w2_new-w2_old) <= accuracy and abs(w2_new-w2_old) <= accuracy:
|
||||
break
|
||||
skip_frame+=1
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
decent()
|
||||
fig=plt.figure()
|
||||
|
||||
#print(np.round(h(x_1,x_2)))
|
||||
#pred=np.round(h(x_1,x_2))
|
||||
|
||||
# Plot data
|
||||
ax = fig.add_subplot(2,2,1)
|
||||
ax.set_title("Original Data")
|
||||
ax.set_xlabel("X")
|
||||
ax.set_ylabel("Y")
|
||||
scatter=plt.scatter(x_1,x_2,c=y,marker="o")
|
||||
handles, labels = scatter.legend_elements(prop="colors", alpha=0.6)
|
||||
legend = ax.legend(handles, ["Class A","Class B"], loc="upper right", title="Legend")
|
||||
|
||||
|
||||
# Plot model
|
||||
ax = fig.add_subplot(2,2,2,projection='3d')
|
||||
ax.set_title("Model")
|
||||
X,Y= np.meshgrid(np.sort(x_1), np.sort(x_2))
|
||||
ax.set_xlabel("X")
|
||||
ax.set_ylabel("Y")
|
||||
ax.set_zlabel("Probability")
|
||||
surf = ax.plot_wireframe(X,Y, h(X,Y),rstride=10,cstride=10)
|
||||
|
||||
# Plot prediction
|
||||
ax = fig.add_subplot(2,1,2)
|
||||
ax.set_title("Predictions")
|
||||
ax.set_xlabel("X")
|
||||
ax.set_ylabel("Y")
|
||||
scatter=plt.scatter(x_1,x_2,c=np.round(h(x_1,x_2)),marker="o")
|
||||
handles, labels = scatter.legend_elements(prop="colors", alpha=0.6)
|
||||
legend = ax.legend(handles, ["Class A","Class B"], loc="upper right", title="Legend")
|
||||
|
||||
# Save
|
||||
plt.tight_layout()
|
||||
plt.savefig("binary.png",dpi=300)
|
Loading…
Add table
Add a link
Reference in a new issue