1

我正在 Splay Tree 节点上编写一些代码。不需要太技术性,我想实现一棵基础树和一棵支持左右子树反转的派生树。当前摘录如下所示:

struct node {
  node *f, *c[2];
  int size;
  void push_down() {}
};

struct reversable_node : node {
  int r;
  void push_down() {
    if (r) {
      std::swap(c[0], c[1]);
      c[0]->r ^= 1, c[1]->r ^= 1, r = 0;
    }
  }
};

这显然是行不通的,因为c[0]是 typenode并且没有 member r。不过,我知道c[0]of nodeonly points tonodec[0]of reversable_nodeonly points to reversable_node。所以我可以做一些演员表:

      ((reversable_node *)c[0])->r ^= 1, ((reversable_node *)c[1])->r ^= 1, r = 0;

但这看起来超级笨拙。有没有更好的方法在派生类中也可以在基类中使用自引用指针?

PS 整个代码如下所示:

struct node {
  node *f, *c[2];
  int size;
  node() {
    f = c[0] = c[1] = nullptr;
    size = 1;
  }
  void push_down() {}
  void update() {
    size = 1;
    for (int t = 0; t < 2; ++t)
      if (c[t]) size += c[t]->size;
  }
};

struct reversable_node : node {
  int r;
  reversable_node() : node() { r = 0; }
  void push_down() {
    if (r) {
      std::swap(c[0], c[1]);
      ((reversable_node *)c[0])->r ^= 1, ((reversable_node *)c[1])->r ^= 1, r = 0;
    }
  }
};

template <typename T = node, int MAXSIZE = 500000>
struct tree {
  T pool[MAXSIZE + 2];
  node *root;
  int size;
  tree() {
    size = 2;
    root = pool[0], root->c[1] = pool[1], root->size = 2;
    pool[1]->f = root;
  }
  void rotate(T *n) {
    int v = n->f->c[0] == n;
    node *p = n->f, *m = n->c[v];
    p->push_down(), n->push_down();
    n->c[v] = p, p->f = n, p->c[v ^ 1] = m;
    if (m) m->f = p;
    p->update(), n->update();
  }
  void splay(T *n, T *s = nullptr) {
    while (n->f != s) {
      T *m = n->f, *l = m->f;
      if (l == s)
        rotate(n);
      else if ((l->c[0] == m) == (m->c[0] == n))
        rotate(m), rotate(n);
      else
        rotate(n), rotate(n);
    }
    if (!s) root = n;
  }
  node *new_node() { return pool[size++]; }
  void walk(node *n, int &v, int &pos) {
    n->push_down();
    int s = n->c[0] ? n->c[0]->size : 0;
    (v = s > pos) && (pos -= s + 1);
  }
  void add_node(node *n, int pos) {
    node *c = root;
    int v;
    ++pos;
    do {
      walk(c, v, pos);
    } while (c->c[v] && (c = c->c[v]));
    c->c[v] = n, n->f = cur, splay(n);
  }
  node *find(int pos, int splay = true) {
    node *c = root;
    int v;
    ++pos;
    do {
      walk(c, v, pos);
    } while (pos && (c = c->c[v]));
    if (splay) splay(c);
    return c;
  }
  node *find_range(int posl, int posr) {
    node *l = find(posl - 1), *r = find(posr, false);
    splay(r, l);
    if (r->c[0]) r->c[0]->push_down();
    return r->c[0];
  }
};

所以基本上我们有一个节点是否反转的标志,当我们尝试旋转树时,我们将标志从节点向下推到它的子节点。这可能需要对 Splay Tree 有所了解。

PS2 它应该是一个库,但一些用例会是这样的:

#include "../template.h"

splay::tree<splay::reversable_node> s;

void dfs(splay::reversable_node *n) {
  if (n) {
    // Push down the flag.
    n->push_down();
    dfs(n->c[0]);
    // Do something about n...
    dfs(n->c[1]);
  }
}

int main() {
  // Insert 5 nodes to the Splay Tree.
  for (int i = 0; i < 5; ++i) s.add_node(s.new_node(), 0);
  // Find a range of the tree.
  splay::reversable_node *n = s.find_range(0, 3);
  // Reverse it.
  n->r = 1;
  std::swap(n->c[0], n->c[1]);
  // Traverse it in inorder.
  dfs(s.root);
}
4

1 回答 1

0

无论如何,感谢 CRTP 我让它工作。

