1
\$\begingroup\$

There are two main classes: Game and MCTSPlayer. The first one is just an abstract class with a couple of methods and some hints on how to implement actual games. The second implements the simplest form of MCTS algorithm (it does not use any expert knowledge), storing a graph of all known games as a dict[Game, NodeInfo]. That graph is sometimes trimmed to save some memory by deleting games that are unreachable from the current one. For that, a class that inherits from the Game can implement __lt__, so that if a and b are instances of the same game a<b would mean that a is unreachable from b.

game.py:

from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Optional


class GameState(Enum):
    WhitesMove = 1
    BlacksMove = 2
    WhiteWon = 3
    BlackWon = 4
    Draw = 5


class MoveType(Enum):
    Put = 1
    Jump = 2
    Slide = 3
    Take = 4


@dataclass
class Move:
    type: MoveType
    pos: Optional[tuple[int, int]]
    pfrom: Optional[tuple[int, int]]
    pto: Optional[tuple[int, int]]

    @staticmethod
    def Put(pos: tuple[int, int]):
        return Move(MoveType.Put, pos, None, None)

    @staticmethod
    def Jump(pfrom, pto):
        return Move(MoveType.Jump, None, pfrom, pto)

    @staticmethod
    def Slide(pfrom, pto):
        return Move(MoveType.Slide, None, pfrom, pto)

    @staticmethod
    def Take(pos: tuple[int, int]):
        return Move(MoveType.Take, pos, None, None)


class Game(ABC):
    state: GameState

    @abstractmethod
    def with_move(self, m: Move) -> 'Game':
        pass

    @abstractmethod
    def possible(self) -> list[Move]:
        pass

    @abstractmethod
    def finished(self) -> bool:
        pass

    @abstractmethod
    def __lt__(self, other):
        pass

mctsplayer.py:

from game import Game, GameState, Move
from typing import Optional, Union
from dataclasses import dataclass
from math import sqrt, log
from random import choice
from time import time


@dataclass
class NodeInfo:
    whiteWon: int
    blackWon: int
    visitCount: int


class MCTSPlayer:
    def __init__(self, game, c=0.7, warm_up=1):
        self.game: Game = game
        self.tree: dict[Game, NodeInfo] = dict()
        self.tree[self.game] = NodeInfo(0, 0, 1)
        self.cur = self.game
        self.C = c
        start = time()

        while time() - start < warm_up:
            self.build()

    def move_to(self, game: Game):
        if game not in self.tree.keys():
            self.tree[game] = NodeInfo(0, 0, 1)

        keys = list(self.tree.keys())

        for key in keys:
            if key < game:
                del self.tree[key]

        self.cur = game

    def best_move(self, time_limit):
        start = time()

        ps = 0
        while time() - start < time_limit:
            self.build()
            ps += 1

        return self.best()

    def build(self):
        cur: Game = self.cur
        path: list[Game] = []

        self.expand(cur)

        while cur in self.tree.keys():
            path.append(cur)
            cur = self.select(cur)

        cur = path[-1]
        self.expand(cur)

        res = self.simulate(cur)

        while path:
            cur = path.pop()
            self.backpropagate(cur, res)

    def value(self, game: Game) -> int:
        ret = self.tree[game].whiteWon - self.tree[game].blackWon

        if self.cur.state == GameState.BlacksMove:
            ret *= -1

        return ret

    def select(self, game: Game) -> Optional[Game]:
        vs: list[Game] = list(self.tree.keys())
        next_games: list[list[Union[Game, int]]] = \
            [[game.with_move(m), 0] for m in game.possible() if game.with_move(m) in vs]

        if not next_games:
            return None

        for i in range(len(next_games)):
            next_games[i][1] = self.value(game)

            next_games[i][1] += self.C * sqrt(log(self.tree[game].visitCount) /
                                              self.tree[next_games[i][0]].visitCount)

        next_games.sort(key=lambda p: p[1], reverse=True)

        return next_games[0][0]

    def expand(self, game):
        next_games: list[Game] = [game.with_move(m) for m in game.possible()]
        if not next_games:
            return

        for ng in next_games:
            if ng not in self.tree.keys():
                self.tree[ng] = NodeInfo(0, 0, 1)

    @staticmethod
    def simulate(game: Game) -> GameState:
        while not game.finished():
            next_moves = game.possible()
            game = game.with_move(choice(next_moves))

        return game.state

    def backpropagate(self, game: Game, state: GameState):
        self.tree[game].visitCount += 1

        if state == GameState.BlackWon:
            self.tree[game].blackWon += 1

        if state == GameState.WhiteWon:
            self.tree[game].whiteWon += 1

    def best(self) -> Move:
        vs = self.tree.keys()

        moves2values: list[tuple[Move, int]] = \
            [(m, self.value(self.cur.with_move(m))) for m in self.cur.possible() if
             self.cur.with_move(m) in vs]

        moves2values.sort(key=lambda p: p[1], reverse=True)

        return moves2values[0][0]

