I'm writing tutorial code to help high school students understand the MuZero algorithm. There are two main requirements for this code.
- The code needs to be simple and easy for any student to understand.
- The code needs to be flexible enough so that students can implement it for their own game projects without changing the core code.
Here is the code repository: https://github.com/fnclovers/Minimal-AlphaZero/tree/master
To write this code, I followed these criteria:
For flexibility: Students need to modify the Game class to define their game environment, the Action class to define in-game actions, and the Network class for inference, in order to apply the MuZero algorithm to their own game.
- Game class: This core class must provide four functions: 1) Check for a terminal condition (def terminal(self) -> bool:), 2) Return a list of possible actions (def legal_actions(self) -> List[Action]:), 3) Perform an action in the game (def apply(self, action: Action):), 4) Create an image that stores the board state for reasoning and learning (def make_image(self, state_index: int):).
- Network class: The Network class has to infer four things: value, reward, policy, and hidden. Value means the sum of future rewards, reward is the current reward, policy is the probability of choosing each action, and hidden is the hidden network for inferring the next state. Students need to implement two functions:
initial_inference
, where the game image is provided, andrecurrent_inference
, where the subsequent actions are provided.
For simplicity: To make the code easy to understand, we kept only the core algorithm (900 lines of code). At the same time, we support multiple CPUs and GPUs so that reinforcement learning algorithms that require hours of training can be trained in a feasible time. However, I would like to find a way to write the MuZero algorithms more concisely and easily.
Please review the code from the following perspective:
- Is the code simple enough for high school students to look at and easily understand the MuZero algorithm?
- Is the code flexible enough to allow students to add their own games to the project?
(Sample Game class and Network class for explanation. See more code on github)
class Game(object):
"""A single episode of interaction with the environment."""
def __init__(self, action_space_size: int, discount: float):
self.environment = Environment() # Game specific environment.
self.history = (
[]
) # history of prev actions; used for recurrent inference for training
self.rewards = [] # rewards of prev actions; used for training dynamics network
self.child_visits = (
[]
) # child visit probabilities; used for training policy network
self.root_values = []
self.action_space_size = action_space_size
self.discount = discount
def terminal(self) -> bool:
# Game specific termination rules.
pass
def legal_actions(self) -> List[Action]:
# Game specific calculation of legal actions.
return []
def apply(self, action: Action):
reward = self.environment.step(action)
self.rewards.append(reward)
self.history.append(action)
def store_search_statistics(self, root: Node):
sum_visits = sum(child.visit_count for child in root.children.values())
action_space = (Action(index) for index in range(self.action_space_size))
self.child_visits.append(
[
root.children[a].visit_count / sum_visits if a in root.children else 0
for a in action_space
]
)
self.root_values.append(root.value())
def make_image(self, state_index: int):
# Game specific feature planes.
return []
def make_target(self, state_index: int, num_unroll_steps: int, td_steps: int):
# The value target is the discounted root value of the search tree N steps
# into the future, plus the discounted sum of all rewards until then.
targets = []
for current_index in range(state_index, state_index + num_unroll_steps + 1):
bootstrap_index = current_index + td_steps
if bootstrap_index < len(self.root_values):
value = self.root_values[bootstrap_index] * self.discount**td_steps
else:
value = 0
for i, reward in enumerate(self.rewards[current_index:bootstrap_index]):
value += (
reward * self.discount**i
) # pytype: disable=unsupported-operands
if current_index > 0 and current_index <= len(self.rewards):
last_reward = self.rewards[current_index - 1]
else:
last_reward = 0
if current_index < len(self.root_values):
# 1) image[n] --> pred[n], value[n], hidden_state[n]
# 2) hidden_state[n] + action[n] --> reward[n], pred[n+1], value[n+1], hidden_state[n+1]
targets.append(
(value, last_reward, self.child_visits[current_index], True)
)
else:
# States past the end of games are treated as absorbing states.
targets.append(
(value, last_reward, [0] * self.action_space_size, False)
)
return targets
def to_play(self, state_index: int = None) -> Player:
return Player()
def action_history(self) -> ActionHistory:
return ActionHistory(self.history, self.action_space_size)
def print_game(self, state_index: int):
pass
def get_score(self, state_index: int):
return len(self.history)
class NetworkOutput(typing.NamedTuple):
value: np.ndarray
reward: np.ndarray
policy_logits: np.ndarray
hidden_state: Any
class Network(object):
def __init__(self):
self.n_training_steps = 0
def initial_inference(self, image, player) -> NetworkOutput:
# representation + prediction function
return NetworkOutput(0, 0, {}, [])
def recurrent_inference(self, hidden_state, action) -> NetworkOutput:
# dynamics + prediction function
return NetworkOutput(0, 0, {}, [])
def get_weights(self):
# Returns the weights of this network.
return []
def set_weights(self, weights):
# Sets the weights of this network.
pass
def training_steps(self) -> int:
# How many steps / batches the network has been trained for.
return self.n_training_steps
def increment_training_steps(self):
self.n_training_steps += 1
def update_weights(
self,
config: MuZeroConfig,
optimizer: optim.Optimizer,
batch,
):
# Update the weights of this network given a batch of data.
return 0