diff --git a/README.rst b/README.rst index ac6cea1..3105514 100644 --- a/README.rst +++ b/README.rst @@ -34,6 +34,8 @@ Currently implemented methods are: * setdefault() * keys() * iterkeys() +* keys_with_prefix() +* iterkeys_with_prefix() Other methods are not implemented. @@ -112,6 +114,7 @@ Authors & Contributors ---------------------- * Mikhail Korobov +* Felix Liu This module is based on `hat-trie`_ C library by Daniel Jones & contributors. diff --git a/hat-trie/.gitignore b/hat-trie/.gitignore new file mode 100644 index 0000000..323a599 --- /dev/null +++ b/hat-trie/.gitignore @@ -0,0 +1,11 @@ +*.m4 +*.o +*.lo +*.a +*.la +config.* +.DS_Store +.deps +.libs +Makefile + diff --git a/hat-trie/m4/.gitkeep b/hat-trie/m4/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/hat-trie/src/ahtable.c b/hat-trie/src/ahtable.c index 01bb4a9..81aea64 100644 --- a/hat-trie/src/ahtable.c +++ b/hat-trie/src/ahtable.c @@ -14,7 +14,7 @@ const double ahtable_max_load_factor = 100000.0; /* arbitrary large number => don't resize */ -const const size_t ahtable_initial_size = 4096; +const size_t ahtable_initial_size = 4096; static const uint16_t LONG_KEYLEN_MASK = 0x7fff; static size_t keylen(slot_t s) { diff --git a/hat-trie/src/hat-trie.c b/hat-trie/src/hat-trie.c index 8b87752..3f9b5cd 100644 --- a/hat-trie/src/hat-trie.c +++ b/hat-trie/src/hat-trie.c @@ -63,10 +63,10 @@ static trie_node_t* alloc_trie_node(hattrie_t* T, node_ptr child) trie_node_t* node = malloc_or_die(sizeof(trie_node_t)); node->flag = NODE_TYPE_TRIE; node->val = 0; - + /* pass T to allow custom allocator for trie. */ HT_UNUSED(T); /* unused now */ - + size_t i; for (i = 0; i < NODE_CHILDS; ++i) node->xs[i] = child; return node; @@ -120,7 +120,7 @@ static node_ptr hattrie_find(hattrie_t* T, const char **key, size_t *len) if (*len == 0) return parent; node_ptr node = hattrie_consume(&parent, key, len, 1); - + /* if the trie node consumes value, use it */ if (*node.flag & NODE_TYPE_TRIE) { if (!(node.t->flag & NODE_HAS_VAL)) { @@ -131,10 +131,10 @@ static node_ptr hattrie_find(hattrie_t* T, const char **key, size_t *len) /* pure bucket holds only key suffixes, skip current char */ if (*node.flag & NODE_TYPE_PURE_BUCKET) { - *key += 1; + *key += 1; *len -= 1; } - + /* do not scan bucket, it's not needed for this operation */ return node; } @@ -180,6 +180,13 @@ void hattrie_free(hattrie_t* T) free(T); } + +size_t hattrie_size(hattrie_t* T) +{ + return T->m; +} + + /* Perform one split operation on the given node with the given parent. */ static void hattrie_split(hattrie_t* T, node_ptr parent, node_ptr node) @@ -391,16 +398,70 @@ value_t* hattrie_tryget(hattrie_t* T, const char* key, size_t len) if (node.flag == NULL) { return NULL; } - + /* if the trie node consumes value, use it */ if (*node.flag & NODE_TYPE_TRIE) { return &node.t->val; } - + return ahtable_tryget(node.b, key, len); } +void hattrie_walk (hattrie_t* T, const char* key, size_t len, void* user_data, hattrie_walk_cb cb) { + unsigned char* k = (unsigned char*)key; + node_ptr node = T->root; + size_t i, j; + ahtable_iter_t* it; + + /* go down until a bucket is reached */ + for (i = 0; i < len; i++, k++) { + if (!(*node.flag & NODE_TYPE_TRIE)) + break; + node = node.t->xs[*k]; + if (*node.flag & NODE_HAS_VAL) { + if (hattrie_walk_stop == cb(key, i, &node.t->val, user_data)) + return; + } + } + if (i == len) + return; + + assert(i); + if (*node.flag & NODE_TYPE_HYBRID_BUCKET) { + i--; + k--; + } else { + assert(*node.flag & NODE_TYPE_PURE_BUCKET); + } + + /* dict order ensured short => long */ + it = ahtable_iter_begin(node.b, true); + for(; !ahtable_iter_finished(it); ahtable_iter_next(it)) { + size_t stored_len; + unsigned char* stored_key = (unsigned char*)ahtable_iter_key(it, &stored_len); + int matched = 1; + if (stored_len + i > len) { + continue; + } + for (j = 0; j < stored_len; j++) { + if (stored_key[j] != k[j]) { + matched = 0; + break; + } + } + if (matched) { + value_t* val = ahtable_iter_val(it); + if (hattrie_walk_stop == cb(key, i + stored_len, val, user_data)) { + ahtable_iter_free(it); + return; + } + } + } + ahtable_iter_free(it); +} + + int hattrie_del(hattrie_t* T, const char* key, size_t len) { node_ptr parent = T->root; @@ -411,7 +472,7 @@ int hattrie_del(hattrie_t* T, const char* key, size_t len) if (node.flag == NULL) { return -1; } - + /* if consumed on a trie node, clear the value */ if (*node.flag & NODE_TYPE_TRIE) { return hattrie_clrval(T, node); @@ -421,10 +482,10 @@ int hattrie_del(hattrie_t* T, const char* key, size_t len) size_t m_old = ahtable_size(node.b); int ret = ahtable_del(node.b, key, len); T->m -= (m_old - ahtable_size(node.b)); - + /* merge empty buckets */ /*! \todo */ - + return ret; } @@ -460,6 +521,11 @@ struct hattrie_iter_t_ bool sorted; ahtable_iter_t* i; hattrie_node_stack_t* stack; + + // subtree inside a table + // store remaining prefix for filtering nodes not matching it + char* prefix; + size_t prefix_len; }; @@ -507,7 +573,7 @@ static void hattrie_iter_nextnode(hattrie_iter_t* i) /* push all child nodes from right to left */ int j; for (j = NODE_MAXCHAR; j >= 0; --j) { - + /* skip repeated pointers to hybrid bucket */ if (j < NODE_MAXCHAR && node.t->xs[j].t == node.t->xs[j + 1].t) continue; @@ -524,17 +590,66 @@ static void hattrie_iter_nextnode(hattrie_iter_t* i) if (*node.flag & NODE_TYPE_PURE_BUCKET) { hattrie_iter_pushchar(i, level, c); } - else { + else if (level) { i->level = level - 1; } - i->i = ahtable_iter_begin(node.b, i->sorted); + } } +/** next non-nil-key node + * TODO pick a better name + */ +static void hattrie_iter_step(hattrie_iter_t* i) +{ + while (((i->i == NULL || ahtable_iter_finished(i->i)) && !i->has_nil_key) && + i->stack != NULL ) { + + ahtable_iter_free(i->i); + i->i = NULL; + hattrie_iter_nextnode(i); + } + + if (i->i != NULL && ahtable_iter_finished(i->i)) { + ahtable_iter_free(i->i); + i->i = NULL; + } +} + +static bool hattrie_iter_prefix_not_match(hattrie_iter_t* i) +{ + if (hattrie_iter_finished(i)) { + return false; // can not advance the iter + } + if (i->level >= i->prefix_len) { + return memcmp(i->key, i->prefix, i->prefix_len); + } else if (i->has_nil_key) { + return true; // subkey too short + } + + size_t sublen; + const char* subkey; + subkey = ahtable_iter_key(i->i, &sublen); + if (i->level + sublen < i->prefix_len) { + return true; // subkey too short + } + return memcmp(i->key, i->prefix, i->level) || + memcmp(subkey, i->prefix + i->level, (i->prefix_len - i->level)); +} + + hattrie_iter_t* hattrie_iter_begin(const hattrie_t* T, bool sorted) { + return hattrie_iter_with_prefix(T, sorted, NULL, 0); +} + + +hattrie_iter_t* hattrie_iter_with_prefix(const hattrie_t* T, bool sorted, const char* prefix, size_t prefix_len) +{ + node_ptr node = hattrie_find((hattrie_t*)T, &prefix, &prefix_len); + hattrie_iter_t* i = malloc_or_die(sizeof(hattrie_iter_t)); i->T = T; i->sorted = sorted; @@ -545,24 +660,23 @@ hattrie_iter_t* hattrie_iter_begin(const hattrie_t* T, bool sorted) i->has_nil_key = false; i->nil_val = 0; + i->prefix_len = prefix_len; + if (prefix_len) { + i->prefix = (char*)malloc_or_die(prefix_len); + memcpy(i->prefix, prefix, prefix_len); + } else { + i->prefix = NULL; + } + i->stack = malloc_or_die(sizeof(hattrie_node_stack_t)); i->stack->next = NULL; - i->stack->node = T->root; + i->stack->node = node; i->stack->c = '\0'; i->stack->level = 0; - - while (((i->i == NULL || ahtable_iter_finished(i->i)) && !i->has_nil_key) && - i->stack != NULL ) { - - ahtable_iter_free(i->i); - i->i = NULL; - hattrie_iter_nextnode(i); - } - - if (i->i != NULL && ahtable_iter_finished(i->i)) { - ahtable_iter_free(i->i); - i->i = NULL; + hattrie_iter_step(i); + if (i->prefix_len && hattrie_iter_prefix_not_match(i)) { + hattrie_iter_next(i); } return i; @@ -571,29 +685,20 @@ hattrie_iter_t* hattrie_iter_begin(const hattrie_t* T, bool sorted) void hattrie_iter_next(hattrie_iter_t* i) { - if (hattrie_iter_finished(i)) return; - - if (i->i != NULL && !ahtable_iter_finished(i->i)) { - ahtable_iter_next(i->i); - } - else if (i->has_nil_key) { - i->has_nil_key = false; - i->nil_val = 0; - hattrie_iter_nextnode(i); - } + do { + if (hattrie_iter_finished(i)) return; - while (((i->i == NULL || ahtable_iter_finished(i->i)) && !i->has_nil_key) && - i->stack != NULL ) { - - ahtable_iter_free(i->i); - i->i = NULL; - hattrie_iter_nextnode(i); - } + if (i->i != NULL && !ahtable_iter_finished(i->i)) { + ahtable_iter_next(i->i); + } + else if (i->has_nil_key) { + i->has_nil_key = false; + i->nil_val = 0; + hattrie_iter_nextnode(i); + } - if (i->i != NULL && ahtable_iter_finished(i->i)) { - ahtable_iter_free(i->i); - i->i = NULL; - } + hattrie_iter_step(i); + } while (i->prefix_len && hattrie_iter_prefix_not_match(i)); } @@ -615,6 +720,10 @@ void hattrie_iter_free(hattrie_iter_t* i) i->stack = next; } + if (i->prefix_len) { + free(i->prefix); + } + free(i->key); free(i); } @@ -641,8 +750,8 @@ const char* hattrie_iter_key(hattrie_iter_t* i, size_t* len) memcpy(i->key + i->level, subkey, sublen); i->key[i->level + sublen] = '\0'; - *len = i->level + sublen; - return i->key; + *len = i->level + sublen - i->prefix_len; + return i->key + i->prefix_len; } @@ -654,6 +763,3 @@ value_t* hattrie_iter_val(hattrie_iter_t* i) return ahtable_iter_val(i->i); } - - - diff --git a/hat-trie/src/hat-trie.h b/hat-trie/src/hat-trie.h index d8439b6..1b8a432 100644 --- a/hat-trie/src/hat-trie.h +++ b/hat-trie/src/hat-trie.h @@ -33,6 +33,9 @@ void hattrie_free (hattrie_t*); //< Free all memory used by a trie hattrie_t* hattrie_dup (const hattrie_t*); //< Duplicate an existing trie. void hattrie_clear (hattrie_t*); //< Remove all entries. +/** number of inserted keys + */ +size_t hattrie_size (hattrie_t*); /** Find the given key in the trie, inserting it if it does not exist, and * returning a pointer to it's key. @@ -43,11 +46,22 @@ void hattrie_clear (hattrie_t*); //< Remove all entries. */ value_t* hattrie_get (hattrie_t*, const char* key, size_t len); - /** Find a given key in the table, returning a NULL pointer if it does not * exist. */ value_t* hattrie_tryget (hattrie_t*, const char* key, size_t len); +/** hattrie_walk callback signature */ +typedef int (*hattrie_walk_cb)(const char* key, size_t len, value_t* val, void* user_data); + +/** hattrie_walk callback return values, controls whether should stop the walk or not */ +#define hattrie_walk_stop 0 +#define hattrie_walk_continue 1 + +/** Find stored keys which are prefices of key, and invoke callback for every found key and val. + * The invocation order is: short key to long key. + */ +void hattrie_walk (hattrie_t*, const char* key, size_t len, void* user_data, hattrie_walk_cb); + /** Delete a given key from trie. Returns 0 if successful or -1 if not found. */ int hattrie_del(hattrie_t* T, const char* key, size_t len); @@ -61,10 +75,12 @@ void hattrie_iter_free (hattrie_iter_t*); const char* hattrie_iter_key (hattrie_iter_t*, size_t* len); value_t* hattrie_iter_val (hattrie_iter_t*); +/** Note the hattrie_iter_key() for prefixed search gets the suffix instead of the whole key + */ +hattrie_iter_t* hattrie_iter_with_prefix(const hattrie_t*, bool sorted, const char* prefix, size_t prefix_len); + #ifdef __cplusplus } #endif #endif - - diff --git a/hat-trie/test/check_hattrie.c b/hat-trie/test/check_hattrie.c index 797a981..a2ca712 100644 --- a/hat-trie/test/check_hattrie.c +++ b/hat-trie/test/check_hattrie.c @@ -26,7 +26,7 @@ char** ds; hattrie_t* T; str_map* M; - +int have_error = 0; void setup() { @@ -88,6 +88,7 @@ void test_hattrie_insert() if (*u != v) { fprintf(stderr, "[error] tally mismatch (reported: %lu, correct: %lu)\n", *u, v); + have_error = 1; } } @@ -99,6 +100,7 @@ void test_hattrie_insert() if (u) { fprintf(stderr, "[error] item %zu still found in trie after delete\n", j); + have_error = 1; } } @@ -131,9 +133,11 @@ void test_hattrie_iteration() if (*u != v) { if (v == 0) { fprintf(stderr, "[error] incorrect iteration (%lu, %lu)\n", *u, v); + have_error = 1; } else { fprintf(stderr, "[error] incorrect iteration tally (%lu, %lu)\n", *u, v); + have_error = 1; } } @@ -147,6 +151,7 @@ void test_hattrie_iteration() if (count != M->m) { fprintf(stderr, "[error] iterated through %zu element, expected %zu\n", count, M->m); + have_error = 1; } hattrie_iter_free(i); @@ -187,12 +192,13 @@ void test_hattrie_sorted_iteration() ++count; key = hattrie_iter_key(i, &len); - + /* memory for key may be changed on iter, copy it */ strncpy(key_copy, key, len); if (prev_key != NULL && cmpkey(prev_key, prev_len, key, len) > 0) { fprintf(stderr, "[error] iteration is not correctly ordered.\n"); + have_error = 1; } u = hattrie_iter_val(i); @@ -201,9 +207,11 @@ void test_hattrie_sorted_iteration() if (*u != v) { if (v == 0) { fprintf(stderr, "[error] incorrect iteration (%lu, %lu)\n", *u, v); + have_error = 1; } else { fprintf(stderr, "[error] incorrect iteration tally (%lu, %lu)\n", *u, v); + have_error = 1; } } @@ -217,6 +225,7 @@ void test_hattrie_sorted_iteration() if (count != M->m) { fprintf(stderr, "[error] iterated through %zu element, expected %zu\n", count, M->m); + have_error = 1; } hattrie_iter_free(i); @@ -248,11 +257,66 @@ void test_trie_non_ascii() } +typedef struct { + int size; + size_t lens[10]; + value_t vals[10]; +} trie_walk_data_t; + + +static int trie_walk_cb(const char* key __attribute__((unused)), size_t len, value_t* val, void* data) { + trie_walk_data_t* d = data; + d->lens[d->size] = len; + d->vals[d->size] = *val; + d->size++; + return hattrie_walk_continue; +} + + +void test_trie_walk() +{ + fprintf(stderr, "checking tryget_longest_match... \n"); + + hattrie_t* T = hattrie_create(); + char* txt1 = "hello world1"; + char* txt2 = "hello world2"; + char* txt3 = "hello"; + value_t* val; + + val = hattrie_get(T, txt1, strlen(txt1)); + *val = 1; + val = hattrie_get(T, txt2, strlen(txt2)); + *val = 2; + val = hattrie_get(T, txt3, strlen(txt3)); + *val = 3; + +#define EXPECT(check) \ + if (!(check)) {\ + fprintf(stderr, "[error] %s:%d: expect failure\n", __FILE__, __LINE__);\ + have_error = 1;\ + } + + trie_walk_data_t data = { + .size = 0 + }; + char* txt = "hello world20"; + hattrie_walk(T, txt, strlen(txt), &data, trie_walk_cb); + EXPECT(data.size == 2); + EXPECT(data.lens[0] = strlen(txt3)); + EXPECT(data.vals[0] == 3); + EXPECT(data.lens[1] = strlen(txt2)); + EXPECT(data.vals[1] == 2); +#undef EXPECT + + hattrie_free(T); +} + int main() { test_trie_non_ascii(); + test_trie_walk(); setup(); test_hattrie_insert(); @@ -264,10 +328,8 @@ int main() test_hattrie_sorted_iteration(); teardown(); + if (have_error) { + return -1; + } return 0; } - - - - - diff --git a/src/chat_trie.pxd b/src/chat_trie.pxd index f4e2c47..204bc04 100644 --- a/src/chat_trie.pxd +++ b/src/chat_trie.pxd @@ -25,12 +25,12 @@ cdef extern from "../hat-trie/src/hat-trie.h": ctypedef struct hattrie_iter_t: pass - hattrie_iter_t* hattrie_iter_begin (hattrie_t*, bint sorted) - void hattrie_iter_next (hattrie_iter_t*) - bint hattrie_iter_finished (hattrie_iter_t*) - void hattrie_iter_free (hattrie_iter_t*) - char* hattrie_iter_key (hattrie_iter_t*, size_t* len) - value_t* hattrie_iter_val (hattrie_iter_t*) + hattrie_iter_t* hattrie_iter_with_prefix (hattrie_t*, bint sorted, char* prefix, size_t prefix_len) + void hattrie_iter_next (hattrie_iter_t*) + bint hattrie_iter_finished (hattrie_iter_t*) + void hattrie_iter_free (hattrie_iter_t*) + char* hattrie_iter_key (hattrie_iter_t*, size_t* len) + value_t* hattrie_iter_val (hattrie_iter_t*) cdef struct hattrie_t_: void* root diff --git a/src/hat_trie.pyx b/src/hat_trie.pyx index 5f6c6c1..0150bf7 100644 --- a/src/hat_trie.pyx +++ b/src/hat_trie.pyx @@ -32,12 +32,12 @@ cdef class BaseTrie: def setdefault(self, bytes key, int value): return self._setdefault(key, value) - def keys(self): - return list(self.iterkeys()) + def keys(self, prefix = ''): + return list(self.iterkeys(prefix)) - def iterkeys(self): + def iterkeys(self, prefix = ''): cdef: - hattrie_iter_t* it = hattrie_iter_begin(self._trie, 0) + hattrie_iter_t* it = hattrie_iter_with_prefix(self._trie, 0, prefix, len(prefix)) char* c_key size_t val size_t length @@ -53,7 +53,6 @@ cdef class BaseTrie: finally: hattrie_iter_free(it) - cdef int _getitem(self, char* key) except -1: cdef value_t* value_ptr = hattrie_tryget(self._trie, key, len(key)) if value_ptr == NULL: @@ -113,7 +112,7 @@ cdef class Trie(BaseTrie): def setdefault(self, unicode key, int value): cdef bytes bkey = key.encode('utf8') - self._setdefault(bkey, value) + return self._setdefault(bkey, value) - def keys(self): - return [key.decode('utf8') for key in self.iterkeys()] + def keys(self, prefix = ''): + return [key.decode('utf8') for key in self.iterkeys(prefix)]