2
\$\begingroup\$

I'm writing tutorial code to help high school students understand the MuZero algorithm. There are two main requirements for this code.

  1. The code needs to be simple and easy for any student to understand.
  2. The code needs to be flexible enough so that students can implement it for their own game projects without changing the core code.

Here is the code repository: https://github.com/fnclovers/Minimal-AlphaZero/tree/master

To write this code, I followed these criteria:

  1. For flexibility: Students need to modify the Game class to define their game environment, the Action class to define in-game actions, and the Network class for inference, in order to apply the MuZero algorithm to their own game.

    • Game class: This core class must provide four functions: 1) Check for a terminal condition (def terminal(self) -> bool:), 2) Return a list of possible actions (def legal_actions(self) -> List[Action]:), 3) Perform an action in the game (def apply(self, action: Action):), 4) Create an image that stores the board state for reasoning and learning (def make_image(self, state_index: int):).
    • Network class: The Network class has to infer four things: value, reward, policy, and hidden. Value means the sum of future rewards, reward is the current reward, policy is the probability of choosing each action, and hidden is the hidden network for inferring the next state. Students need to implement two functions: initial_inference, where the game image is provided, and recurrent_inference, where the subsequent actions are provided.
  2. For simplicity: To make the code easy to understand, we kept only the core algorithm (900 lines of code). At the same time, we support multiple CPUs and GPUs so that reinforcement learning algorithms that require hours of training can be trained in a feasible time. However, I would like to find a way to write the MuZero algorithms more concisely and easily.

Please review the code from the following perspective:

  • Is the code simple enough for high school students to look at and easily understand the MuZero algorithm?
  • Is the code flexible enough to allow students to add their own games to the project?

(Sample Game class and Network class for explanation. See more code on github)

class Game(object):
    """A single episode of interaction with the environment."""

    def __init__(self, action_space_size: int, discount: float):
        self.environment = Environment()  # Game specific environment.
        self.history = (
            []
        )  # history of prev actions; used for recurrent inference for training
        self.rewards = []  # rewards of prev actions; used for training dynamics network
        self.child_visits = (
            []
        )  # child visit probabilities; used for training policy network
        self.root_values = []
        self.action_space_size = action_space_size
        self.discount = discount

    def terminal(self) -> bool:
        # Game specific termination rules.
        pass

    def legal_actions(self) -> List[Action]:
        # Game specific calculation of legal actions.
        return []

    def apply(self, action: Action):
        reward = self.environment.step(action)
        self.rewards.append(reward)
        self.history.append(action)

    def store_search_statistics(self, root: Node):
        sum_visits = sum(child.visit_count for child in root.children.values())
        action_space = (Action(index) for index in range(self.action_space_size))
        self.child_visits.append(
            [
                root.children[a].visit_count / sum_visits if a in root.children else 0
                for a in action_space
            ]
        )
        self.root_values.append(root.value())

    def make_image(self, state_index: int):
        # Game specific feature planes.
        return []

    def make_target(self, state_index: int, num_unroll_steps: int, td_steps: int):
        # The value target is the discounted root value of the search tree N steps
        # into the future, plus the discounted sum of all rewards until then.
        targets = []
        for current_index in range(state_index, state_index + num_unroll_steps + 1):
            bootstrap_index = current_index + td_steps
            if bootstrap_index < len(self.root_values):
                value = self.root_values[bootstrap_index] * self.discount**td_steps
            else:
                value = 0

            for i, reward in enumerate(self.rewards[current_index:bootstrap_index]):
                value += (
                    reward * self.discount**i
                )  # pytype: disable=unsupported-operands

            if current_index > 0 and current_index <= len(self.rewards):
                last_reward = self.rewards[current_index - 1]
            else:
                last_reward = 0

            if current_index < len(self.root_values):
                # 1) image[n] --> pred[n], value[n], hidden_state[n]
                # 2) hidden_state[n] + action[n] --> reward[n], pred[n+1], value[n+1], hidden_state[n+1]
                targets.append(
                    (value, last_reward, self.child_visits[current_index], True)
                )
            else:
                # States past the end of games are treated as absorbing states.
                targets.append(
                    (value, last_reward, [0] * self.action_space_size, False)
                )
        return targets

    def to_play(self, state_index: int = None) -> Player:
        return Player()

    def action_history(self) -> ActionHistory:
        return ActionHistory(self.history, self.action_space_size)

    def print_game(self, state_index: int):
        pass

    def get_score(self, state_index: int):
        return len(self.history)


class NetworkOutput(typing.NamedTuple):
    value: np.ndarray
    reward: np.ndarray
    policy_logits: np.ndarray
    hidden_state: Any


