Add binary logistic regression

This commit is contained in:
manzerbredes 2021-02-20 16:07:06 +01:00
parent ad57158ad2
commit 8e79c80354
3 changed files with 202 additions and 0 deletions

100
data/binary_logistic.csv Normal file
View file

@ -0,0 +1,100 @@
4.8550642421469092,9.6399615658447146,1
8.6254397593438625,0.058926530182361603,0
3.8281915383413434,0.72319923434406519,0
7.1509548369795084,3.899420415982604,1
6.4779004408046603,8.1981805479153991,1
1.9222695007920265,1.3314272649586201,0
8.9782158890739083,0.99343751091510057,1
6.6356030758470297,8.5428026784211397,1
7.6723589515313506,5.4163997946307063,1
4.8660153336822987,2.0426712930202484,0
6.8614049674943089,9.655309715308249,1
8.5404213238507509,2.5903742294758558,1
3.7178806541487575,5.3816621145233512,0
9.1812971234321594,0.1714746467769146,1
9.5601400220766664,0.02494648564606905,0
5.9713694732636213,4.1883018705993891,1
9.4382026931270957,1.9438124401494861,1
4.3575510196387768,9.8879833146929741,1
4.5403319643810391,6.7138733575120568,1
1.5491016302257776,9.3751321639865637,0
8.0819737119600177,9.8422068124637008,1
9.6204650029540062,2.0993275381624699,1
8.8347709784284234,3.1522041233256459,1
1.753448536619544,4.2192426044493914,0
1.0432128375396132,2.6097651151940227,0
1.1963831819593906,7.4757448583841324,0
8.9100698800757527,8.2329279417172074,1
4.5296187419444323,4.9055115412920713,1
1.8591124145314097,5.6918675592169166,0
9.5571788400411606,1.6446719132363796,1
7.1547012263908982,8.0147901969030499,1
2.3436185251921415,2.9587068501859903,0
2.922684489749372,8.2175949169322848,1
6.333096232265234,7.240304984152317,1
0.92562817502766848,3.4212671080604196,0
7.8743905667215586,7.7910933550447226,1
8.3477510465309024,1.8608125066384673,1
5.5810611322522163,2.4961292929947376,0
5.0910290936008096,9.8731340887024999,1
4.5301713701337576,3.7617589998990297,0
1.4237779891118407,0.22859792690724134,0
9.0959601290524006,1.0679170489311218,1
4.0066159190610051,9.792127856053412,1
8.9765674341470003,3.9351597707718611,1
0.098052877001464367,7.2145125409588218,0
0.45238867402076721,2.7746942453086376,0
3.8630462670698762,3.9132022904232144,0
7.863850174471736,7.7263833675533533,1
8.9227064093574882,7.7542167110368609,1
7.4643678776919842,9.9451762065291405,1
1.3419292913749814,2.3428780445829034,0
5.9409695956856012,4.6206316258758307,1
0.90407765936106443,9.4209287827834487,0
7.750530056655407,9.0571718849241734,1
9.5179252931848168,1.3011859031394124,0
7.7437867131084204,1.1544216889888048,0
7.691923058591783,8.2982278196141124,1
7.0922730304300785,2.3574569076299667,0
6.9444390805438161,6.4847038919106126,1
0.045024724677205086,3.346005929633975,0
1.5459691314026713,7.5677000870928168,1
5.2723831683397293,9.1496153734624386,1
0.86040707770735025,8.9881881373003125,0
7.2534389328211546,1.762510621920228,1
7.5173089792951941,2.2489292873069644,0
9.0816271863877773,1.4373503997921944,1
0.45567818451672792,4.7222974756732583,0
6.9497054163366556,1.411293363198638,0
9.2821425152942538,8.5293305432423949,1
7.180812694132328,3.6107634194195271,1
1.1324883857741952,3.2649118127301335,0
7.7465284522622824,3.6430192459374666,1
7.0653604483231902,1.1213281331583858,0
6.5058174915611744,0.86310222744941711,0
5.7005291106179357,7.0835442328825593,1
6.6604666877537966,2.2539557795971632,0
1.0914720175787807,7.0843769749626517,0
4.9030876159667969,6.0254777781665325,0
3.4431093418970704,7.0663468586280942,0
8.1829780619591475,0.97498656250536442,1
9.00037647690624,9.5493278605863452,1
9.6831041388213634,9.5070497319102287,1
2.991911475546658,5.2992104599252343,0
2.2381834778934717,4.5348437037318945,0
0.66547832917422056,9.782636440359056,1
6.7793187126517296,2.0281807519495487,0
9.9478409299626946,1.0264578135684133,1
3.2148492243140936,0.48505899496376514,0
8.8516463106498122,1.0279159573838115,1
0.20005786791443825,4.834059551358223,0
5.1854695053771138,0.73263081256300211,0
7.2600881475955248,3.9741338323801756,1
9.1505161253735423,2.5623337319120765,1
6.4608960598707199,7.0762926898896694,1
4.7785724932327867,8.2828713255003095,1
0.022279573604464531,2.6584278885275126,0
7.6306369295343757,7.4053513957187533,1
3.6849974654614925,5.0499651208519936,0
7.4842595355585217,6.0593958059325814,1
2.0307079795747995,3.9372665341943502,0
1 4.8550642421469092 9.6399615658447146 1
2 8.6254397593438625 0.058926530182361603 0
3 3.8281915383413434 0.72319923434406519 0
4 7.1509548369795084 3.899420415982604 1
5 6.4779004408046603 8.1981805479153991 1
6 1.9222695007920265 1.3314272649586201 0
7 8.9782158890739083 0.99343751091510057 1
8 6.6356030758470297 8.5428026784211397 1
9 7.6723589515313506 5.4163997946307063 1
10 4.8660153336822987 2.0426712930202484 0
11 6.8614049674943089 9.655309715308249 1
12 8.5404213238507509 2.5903742294758558 1
13 3.7178806541487575 5.3816621145233512 0
14 9.1812971234321594 0.1714746467769146 1
15 9.5601400220766664 0.02494648564606905 0
16 5.9713694732636213 4.1883018705993891 1
17 9.4382026931270957 1.9438124401494861 1
18 4.3575510196387768 9.8879833146929741 1
19 4.5403319643810391 6.7138733575120568 1
20 1.5491016302257776 9.3751321639865637 0
21 8.0819737119600177 9.8422068124637008 1
22 9.6204650029540062 2.0993275381624699 1
23 8.8347709784284234 3.1522041233256459 1
24 1.753448536619544 4.2192426044493914 0
25 1.0432128375396132 2.6097651151940227 0
26 1.1963831819593906 7.4757448583841324 0
27 8.9100698800757527 8.2329279417172074 1
28 4.5296187419444323 4.9055115412920713 1
29 1.8591124145314097 5.6918675592169166 0
30 9.5571788400411606 1.6446719132363796 1
31 7.1547012263908982 8.0147901969030499 1
32 2.3436185251921415 2.9587068501859903 0
33 2.922684489749372 8.2175949169322848 1
34 6.333096232265234 7.240304984152317 1
35 0.92562817502766848 3.4212671080604196 0
36 7.8743905667215586 7.7910933550447226 1
37 8.3477510465309024 1.8608125066384673 1
38 5.5810611322522163 2.4961292929947376 0
39 5.0910290936008096 9.8731340887024999 1
40 4.5301713701337576 3.7617589998990297 0
41 1.4237779891118407 0.22859792690724134 0
42 9.0959601290524006 1.0679170489311218 1
43 4.0066159190610051 9.792127856053412 1
44 8.9765674341470003 3.9351597707718611 1
45 0.098052877001464367 7.2145125409588218 0
46 0.45238867402076721 2.7746942453086376 0
47 3.8630462670698762 3.9132022904232144 0
48 7.863850174471736 7.7263833675533533 1
49 8.9227064093574882 7.7542167110368609 1
50 7.4643678776919842 9.9451762065291405 1
51 1.3419292913749814 2.3428780445829034 0
52 5.9409695956856012 4.6206316258758307 1
53 0.90407765936106443 9.4209287827834487 0
54 7.750530056655407 9.0571718849241734 1
55 9.5179252931848168 1.3011859031394124 0
56 7.7437867131084204 1.1544216889888048 0
57 7.691923058591783 8.2982278196141124 1
58 7.0922730304300785 2.3574569076299667 0
59 6.9444390805438161 6.4847038919106126 1
60 0.045024724677205086 3.346005929633975 0
61 1.5459691314026713 7.5677000870928168 1
62 5.2723831683397293 9.1496153734624386 1
63 0.86040707770735025 8.9881881373003125 0
64 7.2534389328211546 1.762510621920228 1
65 7.5173089792951941 2.2489292873069644 0
66 9.0816271863877773 1.4373503997921944 1
67 0.45567818451672792 4.7222974756732583 0
68 6.9497054163366556 1.411293363198638 0
69 9.2821425152942538 8.5293305432423949 1
70 7.180812694132328 3.6107634194195271 1
71 1.1324883857741952 3.2649118127301335 0
72 7.7465284522622824 3.6430192459374666 1
73 7.0653604483231902 1.1213281331583858 0
74 6.5058174915611744 0.86310222744941711 0
75 5.7005291106179357 7.0835442328825593 1
76 6.6604666877537966 2.2539557795971632 0
77 1.0914720175787807 7.0843769749626517 0
78 4.9030876159667969 6.0254777781665325 0
79 3.4431093418970704 7.0663468586280942 0
80 8.1829780619591475 0.97498656250536442 1
81 9.00037647690624 9.5493278605863452 1
82 9.6831041388213634 9.5070497319102287 1
83 2.991911475546658 5.2992104599252343 0
84 2.2381834778934717 4.5348437037318945 0
85 0.66547832917422056 9.782636440359056 1
86 6.7793187126517296 2.0281807519495487 0
87 9.9478409299626946 1.0264578135684133 1
88 3.2148492243140936 0.48505899496376514 0
89 8.8516463106498122 1.0279159573838115 1
90 0.20005786791443825 4.834059551358223 0
91 5.1854695053771138 0.73263081256300211 0
92 7.2600881475955248 3.9741338323801756 1
93 9.1505161253735423 2.5623337319120765 1
94 6.4608960598707199 7.0762926898896694 1
95 4.7785724932327867 8.2828713255003095 1
96 0.022279573604464531 2.6584278885275126 0
97 7.6306369295343757 7.4053513957187533 1
98 3.6849974654614925 5.0499651208519936 0
99 7.4842595355585217 6.0593958059325814 1
100 2.0307079795747995 3.9372665341943502 0

