競プロで使えるGoのOrdered Set(Treapの実装)
競技プログラミング(競プロ)において、C++の std::set のような順序集合が必要な場面が出てきますがGoの標準ライブラリには存在しないため、コピペで使える実装とその使い方を書きました。
順序集合のデータ構造はいくつかありますが、std::set のような機能に加え、「K番目の値の取得」や「要素のインデックス取得」 を高速に行いたい時に非常に強力なのが Treap(トリープ) です。
以下のような特徴があります。
- 実装が Red-Black Tree などに比べ軽い
- 平均
- 理論がきれい
- 機能拡張がしやすい
- 「Split(分割)」と「Merge(結合)」をベースにすることで、複雑な区間操作にも応用しやすい
この実装でできること
- 挿入・削除:
Insert,Erase - 検索:
Contains,Min,Max - 順序統計:
Kth(K番目に小さい値),Index(値が何番目か) - 境界検索:
LowerBound,UpperBound,Prev,Next - 集合操作:
Split(特定の値を境に2つに分割),Merge(2つのTreapを結合)
この実装の特徴
- Split/Merge型: 回転(Rotation)を使わない実装のため、コードが比較的シンプルで、永続化などの応用も効きやすい形式です。
- Generics対応:
any型とcmp関数を使用しているため、intだけでなく構造体や文字列など、比較可能なあらゆる型で使用できます。 - 高速な乱数: 標準ライブラリの
math/randよりも高速な Xorshift を採用しています。
実装
type Treap[T any] struct {
root *treapNode[T]
cmp func(a, b T) int
}
func NewTreap[T any](cmp func(a, b T) int) *Treap[T] {
return &Treap[T]{cmp: cmp}
}
func (t *Treap[T]) Insert(key T) {
if t.Contains(key) {
return
}
l, r := tnsplit(t.root, func(node T) int { return t.cmp(node, key) })
t.root = tnmerge(l, tnmerge(newTreapNode(key), r))
}
func (t *Treap[T]) Erase(key T) {
l, r := tnsplit(t.root, func(node T) int { return t.cmp(node, key) })
_, r = tnsplit(r, func(node T) int { return t.cmp(node, key) - 1 })
t.root = tnmerge(l, r)
}
func (t *Treap[T]) Split(key T) (*Treap[T], *Treap[T]) {
l, r := tnsplit(t.root, func(node T) int { return t.cmp(node, key) })
return &Treap[T]{root: l, cmp: t.cmp},
&Treap[T]{root: r, cmp: t.cmp}
}
func (t *Treap[T]) Merge(tr *Treap[T]) *Treap[T] {
return &Treap[T]{
root: tnmerge(t.root, tr.root),
cmp: t.cmp,
}
}
func (t *Treap[T]) LowerBound(key T) (v T, found bool) {
cur := t.root
for cur != nil {
if t.cmp(cur.key, key) >= 0 {
v = cur.key
found = true
cur = cur.left
} else {
cur = cur.right
}
}
return
}
func (t *Treap[T]) UpperBound(key T) (v T, found bool) { return t.Next(key) }
func (t *Treap[T]) Prev(key T) (v T, found bool) {
cur := t.root
for cur != nil {
if t.cmp(cur.key, key) < 0 {
v = cur.key
found = true
cur = cur.right
} else {
cur = cur.left
}
}
return
}
func (t *Treap[T]) Next(key T) (v T, found bool) {
cur := t.root
for cur != nil {
if t.cmp(cur.key, key) > 0 {
v = cur.key
found = true
cur = cur.left
} else {
cur = cur.right
}
}
return
}
func (t *Treap[T]) Index(key T) int {
cur := t.root
idx := 0
for cur != nil {
v := t.cmp(cur.key, key)
if v > 0 {
cur = cur.left
continue
}
idx += tnsz(cur.left)
if v == 0 {
return idx
}
idx++
cur = cur.right
}
return -1
}
func (t *Treap[T]) CountLess(key T) int {
cur := t.root
cnt := 0
for cur != nil {
v := t.cmp(cur.key, key)
if v > 0 {
cur = cur.left
continue
}
cnt += tnsz(cur.left)
if v == 0 {
return cnt
}
cnt++
cur = cur.right
}
return cnt
}
func (t *Treap[T]) Kth(k int) (v T, found bool) {
cur := t.root
for {
if cur == nil || k < 0 || k >= tnsz(cur) {
return
}
l := tnsz(cur.left)
if k < l {
cur = cur.left
} else if k == l {
return cur.key, true
} else {
cur = cur.right
k = k - l - 1
}
}
}
func (t *Treap[T]) Size() int { return tnsz(t.root) }
func (t *Treap[T]) Min() (v T, found bool) { return t.Kth(0) }
func (t *Treap[T]) Max() (v T, found bool) { return t.Kth(t.Size() - 1) }
func (t *Treap[T]) Contains(key T) bool { return t.Index(key) != -1 }
func (t *Treap[T]) Iter(f func(T)) { iter(t.root, f) }
func iter[T any](cur *treapNode[T], f func(T)) {
if cur != nil {
iter(cur.left, f)
f(cur.key)
iter(cur.right, f)
}
}
type treapNode[T any] struct {
key T
priority uint64
size int
left, right *treapNode[T]
}
func newTreapNode[T any](key T) *treapNode[T] {
return &treapNode[T]{
key: key,
priority: nextRand(),
size: 1,
}
}
func (t *treapNode[T]) upd() {
t.size = 1 + tnsz(t.left) + tnsz(t.right)
}
func tnsz[T any](t *treapNode[T]) int {
if t == nil {
return 0
}
return t.size
}
func tnsplit[T any](t *treapNode[T], cmp func(T) int) (l, r *treapNode[T]) {
if t == nil {
return nil, nil
}
if cmp(t.key) >= 0 {
l, t.left = tnsplit(t.left, cmp)
t.upd()
return l, t
} else {
t.right, r = tnsplit(t.right, cmp)
t.upd()
return t, r
}
}
func tnmerge[T any](l, r *treapNode[T]) *treapNode[T] {
if l == nil {
return r
}
if r == nil {
return l
}
if l.priority > r.priority {
l.right = tnmerge(l.right, r)
l.upd()
return l
} else {
r.left = tnmerge(l, r.left)
r.upd()
return r
}
}
var rng uint64 = 88172645463325252
func nextRand() uint64 {
rng ^= rng << 7
rng ^= rng >> 9
return rng
}
使い方
この Treap を使って、基本的な集合操作と統計操作を行う例です。
import (
"cmp"
"fmt"
)
func main() {
// int型のTreapを作成(昇順比較)
t := .NewTreap[int](cmp.Compare[int])
// 要素の挿入
values := []int{10, 5, 20, 15, 25}
for _, v := range values {
t.Insert(v)
}
// サイズの確認
fmt.Println("Size:", t.Size()) // Size: 5
// K番目の要素を取得 (0-indexed)
val, _ := t.Kth(2)
fmt.Println("2nd smallest (index 2):", val) // 15 (5, 10, 15, 20, 25)
// 値のインデックスを取得
fmt.Println("Index of 20:", t.Index(20)) // 3
// LowerBound (12以上の最小の要素)
lb, _ := t.LowerBound(12)
fmt.Println("LowerBound(12):", lb) // 15
// 要素の削除
t.Erase(15)
fmt.Println("Contains 15 after erase:", t.Contains(15)) // false
// 範囲削除(10~20を削除)
l, mid := t.Split(10)
_, r := mid.Split(21)
t = l.Merge(r)
// 全要素の走査(昇順)
fmt.Print("Elements: ") // Elements: 5 25
t.Iter(func(v int) {
fmt.Print(v, " ")
})
fmt.Println()
}