There are two main classes: Game
and MCTSPlayer
. The first one is just an abstract class with a couple of methods and some hints on how to implement actual games. The second implements the simplest form of MCTS algorithm (it does not use any expert knowledge), storing a graph of all known games as a dict[Game, NodeInfo]
. That graph is sometimes trimmed to save some memory by deleting games that are unreachable from the current one. For that, a class that inherits from the Game can implement __lt_
_, so that if a
and b
are instances of the same game a<b
would mean that a
is unreachable from b
.
game.py:
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional
class GameState(Enum):
WhitesMove = 1
BlacksMove = 2
WhiteWon = 3
BlackWon = 4
Draw = 5
class MoveType(Enum):
Put = 1
Jump = 2
Slide = 3
Take = 4
@dataclass
class Move:
type: MoveType
pos: Optional[tuple[int, int]]
pfrom: Optional[tuple[int, int]]
pto: Optional[tuple[int, int]]
@staticmethod
def Put(pos: tuple[int, int]):
return Move(MoveType.Put, pos, None, None)
@staticmethod
def Jump(pfrom, pto):
return Move(MoveType.Jump, None, pfrom, pto)
@staticmethod
def Slide(pfrom, pto):
return Move(MoveType.Slide, None, pfrom, pto)
@staticmethod
def Take(pos: tuple[int, int]):
return Move(MoveType.Take, pos, None, None)
class Game(ABC):
state: GameState
@abstractmethod
def with_move(self, m: Move) -> 'Game':
pass
@abstractmethod
def possible(self) -> list[Move]:
pass
@abstractmethod
def finished(self) -> bool:
pass
@abstractmethod
def __lt__(self, other):
pass
mctsplayer.py:
from game import Game, GameState, Move
from typing import Optional, Union
from dataclasses import dataclass
from math import sqrt, log
from random import choice
from time import time
@dataclass
class NodeInfo:
whiteWon: int
blackWon: int
visitCount: int
class MCTSPlayer:
def __init__(self, game, c=0.7, warm_up=1):
self.game: Game = game
self.tree: dict[Game, NodeInfo] = dict()
self.tree[self.game] = NodeInfo(0, 0, 1)
self.cur = self.game
self.C = c
start = time()
while time() - start < warm_up:
self.build()
def move_to(self, game: Game):
if game not in self.tree.keys():
self.tree[game] = NodeInfo(0, 0, 1)
keys = list(self.tree.keys())
for key in keys:
if key < game:
del self.tree[key]
self.cur = game
def best_move(self, time_limit):
start = time()
ps = 0
while time() - start < time_limit:
self.build()
ps += 1
return self.best()
def build(self):
cur: Game = self.cur
path: list[Game] = []
self.expand(cur)
while cur in self.tree.keys():
path.append(cur)
cur = self.select(cur)
cur = path[-1]
self.expand(cur)
res = self.simulate(cur)
while path:
cur = path.pop()
self.backpropagate(cur, res)
def value(self, game: Game) -> int:
ret = self.tree[game].whiteWon - self.tree[game].blackWon
if self.cur.state == GameState.BlacksMove:
ret *= -1
return ret
def select(self, game: Game) -> Optional[Game]:
vs: list[Game] = list(self.tree.keys())
next_games: list[list[Union[Game, int]]] = \
[[game.with_move(m), 0] for m in game.possible() if game.with_move(m) in vs]
if not next_games:
return None
for i in range(len(next_games)):
next_games[i][1] = self.value(game)
next_games[i][1] += self.C * sqrt(log(self.tree[game].visitCount) /
self.tree[next_games[i][0]].visitCount)
next_games.sort(key=lambda p: p[1], reverse=True)
return next_games[0][0]
def expand(self, game):
next_games: list[Game] = [game.with_move(m) for m in game.possible()]
if not next_games:
return
for ng in next_games:
if ng not in self.tree.keys():
self.tree[ng] = NodeInfo(0, 0, 1)
@staticmethod
def simulate(game: Game) -> GameState:
while not game.finished():
next_moves = game.possible()
game = game.with_move(choice(next_moves))
return game.state
def backpropagate(self, game: Game, state: GameState):
self.tree[game].visitCount += 1
if state == GameState.BlackWon:
self.tree[game].blackWon += 1
if state == GameState.WhiteWon:
self.tree[game].whiteWon += 1
def best(self) -> Move:
vs = self.tree.keys()
moves2values: list[tuple[Move, int]] = \
[(m, self.value(self.cur.with_move(m))) for m in self.cur.possible() if
self.cur.with_move(m) in vs]
moves2values.sort(key=lambda p: p[1], reverse=True)
return moves2values[0][0]
To test it, I wrote an Othello (Reversi) game with pygame visualization (I did not put a lot of effort into it and post it rather for demonstration than for an actual review):
from game import Game, GameState, Move, MoveType
from mctsplayer import MCTSPlayer
from copy import copy, deepcopy
from random import shuffle
from enum import Enum
import pygame
def draw_board(game, display, last, plr):
display.fill((0, 100, 0))
for i in range(9):
pygame.draw.line(display, (0, 0, 0), (i*50, 0), (i*50, 400), 1)
pygame.draw.line(display, (0, 0, 0), (0, i*50), (400, i*50), 1)
for x in range(8):
for y in range(8):
pos = (x, y)
m = Move.Put(pos)
if pos in game.board.keys():
fill = (0, 0, 0)
if game.board[pos] == OthelloPieces.White:
fill = (200, 200, 200)
pygame.draw.circle(display, fill, ((x+0.5)*50, (y+0.5)*50), 20)
if plr:
if m in game.possible():
pygame.draw.circle(display, (255, 125, 0), ((x + 0.5) * 50, (y + 0.5) * 50), 5)
if Move.Put((x, y)) == last:
pygame.draw.circle(display, (255, 0, 0), ((x + 0.5) * 50, (y + 0.5) * 50), 5)
def play_othello():
g = Othello()
p = MCTSPlayer(g)
pygame.init()
display = pygame.display.set_mode((400, 400))
finished = False
last_move = (-1, -1)
colors = [GameState.WhitesMove, GameState.BlacksMove]
shuffle(colors)
print(colors)
cmp, plr = colors
while not finished:
draw_board(g, display, last_move, g.state == plr)
pygame.display.update()
for event in pygame.event.get():
if event.type == pygame.QUIT:
finished = True
if g.state == cmp:
bm = p.best_move(2)
last_move = bm
g = g.with_move(bm)
p.move_to(g)
if event.type == pygame.MOUSEBUTTONDOWN:
if g.state == plr:
x, y = event.pos
x = int(x/50)
y = int(y/50)
if Move.Put((x, y)) in g.possible():
g = g.with_move(Move.Put((x, y)))
p.move_to(g)
pygame.quit()
class OthelloPieces(Enum):
White = 1
Black = 2
Empty = 3
def flip(self):
if self == OthelloPieces.White:
return OthelloPieces.Black
if self == OthelloPieces.Black:
return OthelloPieces.White
return OthelloPieces.Empty
class Othello(Game):
def __init__(self, base: 'Othello' = None):
self.board = dict()
self.state = GameState.BlacksMove
if base is not None:
self.board = base.board.copy()
self.state = copy(base.state)
self.possible_memo = deepcopy(base.possible_memo)
else:
self.board[(3, 3)] = OthelloPieces.White
self.board[(4, 4)] = OthelloPieces.White
self.board[(4, 3)] = OthelloPieces.Black
self.board[(3, 4)] = OthelloPieces.Black
self.possible_memo = None
self.possible_memo = self.possible()
self.memo_hash = None
def with_move(self, m: Move) -> 'Othello':
ret = Othello(self)
assert m.type == MoveType.Put
assert m in ret.possible(), 'Invalid move: ' + str(m) + ' in game \n' + str(ret)
x_0, y_0 = m.pos
piece = OthelloPieces.White if ret.state == GameState.WhitesMove else OthelloPieces.Black
opposite = piece.flip()
to_flip = []
ds = [(1, 0),
(-1, 0),
(0, 1),
(0, -1),
(1, 1),
(1, -1),
(-1, 1),
(-1, -1)]
for dx, dy in ds:
to_add = []
add = False
for i in range(1, 8):
x = x_0 + dx * i
y = y_0 + dy * i
if not 0 <= x < 8 or not 0 <= y < 8:
break
if (x, y) not in ret.board.keys():
break
if ret.board[(x, y)] == opposite:
to_add.append((x, y))
if ret.board[(x, y)] == piece:
add = True
break
if add:
to_flip.extend(to_add)
for p in to_flip:
ret.board[p] = piece
ret.board[(x_0, y_0)] = piece
ret.change_state()
ret.possible_memo = None
ret.possible_memo = ret.possible()
if not ret.possible():
ret.change_state()
ret.changed = True
if not ret.possible():
ret.change_state(True)
return ret
def change_state(self, force_end=False):
if {(x, y) for x in range(8) for y in range(8)} != self.board.keys() and not force_end:
if self.state == GameState.WhitesMove:
self.state = GameState.BlacksMove
elif self.state == GameState.BlacksMove:
self.state = GameState.WhitesMove
return
w = list(self.board.values()).count(OthelloPieces.White)
b = list(self.board.values()).count(OthelloPieces.Black)
if w > b:
self.state = GameState.WhiteWon
if w < b:
self.state = GameState.BlackWon
if w == b:
self.state = GameState.Draw
def possible(self) -> list[Move]:
if self.possible_memo:
return self.possible_memo
ret: list[Move] = []
ds = [(1, 0),
(-1, 0),
(0, 1),
(0, -1),
(1, 1),
(1, -1),
(-1, 1),
(-1, -1)]
piece = OthelloPieces.White if self.state == GameState.WhitesMove else OthelloPieces.Black
opposite = piece.flip()
for x_0 in range(8):
for y_0 in range(8):
good = False
if (x_0, y_0) in self.board.keys():
continue
for dx, dy in ds:
started = False
if good:
break
for i in range(1, 8):
x = x_0 + dx * i
y = y_0 + dy * i
if not 0 <= x < 8 or not 0 <= y < 8:
good = False
break
if (x, y) not in self.board.keys():
break
if self.board[(x, y)] == piece and started:
good = True
break
elif self.board[(x, y)] == opposite:
started = True
else:
break
if good:
ret.append(Move.Put((x_0, y_0)))
self.possible_memo = ret
return ret
def finished(self):
return self.state in [GameState.BlackWon, GameState.WhiteWon, GameState.Draw]
def __repr__(self):
ret = [['-' for _ in range(8)] for _ in range(8)]
for x in range(8):
for y in range(8):
key = (x, y)
symbol = '-'
if key in self.board.keys():
if self.board[key] == OthelloPieces.White:
symbol = 'O'
if self.board[key] == OthelloPieces.Black:
symbol = 'X'
if key in self.possible():
symbol = '.'
ret[key[1]][key[0]] = symbol
return ''.join([''.join(l) + '\n' for i, l in enumerate(ret)])
def __str__(self):
ret = repr(self)[:-1].split('\n')
ret = ''.join([str(i) + l + '\n' for i, l in enumerate(ret)])
ret = ' 01234567\n' + ret
return ret
def __eq__(self, other):
if type(other) == str:
return repr(self).replace('\n', '') == other.replace('\n', '').replace(' ', '')
if type(other) == Othello:
return self.board == other.board and self.state == other.state
def __hash__(self):
if self.memo_hash:
return self.memo_hash
self.memo_hash = hash((frozenset(self.board.items()), self.state))
return self.memo_hash
def __lt__(self, other):
assert type(other) == Othello
return len(self.board) < len(other.board)
What are some ways to make this code better? It seems reasonably readable, but I feel like it could be improved. I also feel that the actual "framework" part design can be improved.
Also, what are some ways to make MCTSPlayer
stronger? Right now it plays well only when it has a lot of time to think and it is not too strong. I know that expert knowledge does improve playing strength, but I dont want to use it to keep
MCTSPlayer` game-independent.