test_augment_rbtree.c 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. #include <assert.h>
  2. #include <stdio.h>
  3. #include <string.h>
  4. #include <time.h>
  5. #include "rb_tree.h"
  6. typedef struct {
  7. rb_node_t node;
  8. int key;
  9. int value;
  10. int amount;
  11. } IntIntEntry;
  12. static int cmpfunc(void *x, void *y) {
  13. int *a = x, *b = y;
  14. return *a < *b ? -1 : *a > *b;
  15. }
  16. static void augment(void *n) {
  17. IntIntEntry *node = n;
  18. IntIntEntry *left = rb_tree_left(node);
  19. IntIntEntry *right = rb_tree_right(node);
  20. node->amount = 1;
  21. node->amount += left == NULL ? 0 : left->amount;
  22. node->amount += right == NULL ? 0 : right->amount;
  23. }
  24. static void test_largedata();
  25. static int max(int a, int b) { return a > b ? a : b; }
  26. int depth(void *n) {
  27. rb_node_t *node = n;
  28. if (node == NULL) return 0;
  29. return max(depth(node->entry.rbe_left), depth(node->entry.rbe_right)) + 1;
  30. }
  31. void checkaugment(IntIntEntry *node) {
  32. if (node == NULL) return;
  33. IntIntEntry *left = rb_tree_left(node);
  34. IntIntEntry *right = rb_tree_right(node);
  35. int amount = 1;
  36. amount += left == NULL ? 0 : left->amount;
  37. amount += right == NULL ? 0 : right->amount;
  38. assert(amount == node->amount);
  39. checkaugment(left);
  40. checkaugment(right);
  41. }
  42. int main() {
  43. printf("[TEST] augment rbtree\n");
  44. test_largedata();
  45. printf("[PASS] augment rbtree\n");
  46. return 0;
  47. }
  48. #define TESTSZ 10000
  49. int input[TESTSZ];
  50. void shuffle(int *input, int n) {
  51. for (int i = n - 1; i > 0; i--) {
  52. int j = rand() % i;
  53. int tmp = input[i];
  54. input[i] = input[j];
  55. input[j] = tmp;
  56. }
  57. }
  58. static void test_largedata() {
  59. // generate random input
  60. time_t t;
  61. srand((unsigned)time(&t));
  62. for (int i = 0; i < TESTSZ; i++) {
  63. input[i] = i;
  64. }
  65. shuffle(input, TESTSZ);
  66. // insert
  67. rb_tree_t tree = {NULL, cmpfunc, augment};
  68. IntIntEntry *n;
  69. for (int i = 0; i < TESTSZ; i++) {
  70. n = malloc(sizeof(*n));
  71. n->key = input[i];
  72. n->value = input[i];
  73. n->amount = 1;
  74. rb_tree_insert(&tree, n);
  75. }
  76. // check tree validity
  77. int d = depth(tree.rbh_root);
  78. assert(d >= 13 && d <= 28);
  79. IntIntEntry *root = (IntIntEntry *)(tree.rbh_root);
  80. assert(root->amount == TESTSZ);
  81. checkaugment(root);
  82. IntIntEntry *iter = rb_tree_min(&tree);
  83. int i = 0;
  84. for (; iter != NULL; iter = rb_tree_next(&tree, iter)) {
  85. assert(iter->key == i);
  86. i++;
  87. }
  88. // delete when: key % 3 != 0
  89. memset(input, 0, sizeof(int) * TESTSZ);
  90. int count = 0;
  91. for (int i = 0; i < TESTSZ; i++) {
  92. if (i % 3 != 0) {
  93. input[count] = i;
  94. } else {
  95. continue;
  96. }
  97. count++;
  98. }
  99. shuffle(input, count);
  100. for (int i = 0; i < count; i++) {
  101. IntIntEntry *iter = rb_tree_find(&tree, &input[i]);
  102. assert(iter != NULL);
  103. rb_tree_remove(&tree, iter);
  104. }
  105. // check tree validity
  106. d = depth(tree.rbh_root);
  107. assert(d >= 11 && d <= 24);
  108. root = (IntIntEntry *)(tree.rbh_root);
  109. assert(root->amount == TESTSZ - count);
  110. checkaugment(root);
  111. iter = rb_tree_min(&tree);
  112. i = 0;
  113. for (; iter != NULL; iter = rb_tree_next(&tree, iter)) {
  114. assert(iter->key == i * 3);
  115. i++;
  116. }
  117. }