Add binary logistic regression
This commit is contained in:
parent
ad57158ad2
commit
8e79c80354
3 changed files with 202 additions and 0 deletions
100
data/binary_logistic.csv
Normal file
100
data/binary_logistic.csv
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
4.8550642421469092,9.6399615658447146,1
|
||||||
|
8.6254397593438625,0.058926530182361603,0
|
||||||
|
3.8281915383413434,0.72319923434406519,0
|
||||||
|
7.1509548369795084,3.899420415982604,1
|
||||||
|
6.4779004408046603,8.1981805479153991,1
|
||||||
|
1.9222695007920265,1.3314272649586201,0
|
||||||
|
8.9782158890739083,0.99343751091510057,1
|
||||||
|
6.6356030758470297,8.5428026784211397,1
|
||||||
|
7.6723589515313506,5.4163997946307063,1
|
||||||
|
4.8660153336822987,2.0426712930202484,0
|
||||||
|
6.8614049674943089,9.655309715308249,1
|
||||||
|
8.5404213238507509,2.5903742294758558,1
|
||||||
|
3.7178806541487575,5.3816621145233512,0
|
||||||
|
9.1812971234321594,0.1714746467769146,1
|
||||||
|
9.5601400220766664,0.02494648564606905,0
|
||||||
|
5.9713694732636213,4.1883018705993891,1
|
||||||
|
9.4382026931270957,1.9438124401494861,1
|
||||||
|
4.3575510196387768,9.8879833146929741,1
|
||||||
|
4.5403319643810391,6.7138733575120568,1
|
||||||
|
1.5491016302257776,9.3751321639865637,0
|
||||||
|
8.0819737119600177,9.8422068124637008,1
|
||||||
|
9.6204650029540062,2.0993275381624699,1
|
||||||
|
8.8347709784284234,3.1522041233256459,1
|
||||||
|
1.753448536619544,4.2192426044493914,0
|
||||||
|
1.0432128375396132,2.6097651151940227,0
|
||||||
|
1.1963831819593906,7.4757448583841324,0
|
||||||
|
8.9100698800757527,8.2329279417172074,1
|
||||||
|
4.5296187419444323,4.9055115412920713,1
|
||||||
|
1.8591124145314097,5.6918675592169166,0
|
||||||
|
9.5571788400411606,1.6446719132363796,1
|
||||||
|
7.1547012263908982,8.0147901969030499,1
|
||||||
|
2.3436185251921415,2.9587068501859903,0
|
||||||
|
2.922684489749372,8.2175949169322848,1
|
||||||
|
6.333096232265234,7.240304984152317,1
|
||||||
|
0.92562817502766848,3.4212671080604196,0
|
||||||
|
7.8743905667215586,7.7910933550447226,1
|
||||||
|
8.3477510465309024,1.8608125066384673,1
|
||||||
|
5.5810611322522163,2.4961292929947376,0
|
||||||
|
5.0910290936008096,9.8731340887024999,1
|
||||||
|
4.5301713701337576,3.7617589998990297,0
|
||||||
|
1.4237779891118407,0.22859792690724134,0
|
||||||
|
9.0959601290524006,1.0679170489311218,1
|
||||||
|
4.0066159190610051,9.792127856053412,1
|
||||||
|
8.9765674341470003,3.9351597707718611,1
|
||||||
|
0.098052877001464367,7.2145125409588218,0
|
||||||
|
0.45238867402076721,2.7746942453086376,0
|
||||||
|
3.8630462670698762,3.9132022904232144,0
|
||||||
|
7.863850174471736,7.7263833675533533,1
|
||||||
|
8.9227064093574882,7.7542167110368609,1
|
||||||
|
7.4643678776919842,9.9451762065291405,1
|
||||||
|
1.3419292913749814,2.3428780445829034,0
|
||||||
|
5.9409695956856012,4.6206316258758307,1
|
||||||
|
0.90407765936106443,9.4209287827834487,0
|
||||||
|
7.750530056655407,9.0571718849241734,1
|
||||||
|
9.5179252931848168,1.3011859031394124,0
|
||||||
|
7.7437867131084204,1.1544216889888048,0
|
||||||
|
7.691923058591783,8.2982278196141124,1
|
||||||
|
7.0922730304300785,2.3574569076299667,0
|
||||||
|
6.9444390805438161,6.4847038919106126,1
|
||||||
|
0.045024724677205086,3.346005929633975,0
|
||||||
|
1.5459691314026713,7.5677000870928168,1
|
||||||
|
5.2723831683397293,9.1496153734624386,1
|
||||||
|
0.86040707770735025,8.9881881373003125,0
|
||||||
|
7.2534389328211546,1.762510621920228,1
|
||||||
|
7.5173089792951941,2.2489292873069644,0
|
||||||
|
9.0816271863877773,1.4373503997921944,1
|
||||||
|
0.45567818451672792,4.7222974756732583,0
|
||||||
|
6.9497054163366556,1.411293363198638,0
|
||||||
|
9.2821425152942538,8.5293305432423949,1
|
||||||
|
7.180812694132328,3.6107634194195271,1
|
||||||
|
1.1324883857741952,3.2649118127301335,0
|
||||||
|
7.7465284522622824,3.6430192459374666,1
|
||||||
|
7.0653604483231902,1.1213281331583858,0
|
||||||
|
6.5058174915611744,0.86310222744941711,0
|
||||||
|
5.7005291106179357,7.0835442328825593,1
|
||||||
|
6.6604666877537966,2.2539557795971632,0
|
||||||
|
1.0914720175787807,7.0843769749626517,0
|
||||||
|
4.9030876159667969,6.0254777781665325,0
|
||||||
|
3.4431093418970704,7.0663468586280942,0
|
||||||
|
8.1829780619591475,0.97498656250536442,1
|
||||||
|
9.00037647690624,9.5493278605863452,1
|
||||||
|
9.6831041388213634,9.5070497319102287,1
|
||||||
|
2.991911475546658,5.2992104599252343,0
|
||||||
|
2.2381834778934717,4.5348437037318945,0
|
||||||
|
0.66547832917422056,9.782636440359056,1
|
||||||
|
6.7793187126517296,2.0281807519495487,0
|
||||||
|
9.9478409299626946,1.0264578135684133,1
|
||||||
|
3.2148492243140936,0.48505899496376514,0
|
||||||
|
8.8516463106498122,1.0279159573838115,1
|
||||||
|
0.20005786791443825,4.834059551358223,0
|
||||||
|
5.1854695053771138,0.73263081256300211,0
|
||||||
|
7.2600881475955248,3.9741338323801756,1
|
||||||
|
9.1505161253735423,2.5623337319120765,1
|
||||||
|
6.4608960598707199,7.0762926898896694,1
|
||||||
|
4.7785724932327867,8.2828713255003095,1
|
||||||
|
0.022279573604464531,2.6584278885275126,0
|
||||||
|
7.6306369295343757,7.4053513957187533,1
|
||||||
|
3.6849974654614925,5.0499651208519936,0
|
||||||
|
7.4842595355585217,6.0593958059325814,1
|
||||||
|
2.0307079795747995,3.9372665341943502,0
|
|
BIN
logistic_regression/binary.png
Normal file
BIN
logistic_regression/binary.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 317 KiB |
102
logistic_regression/binary.py
Executable file
102
logistic_regression/binary.py
Executable file
|
@ -0,0 +1,102 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.animation import FuncAnimation
|
||||||
|
from mpl_toolkits.mplot3d import Axes3D
|
||||||
|
|
||||||
|
|
||||||
|
# Load the data
|
||||||
|
csv="../data/binary_logistic.csv"
|
||||||
|
data=pd.read_csv(csv)
|
||||||
|
x_1=np.array(data[data.columns[0]])
|
||||||
|
x_2=np.array(data[data.columns[1]])
|
||||||
|
y=np.array(data[data.columns[2]])
|
||||||
|
|
||||||
|
w1=w2=w3=-8
|
||||||
|
|
||||||
|
# Define our model
|
||||||
|
def h(x_1,x_2):
|
||||||
|
global w1,w2,w3
|
||||||
|
model=w1+w2*x_1+w3*x_2
|
||||||
|
return(1/(1+np.exp(-model)))
|
||||||
|
|
||||||
|
|
||||||
|
def dw1():
|
||||||
|
global x_1,x_2,y
|
||||||
|
return(1/len(x_1)*(sum(h(x_1,x_2)-y)))
|
||||||
|
def dw2():
|
||||||
|
global x_1,x_2,y
|
||||||
|
return(1/len(x_1)*sum(x_1*(h(x_1,x_2)-y)))
|
||||||
|
def dw3():
|
||||||
|
global x_1,x_2,y
|
||||||
|
return(1/len(x_1)*sum(x_2*(h(x_1,x_2)-y)))
|
||||||
|
|
||||||
|
|
||||||
|
# Perform the gradient decent
|
||||||
|
#fig, ax = plt.subplots(dpi=300)
|
||||||
|
alpha=0.01 # Proportion of the gradient to take into account
|
||||||
|
accuracy=0.0001 # Accuracy of the decent
|
||||||
|
done=False
|
||||||
|
def decent():
|
||||||
|
global w1,w2,w3,x,y
|
||||||
|
skip_frame=0 # Current frame (plot animation)
|
||||||
|
while True:
|
||||||
|
w1_old=w1
|
||||||
|
w1_new=w1-alpha*dw1()
|
||||||
|
w2_old=w2
|
||||||
|
w2_new=w2-alpha*dw2()
|
||||||
|
w3_old=w3
|
||||||
|
w3_new=w3-alpha*dw3()
|
||||||
|
w1=w1_new
|
||||||
|
w2=w2_new
|
||||||
|
w3=w3_new
|
||||||
|
|
||||||
|
if abs(w1_new-w1_old) <= accuracy and abs(w2_new-w2_old) <= accuracy and abs(w2_new-w2_old) <= accuracy:
|
||||||
|
break
|
||||||
|
skip_frame+=1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
decent()
|
||||||
|
fig=plt.figure()
|
||||||
|
|
||||||
|
#print(np.round(h(x_1,x_2)))
|
||||||
|
#pred=np.round(h(x_1,x_2))
|
||||||
|
|
||||||
|
# Plot data
|
||||||
|
ax = fig.add_subplot(2,2,1)
|
||||||
|
ax.set_title("Original Data")
|
||||||
|
ax.set_xlabel("X")
|
||||||
|
ax.set_ylabel("Y")
|
||||||
|
scatter=plt.scatter(x_1,x_2,c=y,marker="o")
|
||||||
|
handles, labels = scatter.legend_elements(prop="colors", alpha=0.6)
|
||||||
|
legend = ax.legend(handles, ["Class A","Class B"], loc="upper right", title="Legend")
|
||||||
|
|
||||||
|
|
||||||
|
# Plot model
|
||||||
|
ax = fig.add_subplot(2,2,2,projection='3d')
|
||||||
|
ax.set_title("Model")
|
||||||
|
X,Y= np.meshgrid(np.sort(x_1), np.sort(x_2))
|
||||||
|
ax.set_xlabel("X")
|
||||||
|
ax.set_ylabel("Y")
|
||||||
|
ax.set_zlabel("Probability")
|
||||||
|
surf = ax.plot_wireframe(X,Y, h(X,Y),rstride=10,cstride=10)
|
||||||
|
|
||||||
|
# Plot prediction
|
||||||
|
ax = fig.add_subplot(2,1,2)
|
||||||
|
ax.set_title("Predictions")
|
||||||
|
ax.set_xlabel("X")
|
||||||
|
ax.set_ylabel("Y")
|
||||||
|
scatter=plt.scatter(x_1,x_2,c=np.round(h(x_1,x_2)),marker="o")
|
||||||
|
handles, labels = scatter.legend_elements(prop="colors", alpha=0.6)
|
||||||
|
legend = ax.legend(handles, ["Class A","Class B"], loc="upper right", title="Legend")
|
||||||
|
|
||||||
|
# Save
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig("binary.png",dpi=300)
|
Loading…
Add table
Reference in a new issue