I have an interest in how the number of ways you can reach from the origin to a certain distance on both square and hexagonal lattices. If L is the square of the distance and integer pairs (i, j) are the lattice coordinates (lattice directions separated by 90° and 60°, respectively) then we have L_square = i² + j², and L_hex = i² + i*j + j².
The shortest distance you can get to in two different ways on a square grid is 50 via (5, 5) and (7, 1), and on the hexagonal grid it's 91 via (6, 5) and (9, 1).
The Python script below makes dictionaries with L values as keys and a list of (i, j) pairs as the values. It is making (and of course saving) these dictionaries that is my goal. I understand there is a lot of mathematics associated with this problem, but for the purposes of this question I simply want to increase the speed of this calculation.
The calculation is limited by nmax
the highest integer value considered for i and j.
The first plot shows that time.process_time()
increases with nmax with faster-than-power-law dependence, i.e. it's convex on a log-log plot. For example doubling nmax from 5,000 to 10,000 increases my process time of get_dicts by a factor of 20 (60 vs 1200 sec)
My current solution is admittedly "loopy" and grows dictionaries within the loop. It was fast and easy to write, but I have a hunch that a smarter approach might run substantially faster - either by finding an O(n²) or even an O(n*log(n)) solution or decreasing that 2.5 microsecond coefficient by using clever existing Python methods or avoiding some memory issues.
Question: Can my script to identify distance degeneracies on square and hexagonal lattices be substantially faster?
nmax = 5,000 is the largest problem I can comfortably do on my laptop. At 10,000 it takes a half-hour, eats 17 GB of RAM + disk swap and makes my laptop hard to use for anything else except reheating my coffee. The script is set at 2,000 which only takes about a minute.
The script:
import matplotlib.pyplot as plt
from itertools import product
import json
from math import ceil
from time import process_time
def get_dicts(pairs):
dict_hex, dict_square = dict(), dict()
for pair in pairs:
a, b = pair
l_hex = a**2 + a*b + b**2
l_square = a**2 + b**2
try:
dict_hex[l_hex].append(pair)
except:
dict_hex[l_hex] = [pair]
try:
dict_square[l_square].append(pair)
except:
dict_square[l_square] = [pair]
return dict_hex, dict_square
nmaxes = 10, 20 , 50, 100, 200, 500, 1000, 2000,
# 5000 takes several minutes on a laptop
# 10000 not laptop resource-friendly
results = []
for nmax in nmaxes:
t_start = process_time()
# positive integers only (we are limited to one quadrant or sextant)
integers = range(1, nmax+1)
pairs = list(product(integers, repeat=2))
# exclude trival differences (e.g. (2, 7) vs (7, 2)) can be recovered later
pairs = [(a, b) for (a, b) in pairs if a >= b]
t1 = process_time() - t_start
dict_hex, dict_square = get_dicts(pairs)
t2 = process_time() - t1
results.append((t1, t2))
if True:
names = 'dict_hex', 'dict_square'
dicts = dict_hex, dict_square
for name, dictio in zip(names, dicts):
with open(name + '.json', 'w') as outfile:
json.dump(dictio, outfile)
if True:
timeses = zip(*results)
labels = 'get pairs', 'get dicts'
fig, ax = plt.subplots(1, 1)
for times, label in zip(timeses, labels):
ax.plot(nmaxes, times, 'ok')
ax.plot(nmaxes, times, label=label)
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('nmax')
ax.set_ylabel('process_time (sec)')
ax.legend()
plt.show()
if True:
fig, axes = plt.subplots(2, 1, figsize=[6, 7])
dicts = dict_hex, dict_square
titles = 'hex', 'square'
xmaxes, ymaxes = [], []
for dictio, ax, title in zip(dicts, axes, titles):
x = dictio.keys()
y = [len(thing) for thing in dictio.values()]
ax.plot(x, y, '.k', ms=2)
ax.set_xscale('log')
ax.set_title(title)
xmaxes.append(ax.get_xlim()[1])
ymaxes.append(ax.get_ylim()[1])
xmax = max(xmaxes)
ymax = 2 * ((ceil(max(ymaxes)) >> 1) + 1)
for iax, ax in enumerate(axes):
ax.set_xlim(0.9, xmax)
ax.set_ylim(0, ymax)
ax.set_ylabel('number of ways')
if iax == len(axes) - 1:
ax.set_xlabel('L value')
plt.suptitle('nmax: ' + str(nmax), fontsize=14)
plt.subplots_adjust(bottom=0.1, top=0.9, hspace=0.3)
plt.show()