Monday, 22 August 2011

Solving Causes' Levenshtein Distance challenge in 28 seconds. (Python)

Causes has recently put up a coding challenge for prospective engineers who want to work with them. Here is their description, short and sweet:

Two words are friends if they have a Levenshtein distance of 1. That is, you can add, remove, or substitute exactly one letter in word X to create word Y. A word’s social network consists of all of its friends, plus all of their friends, and all of their friends’ friends, and so on. Write a program to tell us how big the social network for the word “causes” is, using this word list. Have fun!

The corpus size is 264061 words. A friend wrote a C solution that solved it in sub-300 seconds and challenged me to do better. Looking at it, it seemed like it would benefit from using Python's amazing set() data structure, so I gave it a shot. After some horrible bugs, I got the right answer: 78482 (including the original word 'causes').

My Python solution is in 28 lines, including comments. It also runs in 28 seconds on my antiquated Core Duo. Someone on Hacker News found that pypy also speeds this up by 70%. Not bad at all.

Other solutions I saw online take longer, which is why I decided to put this on here. Other attempts, (claimed runtime in parentheses): Ruby (50 sec), Ruby again (?), PHP (1.5 hours), Python (4 hours, 1 hour), Python again(?), Python redux(?), newlisp (?), C++ (1 min).

The secret behind my version's performance is Python's incredible set(), which does membership testing in O(1). I suppose the approach of generating all possible Levenshtein distance 1 strings and testing for membership also helped performance while keeping code size down. It's much easier to generate the small number of next steps rather than testing each of 260,000 options to see if they are a next step.

Without further ado, the code:

import string

w = set(open("00wordlist.txt").read().splitlines())
f, nf = set(), set(["causes"])

#from b, yield all unused words where levdist==1
def nextgen(b):
for i in range(len(b)): #for each index in b
for c in string.ascii_lowercase: #for letters [a..z]
if c != b[i]:
#substitute b[i] with c
if b[:i] + c + b[i+1:] in w:
yield b[:i] + c + b[i+1:]
#inject c before b[i]
if b[:i] + c + b[i:] in w:
yield b[:i] + c + b[i:]
#remove b[i]
if b[:i] + b[i+1:] in w: yield b[:i] + b[i+1:]

for c in string.ascii_lowercase: #for letters [a..z]
if b + c in w: yield b + c #append c after b

while len(nf):
cf = nf
nf = set([j for i in cf for j in nextgen(i) if j not in f])
w -= nf
f |= nf

print len(f)