Recursive is Top-down is Bottom-up

Compiling functional balanced search trees to imperative code



Anton Lorenzen, University of Edinburgh

Top-down vs bottom-up insertion

Top-down vs bottom-up insertion (Sleator & Tarjan)

Q: Do these methods yield the same trees?

A: Yes! You can transform them into each other!

Overview


data List a = Cons { item :: a; next :: List a } | Nil
					

typedef struct list {
    void* item;
    struct list* next;
} list;
					

map f Nil = Nil
map f (Cons item next) = Cons (f item) (map f next)
					

list* map(function f, list* xs) {
    if(xs == NULL) {
        return NULL;
    } else {
        void* item = call(f, xs->item);
        list* next = map(f, xs->next);
        list* cons = (list*) malloc(sizeof(*cons));
        cons->item = item;
        cons->next = next;
        return cons;
    }
}
					

Reuse analysis


Static analysis: Compile-time garbage collection

Types: Clean's uniqueness types, Hofmann's LFPL, Modal types in OCaml

Reference Counting: OPAL, Lean, Koka
"Reference Counting With Frame-Limited Reuse", Lorenzen, Leijen, ICFP'22


map f Nil = Nil
map f (Cons@r item next) = Cons@r (f item) (map f next)
					

list* map(function f, list* xs) {
    if(xs == NULL) {
        return NULL;
    } else {
        void* item = call(f, xs->item);
        list* next = map(f, xs->next);
        list* cons = xs; // reuse
        cons->item = item;
        cons->next = next;
        return cons;
    }
}
					

CPS transformation


map f Nil acc = acc Nil
map f (Cons item next) acc =
   let y = f item in
   map f next (\res -> acc (Cons y res))
					

Defunctionalize!


map f Nil acc = reverse acc Nil
map f (Cons item next) acc = map f next (Cons (f item) acc)

reverse Nil xs = xs
reverse (Cons item next) xs = reverse next (Cons item xs)
					

map f Nil acc = reverse acc Nil
map f (Cons@r item next) acc = map f next (Cons@r (f item) acc)
					

list* map(function f, list* xs, list* acc) {
    if(xs == NULL) {
        return reverse(acc, NULL);
    } else {
        void* item = call(f, xs->item);
        list* next = xs->next;
        list* cons = xs; // reuse
        cons->item = item;
        cons->next = acc;
        return map(f, next, cons);
    }
}
					

list* map(function f, list* xs, list* acc) {
    while(true) {
        if(xs == NULL) {
            return reverse(acc, NULL);
        } else {
            void* item = call(f, xs->item);
            list* next = xs->next;
            list* cons = xs; // reuse
            cons->item = item;
            cons->next = acc;
            xs = next;
            acc = cons;
        }
    }
}
					

list* map(function f, list* xs, list* acc) {
    while(xs != NULL) {
        xs->item = call(f, xs->item);
        list* next = xs->next;
        xs->next = acc;
        acc = xs;
        xs = next;
    }
    return reverse(acc, NULL);
}
					

list* map_bottomup(function f, list* xs) {
    list* acc = NULL;
    while(xs != NULL) { // map
        xs->item = call(f, xs->item);
        list* next = xs->next;
        xs->next = acc;
        acc = xs;
        xs = next;
    }
    while(acc != NULL) { // reverse
        list* next = acc->next;
        acc->next = xs;
        xs = acc;
        acc = next;
    }
    return xs;
}
					

Pointer reversal


list* map_bottomup(function f, list* xs) {
    list* acc = NULL;
    while(xs != NULL) { // map_defun_cps
        xs->item = call(f, xs->item);
        list* next = xs->next;
        xs->next = acc;
        acc = xs;
        xs = next;
    }
    while(acc != NULL) { // reverse
        list* next = acc->next;
        acc->next = xs;
        xs = acc;
        acc = next;
    }
    return xs;
}
					

Top-down vs bottom-up insertion


list* map_topdown(function f, list* xs) {
    list* root = xs;
    while(xs != NULL) {
        xs->item = call(f, xs->item);
        xs = xs->next;
    }
    return root;
}
					