Binary file not shown.

After

Width:  |  Height:  |  Size: 317 KiB

102
logistic_regression/binary.py Executable file
View file

@ -0,0 +1,102 @@
#!/usr/bin/env python
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D
# Load the data
csv="../data/binary_logistic.csv"
data=pd.read_csv(csv)
x_1=np.array(data[data.columns[0]])
x_2=np.array(data[data.columns[1]])
y=np.array(data[data.columns[2]])
w1=w2=w3=-8
# Define our model
def h(x_1,x_2):
global w1,w2,w3
model=w1+w2*x_1+w3*x_2
return(1/(1+np.exp(-model)))
def dw1():
global x_1,x_2,y
return(1/len(x_1)*(sum(h(x_1,x_2)-y)))
def dw2():
global x_1,x_2,y
return(1/len(x_1)*sum(x_1*(h(x_1,x_2)-y)))
def dw3():
global x_1,x_2,y
return(1/len(x_1)*sum(x_2*(h(x_1,x_2)-y)))
# Perform the gradient decent
#fig, ax = plt.subplots(dpi=300)
alpha=0.01 # Proportion of the gradient to take into account
accuracy=0.0001 # Accuracy of the decent
done=False
def decent():
global w1,w2,w3,x,y
skip_frame=0 # Current frame (plot animation)
while True:
w1_old=w1
w1_new=w1-alpha*dw1()
w2_old=w2
w2_new=w2-alpha*dw2()
w3_old=w3
w3_new=w3-alpha*dw3()
w1=w1_new
w2=w2_new
w3=w3_new
if abs(w1_new-w1_old) <= accuracy and abs(w2_new-w2_old) <= accuracy and abs(w2_new-w2_old) <= accuracy:
break
skip_frame+=1
decent()
fig=plt.figure()
#print(np.round(h(x_1,x_2)))
#pred=np.round(h(x_1,x_2))
# Plot data
ax = fig.add_subplot(2,2,1)
ax.set_title("Original Data")
ax.set_xlabel("X")
ax.set_ylabel("Y")
scatter=plt.scatter(x_1,x_2,c=y,marker="o")
handles, labels = scatter.legend_elements(prop="colors", alpha=0.6)
legend = ax.legend(handles, ["Class A","Class B"], loc="upper right", title="Legend")
# Plot model
ax = fig.add_subplot(2,2,2,projection='3d')
ax.set_title("Model")
X,Y= np.meshgrid(np.sort(x_1), np.sort(x_2))
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Probability")
surf = ax.plot_wireframe(X,Y, h(X,Y),rstride=10,cstride=10)
# Plot prediction
ax = fig.add_subplot(2,1,2)
ax.set_title("Predictions")
ax.set_xlabel("X")
ax.set_ylabel("Y")
scatter=plt.scatter(x_1,x_2,c=np.round(h(x_1,x_2)),marker="o")
handles, labels = scatter.legend_elements(prop="colors", alpha=0.6)
legend = ax.legend(handles, ["Class A","Class B"], loc="upper right", title="Legend")
# Save
plt.tight_layout()
plt.savefig("binary.png",dpi=300)