diff --git a/store/store.go b/store/store.go index 78c3704..0be59a3 100644 --- a/store/store.go +++ b/store/store.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "slices" + "sync" "sync/atomic" "time" @@ -53,9 +54,12 @@ type Store[H header.Header[H]] struct { writesDn chan struct{} // writeHead maintains the current write head writeHead atomic.Pointer[H] + + knownHeadersLk sync.Mutex // knownHeaders tracks all processed headers // to advance writeHead only over continuous headers. knownHeaders map[uint64]H + // pending keeps headers pending to be written in one batch pending *batch[H] @@ -322,18 +326,18 @@ func (s *Store[H]) Append(ctx context.Context, headers ...H) error { var err error // take current write head to verify headers against var head H - headPtr := s.writeHead.Load() - if headPtr == nil { + if headPtr := s.writeHead.Load(); headPtr == nil { head, err = s.Head(ctx) if err != nil { return err } + // store header from the disk. + gotHead := head + s.writeHead.CompareAndSwap(nil, &gotHead) } else { head = *headPtr } - continuousHead := head - slices.SortFunc(headers, func(a, b H) int { return cmp.Compare(a.Height(), b.Height()) }) @@ -363,17 +367,15 @@ func (s *Store[H]) Append(ctx context.Context, headers ...H) error { verified = append(verified, h) head = h - if continuousHead.Height()+1 == head.Height() { - continuousHead = head - } else { + { + s.knownHeadersLk.Lock() s.knownHeaders[head.Height()] = head + s.knownHeadersLk.Unlock() } } onWrite := func() { - newHead := s.tryAdvanceHead(continuousHead) - s.writeHead.Store(&newHead) - + newHead := s.tryAdvanceHead() log.Infow("new head", "height", newHead.Height(), "hash", newHead.Hash()) s.metrics.newHead(newHead.Height()) } @@ -519,19 +521,32 @@ func (s *Store[H]) get(ctx context.Context, hash header.Hash) ([]byte, error) { } // try advance heighest header if we saw a higher continuous before. -func (s *Store[H]) tryAdvanceHead(highestHead H) H { - curr := highestHead.Height() +func (s *Store[H]) tryAdvanceHead() H { + s.knownHeadersLk.Lock() + defer s.knownHeadersLk.Unlock() + head := *s.writeHead.Load() + height := head.Height() + currHead := head + + // try to move to the next height. for len(s.knownHeaders) > 0 { - h, ok := s.knownHeaders[curr+1] + h, ok := s.knownHeaders[height+1] if !ok { break } - highestHead = h - delete(s.knownHeaders, curr+1) - curr++ + head = h + delete(s.knownHeaders, height+1) + height++ + } + + // if writeHead not set OR it's height is less then we found then update. + if currHead.Height() < head.Height() { + // we don't need CAS here because that's the only place + // where writeHead is updated, knownHeadersLk ensures 1 goroutine. + s.writeHead.Store(&head) } - return highestHead + return head } // indexTo saves mapping between header Height and Hash to the given batch. diff --git a/store/store_test.go b/store/store_test.go index b98f1b8..fcaefa6 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -2,6 +2,8 @@ package store import ( "context" + "math/rand" + stdsync "sync" "testing" "time" @@ -141,6 +143,52 @@ func TestStore_Append_BadHeader(t *testing.T) { require.Error(t, err) } +func TestStore_Append(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + + suite := headertest.NewTestSuite(t) + + ds := sync.MutexWrap(datastore.NewMapDatastore()) + store := NewTestStore(t, ctx, ds, suite.Head(), WithWriteBatchSize(4)) + + head, err := store.Head(ctx) + require.NoError(t, err) + assert.Equal(t, head.Hash(), suite.Head().Hash()) + + const workers = 10 + const chunk = 5 + headers := suite.GenDummyHeaders(workers * chunk) + + errCh := make(chan error, workers) + var wg stdsync.WaitGroup + wg.Add(workers) + + for i := range workers { + go func() { + defer wg.Done() + // make every append happened in random order. + time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond) + + err := store.Append(ctx, headers[i*chunk:(i+1)*chunk]...) + errCh <- err + }() + } + + wg.Wait() + close(errCh) + for err := range errCh { + assert.NoError(t, err) + } + + // wait for batch to be written. + time.Sleep(100 * time.Millisecond) + + head, err = store.Head(ctx) + assert.NoError(t, err) + assert.Equal(t, head.Hash(), headers[len(headers)-1].Hash()) +} + func TestStore_Append_stableHeadWhenGaps(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) t.Cleanup(cancel) @@ -155,22 +203,11 @@ func TestStore_Append_stableHeadWhenGaps(t *testing.T) { assert.Equal(t, head.Hash(), suite.Head().Hash()) firstChunk := suite.GenDummyHeaders(5) - for i := range firstChunk { - t.Log("firstChunk:", firstChunk[i].Height(), firstChunk[i].Hash()) - } missedChunk := suite.GenDummyHeaders(5) - for i := range missedChunk { - t.Log("missedChunk:", missedChunk[i].Height(), missedChunk[i].Hash()) - } lastChunk := suite.GenDummyHeaders(5) - for i := range lastChunk { - t.Log("lastChunk:", lastChunk[i].Height(), lastChunk[i].Hash()) - } wantHead := firstChunk[len(firstChunk)-1] - t.Log("wantHead", wantHead.Height(), wantHead.Hash()) latestHead := lastChunk[len(lastChunk)-1] - t.Log("latestHead", latestHead.Height(), latestHead.Hash()) { err := store.Append(ctx, firstChunk...) @@ -197,7 +234,6 @@ func TestStore_Append_stableHeadWhenGaps(t *testing.T) { head, err := store.Head(ctx) require.NoError(t, err) assert.Equal(t, head.Height(), wantHead.Height()) - t.Log("head", head.Height(), head.Hash()) assert.Equal(t, head.Hash(), wantHead.Hash()) // check that store height is aligned with the head.