Tail recursion modulo cons


map f Nil = Nil
map f (Cons item next) = Cons (f item) (map f next)
					

Lazyness: Listless machine; Prolog: Difference lists; "Tail Recursion Modulo Context – An Equational Approach", Leijen, Lorenzen, POPL'23


map f Nil = Nil
map f (Cons item next) = Cons (f item) (map f next)
					

void map(function f, list* xs, list** hole) {
    if(xs == NULL) {
        *hole = NULL;
        return;
    } else {
        *hole = (list*) malloc(sizeof(list));
        (*hole)->item = call(f, xs->item);
        map(f, xs->next, &(*hole)->next);
    }
}
					

void map(function f, list* xs, list** hole) {
    if(xs == NULL) {
        *hole = NULL;
        return;
    } else {
        *hole = xs; // reuse
        (*hole)->item = call(f, xs->item);
        map(f, xs->next, &(*hole)->next);
    }
}
					

void map(function f, list* xs) {
    if(xs == NULL) {
        return;
    } else {
        xs->item = call(f, xs->item);
        map(f, xs->next);
    }
}
					

void map(function f, list* xs) {
    while(xs != NULL) {
        xs->item = call(f, xs->item);
        xs = xs->next;
    }
}
					

list* map_topdown(function f, list* xs) {
    list* root = xs;
    while(xs != NULL) {
        xs->item = call(f, xs->item);
        xs = xs->next;
    }
    return root;
}
					

Splay Trees


access i t = Node (smaller i t) (find i t) (bigger i t)

smaller i (Node l x r) =
  if i == x then l
  else if x < i then Node l x (smaller i r)
  else smaller i l
					

access i t = Node (smaller i t) (find i t) (bigger i t)

smaller i (Node l x r) =
  if i == x then l
  else if x < i then case r of
  | Node rl rx rr ->
      if i == rx then Node l x rl
      else if rx < i then Node l x (Node rl rx (smaller i rr))
      else smaller i rl
  else case l of
  | Node ll lx lr ->
      if i == lx then ll
	  else if lx < i then Node ll lx (smaller i lr)
	  else smaller i ll
					

access i t = Node (smaller i t) (find i t) (bigger i t)

smaller i (Node l x r) =
  if i == x then l
  else if x < i then case r of
  | Node rl rx rr ->
      if i == rx then Node l x rl
      else if rx < i then Node (Node l x rl) rx (smaller i rr))
      else smaller i rl
  else case l of
  | Node ll lx lr ->
      if i == lx then ll
	  else if lx < i then Node ll lx (smaller i lr)
	  else smaller i ll
					

(Similar to) Okasaki's version


access i t =
  let (l, x, r) = partition i t in case x of
    Node _ x' _ -> Node l x' r

partition i t@(Node l x r) =
  if x == i then (l, t, r)
  else if x < i then case r of
  | Node rl rx rr ->
      if rx == i then (Node l x rl, r, rr)
      else if rx < i then
        let (smaller, item, bigger) = partition i rr in
        (Node (Node l x rl) rx smaller, item, bigger)
      else 
        let (smaller, item, bigger) = partition i rl in
        (Node l x smaller, item, Node bigger rx rr)
   else case l of
   | Node ll lx lr ->
      if lx == i then (ll, l, Node lr x r)
      else if lx < i then 
        let (smaller, item, bigger) = partition i lr in
        (Node ll lx smaller, item, Node bigger x r)
      else 
        let (smaller, item, bigger) = partition i ll in
        (smaller, item, Node bigger lx (Node lr x r))
					

