#!/usr/bin/env python import pandas as pd import numpy as np import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation from mpl_toolkits.mplot3d import Axes3D # Load the data csv="../data/binary_logistic.csv" data=pd.read_csv(csv) x_1=np.array(data[data.columns[0]]) x_2=np.array(data[data.columns[1]]) y=np.array(data[data.columns[2]]) w1=w2=w3=-8 # Define our model def h(x_1,x_2): global w1,w2,w3 model=w1+w2*x_1+w3*x_2 return(1/(1+np.exp(-model))) def dw1(): global x_1,x_2,y return(1/len(x_1)*(sum(h(x_1,x_2)-y))) def dw2(): global x_1,x_2,y return(1/len(x_1)*sum(x_1*(h(x_1,x_2)-y))) def dw3(): global x_1,x_2,y return(1/len(x_1)*sum(x_2*(h(x_1,x_2)-y))) # Perform the gradient decent #fig, ax = plt.subplots(dpi=300) alpha=0.01 # Proportion of the gradient to take into account accuracy=0.0001 # Accuracy of the decent done=False def decent(): global w1,w2,w3,x,y skip_frame=0 # Current frame (plot animation) while True: w1_old=w1 w1_new=w1-alpha*dw1() w2_old=w2 w2_new=w2-alpha*dw2() w3_old=w3 w3_new=w3-alpha*dw3() w1=w1_new w2=w2_new w3=w3_new if abs(w1_new-w1_old) <= accuracy and abs(w2_new-w2_old) <= accuracy and abs(w2_new-w2_old) <= accuracy: break skip_frame+=1 decent() fig=plt.figure() #print(np.round(h(x_1,x_2))) #pred=np.round(h(x_1,x_2)) # Plot data ax = fig.add_subplot(2,2,1) ax.set_title("Original Data") ax.set_xlabel("X") ax.set_ylabel("Y") scatter=plt.scatter(x_1,x_2,c=y,marker="o") handles, labels = scatter.legend_elements(prop="colors", alpha=0.6) legend = ax.legend(handles, ["Class A","Class B"], loc="upper right", title="Legend") # Plot model ax = fig.add_subplot(2,2,2,projection='3d') ax.set_title("Model") X,Y= np.meshgrid(np.sort(x_1), np.sort(x_2)) ax.set_xlabel("X") ax.set_ylabel("Y") ax.set_zlabel("Probability") surf = ax.plot_wireframe(X,Y, h(X,Y),rstride=10,cstride=10) # Plot prediction ax = fig.add_subplot(2,1,2) ax.set_title("Predictions") ax.set_xlabel("X") ax.set_ylabel("Y") scatter=plt.scatter(x_1,x_2,c=np.round(h(x_1,x_2)),marker="o") handles, labels = scatter.legend_elements(prop="colors", alpha=0.6) legend = ax.legend(handles, ["Class A","Class B"], loc="upper right", title="Legend") # Plot decision boundaries x=np.arange(0,10,0.01) y=-w1/w3 +(w2/-w3)*x plt.fill_between(x,y,np.min(y),alpha=0.2) plt.fill_between(x,y,np.max(y),alpha=0.2) plt.plot(x,y,"--") # Save plt.tight_layout() #plt.savefig("binary.png",dpi=300) plt.show()