5
\$\begingroup\$

I have an interest in how the number of ways you can reach from the origin to a certain distance on both square and hexagonal lattices. If L is the square of the distance and integer pairs (i, j) are the lattice coordinates (lattice directions separated by 90° and 60°, respectively) then we have L_square = i² + j², and L_hex = i² + i*j + j².

The shortest distance you can get to in two different ways on a square grid is 50 via (5, 5) and (7, 1), and on the hexagonal grid it's 91 via (6, 5) and (9, 1).

The Python script below makes dictionaries with L values as keys and a list of (i, j) pairs as the values. It is making (and of course saving) these dictionaries that is my goal. I understand there is a lot of mathematics associated with this problem, but for the purposes of this question I simply want to increase the speed of this calculation.

The calculation is limited by nmax the highest integer value considered for i and j.

The first plot shows that time.process_time() increases with nmax with faster-than-power-law dependence, i.e. it's convex on a log-log plot. For example doubling nmax from 5,000 to 10,000 increases my process time of get_dicts by a factor of 20 (60 vs 1200 sec)

My current solution is admittedly "loopy" and grows dictionaries within the loop. It was fast and easy to write, but I have a hunch that a smarter approach might run substantially faster - either by finding an O(n²) or even an O(n*log(n)) solution or decreasing that 2.5 microsecond coefficient by using clever existing Python methods or avoiding some memory issues.

Question: Can my script to identify distance degeneracies on square and hexagonal lattices be substantially faster?

nmax = 5,000 is the largest problem I can comfortably do on my laptop. At 10,000 it takes a half-hour, eats 17 GB of RAM + disk swap and makes my laptop hard to use for anything else except reheating my coffee. The script is set at 2,000 which only takes about a minute.

process time for various nmax

number of different ways vs L

The script:

import matplotlib.pyplot as plt
from itertools import product
import json
from math import ceil
from time import process_time


def get_dicts(pairs):

    dict_hex, dict_square = dict(), dict()

    for pair in pairs:
        a, b = pair
        l_hex = a**2 + a*b + b**2
        l_square = a**2 + b**2
        try:
            dict_hex[l_hex].append(pair)
        except:
            dict_hex[l_hex] = [pair]
        try:
            dict_square[l_square].append(pair)
        except:
            dict_square[l_square] = [pair]

    return dict_hex, dict_square

nmaxes = 10, 20 , 50, 100, 200, 500, 1000, 2000,
# 5000 takes several minutes on a laptop
# 10000 not laptop resource-friendly

results = []

for nmax in nmaxes:
    t_start = process_time()
    
    # positive integers only (we are limited to one quadrant or sextant)
    integers = range(1, nmax+1)
    
    pairs = list(product(integers, repeat=2))
    
    # exclude trival differences (e.g. (2, 7) vs (7, 2)) can be recovered later
    pairs = [(a, b) for (a, b) in pairs if a >= b]
    
    t1 = process_time() - t_start
    
    dict_hex, dict_square = get_dicts(pairs)
    
    t2 = process_time() - t1
    
    results.append((t1, t2))

if True:
    names = 'dict_hex', 'dict_square'
    dicts = dict_hex, dict_square
    for name, dictio in zip(names, dicts):
        with open(name + '.json', 'w') as outfile:
            json.dump(dictio, outfile)
    

if True:
    timeses = zip(*results)
    labels = 'get pairs', 'get dicts'
    fig, ax = plt.subplots(1, 1)
    for times, label in zip(timeses, labels):
        ax.plot(nmaxes, times, 'ok')
        ax.plot(nmaxes, times, label=label)
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel('nmax')
    ax.set_ylabel('process_time (sec)')
    ax.legend()
    plt.show()


if True:
    fig, axes = plt.subplots(2, 1, figsize=[6, 7])
    dicts = dict_hex, dict_square
    titles = 'hex', 'square'
    xmaxes, ymaxes = [], []
    for dictio, ax, title in zip(dicts, axes, titles):
        x = dictio.keys()
        y = [len(thing) for thing in dictio.values()]
        ax.plot(x, y, '.k', ms=2)
        ax.set_xscale('log')
        ax.set_title(title)
        xmaxes.append(ax.get_xlim()[1])
        ymaxes.append(ax.get_ylim()[1])
    xmax = max(xmaxes)
    ymax = 2 * ((ceil(max(ymaxes)) >> 1) + 1)
    for iax, ax in enumerate(axes):
        ax.set_xlim(0.9, xmax)
        ax.set_ylim(0, ymax)
        ax.set_ylabel('number of ways')
        if iax == len(axes) - 1:
            ax.set_xlabel('L value')
    plt.suptitle('nmax: ' + str(nmax), fontsize=14)
    plt.subplots_adjust(bottom=0.1, top=0.9, hspace=0.3)
    plt.show()
\$\endgroup\$
0

1 Answer 1

7
\$\begingroup\$

The core algorithm is this:

integers = range(1, nmax+1)
pairs = list(product(integers, repeat=2))
pairs = [(a, b) for (a, b) in pairs if a >= b]
dict_hex, dict_square = get_dicts(pairs)  # loops over pairs

So you first create a list with all possible pairs, then you prune half of these pairs, then you loop over the pairs and do some computation with them. The code is compact and Pythony, but not efficient because of the very large lists being created. In theory, creating these lists is O(nmax²) (linear in the length of the list), but in practice very large lists will bog down the system, which introduces a non-linear cost.

I would write code like this:

for b in range(1, nmax+1):
   for a in range(b, nmax+1):
      # do work with this pair

Yes, it's less pretty code, you don't make use of the fancy list comprehensions, but it's more efficient because you don't create any large lists, you just generate the pairs and use them directly.

I think (but I'm not very versed in this aspect of Python) that using generators instead of lists would accomplish the same thing:

integers = range(1, nmax+1)
pairs = product(integers, repeat=2)  # a generator
pairs = ((a, b) for (a, b) in pairs if a >= b)  # a generator
dict_hex, dict_square = get_dicts(pairs)  # loops over pairs

If I'm not mistaken, pairs is a generator, which means you can iterate over it like a list, but it's never actually stored in memory like a list. The for loop inside get_dicts would cause each new pair to be computed.


This smells:

try:
   dict_hex[l_hex].append(pair)
except:
   dict_hex[l_hex] = [pair]

A better approach is using defaultdict:

from collections import defaultdict

dict_hex = defaultdict(list)
# ...
dict_hex[l_hex].append(pair)

I don't know how much faster this is, but I'm guessing it's more efficient than catching exceptions.


Storing a tuple in a list is more expensive than storing a single integer (internally more Python objects are being created for the tuple). You could consider encoding the pair as a single integer to save memory (if you're wanting to increase nmax further). For example as a + b * (nmax+1).

\$\endgroup\$
1
  • \$\begingroup\$ Oh this is wonderful; I've learned several very helpful things here! Indeed type(pairs) returns <class 'generator'>. I'll give this all a spin in within the day, but I have a hunch this is going to be a lot more friendly to my laptop. \$\endgroup\$
    – uhoh
    Commented Jul 2, 2023 at 6:04

Not the answer you're looking for? Browse other questions tagged or ask your own question.