class Network(object):
    def __init__(self):
        self.n_training_steps = 0

    def initial_inference(self, image, player) -> NetworkOutput:
        # representation + prediction function
        return NetworkOutput(0, 0, {}, [])

    def recurrent_inference(self, hidden_state, action) -> NetworkOutput:
        # dynamics + prediction function
        return NetworkOutput(0, 0, {}, [])

    def get_weights(self):
        # Returns the weights of this network.
        return []

    def set_weights(self, weights):
        # Sets the weights of this network.
        pass

    def training_steps(self) -> int:
        # How many steps / batches the network has been trained for.
        return self.n_training_steps

    def increment_training_steps(self):
        self.n_training_steps += 1

    def update_weights(
        self,
        config: MuZeroConfig,
        optimizer: optim.Optimizer,
        batch,
    ):
        # Update the weights of this network given a batch of data.
        return 0
\$\endgroup\$

2 Answers 2

5
\$\begingroup\$

You aim to teach folks to write good code. A noble goal!

deps

Before a student can run or enhance this codebase, they first need to get over some installation hurdles. This is a common hangup for python newbies, and worth devoting some care to.

First we must clone the repo. But it is huge, something I will touch on below.

Then we have some text to wade through. It is well organized and not too long, so we'll soon find the how to run section. In my opinion it's better to offer the user short commands which will definitely succeed the first time and which reveal a path for exploring what longer interior commands do.

Here, I encourage you to do that with a Makefile, or if need be with a brief bash script that uses set -x so it is nearly as self explanatory as a make run. In particular, there's no reason for the student to attend to the administrivia of unpacking pickles in the crucial initial minutes where we're trying to show them something cool, quickly, to pull them in. Beware the danger of someone becoming a little frustrated before they're hooked, and wandering off to other pursuits.

The text completely ignores this detail:

$ python -c 'import numpy, torch'

We hope it will silently succeed. If it doesn't, we should offer some sort of venv / poetry advice. Ideally a makefile would automatically notice deficiencies and correct them.

requirements.txt

I don't know what interpreter version(s) you target, and more importantly what versions of numpy and torch you have tested against. You must add a conda environment.yml, a venv requirements.txt, or poetry.lock file to the repo. Else the codebase is unmaintainable.

CI / CD

Consider publishing this package on PyPI, and taking advantage of the icons they can auto-generate via continuous-integration continuous-deployment workflows. Or leverage GitHub's support for automatic workflows that will run unit tests for you. Remember, this isn't just about making life easy for you as a developer. It's about instilling confidence in a student some months down the road who hits a speedbump and needs the courage to persevere and make things work.

repo URL

Students use copy-n-paste all the time, and don't always include citations. Put at least one homepage URL, such as for the github repo, in your README.md. Then folks using a fork won't be confused.

early win

Consider adding a doc/ directory that includes example run transcripts and images. Invite the student to quickly reproduce such results, then make some trivial change like altering the discount rate, to see revised results. It is in this way that you will pull folks in and send them on their hacking journey.

serialized NN

size

Your muzero_2048.pkl is more than 520 MiB. I claim a github repo is not the appropriate place to put that. Or at least, keep your source repo "small" and banish giant binary pickles to another repo or some other download location. GitHub offers e.g. LFS support. Many pip-installed packages use the requests package to transparently cache binary assets without the end user needing to really worry about them.

evolution

Git deals with diff-able text much better than with opaque binary blobs, especially as they evolve over time. Every student that forks or clones your repo will need to download the giant binary history, even after the original pickles have rev'd or been deleted.

format

Pickle is not my favorite format. Consider adopting PyArrow's .parquet file format, which offers better support for compression and zero-copy access.

python3 classes

class Game(object):

Maybe you had some didactic goal related to inheritance when you wrote that? Or maybe it's just boilerplate copy-n-paste?

In python2 that line distinguishes a "new-style" class MRO from an earlier scheme. In python3 there's no different behavior to distinguish between, so it's better to simply write class Game:

In particular, it is better to train students to write that.

zomg we see this everywhere. Global cleanup across three files, please.

comments

        self.history = (
            []
        )  # history of prev actions; used for recurrent inference for training

Thank you for using black; I truly do appreciate it. But here the formatting is rather mangled. Black is essentially suggesting that you instead author these two lines of code:

        # history of prev actions; used for recurrent inference for training
        self.history = []

Similarly for child_visits.

Consider arranging for $ make lint to run the black formatter.

Also, elide the uninformative "Game specific environment" remark. If need be, beef up this docstring instead:

class Environment(object):
    """The environment MuZero is interacting with."""

