#!/usr/bin/env python

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

# Load the data
csv="../data/polynomial.csv"
data=pd.read_csv(csv)
x=np.array(data["x"])
y=np.array(data["y"])

# Define the weight
w1=w2=w3=10

# Define our model
def h(x):
    return(w1+w2*x+w3*(x**2))

# Define all partial derivative
def dh1():
    return(1/len(x)*np.sum(h(x)-y))
def dh2():
    return(1/len(x)*np.sum((h(x)-y)*x))
def dh3():
    return(1/len(x)*np.sum((h(x)-y)*(x**2)))

# Perform the gradient decent
fig, ax = plt.subplots(dpi=300)
ax.set_xlim([0, 7])
ax.set_ylim([0, 5])
ax.plot(x,y,"ro")
h_data,=ax.plot(x,h(x))
alpha=0.005 # Proportion of the gradient to take into account
accuracy=0.000001 # Accuracy of the decent
done=False
def decent(i):
    global w1,w2,w3,x,y
    skip_frame=0 # Current frame (plot animation)
    while True: 
        w1_old=w1
        w1_new=w1-alpha*dh1()
        w2_old=w2
        w2_new=w2-alpha*dh2()
        w3_old=w3
        w3_new=w3-alpha*dh3()
        w1=w1_new
        w2=w2_new
        w3=w3_new

        if abs(w1_new-w1_old) <= accuracy and abs(w2_new-w2_old) <= accuracy and abs(w2_new-w2_old) <= accuracy:
            done=True
        skip_frame+=1
        if skip_frame >=1000:
            h_data.set_ydata(h(x))
            break

def IsDone():
    global done
    i = 0
    while not done:
        i += 1
        yield i
        
anim=FuncAnimation(fig,decent,frames=IsDone,repeat=False)
anim.save('polynomial.gif',writer="imagemagick",dpi=300)