diff options
| -rw-r--r-- | src/hash_table.c | 33 | ||||
| -rw-r--r-- | src/hash_table.h | 79 | ||||
| -rw-r--r-- | src/type_alias.h | 3 | ||||
| -rw-r--r-- | tests/test_htable.c | 48 |
4 files changed, 113 insertions, 50 deletions
diff --git a/src/hash_table.c b/src/hash_table.c index 376d712..8d97875 100644 --- a/src/hash_table.c +++ b/src/hash_table.c @@ -3,17 +3,25 @@ #include <stdlib.h> #include <string.h> +#include "basic_traits.h" + #define HTFL_NUL 0 #define HTFL_VAL 1 #define HTFL_DEL 2 +HASH_TABLE_IMPL(String, Int); +HASH_TABLE_IMPL(String, String); +HASH_TABLE_IMPL(String, Double); +HASH_TABLE_IMPL(Int, Int); +HASH_TABLE_IMPL(Int, Double); + -static void rebuild(HashTable *ht) { +static void rebuild(HashTable *ht, VoidHashFn hash, VoidEqFn eq) { HashTable newht; - init_hash_table(&newht, ht->elemsz, ht->size * 6, ht->hash, ht->eq); + init_hash_table(&newht, ht->elemsz, ht->size * 6); void *iter = hash_table_begin(ht); while (iter != NULL) { - hash_table_insert(&newht, iter); + hash_table_insert(&newht, iter, hash, eq); iter = hash_table_next(ht, iter); } free(ht->buf); @@ -21,8 +29,7 @@ static void rebuild(HashTable *ht) { *ht = newht; } -void init_hash_table(HashTable *ht, int64_t elemsz, int64_t cap, - uint64_t (*hash)(void *), bool (*eq)(void *, void *)) { +void init_hash_table(HashTable *ht, int64_t elemsz, int64_t cap) { if (cap < 16) cap = 16; ht->buf = malloc(cap * elemsz); ht->flagbuf = malloc(cap); @@ -32,20 +39,18 @@ void init_hash_table(HashTable *ht, int64_t elemsz, int64_t cap, ht->cap = cap; ht->taken = 0; ht->elemsz = elemsz; - ht->hash = hash; - ht->eq = eq; } -bool hash_table_insert(HashTable *ht, void *elem) { +bool hash_table_insert(HashTable *ht, void *elem, VoidHashFn hash, VoidEqFn eq) { if (ht->taken + 1 > ht->cap / 2) { - rebuild(ht); + rebuild(ht, hash, eq); } ht->taken++; ht->size++; - int64_t pos = ht->hash(elem) % ht->cap; + int64_t pos = hash(elem) % ht->cap; while (ht->flagbuf[pos] != HTFL_NUL) { if (ht->flagbuf[pos] == HTFL_VAL - && ht->eq(ht->buf + pos * ht->elemsz, elem)) { + && eq(ht->buf + pos * ht->elemsz, elem)) { return false; } pos++; @@ -68,11 +73,11 @@ void *hash_table_ref(HashTable *ht, int64_t pos) { return ht->buf + pos * ht->elemsz; } -void *hash_table_find(HashTable *ht, void *elem) { - int64_t pos = ht->hash(elem) % ht->cap; +void *hash_table_find(HashTable *ht, void *elem, VoidHashFn hash, VoidEqFn eq) { + int64_t pos = hash(elem) % ht->cap; while (ht->flagbuf[pos] != HTFL_NUL) { if (ht->flagbuf[pos] == HTFL_VAL - && ht->eq(hash_table_ref(ht, pos), elem)) { + && eq(hash_table_ref(ht, pos), elem)) { return hash_table_ref(ht, pos); } pos++; diff --git a/src/hash_table.h b/src/hash_table.h index 549e261..857c688 100644 --- a/src/hash_table.h +++ b/src/hash_table.h @@ -4,6 +4,8 @@ #include <stdbool.h> #include <stdint.h> +#include "type_alias.h" + struct hash_table { void *buf; char *flagbuf; @@ -11,18 +13,79 @@ struct hash_table { int64_t cap; int64_t taken; int64_t elemsz; - uint64_t (*hash)(void *); - bool (*eq)(void *, void *); }; typedef struct hash_table HashTable; -void init_hash_table(HashTable *ht, int64_t elemsz, int64_t cap, - uint64_t (*hash)(void *), bool (*eq)(void *, void *)); -bool hash_table_insert(HashTable *ht, void *elem); -void hash_table_remove(HashTable *ht, void *iter); +#define HASH_TABLE_DEF(K, V) \ + typedef struct { \ + K key; \ + V val; \ + } K##2##V##HashTableEntry; \ + typedef K##2##V##HashTableEntry *K##2##V##HashTableIter; \ + typedef struct { \ + HashTable ht; \ + } K##2##V##HashTable; \ + void K##2##V##HashTable_init(K##2##V##HashTable *self); \ + bool K##2##V##HashTable_insert(K##2##V##HashTable *self, K *key, V *value); \ + void K##2##V##HashTable_remove(K##2##V##HashTable *ht, K##2##V##HashTableIter iter); \ + V* K##2##V##HashTable_get(K##2##V##HashTable *self, K *key); \ + K##2##V##HashTableIter K##2##V##HashTable_find(K##2##V##HashTable *self, K *key); \ + K##2##V##HashTableIter K##2##V##HashTable_begin(K##2##V##HashTable *self); \ + K##2##V##HashTableIter K##2##V##HashTable_next(K##2##V##HashTable *self, K##2##V##HashTableIter iter); \ + void K##2##V##HashTable_free(K##2##V##HashTable *self); \ + K##2##V##HashTable K##2##V##HashTable_move(K##2##V##HashTable *self); \ + +#define HASH_TABLE_IMPL(K, V) \ + void K##2##V##HashTable_init(K##2##V##HashTable *self) { \ + init_hash_table(&self->ht, sizeof(K##2##V##HashTableEntry), 16); \ + } \ + bool K##2##V##HashTable_insert(K##2##V##HashTable *self, K *key, V *value) { \ + K##2##V##HashTableEntry entry; \ + memcpy(&entry.key, key, sizeof(K)); \ + memcpy(&entry.val, value, sizeof(K)); \ + return hash_table_insert(&self->ht, &entry, (VoidHashFn)K##_hash, (VoidEqFn)K##_eq); \ + } \ + K##2##V##HashTableIter K##2##V##HashTable_find(K##2##V##HashTable *self, K *key) { \ + return hash_table_find(&self->ht, key, (VoidHashFn)K##_hash, (VoidEqFn)K##_eq); \ + } \ + V* K##2##V##HashTable_get(K##2##V##HashTable *self, K *key) { \ + K##2##V##HashTableEntry* entry = hash_table_find(&self->ht, key, (VoidHashFn)K##_hash, (VoidEqFn)K##_eq); \ + if (entry == NULL) return NULL; \ + return &(entry->val); \ + } \ + void K##2##V##HashTable_remove(K##2##V##HashTable *self, K##2##V##HashTableIter iter) { \ + hash_table_remove(&self->ht, iter); \ + } \ + K##2##V##HashTableIter K##2##V##HashTable_begin(K##2##V##HashTable *self) { \ + return hash_table_begin(&self->ht); \ + } \ + K##2##V##HashTableIter K##2##V##HashTable_next(K##2##V##HashTable *self, K##2##V##HashTableIter iter) { \ + return hash_table_next(&self->ht, iter); \ + } \ + void K##2##V##HashTable_free(K##2##V##HashTable *self) { \ + destroy_hash_table(&self->ht); \ + } \ + K##2##V##HashTable K##2##V##HashTable_move(K##2##V##HashTable *self) { \ + K##2##V##HashTable dup; \ + dup.ht = self->ht; \ + self->ht.buf = NULL; \ + self->ht.flagbuf = NULL; \ + self->ht.size = 0; \ + self->ht.cap = 0; \ + self->ht.taken = 0; \ + return dup; \ + } \ -// return a iterator -void *hash_table_find(HashTable *ht, void *elem); +HASH_TABLE_DEF(String, Int); +HASH_TABLE_DEF(String, String); +HASH_TABLE_DEF(String, Double); +HASH_TABLE_DEF(Int, Int); +HASH_TABLE_DEF(Int, Double); + +void init_hash_table(HashTable *ht, int64_t elemsz, int64_t cap); +bool hash_table_insert(HashTable *ht, void *elem, VoidHashFn hash, VoidEqFn eq); +void hash_table_remove(HashTable *ht, void *iter); +void *hash_table_find(HashTable *ht, void *elem, VoidHashFn hash, VoidEqFn eq); void *hash_table_begin(HashTable *ht); void *hash_table_next(HashTable *ht, void *iter); void destroy_hash_table(HashTable *ht); diff --git a/src/type_alias.h b/src/type_alias.h index 9d40684..ae6b7ff 100644 --- a/src/type_alias.h +++ b/src/type_alias.h @@ -14,5 +14,8 @@ typedef float Float; typedef double Double; typedef const char *String; +typedef uint64_t (*VoidHashFn)(void*); +typedef bool (*VoidEqFn)(void*, void*); +typedef int (*VoidCmpFn)(void*, void*); #endif diff --git a/tests/test_htable.c b/tests/test_htable.c index 8b93af0..3608b5c 100644 --- a/tests/test_htable.c +++ b/tests/test_htable.c @@ -4,69 +4,61 @@ #include <string.h> #include "hash_table.h" -#include "mmhash.h" - -static uint64_t hash(void *i) { return mmhash(i, sizeof(int), 0); } - -static bool eq(void *x, void *y) { - int *a = x, *b = y; - return *a == *b; -} bool found[10000]; int main() { printf("[TEST] htable\n"); - HashTable ht; - init_hash_table(&ht, sizeof(int), -1, hash, eq); + Int2IntHashTable ht; + Int2IntHashTable_init(&ht); for (int i = 0; i < 10000; i++) { - hash_table_insert(&ht, &i); - assert(ht.size == i + 1); - assert(ht.taken == i + 1); - assert(ht.cap >= i + 1); + Int2IntHashTable_insert(&ht, &i, &i); + assert(ht.ht.size == i + 1); + assert(ht.ht.taken == i + 1); + assert(ht.ht.cap >= i + 1); } for (int i = 0; i < 10000; i++) { - assert(hash_table_find(&ht, &i) != NULL); + assert(Int2IntHashTable_get(&ht, &i) != NULL); int t = 10000 + i; - assert(hash_table_find(&ht, &t) == NULL); + assert(Int2IntHashTable_get(&ht, &t) == NULL); } memset(found, 0, sizeof(bool) * 10000); - int *iter = hash_table_begin(&ht); + Int2IntHashTableIter iter = Int2IntHashTable_begin(&ht); while (iter != NULL) { - found[*iter] = true; - iter = hash_table_next(&ht, iter); + found[iter->key] = true; + iter = Int2IntHashTable_next(&ht, iter); } for (int i = 0; i < 10000; i++) { assert(found[i]); } for (int i = 0; i < 5000; i++) { - int *iter = hash_table_find(&ht, &i); - hash_table_remove(&ht, iter); + Int2IntHashTableIter iter = Int2IntHashTable_find(&ht, &i); + Int2IntHashTable_remove(&ht, iter); } for (int i = 0; i < 5000; i++) { - assert(hash_table_find(&ht, &i) == NULL); + assert(Int2IntHashTable_find(&ht, &i) == NULL); int t = 5000 + i; - assert(hash_table_find(&ht, &t) != NULL); + assert(Int2IntHashTable_find(&ht, &t) != NULL); } for (int i = 0; i < 5000; i++) { - hash_table_insert(&ht, &i); + Int2IntHashTable_insert(&ht, &i, &i); } memset(found, 0, sizeof(bool) * 10000); - iter = hash_table_begin(&ht); + iter = Int2IntHashTable_begin(&ht); while (iter != NULL) { - found[*iter] = true; - iter = hash_table_next(&ht, iter); + found[iter->key] = true; + iter = Int2IntHashTable_next(&ht, iter); } for (int i = 0; i < 10000; i++) { assert(found[i]); } - destroy_hash_table(&ht); + Int2IntHashTable_free(&ht); printf("[PASS] htable\n"); } |
