競プロで使えるGoのOrdered Set(Treapの実装)

競技プログラミング(競プロ)において、C++の std::set のような順序集合が必要な場面が出てきますがGoの標準ライブラリには存在しないため、コピペで使える実装とその使い方を書きました。

順序集合のデータ構造はいくつかありますが、std::set のような機能に加え、「K番目の値の取得」や「要素のインデックス取得」 を高速に行いたい時に非常に強力なのが Treap(トリープ) です。 以下のような特徴があります。

  • 実装が Red-Black Tree などに比べ軽い
  • 平均 O(logN)O(log N)
  • 理論がきれい
  • 機能拡張がしやすい
  • 「Split(分割)」と「Merge(結合)」をベースにすることで、複雑な区間操作にも応用しやすい

この実装でできること

  • 挿入・削除: Insert, Erase
  • 検索: Contains, Min, Max
  • 順序統計: Kth (K番目に小さい値), Index (値が何番目か)
  • 境界検索: LowerBound, UpperBound, Prev, Next
  • 集合操作: Split (特定の値を境に2つに分割), Merge (2つのTreapを結合)

この実装の特徴

  1. Split/Merge型: 回転(Rotation)を使わない実装のため、コードが比較的シンプルで、永続化などの応用も効きやすい形式です。
  2. Generics対応: any 型と cmp 関数を使用しているため、int だけでなく構造体や文字列など、比較可能なあらゆる型で使用できます。
  3. 高速な乱数: 標準ライブラリの math/rand よりも高速な Xorshift を採用しています。

実装

type Treap[T any] struct { root *treapNode[T] cmp func(a, b T) int } func NewTreap[T any](cmp func(a, b T) int) *Treap[T] { return &Treap[T]{cmp: cmp} } func (t *Treap[T]) Insert(key T) { if t.Contains(key) { return } l, r := tnsplit(t.root, func(node T) int { return t.cmp(node, key) }) t.root = tnmerge(l, tnmerge(newTreapNode(key), r)) } func (t *Treap[T]) Erase(key T) { l, r := tnsplit(t.root, func(node T) int { return t.cmp(node, key) }) _, r = tnsplit(r, func(node T) int { return t.cmp(node, key) - 1 }) t.root = tnmerge(l, r) } func (t *Treap[T]) Split(key T) (*Treap[T], *Treap[T]) { l, r := tnsplit(t.root, func(node T) int { return t.cmp(node, key) }) return &Treap[T]{root: l, cmp: t.cmp}, &Treap[T]{root: r, cmp: t.cmp} } func (t *Treap[T]) Merge(tr *Treap[T]) *Treap[T] { return &Treap[T]{ root: tnmerge(t.root, tr.root), cmp: t.cmp, } } func (t *Treap[T]) LowerBound(key T) (v T, found bool) { cur := t.root for cur != nil { if t.cmp(cur.key, key) >= 0 { v = cur.key found = true cur = cur.left } else { cur = cur.right } } return } func (t *Treap[T]) UpperBound(key T) (v T, found bool) { return t.Next(key) } func (t *Treap[T]) Prev(key T) (v T, found bool) { cur := t.root for cur != nil { if t.cmp(cur.key, key) < 0 { v = cur.key found = true cur = cur.right } else { cur = cur.left } } return } func (t *Treap[T]) Next(key T) (v T, found bool) { cur := t.root for cur != nil { if t.cmp(cur.key, key) > 0 { v = cur.key found = true cur = cur.left } else { cur = cur.right } } return } func (t *Treap[T]) Index(key T) int { cur := t.root idx := 0 for cur != nil { v := t.cmp(cur.key, key) if v > 0 { cur = cur.left continue } idx += tnsz(cur.left) if v == 0 { return idx } idx++ cur = cur.right } return -1 } func (t *Treap[T]) CountLess(key T) int { cur := t.root cnt := 0 for cur != nil { v := t.cmp(cur.key, key) if v > 0 { cur = cur.left continue } cnt += tnsz(cur.left) if v == 0 { return cnt } cnt++ cur = cur.right } return cnt } func (t *Treap[T]) Kth(k int) (v T, found bool) { cur := t.root for { if cur == nil || k < 0 || k >= tnsz(cur) { return } l := tnsz(cur.left) if k < l { cur = cur.left } else if k == l { return cur.key, true } else { cur = cur.right k = k - l - 1 } } } func (t *Treap[T]) Size() int { return tnsz(t.root) } func (t *Treap[T]) Min() (v T, found bool) { return t.Kth(0) } func (t *Treap[T]) Max() (v T, found bool) { return t.Kth(t.Size() - 1) } func (t *Treap[T]) Contains(key T) bool { return t.Index(key) != -1 } func (t *Treap[T]) Iter(f func(T)) { iter(t.root, f) } func iter[T any](cur *treapNode[T], f func(T)) { if cur != nil { iter(cur.left, f) f(cur.key) iter(cur.right, f) } } type treapNode[T any] struct { key T priority uint64 size int left, right *treapNode[T] } func newTreapNode[T any](key T) *treapNode[T] { return &treapNode[T]{ key: key, priority: nextRand(), size: 1, } } func (t *treapNode[T]) upd() { t.size = 1 + tnsz(t.left) + tnsz(t.right) } func tnsz[T any](t *treapNode[T]) int { if t == nil { return 0 } return t.size } func tnsplit[T any](t *treapNode[T], cmp func(T) int) (l, r *treapNode[T]) { if t == nil { return nil, nil } if cmp(t.key) >= 0 { l, t.left = tnsplit(t.left, cmp) t.upd() return l, t } else { t.right, r = tnsplit(t.right, cmp) t.upd() return t, r } } func tnmerge[T any](l, r *treapNode[T]) *treapNode[T] { if l == nil { return r } if r == nil { return l } if l.priority > r.priority { l.right = tnmerge(l.right, r) l.upd() return l } else { r.left = tnmerge(l, r.left) r.upd() return r } } var rng uint64 = 88172645463325252 func nextRand() uint64 { rng ^= rng << 7 rng ^= rng >> 9 return rng }

使い方

この Treap を使って、基本的な集合操作と統計操作を行う例です。

import ( "cmp" "fmt" ) func main() { // int型のTreapを作成(昇順比較) t := .NewTreap[int](cmp.Compare[int]) // 要素の挿入 values := []int{10, 5, 20, 15, 25} for _, v := range values { t.Insert(v) } // サイズの確認 fmt.Println("Size:", t.Size()) // Size: 5 // K番目の要素を取得 (0-indexed) val, _ := t.Kth(2) fmt.Println("2nd smallest (index 2):", val) // 15 (5, 10, 15, 20, 25) // 値のインデックスを取得 fmt.Println("Index of 20:", t.Index(20)) // 3 // LowerBound (12以上の最小の要素) lb, _ := t.LowerBound(12) fmt.Println("LowerBound(12):", lb) // 15 // 要素の削除 t.Erase(15) fmt.Println("Contains 15 after erase:", t.Contains(15)) // false // 範囲削除(10~20を削除) l, mid := t.Split(10) _, r := mid.Split(21) t = l.Merge(r) // 全要素の走査(昇順) fmt.Print("Elements: ") // Elements: 5 25 t.Iter(func(v int) { fmt.Print(v, " ") }) fmt.Println() }