summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLoic Guegan <manzerbredes@mailbox.org>2022-11-02 12:18:07 +0100
committerLoic Guegan <manzerbredes@mailbox.org>2022-11-02 12:18:07 +0100
commit657578547f60673e227a5e97833b137083e71c2d (patch)
tree47923d68098a0de0e8c6c1ec783ae36edbf66605
parent65280331d3b2b3d32375df783ab89f30706e17da (diff)
Minor changes
-rwxr-xr-xqlearning.py47
-rwxr-xr-xsnake.py6
2 files changed, 39 insertions, 14 deletions
diff --git a/qlearning.py b/qlearning.py
index ee0b0a9..3c8111d 100755
--- a/qlearning.py
+++ b/qlearning.py
@@ -27,15 +27,16 @@ class QTable:
##### Reward +1 when eat an apple
##### Reward -10 when hit obstacle
"""
- def __init__(self, file, save_every=10):
+ def __init__(self, file, save_every=5000):
self.file=file
self.save_every=save_every
- self.update_counter=0
+ self.save_counter=0
if os.path.exists(file):
self.qtable=np.loadtxt(file)
else:
self.qtable=np.zeros((2**13, 4))
-
+ with open(file+"_generation","w") as f:
+ f.write("0")
def isWall(self,h,game):
if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height:
@@ -108,10 +109,18 @@ class QTable:
alpha=0.5
gamma=0.9
self.qtable[state,action]=self.qtable[state,action]+alpha*(reward+gamma*np.max(self.qtable[new_state])-self.qtable[state,action])
- self.update_counter+=1
- if self.update_counter>=self.save_every:
+ self.save_counter+=1
+ if self.save_counter>=self.save_every:
np.savetxt(self.file,self.qtable)
- self.update_counter=0
+ if os.path.exists(self.file+"_generation"):
+ generation=0
+ with open(self.file+"_generation","r") as f:
+ generation=int(f.readline().rstrip())
+ generation+=self.save_every
+ with open(self.file+"_generation","w") as f:
+ f.write(str(generation))
+ print("Checkpointing generation "+str(generation))
+ self.save_counter=0
def get_action(self,state):
# Choose an action
@@ -122,13 +131,16 @@ class QTable:
#action = random.choice(options)
action=np.argmax(self.qtable[state])
return(action)
+
+ def get_random_action(self):
+ return(random.choice((0,1,2,3)))
# Perform learning
-width,height=10,10
+width,height=50,30
perf=0
last_state=None
last_action=None
@@ -138,25 +150,34 @@ while True:
result=0
stuck=0
stuck_tolerance=1
+ stuck_count=0
state=qtable.get_state(game)
while result >= 0:
action=qtable.get_action(state)
result=game.play3(action)
new_state=qtable.get_state(game)
- # Agent is stuck
- if stuck>=(game.grid_width*game.grid_height)/stuck_tolerance:
- game.new_game()
- break
-
# Compute reward and update stuck
reward=0
if result==-1:
reward=-10
stuck=0
+ stuck_count=0
elif result==1:
- reward=50
+ reward=1
stuck=0
+ stuck_count=0
+
+ # Agent is stuck
+ if stuck>=(game.grid_width*game.grid_height)/stuck_tolerance:
+ stuck=0
+ stuck_count+=1
+ action=qtable.get_random_action()
+ print("Stuck!")
+ if stuck_count>2:
+ stuck_count=0
+ game.new_game()
+ break
# Apply learning
qtable.apply_bellman(state,action,new_state,reward)
diff --git a/snake.py b/snake.py
index 7526731..330ba12 100755
--- a/snake.py
+++ b/snake.py
@@ -215,4 +215,8 @@ class Snake:
if self.play(self.direction,handle_quit=False) <0:
break
- \ No newline at end of file
+
+
+if __name__ == "__main__":
+ game=Snake(length=50)
+ game.play_with_keyboard() \ No newline at end of file