void partition_trmc(int i, tree* t, tree** l, tree** x, tree** r) {
    if(t->item == i) {
        *l = t->left;
        *x = t;
        *r = t->right;
        return;
    }
    else if(t->item < i) {
        if(t->right->item == i) {
            tree* right = t->right;
            t->right = t->right->left;
            *l = t;
            *x = right;
            *r = right->right;
            return;
        } else if(t->right->item < i) {
            tree* rr = t->right->right;
            *l = t->right;
            t->right = t->right->left;
            (*l)->left = t;
            partition(i, rr, &((*l)->right), x, r);
        } else {
            *l = t;
            *r = t->right;
            partition(i, t->right->left, &((*l)->right), x, &((*r)->left));
        }
    } else { // t->item > i
        if(t->left->item == i) {
            *l = t->left->left;
            *x = t->left;
            t->left = t->left->right;
            *r = t;
            return;
        } else if(t->left->item < i) {
            *l = t->left;
            *r = t;
            partition(i, t->left->right, &((*l)->right), x, &((*r)->left));
        } else {
            tree* ll = t->left->left;
            *r = t->left;
            t->left = t->left->right;
            (*r)->right = t;
            partition(i, ll, l, x, &((*r)->left));
        }
    }
}

tree* access_trmc(int i, tree* t) {
    tree* l = NULL;
    tree* x = NULL;
    tree* r = NULL;
    partition_trmc(i, t, l, x, r);
    x->left = l;
    x->right = r;
    return x;
}
					

void partition_trmc_tail(int i, tree* t, tree** l, tree** x, tree** r) {
    while(t->item != i) {
        if(t->item < i) {
            if(t->right->item == i) {
                tree* right = t->right;
                t->right = t->right->left;
                *l = t;
                *x = right;
                *r = right->right;
                return;
            } else if(t->right->item < i) {
                tree* rr = t->right->right;
                *l = t->right;
                t->right = t->right->left;
                (*l)->left = t;
                t = rr;
                l = &((*l)->right);
            } else {
                *l = t;
                *r = t->right;
                t = t->right->left;
                l = &((*l)->right);
                r = &((*r)->left);
            }
        } else { // t->item > i
            if(t->left->item == i) {
                *l = t->left->left;
                *x = t->left;
                t->left = t->left->right;
                *r = t;
                return;
            } else if(t->left->item < i) {
                *l = t->left;
                *r = t;
                t = t->left->right;
                l = &((*l)->right);
                r = &((*r)->left);
            } else {
                tree* ll = t->left->left;
                *r = t->left;
                t->left = t->left->right;
                (*r)->right = t;
                t = ll;
                r = &((*r)->left);
            }
        }
    }
    *l = t->left;
    *x = t;
    *r = t->right;
}

tree* access_trmc_tail(int i, tree* t) {
    tree* l = NULL;
    tree* x = NULL;
    tree* r = NULL;
    partition_trmc_tail(i, t, l, x, r);
    x->left = l;
    x->right = r;
    return x;
}
					

tree* access_trmc_tail_simp(int i, tree* t) {
    tree* l_ = NULL; tree** l = &l_;
    tree* r_ = NULL; tree** r = &r_;
    while(t->item != i) {
        if(t->item < i) {
            if(t->right->item == i) {
                tree* right = t->right;
                t->right = t->right->left;
                *l = t;
                *r = right->right;
                right->left = l_;
                right->right = r_;
                return right;
            } else if(t->right->item < i) {
                tree* rr = t->right->right;
                *l = t->right;
                t->right = t->right->left;
                (*l)->left = t;
                t = rr;
                l = &((*l)->right);
            } else {
                *l = t;
                *r = t->right;
                t = t->right->left;
                l = &((*l)->right);
                r = &((*r)->left);
            }
        } else { // t->item > i
            if(t->left->item == i) {
                tree* left = t->left;
                *l = t->left->left;
                t->left = t->left->right;
                *r = t;
                left->left = l_;
                left->right = r_;
                return left;
            } else if(t->left->item < i) {
                *l = t->left;
                *r = t;
                t = t->left->right;
                l = &((*l)->right);
                r = &((*r)->left);
            } else {
                tree* ll = t->left->left;
                *r = t->left;
                t->left = t->left->right;
                (*r)->right = t;
                t = ll;
                r = &((*r)->left);
            }
        }
    }
    *l = t->left;
    *r = t->right;
    t->left = l_;
    t->right = r_;
    return t;
}
					

