I have created a trie in Python and am looking for feedback. Specifically, I am looking for feedback on:
- If my code is 'pythonic'
- If my logic is proper when inserting and retrieving from the trie
I have implemented the following logic:
- Insert a word into the trie
- Check if a word exists in the trie
- Given a prefix, get all possible words . Get all the words in the trie
My purpose for coding the trie is to improve my Python and algorithm knowledge.
from collections import deque
class Node:
def __init__(self, character, parent):
self.character = character
if self.character is not None:
self.character = self.character.lower()
self.parent = parent
self.children = dict()
self.terminus = False
def add(self, child_node):
self.children[child_node.character] = child_node
class Trie:
def __init__(self):
self._root = Node(None, None)
def insert(self, word):
if word:
current_node = self._root
for i, character in enumerate(self._normalize_word(word)):
if character in current_node.children:
current_node = current_node.children[character]
else:
child_node = Node(character, current_node)
current_node.add(child_node)
current_node = child_node
current_node.terminus = True
def __contains__(self, item):
current_node = self._root
contained = True
for symbol in self._normalize_word(item):
if symbol in current_node.children:
current_node = current_node.children[symbol]
else:
contained = False
break
return contained and current_node.terminus
def _normalize_word(self, word):
return word.strip().lower()
def _get_all_words(self, prefix, node, word_list):
if node.character:
prefix.append(node.character)
for child in node.children.values():
self._get_all_words(prefix, child, word_list)
if node.terminus:
word_list.append("".join([i[0] for i in prefix]))
if len(prefix) > 0:
prefix.pop()
def get_possible_words(self, prefix):
current_node = self._root
found_prefix = True
word_list = []
prefix_deque = deque()
for symbol in prefix.strip().lower():
if symbol in current_node.children:
current_node = current_node.children[symbol]
else:
found_prefix = False
break
if found_prefix:
self._get_all_words(prefix_deque, current_node, word_list)
# This is a bit kludgy - add the prefix to the rest of the characters found,
# but I take off the last character from the prefix because it was added
# in the _get_all_words method call since it is the current_node being passed into it.
# Is there a better way to do this?
word_list = list(map(lambda word: prefix[:len(prefix)-1] + word, word_list))
return word_list
def get_all_words(self):
word_list = []
self._get_all_words(deque(), self._root, word_list)
return word_list
Here is my unit test:
import unittest
from algorithms import trie
class TestTrie(unittest.TestCase):
@classmethod
def setUpClass(cls):
# root
# / \
# / \
# a * b
# / \ /
# d n * a
# / / \ / | \
# d * d * y * g * t * y *
# / \ \
# e s * h *
# /
# l *
# asterisk denotes a word
cls._trie = trie.Trie()
cls._trie.insert("a")
cls._trie.insert("add")
cls._trie.insert("an")
cls._trie.insert("and")
cls._trie.insert("any")
cls._trie.insert("bagel")
cls._trie.insert("bag")
cls._trie.insert("bags")
cls._trie.insert("bat")
cls._trie.insert("bath")
cls._trie.insert("bay")
cls._trie_length = 11 # magic number, the number of words in the trie
def test(self):
assert len(self._trie.get_all_words()) == self._trie_length
assert "a" in self._trie
assert "add" in self._trie
assert "an" in self._trie
assert "and" in self._trie
assert "any" in self._trie
assert "bagel" in self._trie
assert "bag" in self._trie
assert "bags" in self._trie
assert "bat" in self._trie
assert "bath" in self._trie
assert "bay" in self._trie
def test_duplicate_entries(self):
"""Adding a word that already exists should not create a new word in the trie"""
t = self._trie
t.insert("bag")
assert len(t.get_all_words()) == self._trie_length
assert "bag" in t
def test_mixed_case(self):
"""insert and retrieval are case insensitive"""
t = trie.Trie()
t.insert("APPLE")
t.insert("oRANge")
assert "apple" in t
assert "orange" in t
assert "APPLE" in t
assert "ORANGE" in t
assert "aPpLe" in t
assert "oRangE" in t
def test_hyphenated_words(self):
t = trie.Trie()
t.insert("e-mail")
t.insert("above-said")
t.insert("above-water")
t.insert("above-written")
t.insert("above")
t.insert("abode")
t.insert("exit")
assert len(t.get_all_words()) == 7
assert "abode" in t
assert "above" in t
assert "above-written" in t
assert "above-water" in t
assert "above-said" in t
assert "e-mail" in t
assert "exit" in t
def test_empty_trie(self):
t = trie.Trie()
assert len(t.get_all_words()) == 0
def test_first_symbol_is_a_word(self):
t = trie.Trie()
t.insert("a")
t.insert("apple")
assert "a" in t
assert "apple" in t
words = t.get_all_words()
assert len(words) == 2
assert "a" in words
assert "apple" in words
def test_get_possible_words(self):
prefix = 'an'
expected_words = ['an', 'and', 'any']
actual_words = self._trie.get_possible_words(prefix)
assert len(expected_words) == len(actual_words)
for word in expected_words:
assert word in actual_words
prefix = 'ba'
expected_words = ["bagel", "bag", "bags", "bat", "bath", "bay"]
actual_words = self._trie.get_possible_words(prefix)
assert len(expected_words) == len(actual_words)
for word in expected_words:
assert word in actual_words
def test_get_possible_words_no_more_words(self):
"""test that given a prefix that is a terminus with no children in the trie, returns that one word"""
prefix = 'any'
actual_words = self._trie.get_possible_words(prefix)
assert len(actual_words) == 1
assert prefix in actual_words
def test_get_possible_words_prefix_not_in_trie(self):
prefix = 'z'
actual_words = self._trie.get_possible_words(prefix)
assert len(actual_words) == 0