Skip to content

Commit

Permalink
revert snapshot check and minor improvements
Browse files Browse the repository at this point in the history
Signed-off-by: ekexium <[email protected]>
  • Loading branch information
ekexium committed Feb 11, 2025
1 parent 941d94b commit de43ee4
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 44 deletions.
2 changes: 1 addition & 1 deletion internal/unionstore/art/art.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ type ART struct {

// The counter of every write operation, used to invalidate iterators that were created before the write operation.
WriteSeqNo int
// Increased by 1 when an operation that may affect the content returned by "snapshot iter" (i.e. stage[0]) happens.
// Increased by 1 when an operation that may affect the content returned by "snapshot" (i.e. stage[0]) happens.
// It's used to invalidate snapshot iterators.
// invariant: no concurrent access to it
SnapshotSeqNo int
Expand Down
8 changes: 4 additions & 4 deletions internal/unionstore/art/art_snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import (
"github.com/tikv/client-go/v2/internal/unionstore/arena"
)

// GetSnapshot returns the "snapshot" for snapshotGetter or snapshotIterator, which is usually the snapshot
// getSnapshot returns the "snapshot" for snapshotGetter or snapshotIterator, which is usually the snapshot
// of stage[0]
func (t *ART) GetSnapshot() arena.MemDBCheckpoint {
func (t *ART) getSnapshot() arena.MemDBCheckpoint {
if len(t.stages) > 0 {
return t.stages[0]
}
Expand All @@ -34,7 +34,7 @@ func (t *ART) GetSnapshot() arena.MemDBCheckpoint {
func (t *ART) SnapshotGetter() *SnapGetter {
return &SnapGetter{
tree: t,
cp: t.GetSnapshot(),
cp: t.getSnapshot(),
}
}

Expand All @@ -54,7 +54,7 @@ func (t *ART) newSnapshotIterator(start, end []byte, desc bool) *SnapIter {
inner.ignoreSeqNo = true
it := &SnapIter{
Iterator: inner,
cp: t.GetSnapshot(),
cp: t.getSnapshot(),
}
it.tree.allocator.snapshotInc()
for !it.setValue() && it.Valid() {
Expand Down
50 changes: 19 additions & 31 deletions internal/unionstore/memdb_art.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,59 +198,47 @@ func (db *artDBWithContext) SnapshotGetter() Getter {
}

type snapshotBatchedIter struct {
db *artDBWithContext
snapshotTruncateSeqNo int
lower []byte
upper []byte
reverse bool
err error
db *artDBWithContext
snapshotSeqNo int
lower []byte
upper []byte
reverse bool
err error

// current batch
keys [][]byte
values [][]byte
pos int
batchSize int
nextKey []byte

// only used to check if the snapshot ever changes between batches. It is not supposed to change.
snapshot MemDBCheckpoint
}

func (db *artDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator {
if len(db.Stages()) == 0 {
logutil.BgLogger().Error("should not use BatchedSnapshotIter for a memdb without any staging buffer")
}
iter := &snapshotBatchedIter{
db: db,
snapshotTruncateSeqNo: db.SnapshotSeqNo,
lower: lower,
upper: upper,
reverse: reverse,
batchSize: 32,
db: db,
snapshotSeqNo: db.SnapshotSeqNo,
lower: lower,
upper: upper,
reverse: reverse,
batchSize: 32,
}

iter.snapshot = db.GetSnapshot()
iter.err = iter.fillBatch()
return iter
}

func (it *snapshotBatchedIter) fillBatch() error {
if it.snapshotTruncateSeqNo != it.db.SnapshotSeqNo {
if it.snapshotSeqNo != it.db.SnapshotSeqNo {
return errors.Errorf(
"invalid iter: truncation happened, iter's=%d, db's=%d",
it.snapshotTruncateSeqNo,
"invalid iter: snapshotSeqNo changed, iter's=%d, db's=%d",
it.snapshotSeqNo,
it.db.SnapshotSeqNo,
)
}

if it.db.GetSnapshot() != it.snapshot {
return errors.Errorf(
"snapshot changed between batches, expected=%v, actual=%v",
it.snapshot,
it.db.GetSnapshot(),
)
}

it.db.RLock()
defer it.db.RUnlock()

Expand Down Expand Up @@ -321,7 +309,7 @@ func (it *snapshotBatchedIter) fillBatch() error {
}

func (it *snapshotBatchedIter) Valid() bool {
return it.snapshotTruncateSeqNo == it.db.SnapshotSeqNo &&
return it.snapshotSeqNo == it.db.SnapshotSeqNo &&
it.pos < len(it.keys) &&
it.err == nil
}
Expand All @@ -330,11 +318,11 @@ func (it *snapshotBatchedIter) Next() error {
if it.err != nil {
return it.err
}
if it.snapshotTruncateSeqNo != it.db.SnapshotSeqNo {
if it.snapshotSeqNo != it.db.SnapshotSeqNo {
return errors.New(
fmt.Sprintf(
"invalid snapshotBatchedIter: truncation happened, iter's=%d, db's=%d",
it.snapshotTruncateSeqNo,
"invalid snapshotBatchedIter: snapshotSeqNo changed, iter's=%d, db's=%d",
it.snapshotSeqNo,
it.db.SnapshotSeqNo,
),
)
Expand Down
2 changes: 1 addition & 1 deletion internal/unionstore/memdb_rbt.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func (db *rbtDBWithContext) SnapshotGetter() Getter {
}

func (db *rbtDBWithContext) BatchedSnapshotIter(lower, upper []byte, reverse bool) Iterator {
// TODO: implement this
// TODO: implement *batched* iter
if reverse {
return db.SnapshotIterReverse(upper, lower)
} else {
Expand Down
35 changes: 28 additions & 7 deletions internal/unionstore/memdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,8 @@ func TestBatchedSnapshotIter(t *testing.T) {
func TestBatchedSnapshotIterEdgeCase(t *testing.T) {
t.Run("EdgeCases", func(t *testing.T) {
db := newArtDBWithContext()

h := db.Staging()
defer db.Release(h)
// invalid range - should be invalid immediately
iter := db.BatchedSnapshotIter([]byte{1}, []byte{1}, false)
require.False(t, iter.Valid())
Expand All @@ -1453,7 +1454,7 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {
iter.Close()

// Single element range
db.Set([]byte{1}, []byte{1})
_ = db.Set([]byte{1}, []byte{1})
iter = db.BatchedSnapshotIter([]byte{1}, []byte{2}, false)
require.True(t, iter.Valid())
require.Equal(t, []byte{1}, iter.Key())
Expand All @@ -1462,9 +1463,9 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {
iter.Close()

// Multiple elements
db.Set([]byte{2}, []byte{2})
db.Set([]byte{3}, []byte{3})
db.Set([]byte{4}, []byte{4})
_ = db.Set([]byte{2}, []byte{2})
_ = db.Set([]byte{3}, []byte{3})
_ = db.Set([]byte{4}, []byte{4})

// Forward iteration [2,4)
iter = db.BatchedSnapshotIter([]byte{2}, []byte{4}, false)
Expand All @@ -1489,11 +1490,13 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {

t.Run("BoundaryTests", func(t *testing.T) {
db := newArtDBWithContext()
h := db.Staging()
defer db.Release(h)
keys := [][]byte{
{1, 0}, {1, 2}, {1, 4}, {1, 6}, {1, 8},
}
for _, k := range keys {
db.Set(k, k)
_ = db.Set(k, k)
}

// lower bound included
Expand Down Expand Up @@ -1529,6 +1532,8 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {

t.Run("AlphabeticalOrder", func(t *testing.T) {
db := newArtDBWithContext()
h := db.Staging()
defer db.Release(h)
keys := [][]byte{
{2},
{2, 1},
Expand Down Expand Up @@ -1564,8 +1569,10 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {

t.Run("BatchSizeGrowth", func(t *testing.T) {
db := newArtDBWithContext()
h := db.Staging()
defer db.Release(h)
for i := 0; i < 100; i++ {
db.Set([]byte{3, byte(i)}, []byte{3, byte(i)})
_ = db.Set([]byte{3, byte(i)}, []byte{3, byte(i)})
}

// forward
Expand All @@ -1590,4 +1597,18 @@ func TestBatchedSnapshotIterEdgeCase(t *testing.T) {
require.Equal(t, -1, count)
iter.Close()
})

t.Run("SnapshotChange", func(t *testing.T) {
db := newArtDBWithContext()
_ = db.Set([]byte{0}, []byte{0})
h := db.Staging()
_ = db.Set([]byte{byte(1)}, []byte{byte(1)})
iter := db.BatchedSnapshotIter([]byte{0}, []byte{255}, false)
require.True(t, iter.Valid())
require.NoError(t, iter.Next())
db.Release(h)
db.Staging()
require.False(t, iter.Valid())
require.Error(t, iter.Next())
})
}

0 comments on commit de43ee4

Please sign in to comment.