namespace splay {

/**
 * Abstract node struct.
 */
template <typename T>
struct node {
  T *f, *c[2];
  int size;
  node() {
    f = c[0] = c[1] = nullptr;
    size = 1;
  }
  void push_down() {}
  void update() {
    size = 1;
    for (int t = 0; t < 2; ++t)
      if (c[t]) size += c[t]->size;
  }
};

/**
 * Abstract reversible node struct.
 */
template <typename T>
struct reversible_node : node<T> {
  int r;
  reversible_node() : node<T>() { r = 0; }
  void push_down() {
    node<T>::push_down();
    if (r) {
      for (int t = 0; t < 2; ++t)
        if (node<T>::c[t]) node<T>::c[t]->reverse();
      r = 0;
    }
  }
  void update() { node<T>::update(); }
  /**
   * Reverse the range of this node.
   */
  void reverse() {
    std::swap(node<T>::c[0], node<T>::c[1]);
    r = r ^ 1;
  }
};

template <typename T, int MAXSIZE = 500000>
struct tree {
  T pool[MAXSIZE + 2];
  T *root;
  int size;
  tree() {
    size = 2;
    root = pool, root->c[1] = pool + 1, root->size = 2;
    pool[1].f = root;
  }
  /**
   * Helper function to rotate node.
   */
  void rotate(T *n) {
    int v = n->f->c[0] == n;
    T *p = n->f, *m = n->c[v];
    if (p->f) p->f->c[p->f->c[1] == p] = n;
    n->f = p->f, n->c[v] = p;
    p->f = n, p->c[v ^ 1] = m;
    if (m) m->f = p;
    p->update(), n->update();
  }
  /**
   * Splay n so that it is under s (or to root if s is null).
   */
  void splay(T *n, T *s = nullptr) {
    while (n->f != s) {
      T *m = n->f, *l = m->f;
      if (l == s)
        rotate(n);
      else if ((l->c[0] == m) == (m->c[0] == n))
        rotate(m), rotate(n);
      else
        rotate(n), rotate(n);
    }
    if (!s) root = n;
  }
  /**
   * Get a new node from the pool.
   */
  T *new_node() { return pool + size++; }
  /**
   * Helper function to walk down the tree.
   */
  int walk(T *n, int &v, int &pos) {
    n->push_down();
    int s = n->c[0] ? n->c[0]->size : 0;
    (v = s < pos) && (pos -= s + 1);
    return s;
  }
  /**
   * Insert node n to position pos.
   */
  void insert(T *n, int pos) {
    T *c = root;
    int v;
    ++pos;
    while (walk(c, v, pos), c->c[v] && (c = c->c[v]))
      ;
    c->c[v] = n, n->f = c, splay(n);
  }
  /**
   * Find the node at position pos. If sp is true, splay it.
   */
  T *find(int pos, int sp = true) {
    T *c = root;
    int v;
    ++pos;
    while ((pos < walk(c, v, pos) || v) && (c = c->c[v]))
      ;
    if (sp) splay(c);
    return c;
  }
  /**
   * Find the range [posl, posr) on the splay tree.
   */
  T *find_range(int posl, int posr) {
    T *l = find(posl - 1), *r = find(posr, false);
    splay(r, l);
    if (r->c[0]) r->c[0]->push_down();
    return r->c[0];
  }
};

}  // namespace splay

一些用例:

struct node : splay::reversible_node<node> {
  int val;
  void push_down() { splay::reversible_node<node>::push_down(); }
  void update() { splay::reversible_node<node>::update(); }
};

splay::tree<node> t;

int N, M;

void inorder(node *n) {
  static int f = 0;
  if (!n) return;
  n->push_down();
  inorder(n->c[0]);
  if (n->val) {
    if (f) printf(" ");
    f = 1;
    printf("%d", n->val);
  }
  inorder(n->c[1]);
}

int main() {
  scanf("%d%d", &N, &M);
  for (int i = 0; i < N; ++i) {
    node *n = t.new_node();
    n->val = i + 1;
    t.insert(n, i);
  }
  for (int i = 0, u, v; i < M; ++i) {
    scanf("%d%d", &u, &v);
    node *n = t.find_range(u - 1, v);
    n->reverse();
  }
  inorder(t.root);
}

希望这能让我在 CP 中更快地编写 Splay。

于 2021-10-31T09:31:11.783 回答