To test it, I wrote an Othello (Reversi) game with pygame visualization (I did not put a lot of effort into it and post it rather for demonstration than for an actual review):

from game import Game, GameState, Move, MoveType
from mctsplayer import MCTSPlayer
from copy import copy, deepcopy
from random import shuffle
from enum import Enum
import pygame


def draw_board(game, display, last, plr):
    display.fill((0, 100, 0))
    for i in range(9):
        pygame.draw.line(display, (0, 0, 0), (i*50, 0), (i*50, 400), 1)
        pygame.draw.line(display, (0, 0, 0), (0, i*50), (400, i*50), 1)

    for x in range(8):
        for y in range(8):
            pos = (x, y)
            m = Move.Put(pos)

            if pos in game.board.keys():
                fill = (0, 0, 0)

                if game.board[pos] == OthelloPieces.White:
                    fill = (200, 200, 200)

                pygame.draw.circle(display, fill, ((x+0.5)*50, (y+0.5)*50), 20)

            if plr:
                if m in game.possible():
                    pygame.draw.circle(display, (255, 125, 0), ((x + 0.5) * 50, (y + 0.5) * 50), 5)

                if Move.Put((x, y)) == last:
                    pygame.draw.circle(display, (255, 0, 0), ((x + 0.5) * 50, (y + 0.5) * 50), 5)


def play_othello():
    g = Othello()
    p = MCTSPlayer(g)

    pygame.init()

    display = pygame.display.set_mode((400, 400))

    finished = False

    last_move = (-1, -1)

    colors = [GameState.WhitesMove, GameState.BlacksMove]
    shuffle(colors)
    print(colors)
    cmp, plr = colors

    while not finished:
        draw_board(g, display, last_move, g.state == plr)
        pygame.display.update()

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                finished = True

            if g.state == cmp:
                bm = p.best_move(2)
                last_move = bm
                g = g.with_move(bm)
                p.move_to(g)

            if event.type == pygame.MOUSEBUTTONDOWN:
                if g.state == plr:
                    x, y = event.pos

                    x = int(x/50)
                    y = int(y/50)

                    if Move.Put((x, y)) in g.possible():
                        g = g.with_move(Move.Put((x, y)))
                        p.move_to(g)

    pygame.quit()


class OthelloPieces(Enum):
    White = 1
    Black = 2
    Empty = 3

    def flip(self):
        if self == OthelloPieces.White:
            return OthelloPieces.Black

        if self == OthelloPieces.Black:
            return OthelloPieces.White

        return OthelloPieces.Empty