Or if you feel the class name is too vague (I don't!) then rename to GameEnvironment.

lint

Recommend you routinely use ruff check *.py.

        if state_index == None:

Prefer is None here in this GameConnect4 to_play() method. Yeah, I know, it's a slightly odd python thing related to singletons, sorry. But linters are here to help!

pre-condition

        self.discount = discount

A student might plausibly mess this up, and it's easy to validate that it's in the unit interval. Consider adding an assert or a conditional raise ValueError

pass

    def terminal(self) -> bool:
        # Game specific termination rules.
        pass

Teaching students about the pass keyword should be deferred as long as possible, as it is one of the less intuitive aspects of the python grammar. I can't tell you how many times I have seen it inappropriately embedded in if / else clauses, alone or with other statements. Or incorrectly attempting to break out of a loop. If you can satisfy syntax rules with something else, I recommend doing that. I might even put an ... Ellipsis object there, which we evaluate and then discard the value. Here, turning the comment into a """docstring""" seems the best option.

abc

This raises another question: should we maybe raise a diagnostic that explains the student's obligation to supply some sensible implementation for the method?

Should this class perhaps be an abstract base class?

parallel lists

This appears to be Fortran-style COMMON declared parallel arrays:

    def apply(self, action: Action):
        reward = ...
        self.rewards.append(reward)
        self.history.append(action)

That is, I think that rewards[i] and history[i] always refer to the same time step. But I'm not sure; the code doesn't make it clear.

Consider using this instead:

from collections import namedtuple

HistoryRecord = namedtuple("HistoryRecord", "action, reward")
...
        self.history.append(HistoryRecord(action, reward))

ZeroDivisionError

        sum_visits = sum(child.visit_count for child in root.children.values())

I imagine it is possible to prove that this sum is always positive. It isn't obvious to me that that will be so. Let's just say that the subsequent ... / sum_visits expression makes me nervous.

comments are not docstrings

    def make_target(self, state_index: int, num_unroll_steps: int, td_steps: int):
        # The value target is the discounted root value of the search tree N steps
        # into the future, plus the discounted sum of all rewards until then.

Please teach students good habits. This should be a docstring. And it should begin with a sentence which describes its single responsibility, before it goes on to helpfully explain the meaning of "target".

Similarly for most of the class Network methods.

Optional[reward]

                value += (
                    reward * self.discount**i
                )  # pytype: disable=unsupported-operands

I imagine that reward can be None here? Consider using a no-op filter or an assert to convince mypy that you've already dealt with that case by the time we start multiplying. Tacking on linter comments is sometimes very useful, but here it seems easier to just deal with the type ambiguity up front. Then students will have fewer things to scratch their heads about.

helpful comments

                # States past the end of games are treated as absorbing states.

You included some insightful remarks such as this one, for which I thank you.

the typing module is being phased out

class NetworkOutput(typing.NamedTuple):

Maybe there's a better way to phrase that annotation? I honestly don't know. So consider making this a @dataclass, just because they're easier to annotate.

(At a future time the typing module will be deleted or at least it will become much much smaller. It offered a good bridge to modern practice, and now its days are numbered.)

env vars up front

Please use $ isort *.py in addition to black. Pep-8 asks that you organize the following imports in a different way. I find it very helpful to know that e.g. import resource refers to this and not that (which your source initially made me believe had been pip installed).

# Lint as: python3
"""Pseudocode description of the MuZero algorithm."""
# pylint: disable=unused-argument
# pylint: disable=missing-docstring
# pylint: disable=g-explicit-length-test

import collections
import typing
import os
from typing import Any, Dict, List, Optional

os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

import numpy as np
import torch
from torch import optim, nn
from torch.nn import functional as F
import pickle as pkl
import os
import torch.multiprocessing as mp
import time
import resource

##########################
####### Helpers ##########

Alas, there is a fly in the ointment! We need three assignments prior to importing, and we can't have isort reorder them.

But why are we even doing it that way? Please prefer the following shebang.

#! /usr/bin/env MKL_NUM_THREADS=1 NUMEXPR_NUM_THREADS=1 OMP_NUM_THREADS=1 python3

no ASCII art

https://github.com/fnclovers/Minimal-AlphaZero/blob/1d7232bcd/alpha_zero.py#L26

##########################
####### Helpers ##########

MAXIMUM_FLOAT_VALUE = ...
##### End Helpers ########
##########################

##################################
####### Part 1: Self-Play ########


# Each self-play job is independent ...
######### End Self-Play ##########
##################################

##################################
####### Part 2: Training #########

No, please do not train students to write code that needs such visual decoration.

Rather, use the tools python gave you for organizing sections of code.

Use modules, e.g. alphazero_p1_self_play, alphazero_p2_training, etc.

Additionally, rather than # comments it would be much better to introduce a module with a module-level """docstring""" before the first line of code.


  1. Is the code simple enough for high school students to look at and easily understand the MuZero algorithm?

I believe so, yes. I commented on the occasional rough edge which I feel could be smoothed off a bit to serve that goal.

  1. Is the code flexible enough to allow students to add their own games to the project?

I am not in high school. Honestly, I think we need to perform the experiment, observing how the first two or three students interact with the project. There will be lessons learned from that experience, that will allow refining the codebase to make things easier or more engaging for a subsequent cohort of students. Qualifying participants with a simple prerequisite python project, less ambitious than this one, might be helpful, as would steering students toward specific OS platforms / environments that have been well tested.

With all UX studies, we tend to learn a great deal from the first handful of examples, since those users are dissimilar from the authors and will reveal hidden assumptions. As we iterate, we tend to need larger and larger cohorts to test the evolving system. That is, our learning rate slows down as the system matures.


Careful thought went into authoring this high-quality code. It achieves its design goals.

I would be willing to delegate or accept maintenance tasks on it.

\$\endgroup\$
5
  • \$\begingroup\$ Regarding "Please prefer the following shebang." -- Is this method of setting environment variables portable across Linux, Windows, and MacOS systems? \$\endgroup\$ Commented Mar 22 at 18:08
  • \$\begingroup\$ @adabsurdum It is portable across posix systems. It works great on all the target systems I care about, including FreeBSD. I don't use Windows, so I have no personal experience there. In particular, I don't know whether e.g. cygwin ensures that /usr/bin/env exists, or whether the target system might typically make some other arrangement for it. The env program is admirably simple and very portable, so it's not like it's a difficult bar to get over, even for a non-posix system. Launching an app from an env A=b python foo.py makefile target, or from a brief bash script, is also easy. \$\endgroup\$
    – J_H
    Commented Mar 22 at 18:35
  • \$\begingroup\$ AFAIK Python supports shebangs with /usr/bin/env in Python scripts on Windows, or rather emulates it in some way, but I'm just not sure if it handles arguments in the same way there. I rarely use Python, but sadly I have been doing some work in Windows of late which made me wonder about portability concerns that OP might need to address in satisfying the needs of students. \$\endgroup\$ Commented Mar 22 at 18:45
  • \$\begingroup\$ Thank you for your kind and insightful feedback. I completely overlooked the challenges students face during the initial installation process... I also didn't consider that students might frequently fork the repository. I've made some changes to the code based on your feedback, and I'll be happy to address the remaining points as well. If you provide me with your email address, I fully willing to give maintenance permission. Of course, you're welcome to fork the repository and use it for your own assignments. \$\endgroup\$
    – user281935
    Commented Mar 23 at 7:17
  • \$\begingroup\$ It's sort of a formulaic response I use. The typical goal of conducting a Code Review near end of a team's sprint is to determine if the newly created code artifact is ready to become a Team artifact which everyone is able to assume maintenance responsibilities for. So we must answer two questions about the code: (1.) Is it correct? (no obvious bugs, conforms to spec, produces expected output), and (2.) Is it maintainable? (in a few months could a junior dev fix bugs, enhance performance, add a feature without speaking with original author). I responded in the affirmative: "yes", "yes" \$\endgroup\$
    – J_H
    Commented Mar 23 at 14:11
5
\$\begingroup\$

The code layout is good, and you used meaningful names for classes, functions and variables.

Assuming this code is to be used as a starting point for the students, I recommend adding comments at the top which summarize the purpose of the code (for example, it is a template).

It is good that the code already has generic comments which lead the student, such as:

    # Game specific termination rules.

But, you could formalize that a little bit. Adorn the comments with a unique string, like TBD_COMMENT.
The header comments should also instruct the students to delete the TBD comments and replace them with comments and/or docstrings specific to their code.

Give the student incentive to remove the TBD comments (like 5 points added to their grade for the assignment if the code has no TBD's). With boilerplate code such as this, people often add generic TBD comments. But, the end user never deletes them, leaving behind a lot of clutter.

Here is what the header instructions might look like:

"""
TBD_COMMENT

This is a template you can use for your own game.

Once you have added your own code to this template, you must delete
all comments marked as "TBD_COMMENT", and replace them with
comments or docstrings specific to your code.

Remember to replace this header documentation as well.
"""

If there are some classes or functions which you require to remain in the code, you should also denote that somehow in the comments.

Similarly, if the student has no need to implement some of the functions, recommend that they be deleted.

\$\endgroup\$
1
  • 1
    \$\begingroup\$ Thank you for sharing the practices of other schools. I agree that including these notes will certainly help them understand what parts they need to revise. I will add the above notes to the assignments before distributing this project to the students. \$\endgroup\$
    – user281935
    Commented Mar 23 at 7:22

Not the answer you're looking for? Browse other questions tagged or ask your own question.