Minor changes
This commit is contained in:
parent
65280331d3
commit
657578547f
2 changed files with 39 additions and 14 deletions
47
qlearning.py
47
qlearning.py
|
@ -27,15 +27,16 @@ class QTable:
|
|||
##### Reward +1 when eat an apple
|
||||
##### Reward -10 when hit obstacle
|
||||
"""
|
||||
def __init__(self, file, save_every=10):
|
||||
def __init__(self, file, save_every=5000):
|
||||
self.file=file
|
||||
self.save_every=save_every
|
||||
self.update_counter=0
|
||||
self.save_counter=0
|
||||
if os.path.exists(file):
|
||||
self.qtable=np.loadtxt(file)
|
||||
else:
|
||||
self.qtable=np.zeros((2**13, 4))
|
||||
|
||||
with open(file+"_generation","w") as f:
|
||||
f.write("0")
|
||||
|
||||
def isWall(self,h,game):
|
||||
if h[0]<0 or h[1]<0 or h[0] >= game.grid_width or h[1] >= game.grid_height:
|
||||
|
@ -108,10 +109,18 @@ class QTable:
|
|||
alpha=0.5
|
||||
gamma=0.9
|
||||
self.qtable[state,action]=self.qtable[state,action]+alpha*(reward+gamma*np.max(self.qtable[new_state])-self.qtable[state,action])
|
||||
self.update_counter+=1
|
||||
if self.update_counter>=self.save_every:
|
||||
self.save_counter+=1
|
||||
if self.save_counter>=self.save_every:
|
||||
np.savetxt(self.file,self.qtable)
|
||||
self.update_counter=0
|
||||
if os.path.exists(self.file+"_generation"):
|
||||
generation=0
|
||||
with open(self.file+"_generation","r") as f:
|
||||
generation=int(f.readline().rstrip())
|
||||
generation+=self.save_every
|
||||
with open(self.file+"_generation","w") as f:
|
||||
f.write(str(generation))
|
||||
print("Checkpointing generation "+str(generation))
|
||||
self.save_counter=0
|
||||
|
||||
def get_action(self,state):
|
||||
# Choose an action
|
||||
|
@ -122,13 +131,16 @@ class QTable:
|
|||
#action = random.choice(options)
|
||||
action=np.argmax(self.qtable[state])
|
||||
return(action)
|
||||
|
||||
def get_random_action(self):
|
||||
return(random.choice((0,1,2,3)))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# Perform learning
|
||||
width,height=10,10
|
||||
width,height=50,30
|
||||
perf=0
|
||||
last_state=None
|
||||
last_action=None
|
||||
|
@ -138,25 +150,34 @@ while True:
|
|||
result=0
|
||||
stuck=0
|
||||
stuck_tolerance=1
|
||||
stuck_count=0
|
||||
state=qtable.get_state(game)
|
||||
while result >= 0:
|
||||
action=qtable.get_action(state)
|
||||
result=game.play3(action)
|
||||
new_state=qtable.get_state(game)
|
||||
|
||||
# Agent is stuck
|
||||
if stuck>=(game.grid_width*game.grid_height)/stuck_tolerance:
|
||||
game.new_game()
|
||||
break
|
||||
|
||||
# Compute reward and update stuck
|
||||
reward=0
|
||||
if result==-1:
|
||||
reward=-10
|
||||
stuck=0
|
||||
stuck_count=0
|
||||
elif result==1:
|
||||
reward=50
|
||||
reward=1
|
||||
stuck=0
|
||||
stuck_count=0
|
||||
|
||||
# Agent is stuck
|
||||
if stuck>=(game.grid_width*game.grid_height)/stuck_tolerance:
|
||||
stuck=0
|
||||
stuck_count+=1
|
||||
action=qtable.get_random_action()
|
||||
print("Stuck!")
|
||||
if stuck_count>2:
|
||||
stuck_count=0
|
||||
game.new_game()
|
||||
break
|
||||
|
||||
# Apply learning
|
||||
qtable.apply_bellman(state,action,new_state,reward)
|
||||
|
|
6
snake.py
6
snake.py
|
@ -215,4 +215,8 @@ class Snake:
|
|||
|
||||
if self.play(self.direction,handle_quit=False) <0:
|
||||
break
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
game=Snake(length=50)
|
||||
game.play_with_keyboard()
|
Loading…
Add table
Reference in a new issue