sat-solver/avl.hpp

625 lines
18 KiB
C++

#ifndef CLASS_AVL_TREE_T
#define CLASS_AVL_TREE_T
#include <stddef.h>
#include <errno.h>
/* We need either depths, counts or both (the latter being the default) */
#if !defined(AVL_DEPTH) && !defined(AVL_COUNT)
#define AVL_DEPTH
#define AVL_COUNT
#endif
#ifdef AVL_COUNT
#define NODE_COUNT(n) ((n) ? (n)->count : 0)
#define L_COUNT(n) (NODE_COUNT((n)->left))
#define R_COUNT(n) (NODE_COUNT((n)->right))
#define CALC_COUNT(n) (L_COUNT(n) + R_COUNT(n) + 1)
#endif
#ifdef AVL_DEPTH
#define NODE_DEPTH(n) ((n) ? (n)->depth : 0)
#define L_DEPTH(n) (NODE_DEPTH((n)->left))
#define R_DEPTH(n) (NODE_DEPTH((n)->right))
#define CALC_DEPTH(n) ((L_DEPTH(n) > R_DEPTH(n) ? L_DEPTH(n) : R_DEPTH(n)) + 1)
#endif
template <
typename I> class avl_tree_t
{
struct avl_node_t
{
struct avl_node_t *next;
struct avl_node_t *prev;
struct avl_node_t *parent;
struct avl_node_t *left;
struct avl_node_t *right;
I item;
#ifdef AVL_COUNT
unsigned int count;
#endif
#ifdef AVL_DEPTH
int depth;
#endif
avl_node_t(I _item): item(_item)
{
;
}
void clear()
{
this->left = this->right = NULL;
#ifdef AVL_COUNT
this->count = 1;
#endif
#ifdef AVL_DEPTH
this->depth = 1;
#endif
}
};
public:
struct avl_node_t *top = NULL;
struct avl_node_t *head = NULL;
struct avl_node_t *tail = NULL;
private:
#ifndef AVL_DEPTH
/* Also known as ffs() (from BSD) */
static int lg(unsigned int u)
{
int r = 1;
if (!u) return 0;
if (u & 0xffff0000) { u >>= 16; r += 16; }
if (u & 0x0000ff00) { u >>= 8; r += 8; }
if (u & 0x000000f0) { u >>= 4; r += 4; }
if (u & 0x0000000c) { u >>= 2; r += 2; }
if (u & 0x00000002) r++;
return r;
}
#endif
static int check_balance(avl_node_t *avlnode)
{
#ifdef AVL_DEPTH
int d = R_DEPTH(avlnode) - L_DEPTH(avlnode);
return d < -1 ? -1 : d > 1 ? 1 : 0;
#else
/* int d;
* d = lg(R_COUNT(avlnode)) - lg(L_COUNT(avlnode));
* d = d<-1?-1:d>1?1:0;
*/
#ifdef AVL_COUNT
int pl, r;
pl = lg(L_COUNT(avlnode));
r = R_COUNT(avlnode);
if(r>>pl+1)
return 1;
if(pl<2 || r>>pl-2)
return 0;
return -1;
#else
#error No balancing possible.
#endif
#endif
}
int search_closest(
const I& item,
avl_node_t **avlnode) const
{
avl_node_t *node;
if (!avlnode)
{
avlnode = &node;
}
node = this->top;
if (!node)
{
return *avlnode = NULL, 0;
}
for (;;)
{
auto c = (item <=> node->item);
if (c < 0) {
if (node->left)
node = node->left;
else
return *avlnode = node, -1;
} else if (c > 0) {
if (node->right)
node = node->right;
else
return *avlnode = node, 1;
} else {
return *avlnode = node, 0;
}
}
}
struct avl_node_t *insert_top(avl_node_t *newnode)
{
newnode->clear();
newnode->prev = newnode->next = newnode->parent = NULL;
this->head = this->tail = this->top = newnode;
return newnode;
}
struct avl_node_t *insert_before(avl_node_t *node, avl_node_t *newnode)
{
if (!node)
{
return this->tail
? insert_after(this->tail, newnode)
: insert_top(newnode);
}
if (node->left)
{
return insert_after(node->prev, newnode);
}
newnode->clear();
newnode->next = node;
newnode->parent = node;
newnode->prev = node->prev;
if (node->prev)
node->prev->next = newnode;
else
this->head = newnode;
node->prev = newnode;
node->left = newnode;
rebalance(node);
return newnode;
}
struct avl_node_t *insert_after(avl_node_t *node, avl_node_t *newnode)
{
if (!node)
return this->head
? insert_before(this->head, newnode)
: insert_top(newnode);
if (node->right)
return insert_before(node->next, newnode);
newnode->clear();
newnode->prev = node;
newnode->parent = node;
newnode->next = node->next;
if (node->next)
node->next->prev = newnode;
else
this->tail = newnode;
node->next = newnode;
node->right = newnode;
rebalance(node);
return newnode;
}
avl_node_t *insert_node(avl_node_t *newnode)
{
avl_node_t *node;
if (!this->top)
return insert_top(newnode);
switch (search_closest(newnode->item, &node))
{
case -1:
return insert_before(node, newnode);
case 1:
return insert_after(node, newnode);
}
return NULL;
}
void rebalance(avl_node_t *avlnode)
{
avl_node_t *child;
avl_node_t *gchild;
avl_node_t *parent;
avl_node_t **superparent;
parent = avlnode;
while (avlnode)
{
parent = avlnode->parent;
superparent = parent
? avlnode == parent->left
? &parent->left
: &parent->right
: &this->top;
switch (check_balance(avlnode))
{
case -1:
{
child = avlnode->left;
#ifdef AVL_DEPTH
if (L_DEPTH(child) >= R_DEPTH(child)) {
#else
#ifdef AVL_COUNT
if (L_COUNT(child) >= R_COUNT(child)) {
#else
#error No balancing possible.
#endif
#endif
avlnode->left = child->right;
if (avlnode->left)
avlnode->left->parent = avlnode;
child->right = avlnode;
avlnode->parent = child;
*superparent = child;
child->parent = parent;
#ifdef AVL_COUNT
avlnode->count = CALC_COUNT(avlnode);
child->count = CALC_COUNT(child);
#endif
#ifdef AVL_DEPTH
avlnode->depth = CALC_DEPTH(avlnode);
child->depth = CALC_DEPTH(child);
#endif
} else {
gchild = child->right;
avlnode->left = gchild->right;
if (avlnode->left)
avlnode->left->parent = avlnode;
child->right = gchild->left;
if (child->right)
child->right->parent = child;
gchild->right = avlnode;
if (gchild->right)
gchild->right->parent = gchild;
gchild->left = child;
if (gchild->left)
gchild->left->parent = gchild;
*superparent = gchild;
gchild->parent = parent;
#ifdef AVL_COUNT
avlnode->count = CALC_COUNT(avlnode);
child->count = CALC_COUNT(child);
gchild->count = CALC_COUNT(gchild);
#endif
#ifdef AVL_DEPTH
avlnode->depth = CALC_DEPTH(avlnode);
child->depth = CALC_DEPTH(child);
gchild->depth = CALC_DEPTH(gchild);
#endif
}
break;
}
case 1:
{
child = avlnode->right;
#ifdef AVL_DEPTH
if (R_DEPTH(child) >= L_DEPTH(child)) {
#else
#ifdef AVL_COUNT
if (R_COUNT(child) >= L_COUNT(child)) {
#else
#error No balancing possible.
#endif
#endif
avlnode->right = child->left;
if (avlnode->right)
avlnode->right->parent = avlnode;
child->left = avlnode;
avlnode->parent = child;
*superparent = child;
child->parent = parent;
#ifdef AVL_COUNT
avlnode->count = CALC_COUNT(avlnode);
child->count = CALC_COUNT(child);
#endif
#ifdef AVL_DEPTH
avlnode->depth = CALC_DEPTH(avlnode);
child->depth = CALC_DEPTH(child);
#endif
} else {
gchild = child->left;
avlnode->right = gchild->left;
if (avlnode->right)
avlnode->right->parent = avlnode;
child->left = gchild->right;
if (child->left)
child->left->parent = child;
gchild->left = avlnode;
if (gchild->left)
gchild->left->parent = gchild;
gchild->right = child;
if (gchild->right)
gchild->right->parent = gchild;
*superparent = gchild;
gchild->parent = parent;
#ifdef AVL_COUNT
avlnode->count = CALC_COUNT(avlnode);
child->count = CALC_COUNT(child);
gchild->count = CALC_COUNT(gchild);
#endif
#ifdef AVL_DEPTH
avlnode->depth = CALC_DEPTH(avlnode);
child->depth = CALC_DEPTH(child);
gchild->depth = CALC_DEPTH(gchild);
#endif
}
break;
}
default:
{
#ifdef AVL_COUNT
avlnode->count = CALC_COUNT(avlnode);
#endif
#ifdef AVL_DEPTH
avlnode->depth = CALC_DEPTH(avlnode);
#endif
break;
}
}
avlnode = parent;
}
}
void unlink_node(avl_node_t *avlnode)
{
avl_node_t *parent;
avl_node_t **superparent;
avl_node_t *subst, *left, *right;
avl_node_t *balnode;
if (avlnode->prev)
avlnode->prev->next = avlnode->next;
else
this->head = avlnode->next;
if (avlnode->next)
avlnode->next->prev = avlnode->prev;
else
this->tail = avlnode->prev;
parent = avlnode->parent;
superparent = parent
? avlnode == parent->left
? &parent->left
: &parent->right
: &this->top;
left = avlnode->left;
right = avlnode->right;
if (!left)
{
*superparent = right;
if (right)
right->parent = parent;
balnode = parent;
}
else if (!right)
{
*superparent = left;
left->parent = parent;
balnode = parent;
}
else
{
subst = avlnode->prev;
if (subst == left)
{
balnode = subst;
}
else
{
balnode = subst->parent;
balnode->right = subst->left;
if (balnode->right)
balnode->right->parent = balnode;
subst->left = left;
left->parent = subst;
}
subst->right = right;
subst->parent = parent;
right->parent = subst;
*superparent = subst;
}
rebalance(balnode);
}
void delete_node(
avl_node_t *avlnode)
{
assert(avlnode);
unlink_node(avlnode);
delete avlnode;
}
void free_nodes()
{
for (avl_node_t *node = head, *next; node; node = next)
{
next = node->next;
delete node;
}
clear_tree();
}
void clear_tree()
{
this->top = this->head = this->tail = NULL;
}
public:
class iterable
{
private:
struct avl_node_t* node;
public:
iterable(
struct avl_node_t* _node):
node(_node)
{
;
}
I& operator *()
{
assert(node);
return node->item;
}
void operator ++()
{
node = node->next;
}
bool operator != (const iterable& other)
{
return node != other.node;
}
};
iterable begin() const
{
return iterable(head);
}
iterable end() const
{
return iterable(NULL);
}
public:
avl_tree_t()
{
;
}
struct avl_node_t* search(const I& item) const
{
avl_node_t *node;
return search_closest(item, &node) ? NULL : node;
}
struct avl_node_t* insert(I item)
{
avl_node_t *newnode = new avl_node_t(item);
if (newnode)
{
if (insert_node(newnode))
{
return newnode;
}
else
{
delete newnode;
errno = EEXIST;
}
}
return NULL;
}
bool delete_item(I item)
{
auto node = search(item);
if (node)
return delete_node(node), true;
return false;
}
~avl_tree_t()
{
free_nodes();
}
};
#endif