Cleaning
This commit is contained in:
parent
f913e808df
commit
ad57158ad2
4 changed files with 21 additions and 20 deletions
Binary file not shown.
Before Width: | Height: | Size: 412 KiB After Width: | Height: | Size: 2.1 MiB |
|
@ -15,19 +15,21 @@ h_w(x) = w_1 + w_2x + w_3x^2
|
||||||
Then, we should define a cost function. A common approach is to use the *Mean Square Error*
|
Then, we should define a cost function. A common approach is to use the *Mean Square Error*
|
||||||
cost function:
|
cost function:
|
||||||
\begin{equation}\label{eq:cost}
|
\begin{equation}\label{eq:cost}
|
||||||
J(w) = \frac{1}{2n} \sum_{i=0}^n (h_w(x^{(i)}) - \hat{y}^{(i)})^2
|
J(w) = \frac{1}{2n} \sum_{i=0}^n (h_w(x^{(i)}) - y^{(i)})^2
|
||||||
\end{equation}
|
\end{equation}
|
||||||
|
|
||||||
Note that in Equation \ref{eq:cost} we average by $2n$ and not $n$. This is because it get simplify
|
With $n$ the number of observations, $x^{(i)}$ the value of the independant variable associated with
|
||||||
while doing the partial derivatives as we will see below. This is a pure cosmetic approach which do
|
the observation $y^{(i)}$. Note that in Equation \ref{eq:cost} we average by $2n$ and not $n$. This
|
||||||
not impact the gradient decent (see [[https://math.stackexchange.com/questions/884887/why-divide-by-2m][here]] for more informations). The next step is to $min_w J(w)$
|
is because it simplify the partial derivatives expression as we will see below. This is a pure
|
||||||
for each weight $w_i$ (performing the gradient decent). Thus we compute each partial derivatives:
|
cosmetic approach which do not impact the gradient decent (see [[https://math.stackexchange.com/questions/884887/why-divide-by-2m][here]] for more informations). The next
|
||||||
|
step is to $min_w J(w)$ for each weight $w_i$ (performing the gradient decent, see [[https://towardsdatascience.com/gradient-descent-demystified-bc30b26e432a][here]]). Thus we
|
||||||
|
compute each partial derivatives:
|
||||||
\begin{align}
|
\begin{align}
|
||||||
\frac{\partial J(w)}{\partial w_1}&=\frac{\partial J(w)}{\partial h_w(x)}\frac{\partial h_w(x)}{\partial w_1}\nonumber\\
|
\frac{\partial J(w)}{\partial w_1}&=\frac{\partial J(w)}{\partial h_w(x)}\frac{\partial h_w(x)}{\partial w_1}\nonumber\\
|
||||||
&= \frac{1}{n} \sum_{i=0}^n (h_w(x^{(i)}) - \hat{y}^{(i)})\\
|
&= \frac{1}{n} \sum_{i=0}^n (h_w(x^{(i)}) - y^{(i)})\\
|
||||||
\text{similarly:}\nonumber\\
|
\text{similarly:}\nonumber\\
|
||||||
\frac{\partial J(w)}{\partial w_2}&= \frac{1}{n} \sum_{i=0}^n x(h_w(x^{(i)}) - \hat{y}^{(i)})\\
|
\frac{\partial J(w)}{\partial w_2}&= \frac{1}{n} \sum_{i=0}^n x(h_w(x^{(i)}) - y^{(i)})\\
|
||||||
\frac{\partial J(w)}{\partial w_3}&= \frac{1}{n} \sum_{i=0}^n x^2(h_w(x^{(i)}) - \hat{y}^{(i)})
|
\frac{\partial J(w)}{\partial w_3}&= \frac{1}{n} \sum_{i=0}^n x^2(h_w(x^{(i)}) - y^{(i)})
|
||||||
\end{align}
|
\end{align}
|
||||||
|
|
||||||
|
|
||||||
|
|
Binary file not shown.
|
@ -27,13 +27,17 @@ def dh3():
|
||||||
return(1/len(x)*np.sum((h(x)-y)*(x**2)))
|
return(1/len(x)*np.sum((h(x)-y)*(x**2)))
|
||||||
|
|
||||||
# Perform the gradient decent
|
# Perform the gradient decent
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots(dpi=300)
|
||||||
frame=0 # Current frame (plot animation)
|
ax.set_xlim([0, 7])
|
||||||
|
ax.set_ylim([0, 5])
|
||||||
|
ax.plot(x,y,"ro")
|
||||||
|
h_data,=ax.plot(x,h(x))
|
||||||
alpha=0.005 # Proportion of the gradient to take into account
|
alpha=0.005 # Proportion of the gradient to take into account
|
||||||
accuracy=0.000001 # Accuracy of the decent
|
accuracy=0.000001 # Accuracy of the decent
|
||||||
done=False
|
done=False
|
||||||
def decent(i):
|
def decent(i):
|
||||||
global w1,w2,w3,x,y,frame
|
global w1,w2,w3,x,y
|
||||||
|
skip_frame=0 # Current frame (plot animation)
|
||||||
while True:
|
while True:
|
||||||
w1_old=w1
|
w1_old=w1
|
||||||
w1_new=w1-alpha*dh1()
|
w1_new=w1-alpha*dh1()
|
||||||
|
@ -47,14 +51,9 @@ def decent(i):
|
||||||
|
|
||||||
if abs(w1_new-w1_old) <= accuracy and abs(w2_new-w2_old) <= accuracy and abs(w2_new-w2_old) <= accuracy:
|
if abs(w1_new-w1_old) <= accuracy and abs(w2_new-w2_old) <= accuracy and abs(w2_new-w2_old) <= accuracy:
|
||||||
done=True
|
done=True
|
||||||
frame+=1
|
skip_frame+=1
|
||||||
if frame >=1000:
|
if skip_frame >=1000:
|
||||||
frame=0
|
h_data.set_ydata(h(x))
|
||||||
ax.clear()
|
|
||||||
ax.set_xlim([0, 7])
|
|
||||||
ax.set_ylim([0, 5])
|
|
||||||
ax.plot(x,y,"ro")
|
|
||||||
ax.plot(x,h(x))
|
|
||||||
break
|
break
|
||||||
|
|
||||||
def IsDone():
|
def IsDone():
|
||||||
|
@ -65,5 +64,5 @@ def IsDone():
|
||||||
yield i
|
yield i
|
||||||
|
|
||||||
anim=FuncAnimation(fig,decent,frames=IsDone,repeat=False)
|
anim=FuncAnimation(fig,decent,frames=IsDone,repeat=False)
|
||||||
anim.save('polynomial.gif',dpi=80,writer="imagemagick")
|
anim.save('polynomial.gif',writer="imagemagick",dpi=300)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue