6
\$\begingroup\$

I made a Connect Four AI with minimax algorithm. It's my first bigger JavaFX project. Any help for improvements would be really appreciated.

screenshot

Board:

import java.util.ArrayList;

class Move {

    private int row, col;

    public Move(int row, int col) {
        this.row = row;
        this.col = col;
    }

    public int getRow() {
        return row;
    }

    public int getCol() {
        return col;
    }

    @Override
    public String toString() {
        return "Move{" +
                "row=" + row +
                ", col=" + col +
                '}';
    }
}


public class Board {

    private final int playerX = 1;
    private final int playerO = 2;

    private int rows;
    private int cols ;
    private int board[][];
    private int moves;
    private Move lastMove;
    private Ai ai;

    public Board(int rows, int cols, int depth) {
        this.rows = rows;
        this.cols = cols;
        this.board = new int[rows][cols];
        this.moves = 0;
        this.lastMove = new Move(0, 0);
        this.ai = new Ai(this, depth);
    }

    public int[][] getBoard() {
        return board;
    }

    public int getMoves() {
        return moves;
    }

    public Move getLastMove() {
        return lastMove;
    }

    public void doMove(Move m, int player) {
        board[m.getRow()][m.getCol()] = player;
    }

    public void undoMove(Move m) {
        board[m.getRow()][m.getCol()] = 0;
    }

    public void doAiMove(int player) {
        Move best = ai.getBestMove();
        board[best.getRow()][best.getCol()] = player;
        moves++;
        lastMove = best;
    }

    private Move generateMove(int col) {

        for(int r = rows - 1; r >= 0; r--) {
            if(board[r][col] == 0) {
                return new Move(r, col);
            }
        }
        return null;
    }

    public ArrayList<Move> generateMoves() {
        ArrayList<Move> moves = new ArrayList<>();

        for(int c = 0; c < cols; c++) {

            Move m = generateMove(c);

            if(m != null) {
                moves.add(m);
            }
        }
        return moves;
    }

    public void playerMove(int col) {
        Move m = generateMove(col);

        if(m != null) {
            doMove(m, playerX);
            moves++;
            lastMove = m;
        }
    }

    private int checkWin(int row, int col, int player) {

        if(col >= cols - 4) {
            if(board[row][col] == player && board[row][col-1] == player &&
                    board[row][col-2] == player && board[row][col-3] == player) {
                return player;
            }
        }

        if(col <= cols - 4) {
            if(board[row][col] == player && board[row][col+1] == player &&
                    board[row][col+2] == player && board[row][col+3] == player) {
                return player;
            }
        }

        if(row <= rows - 4) {
            if(board[row][col] == player && board[row+1][col] == player &&
                    board[row+2][col] == player && board[row+3][col] == player) {
                return player;
            }
        }

        if(row >= rows - 3) {
            if(board[row][col] == player && board[row-1][col] == player &&
                    board[row-2][col] == player && board[row-3][col] == player) {
                return player;
            }
        }

        if(col >= cols - 4 && row >= rows - 3) {
            if(board[row][col] == player && board[row-1][col-1] == player &&
                    board[row-2][col-2] == player && board[row-3][col-3] == player) {
                return player;
            }
        }

        if(col <= cols - 4 && row <= rows - 4) {
            if(board[row][col] == player && board[row+1][col+1] == player &&
                    board[row+2][col+2] == player && board[row+3][col+3] == player) {
                return player;
            }
        }

        if(col <= cols - 4 && row >= rows - 3) {
            if(board[row][col] == player && board[row-1][col+1] == player &&
                    board[row-2][col+2] == player && board[row-3][col+3] == player) {
                return player;
            }
        }

        if(col >= cols - 4 && row <= rows - 4) {
            if(board[row][col] == player && board[row+1][col-1] == player &&
                    board[row+2][col-2] == player && board[row+3][col-3] == player) {
                return player;
            }
        }

        return 0;
    }

    public int checkBoardState(Move lastMove, int board[][]) {

            int x = checkWin(lastMove.getRow(), lastMove.getCol(), playerX);
            int o = checkWin(lastMove.getRow(), lastMove.getCol(), playerO);

            if(x == playerX) return x;
            if(o == playerO) return o;

            if(moves == rows * cols) return 0;

            return -1;
    }

}

Ai:

import java.util.ArrayList;

public class Ai {

    private final int playerX = 1;
    private final int playerO = 2;
    private Board board;

    private int depth;
    private Move bestMove;

    public Ai(Board board, int depth) {
        this.board = board;
        this.depth = depth;
    }

    /*public void setDepth(int depth) {
        this.depth = depth;
    }*/

    private int min(int depth, int alpha, int beta) {

        ArrayList<Move> moves = board.generateMoves();
        int minValue = beta;

        if(depth == 0 || moves.size() == 0) {
            return evaluate(board.getBoard());
        }

        for(Move m: moves) {

            board.doMove(m, playerO);
            int v = max(depth - 1, alpha, minValue);
            board.undoMove(m);

            if(v < minValue) {
                minValue = v;

                if(minValue <= alpha) break;

                if(depth == this.depth) {
                    bestMove = m;
                }
            }

        }
        return minValue;
    }

    private int max(int depth, int alpha, int beta) {

        ArrayList<Move> moves = board.generateMoves();
        int maxValue = alpha;

        if(depth == 0 || moves.size() == 0) {
            return evaluate(board.getBoard());
        }

        for(Move m: moves) {

            board.doMove(m, playerX);
            int v = min(depth - 1, maxValue, beta);
            board.undoMove(m);

            if(v > maxValue) {
                maxValue = v;

                if(maxValue >= beta) break;

                if(depth == this.depth) {
                    bestMove = m;
                }

            }

        }
        return maxValue;
    }

    private int evaluateSegment(int[] s) {

        int countX = 0, countO = 0;

        for(int i = 0; i < s.length; i++) {
            if(s[i] == playerX) countX++;
            if(s[i] == playerO) countO++;
        }

        if(countX == 0) {
            if(countO == 4) return -1000;
            if(countO == 3) return -50;
            if(countO == 2) return -10;
            if(countO == 1) return -1;
        }

        if(countO == 0) {
            if(countX == 4) return 1000;
            if(countX == 3) return 50;
            if(countX == 2) return 10;
            if(countX == 1) return 1;
        }

        return 0;
    }

    private int evaluate(int board[][], int row, int col) {

        int rows = board.length;
        int cols = board[0].length;
        int score = 0;

        if(col >= cols - 4) {
           score += evaluateSegment(new int[] {board[row][col], board[row][col-1], board[row][col-2], board[row][col-3]});
        }

        if(col <= cols - 4) {
            score += evaluateSegment(new int[] {board[row][col], board[row][col+1], board[row][col+2], board[row][col+3]});
        }

        if(row <= rows - 4) {
            score += evaluateSegment(new int[] {board[row][col], board[row+1][col], board[row+2][col], board[row+3][col]});
        }

        if(row >= rows - 3) {
            score += evaluateSegment(new int[] {board[row][col], board[row-1][col], board[row-2][col], board[row-3][col]});
        }

        if(col >= cols - 4 && row >= rows - 3) {
            score += evaluateSegment(new int[]{board[row][col], board[row-1][col-1], board[row-2][col-2], board[row-3][col-3]});
        }

        if(col <= cols - 4 && row <= rows - 4) {
            score += evaluateSegment(new int[]{board[row][col], board[row+1][col+1], board[row+2][col+2], board[row+3][col+3]});
        }

        if(col <= cols - 4 && row >= rows - 3) {
            score += evaluateSegment(new int[]{board[row][col], board[row-1][col+1], board[row-2][col+2], board[row-3][col+3]});
        }

        if(col >= cols - 4 && row <= rows - 4) {
            score += evaluateSegment(new int[]{board[row][col], board[row+1][col-1], board[row+2][col-2], board[row+3][col-3]});
        }

        return score;
    }

    private int evaluate(int board[][]) {

        int score = 0;

        for(int r = 0; r < board.length; r++) {
            for(int c = 0; c < board[r].length; c++) {

                score += evaluate(board, r, c);

            }
        }

        return score;
    }


    public Move getBestMove() {
        max(depth, Integer.MIN_VALUE, Integer.MAX_VALUE);
        return bestMove;
    }

}

Main:

import javafx.application.Application;
import javafx.event.EventHandler;
import javafx.geometry.Insets;
import javafx.scene.Scene;
import javafx.scene.canvas.Canvas;
import javafx.scene.canvas.GraphicsContext;
import javafx.scene.control.Button;
import javafx.scene.layout.GridPane;
import javafx.scene.layout.VBox;
import javafx.scene.paint.Color;
import javafx.stage.Stage;
import javafx.event.ActionEvent;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;


public class Main extends Application {

    private final int rows = 6;
    private final int cols = 7;

    private final int playerX = 1;
    private final int playerO = 2;

    private final int canvasHeight = 600;
    private final int canvasWidth = 700;
    private final double cellHeight = (double) canvasHeight / rows;
    private final double cellWidth = (double) canvasWidth / cols;

    private int clickedButton = -1;

    private Board board;
    private int depth = 8;

    public Main() {
        board = new Board(rows, cols, depth);
    }

    @Override
    public void start(Stage primaryStage) {

        primaryStage.setTitle("Connect Four");

        GridPane gridPane = new GridPane();
        gridPane.setPadding(new Insets(20.0, 0.0, 20.0, 0.0));

        VBox box = new VBox();
        box.setPadding(new Insets(20.0,20.0,20.0,50.0));

        Canvas canvas = new Canvas(canvasWidth, canvasHeight);
        GraphicsContext gc = canvas.getGraphicsContext2D();
        drawBoard(gc);


        Button buttonArray[] = new Button[cols];

        for(int i = 0; i < cols; i++) {
            buttonArray[i] = new Button("Input");
            buttonArray[i].setMinWidth(cellWidth);
            buttonArray[i].setMaxWidth(cellWidth);
            buttonArray[i].setId(""+i);
            gridPane.add(buttonArray[i], i, 0);
        }

        for(Button b: buttonArray) {

            b.setOnAction(new EventHandler<ActionEvent>() {
                @Override
                public void handle(ActionEvent event) {

                    if(board.getMoves() % 2 == 0 && board.checkBoardState(board.getLastMove(), board.getBoard()) == -1) {

                        clickedButton = Integer.valueOf(b.getId());
                        board.playerMove(clickedButton);
                        repaintCanvas(gc);

                        for(Button button: buttonArray) {
                            button.setDisable(true);
                        }


                        if(board.getMoves() % 2 == 1 && board.checkBoardState(board.getLastMove(), board.getBoard()) == -1) {

                            ExecutorService es = Executors.newFixedThreadPool(1);
                            Runnable r = new Runnable() {

                                @Override
                                public void run() {

                                        board.doAiMove(playerO);
                                        repaintCanvas(gc);

                                        for(Button button: buttonArray) {
                                            button.setDisable(false);
                                        }

                                    }
                            };

                            es.execute(r);
                            es.shutdown();

                        }
                    }

                }

            });
        }

        VBox b = new VBox();
        b.setPadding(new Insets(20.0,20.0,20.0,350.0));

        Button restart = new Button("Restart");
        restart.setOnAction(new EventHandler<ActionEvent>() {
            @Override
            public void handle(ActionEvent event) {
                board = new Board(rows, cols, depth);
                repaintCanvas(gc);
            }
        });

        b.getChildren().add(restart);


        box.getChildren().addAll(gridPane, canvas, b);
        primaryStage.setScene(new Scene(box, 800, 800));
        primaryStage.setResizable(false);


        primaryStage.show();
    }


    private void drawBoard(GraphicsContext gc) {

        gc.setFill(Color.rgb(128, 255, 0));
        gc.fillRect(0, 0, canvasWidth, canvasHeight);

        gc.setFill(Color.BLACK);

        for(int i = 0; i <= rows; i++) {
            gc.strokeLine(0, i * cellHeight, canvasWidth, i * cellHeight);
        }

        for(int i = 0; i <= cols; i++) {
            gc.strokeLine(i * cellWidth, 0, i * cellWidth, canvasHeight);
        }

        int offset = 3;
        int board[][] = this.board.getBoard();

        for(int r = 0; r < rows; r++) {
            for(int c = 0; c < cols; c++) {

                if(board[r][c] == playerX) {
                    gc.setFill(Color.RED);
                    gc.fillOval(c * cellHeight, r * cellWidth, cellWidth - offset, cellHeight - offset);
                }

                if(board[r][c] == playerO) {
                    gc.setFill(Color.BLUE);
                    gc.fillOval(c * cellHeight, r * cellWidth, cellWidth - offset, cellHeight - offset);
                }

            }
        }
    }

    private void repaintCanvas(GraphicsContext gc) {
        gc.clearRect(0, 0, canvasWidth, canvasHeight);
        drawBoard(gc);
    }


    public static void main(String[] args) {
        Main m = new Main();
        launch(args);

    }
}
\$\endgroup\$
1
  • \$\begingroup\$ Note enough detail for an answer - but those colours are hurting the eyes! \$\endgroup\$
    – AJD
    Commented Apr 18, 2019 at 22:59

1 Answer 1

3
\$\begingroup\$

The Move class could have its getters removed by making use of final since its members row and col are immutable:

class Move {

    public final int row;
    public final int col;

    public Move(int row, int col) {
        this.row = row;
        this.col = col;
    }

    @Override
    public String toString() {
        return "Move{" +
                "row=" + row +
                ", col=" + col +
                '}';
    }
}

Then just access it directly like move.row. There's a discussion on the idea here. Of course, if the members were mutable, it would be a different story, but ints can't be altered.

I also put the declarations on separate lines.


It may be a better idea to make use of constants or an Enum for your playerX and O weights here (I'm assuming they're weights):

if(countO == 4) return -1000;
if(countO == 3) return -50;
if(countO == 2) return -10;
if(countO == 1) return -1;

In both cases, you're using the same numbers, just negated in the first case. If you ever change the weights in the future, you may forget to change the values in both places, and may get odd behavior as a result. Something like this may be better (although with better names):

final int EXTREME_WEIGHT = 1000;
final int HIGH_WEIGHT = 50;
final int MEDIUM_WEIGHT = 10;
final int LOW_WEIGHT = 1;

if(countX == 0) {
    if(countO == 4) return -EXTREME_WEIGHT;
    if(countO == 3) return -HIGH_WEIGHT;
    if(countO == 2) return -MEDIUM_WEIGHT;
    if(countO == 1) return -LOW_WEIGHT;
}

if(countO == 0) {
    if(countX == 4) return EXTREME_WEIGHT;
    if(countX == 3) return HIGH_WEIGHT;
    if(countX == 2) return MEDIUM_WEIGHT;
    if(countX == 1) return LOW_WEIGHT;
}

Now (when the names are corrected), the values will be self-explanatory, and you aren't risking asymmetrical changes in the future.

You may also find that a switch or Map would work well here too, although any gain from them would be unnecessary.

\$\endgroup\$

Not the answer you're looking for? Browse other questions tagged or ask your own question.