tree* access_trmc_tail_simp_stack(int i, tree* t) {
    struct tree null;
    tree* l = &null;
    tree* r = &null;
    while(t->item != i) {
        if(t->item < i) {
            if(t->right->item == i) {
                tree* right = t->right;
                t->right = t->right->left;
                l->right = t;
                r->left = right->right;
                right->left = null.right;
                right->right = null.left;
                return right;
            } else if(t->right->item < i) {
                tree* rr = t->right->right;
                l->right = t->right;
                t->right = t->right->left;
                l->left = t;
                t = rr;
                l = l->right;
            } else {
                l->right = t;
                r->left = t->right;
                t = t->right->left;
                l = l->right;
                r = r->left;
            }
        } else { // t->item > i
            if(t->left->item == i) {
                tree* left = t->left;
                l->right = t->left->left;
                t->left = t->left->right;
                r->left = t;
                left->left = null.right;
                left->right = null.left;
                return left;
            } else if(t->left->item < i) {
                l->right = t->left;
                r->left = t;
                t = t->left->right;
                l = l->right;
                r = r->left;
            } else {
                tree* ll = t->left->left;
                r->left = t->left;
                t->left = t->left->right;
                r->right = t;
                t = ll;
                r = r->left;
            }
        }
    }
    l->right = t->left;
    r->left = t->right;
    t->left = null.right;
    t->right = null.left;
    return t;
}
					

tree* access_trmc_tail_simp_stack_simp(int i, tree* t) {
    struct tree null;
    tree* l = &null;
    tree* r = &null;
    while(t->item != i) {
        if(t->item < i) {
            if(t->right->item == i) {
                tree* tmp = t;
                t = tmp->right;
                l = tmp;
                l->right = tmp;
                r->left = t->right;
                l->right = t->left;
                t->left = null.right;
                t->right = null.left;
                return t;
            } else if(t->right->item < i) {
                tree* tmp = t;
                t = tmp->right;
                t->right = tmp->right->left;
                t->right->left = tmp;
                tmp = t;
                t = tmp->right;
                l = tmp;
                l->right = tmp;
            } else {
                tree* tmp = t;
                t = tmp->right;
                l = tmp;
                l->right = tmp;
                tmp = t;
                t = tmp->left;
                r = tmp;
                r->left = tmp;
            }
        } else { // t->item > i
            if(t->left->item == i) {
                tree* tmp = t;
                t = tmp->left;
                r = tmp;
                r->left = tmp;
                r->left = t->right;
                l->right = t->left;
                t->left = null.right;
                t->right = null.left;
                return t;
            } else if(t->left->item < i) {
                tree* tmp = t;
                t = tmp->left;
                r = tmp;
                r->left = tmp;
                tmp = t;
                t = tmp->right;
                l = tmp;
                l->right = tmp;
            } else {
                tree* tmp = t;
                t = tmp->left;
                t->left = tmp->left->right;
                t->left->right = tmp;
                tmp = t;
                t = tmp->left;
                r = tmp;
                r->left = tmp;
            }
        }
    }
    r->left = t->right;
    l->right = t->left;
    t->left = null.right;
    t->right = null.left;
    return t;
}
					

tree* access_trmc_tail_simp_stack_simp_extract(int i, tree* t) {
    struct tree null;
    tree* l = &null;
    tree* r = &null;
    while(t->item != i) {
        if(t->item < i) {
            if(t->right->item == i) {
                link_left(&l, &t, &r);
            } else if(t->right->item < i) {
                rotate_left(&t);
                link_left(&l, &t, &r);
            } else {
                link_left(&l, &t, &r);
                link_right(&l, &t, &r);
            }
        } else { // t->item > i
            if(t->left->item == i) {
                link_right(&l, &t, &r);
            } else if(t->left->item < i) {
                link_right(&l, &t, &r);
                link_left(&l, &t, &r);
            } else {
                rotate_right(&t);
                link_right(&l, &t, &r);
            }
        }
    }
    r->left = t->right;
    l->right = t->left;
    t->left = null.right;
    t->right = null.left;
    return t;
}
					

Top-down (Sleator & Tarjan)

(Similar to) Okasaki's version


access i t =
  let (l, x, r) = partition i t in case x of
    Node _ x' _ -> Node l x' r