class Othello(Game):
    def __init__(self, base: 'Othello' = None):
        self.board = dict()
        self.state = GameState.BlacksMove

        if base is not None:
            self.board = base.board.copy()
            self.state = copy(base.state)
            self.possible_memo = deepcopy(base.possible_memo)

        else:
            self.board[(3, 3)] = OthelloPieces.White
            self.board[(4, 4)] = OthelloPieces.White

            self.board[(4, 3)] = OthelloPieces.Black
            self.board[(3, 4)] = OthelloPieces.Black

            self.possible_memo = None
            self.possible_memo = self.possible()

        self.memo_hash = None

    def with_move(self, m: Move) -> 'Othello':
        ret = Othello(self)

        assert m.type == MoveType.Put
        assert m in ret.possible(), 'Invalid move: ' + str(m) + ' in game \n' + str(ret)

        x_0, y_0 = m.pos

        piece = OthelloPieces.White if ret.state == GameState.WhitesMove else OthelloPieces.Black
        opposite = piece.flip()

        to_flip = []

        ds = [(1, 0),
              (-1, 0),
              (0, 1),
              (0, -1),
              (1, 1),
              (1, -1),
              (-1, 1),
              (-1, -1)]

        for dx, dy in ds:
            to_add = []
            add = False

            for i in range(1, 8):
                x = x_0 + dx * i
                y = y_0 + dy * i

                if not 0 <= x < 8 or not 0 <= y < 8:
                    break

                if (x, y) not in ret.board.keys():
                    break

                if ret.board[(x, y)] == opposite:
                    to_add.append((x, y))

                if ret.board[(x, y)] == piece:
                    add = True
                    break

            if add:
                to_flip.extend(to_add)

        for p in to_flip:
            ret.board[p] = piece

        ret.board[(x_0, y_0)] = piece

        ret.change_state()

        ret.possible_memo = None
        ret.possible_memo = ret.possible()

        if not ret.possible():
            ret.change_state()

            ret.changed = True

            if not ret.possible():
                ret.change_state(True)

        return ret

    def change_state(self, force_end=False):
        if {(x, y) for x in range(8) for y in range(8)} != self.board.keys() and not force_end:
            if self.state == GameState.WhitesMove:
                self.state = GameState.BlacksMove
            elif self.state == GameState.BlacksMove:
                self.state = GameState.WhitesMove

            return

        w = list(self.board.values()).count(OthelloPieces.White)
        b = list(self.board.values()).count(OthelloPieces.Black)

        if w > b:
            self.state = GameState.WhiteWon

        if w < b:
            self.state = GameState.BlackWon

        if w == b:
            self.state = GameState.Draw

    def possible(self) -> list[Move]:
        if self.possible_memo:
            return self.possible_memo

        ret: list[Move] = []

        ds = [(1, 0),
              (-1, 0),
              (0, 1),
              (0, -1),
              (1, 1),
              (1, -1),
              (-1, 1),
              (-1, -1)]

        piece = OthelloPieces.White if self.state == GameState.WhitesMove else OthelloPieces.Black
        opposite = piece.flip()

        for x_0 in range(8):
            for y_0 in range(8):
                good = False
                if (x_0, y_0) in self.board.keys():
                    continue

                for dx, dy in ds:
                    started = False
                    if good:
                        break

                    for i in range(1, 8):
                        x = x_0 + dx * i
                        y = y_0 + dy * i

                        if not 0 <= x < 8 or not 0 <= y < 8:
                            good = False
                            break

                        if (x, y) not in self.board.keys():
                            break

                        if self.board[(x, y)] == piece and started:
                            good = True
                            break
                        elif self.board[(x, y)] == opposite:
                            started = True
                        else:
                            break

                if good:
                    ret.append(Move.Put((x_0, y_0)))

        self.possible_memo = ret
        return ret

    def finished(self):
        return self.state in [GameState.BlackWon, GameState.WhiteWon, GameState.Draw]

    def __repr__(self):
        ret = [['-' for _ in range(8)] for _ in range(8)]

        for x in range(8):
            for y in range(8):
                key = (x, y)
                symbol = '-'

                if key in self.board.keys():
                    if self.board[key] == OthelloPieces.White:
                        symbol = 'O'

                    if self.board[key] == OthelloPieces.Black:
                        symbol = 'X'

                if key in self.possible():
                    symbol = '.'

                ret[key[1]][key[0]] = symbol

        return ''.join([''.join(l) + '\n' for i, l in enumerate(ret)])

    def __str__(self):
        ret = repr(self)[:-1].split('\n')
        ret = ''.join([str(i) + l + '\n' for i, l in enumerate(ret)])
        ret = ' 01234567\n' + ret
        return ret

    def __eq__(self, other):
        if type(other) == str:
            return repr(self).replace('\n', '') == other.replace('\n', '').replace(' ', '')

        if type(other) == Othello:
            return self.board == other.board and self.state == other.state

    def __hash__(self):
        if self.memo_hash:
            return self.memo_hash

        self.memo_hash = hash((frozenset(self.board.items()), self.state))

        return self.memo_hash

    def __lt__(self, other):
        assert type(other) == Othello

        return len(self.board) < len(other.board)

What are some ways to make this code better? It seems reasonably readable, but I feel like it could be improved. I also feel that the actual "framework" part design can be improved.

Also, what are some ways to make MCTSPlayer stronger? Right now it plays well only when it has a lot of time to think and it is not too strong. I know that expert knowledge does improve playing strength, but I dont want to use it to keep MCTSPlayer` game-independent.

\$\endgroup\$

0

Browse other questions tagged or ask your own question.