From bc706018b92df380052fe0ab76af533bc8225fa8 Mon Sep 17 00:00:00 2001 From: Marianne Chevrot Date: Wed, 20 May 2020 17:56:17 +0200 Subject: speed up things --- solver.py | 60 ++++++++++++------------------------------------------------ 1 file changed, 12 insertions(+), 48 deletions(-) (limited to 'solver.py') diff --git a/solver.py b/solver.py index 8735f79..11ecce0 100755 --- a/solver.py +++ b/solver.py @@ -14,15 +14,13 @@ Options: --print Display only the original (unresolved) grid ''' -#import copy +import copy import os import pickle import time -#import zlib import docopt import yaml -import numpy class Kana: @@ -82,11 +80,7 @@ class KanaGrid: self.width = size[0] self.height = size[1] - #self.grid = grid - if type(grid) is list: - self.grid = numpy.array(grid) - else: - self.grid = grid + self.grid = copy.copy(grid) self.action_count = action_count self.score = score self.myst_count = myst_count @@ -95,15 +89,12 @@ class KanaGrid: def copy(self): new_grid = KanaGrid( (self.width, self.height), - None, - #self.grid.copy(), - #copy.copy(self.grid), + grid=self.grid, action_count=self.action_count, score=self.score, myst_count=self.myst_count, parent=self.parent, ) - new_grid.grid = self.grid.copy() return new_grid def is_swappable(self, pos1, pos2): @@ -234,19 +225,8 @@ class KanaGrid: # the bare minimum equal to 1. self.score, self.myst_count = self.longest_chain() - def get_tuple(self): - #data = ''.join(( - # str(self.width), - # str(self.height), - # str(self.grid), - # str(self.action_count), - # )) - #return zlib.crc32(data.encode('utf8')) - #return hash((self.width, self.height, str(self.grid), self.action_count)) - #my_tuple = (self.width, self.height, str(self.grid.all()), self.action_count) - #my_tuple = (self.width, self.height, str(self.grid.all()), self.action_count) - #print(f'my_tuple: {my_tuple}, hash {hash(my_tuple)}') - return (self.width, self.height, tuple(self.grid), self.action_count) + def get_grid_hashable(self): + return tuple(self.grid) def get_kana(self, pos): if pos[0] < 0 or pos[0] >= self.width: @@ -279,8 +259,6 @@ class KanaGrid: vect = (pos2[0] - pos1[0], pos2[1] - pos1[1]) while self.is_swappable(pos_src, pos_dst): - #print("swap between src %s (%s) dst %s (%s)" - # % (kana_src, pos_src, kana_dst, pos_dst)) self.set_kana(pos_src, kana_dst) self.set_kana(pos_dst, kana_src) @@ -400,9 +378,6 @@ def generate_possible_grids(kanagrid): for x in range(kanagrid.width): for action_type in KanaGrid.actions: new_grid = kanagrid.action((x, y), action_type) - #if new_grid: # and new_grid.grid != kanagrid.grid: # syntax in numpy ? - #if new_grid and not (new_grid.grid == kanagrid.grid).all(): # and new_grid.grid != kanagrid.grid: # syntax in numpy ? - #if new_grid and not numpy.array_equal(new_grid.grid, kanagrid.grid): if new_grid: # better perf to have only new_grid is None check and not # comparison if move has given same grid as having exact @@ -412,22 +387,22 @@ def generate_possible_grids(kanagrid): yield (x, y), action_type, new_grid -def generate_all_possible_grids(grid, taboos, max_actions): +def generate_all_possible_grids(grid, bests, max_actions): for pos, action_type, new_grid in generate_possible_grids(grid): - grid_hash = new_grid.get_tuple() - if grid_hash in taboos: + key = new_grid.get_grid_hashable() + if key in bests and bests[key].action_count <= new_grid.action_count: continue - taboos.add(grid_hash) + bests[key] = new_grid new_grid.parent = grid if new_grid.action_count >= max_actions: yield new_grid continue - yield from generate_all_possible_grids(new_grid, taboos, max_actions) + yield from generate_all_possible_grids(new_grid, bests, max_actions) def search_all_solution(kanagrid, target_score, max_actions): - taboos = set() - generator = generate_all_possible_grids(kanagrid, taboos=taboos, max_actions=max_actions) + bests = {} + generator = generate_all_possible_grids(kanagrid, bests=bests, max_actions=max_actions) for grid in generator: grid.update_score() if grid.score >= target_score and grid.myst_count == 0: @@ -448,16 +423,6 @@ def repr_grid_with_parents(grid): return '\n'.join(reversed(items)) -#def print_score_over(node, target_score): -# node.grid.update_score() -# if node.grid.score >= target_score: -# print("="*80) -# print(node_repr_with_parents(node)) -# return -# for child in node.children: -# print_score_over(child, target_score) - - def main(): args = docopt.docopt(__doc__) @@ -516,6 +481,5 @@ def main(): print(f'time taken to calculate: {hours:02d}:{minutes:02d}:{seconds:02d}') - if __name__ == '__main__': main() -- cgit v1.2.3