diff options
| author | manzerbredes <manzerbredes@mailbox.org> | 2021-02-20 16:07:06 +0100 |
|---|---|---|
| committer | manzerbredes <manzerbredes@mailbox.org> | 2021-02-20 16:07:06 +0100 |
| commit | 8e79c803545ff937b529505aca938f62c4825ba3 (patch) | |
| tree | 4ff7361f339f3de2bc0bee2aa6e6feb00d6b9987 | |
| parent | ad57158ad2f99f95483792c4e7d135e6458a26fa (diff) | |
Add binary logistic regression
| -rw-r--r-- | data/binary_logistic.csv | 100 | ||||
| -rw-r--r-- | logistic_regression/binary.png | bin | 0 -> 324322 bytes | |||
| -rwxr-xr-x | logistic_regression/binary.py | 102 |
3 files changed, 202 insertions, 0 deletions
diff --git a/data/binary_logistic.csv b/data/binary_logistic.csv new file mode 100644 index 0000000..5c623a9 --- /dev/null +++ b/data/binary_logistic.csv @@ -0,0 +1,100 @@ +4.8550642421469092,9.6399615658447146,1
+8.6254397593438625,0.058926530182361603,0
+3.8281915383413434,0.72319923434406519,0
+7.1509548369795084,3.899420415982604,1
+6.4779004408046603,8.1981805479153991,1
+1.9222695007920265,1.3314272649586201,0
+8.9782158890739083,0.99343751091510057,1
+6.6356030758470297,8.5428026784211397,1
+7.6723589515313506,5.4163997946307063,1
+4.8660153336822987,2.0426712930202484,0
+6.8614049674943089,9.655309715308249,1
+8.5404213238507509,2.5903742294758558,1
+3.7178806541487575,5.3816621145233512,0
+9.1812971234321594,0.1714746467769146,1
+9.5601400220766664,0.02494648564606905,0
+5.9713694732636213,4.1883018705993891,1
+9.4382026931270957,1.9438124401494861,1
+4.3575510196387768,9.8879833146929741,1
+4.5403319643810391,6.7138733575120568,1
+1.5491016302257776,9.3751321639865637,0
+8.0819737119600177,9.8422068124637008,1
+9.6204650029540062,2.0993275381624699,1
+8.8347709784284234,3.1522041233256459,1
+1.753448536619544,4.2192426044493914,0
+1.0432128375396132,2.6097651151940227,0
+1.1963831819593906,7.4757448583841324,0
+8.9100698800757527,8.2329279417172074,1
+4.5296187419444323,4.9055115412920713,1
+1.8591124145314097,5.6918675592169166,0
+9.5571788400411606,1.6446719132363796,1
+7.1547012263908982,8.0147901969030499,1
+2.3436185251921415,2.9587068501859903,0
+2.922684489749372,8.2175949169322848,1
+6.333096232265234,7.240304984152317,1
+0.92562817502766848,3.4212671080604196,0
+7.8743905667215586,7.7910933550447226,1
+8.3477510465309024,1.8608125066384673,1
+5.5810611322522163,2.4961292929947376,0
+5.0910290936008096,9.8731340887024999,1
+4.5301713701337576,3.7617589998990297,0
+1.4237779891118407,0.22859792690724134,0
+9.0959601290524006,1.0679170489311218,1
+4.0066159190610051,9.792127856053412,1
+8.9765674341470003,3.9351597707718611,1
+0.098052877001464367,7.2145125409588218,0
+0.45238867402076721,2.7746942453086376,0
+3.8630462670698762,3.9132022904232144,0
+7.863850174471736,7.7263833675533533,1
+8.9227064093574882,7.7542167110368609,1
+7.4643678776919842,9.9451762065291405,1
+1.3419292913749814,2.3428780445829034,0
+5.9409695956856012,4.6206316258758307,1
+0.90407765936106443,9.4209287827834487,0
+7.750530056655407,9.0571718849241734,1
+9.5179252931848168,1.3011859031394124,0
+7.7437867131084204,1.1544216889888048,0
+7.691923058591783,8.2982278196141124,1
+7.0922730304300785,2.3574569076299667,0
+6.9444390805438161,6.4847038919106126,1
+0.045024724677205086,3.346005929633975,0
+1.5459691314026713,7.5677000870928168,1
+5.2723831683397293,9.1496153734624386,1
+0.86040707770735025,8.9881881373003125,0
+7.2534389328211546,1.762510621920228,1
+7.5173089792951941,2.2489292873069644,0
+9.0816271863877773,1.4373503997921944,1
+0.45567818451672792,4.7222974756732583,0
+6.9497054163366556,1.411293363198638,0
+9.2821425152942538,8.5293305432423949,1
+7.180812694132328,3.6107634194195271,1
+1.1324883857741952,3.2649118127301335,0
+7.7465284522622824,3.6430192459374666,1
+7.0653604483231902,1.1213281331583858,0
+6.5058174915611744,0.86310222744941711,0
+5.7005291106179357,7.0835442328825593,1
+6.6604666877537966,2.2539557795971632,0
+1.0914720175787807,7.0843769749626517,0
+4.9030876159667969,6.0254777781665325,0
+3.4431093418970704,7.0663468586280942,0
+8.1829780619591475,0.97498656250536442,1
+9.00037647690624,9.5493278605863452,1
+9.6831041388213634,9.5070497319102287,1
+2.991911475546658,5.2992104599252343,0
+2.2381834778934717,4.5348437037318945,0
+0.66547832917422056,9.782636440359056,1
+6.7793187126517296,2.0281807519495487,0
+9.9478409299626946,1.0264578135684133,1
+3.2148492243140936,0.48505899496376514,0
+8.8516463106498122,1.0279159573838115,1
+0.20005786791443825,4.834059551358223,0
+5.1854695053771138,0.73263081256300211,0
+7.2600881475955248,3.9741338323801756,1
+9.1505161253735423,2.5623337319120765,1
+6.4608960598707199,7.0762926898896694,1
+4.7785724932327867,8.2828713255003095,1
+0.022279573604464531,2.6584278885275126,0
+7.6306369295343757,7.4053513957187533,1
+3.6849974654614925,5.0499651208519936,0
+7.4842595355585217,6.0593958059325814,1
+2.0307079795747995,3.9372665341943502,0
diff --git a/logistic_regression/binary.png b/logistic_regression/binary.png Binary files differnew file mode 100644 index 0000000..b64b647 --- /dev/null +++ b/logistic_regression/binary.png diff --git a/logistic_regression/binary.py b/logistic_regression/binary.py new file mode 100755 index 0000000..a8a11e2 --- /dev/null +++ b/logistic_regression/binary.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.animation import FuncAnimation +from mpl_toolkits.mplot3d import Axes3D + + +# Load the data +csv="../data/binary_logistic.csv" +data=pd.read_csv(csv) +x_1=np.array(data[data.columns[0]]) +x_2=np.array(data[data.columns[1]]) +y=np.array(data[data.columns[2]]) + +w1=w2=w3=-8 + +# Define our model +def h(x_1,x_2): + global w1,w2,w3 + model=w1+w2*x_1+w3*x_2 + return(1/(1+np.exp(-model))) + + +def dw1(): + global x_1,x_2,y + return(1/len(x_1)*(sum(h(x_1,x_2)-y))) +def dw2(): + global x_1,x_2,y + return(1/len(x_1)*sum(x_1*(h(x_1,x_2)-y))) +def dw3(): + global x_1,x_2,y + return(1/len(x_1)*sum(x_2*(h(x_1,x_2)-y))) + + +# Perform the gradient decent +#fig, ax = plt.subplots(dpi=300) +alpha=0.01 # Proportion of the gradient to take into account +accuracy=0.0001 # Accuracy of the decent +done=False +def decent(): + global w1,w2,w3,x,y + skip_frame=0 # Current frame (plot animation) + while True: + w1_old=w1 + w1_new=w1-alpha*dw1() + w2_old=w2 + w2_new=w2-alpha*dw2() + w3_old=w3 + w3_new=w3-alpha*dw3() + w1=w1_new + w2=w2_new + w3=w3_new + + if abs(w1_new-w1_old) <= accuracy and abs(w2_new-w2_old) <= accuracy and abs(w2_new-w2_old) <= accuracy: + break + skip_frame+=1 + + + + + + + +decent() +fig=plt.figure() + +#print(np.round(h(x_1,x_2))) +#pred=np.round(h(x_1,x_2)) + +# Plot data +ax = fig.add_subplot(2,2,1) +ax.set_title("Original Data") +ax.set_xlabel("X") +ax.set_ylabel("Y") +scatter=plt.scatter(x_1,x_2,c=y,marker="o") +handles, labels = scatter.legend_elements(prop="colors", alpha=0.6) +legend = ax.legend(handles, ["Class A","Class B"], loc="upper right", title="Legend") + + +# Plot model +ax = fig.add_subplot(2,2,2,projection='3d') +ax.set_title("Model") +X,Y= np.meshgrid(np.sort(x_1), np.sort(x_2)) +ax.set_xlabel("X") +ax.set_ylabel("Y") +ax.set_zlabel("Probability") +surf = ax.plot_wireframe(X,Y, h(X,Y),rstride=10,cstride=10) + +# Plot prediction +ax = fig.add_subplot(2,1,2) +ax.set_title("Predictions") +ax.set_xlabel("X") +ax.set_ylabel("Y") +scatter=plt.scatter(x_1,x_2,c=np.round(h(x_1,x_2)),marker="o") +handles, labels = scatter.legend_elements(prop="colors", alpha=0.6) +legend = ax.legend(handles, ["Class A","Class B"], loc="upper right", title="Legend") + +# Save +plt.tight_layout() +plt.savefig("binary.png",dpi=300) |