partition i t@(Node l x r) =
  if x == i then (l, t, r)
  else if x < i then case r of
  | Node rl rx rr ->
      if rx == i then (Node l x rl, r, rr)
      else if rx < i then
        let (smaller, item, bigger) = partition i rr in
        (Node (Node l x rl) rx smaller, item, bigger)
      else 
        let (smaller, item, bigger) = partition i rl in
        (Node l x smaller, item, Node bigger rx rr)
   else case l of
   | Node ll lx lr ->
      if lx == i then (ll, l, Node lr x r)
      else if lx < i then 
        let (smaller, item, bigger) = partition i lr in
        (Node ll lx smaller, item, Node bigger x r)
      else 
        let (smaller, item, bigger) = partition i ll in
        (smaller, item, Node bigger lx (Node lr x r))
					

access i t = partition i t Done
partition i t@(Node l x r) acc =
  if x == i then splay l t r acc
  else if x < i then case r of
  | Node rl rx rr ->
      if rx == i then splay (Node l x rl) r rr acc
      else if rx < i then
        partition i rr (RR l x rl rx acc)
      else 
        partition i rl (RL l x rx rr acc)
   else case l of
   | Node ll lx lr ->
      if lx == i then splay ll l (Node lr x r) acc
      else if lx < i then 
        partition i lr (LR ll lx x r acc)
      else 
        partition i ll (LL lx lr x r acc)
splay s t b acc = case acc of
  | Done -> case t of Node _ x _ -> Node l x r
  | RR l x rl rx acc -> splay (Node (Node l x rl) rx s) t b
  | RL l x rx rr acc -> splay (Node l x s) t (Node b rx rr)
  | LR ll lx x r acc -> splay (Node ll lx s) t (Node b x r)
  | LL lx lr x r acc -> splay s t (Node b lx (Node lr x r))
					

splay s t b acc = case acc of
  | Done -> case t of Node _ x _ -> Node l x r
  | RR l x rl rx acc -> splay (Node (Node l x rl) rx s) t b
  | RL l x rx rr acc -> splay (Node l x s) t (Node b rx rr)
  | LR ll lx x r acc -> splay (Node ll lx s) t (Node b x r)
  | LL lx lr x r acc -> splay s t (Node b lx (Node lr x r))
					

splay s t b acc = case acc of
  | Done -> case t of Node _ x _ -> Node l x r
  | R l x (R rl rx acc) -> splay (Node (Node l x rl) rx s) t b
  | R l x (L acc rx rr) -> splay (Node l x s) t (Node b rx rr)
  | L (R ll lx acc) x r -> splay (Node ll lx s) t (Node b x r)
  | L (L acc lx lr) x r -> splay s t (Node b lx (Node lr x r))
					

splay t@(Node s i b) p = case p of
  | Done -> Node s i b
  | R l x (R rl rx acc) ->
      splay (Node (Node (Node l x rl) rx s) i b)
  | R l x (L acc rx rr) ->
      splay (Node (Node l x s) i (Node b rx rr))
  | L (R ll lx acc) x r ->
      splay (Node (Node ll lx s) i (Node b x r))
  | L (L acc lx lr) x r ->
      splay (Node s i (Node b lx (Node lr x r)))
					

splay t@(Node s i b) p = case p of
  | Done -> Node s i b
  | R l x g@(R rl rx acc) ->
      splay (rotate_left (rotate_left (app (app t p) g)))
  | R l x g@(L acc rx rr) ->
      splay (rotate_right (app (rotate_left (app t p)) g))
  | L g@(R ll lx acc) x r ->
      splay (rotate_left (app (rotate_right (app t p)) g))
  | L g@(L acc lx lr) x r ->
      splay (rotate_right (rotate_right (app (app t p) g)))
					

Bottom-up (Sleator & Tarjan)

Completed


  • Splay Trees
  • Zip Tree insertion (treaps, isomorphic to skip lists)
  • Bottom-up red-black-tree insertion

Future Work


  • Top-down red-black trees require extension to TRMC
  • Use equational reasoning to prove equivalence (Relational Hoare logic? Separation logic?)

Questions?