From 64d64010b1d89c1e5a2b8fb33538975fd605aa7a Mon Sep 17 00:00:00 2001 From: Maximilien Cuony Date: Mon, 15 Jun 2026 15:07:26 +0200 Subject: [PATCH 1/5] [raft/memstore] Add aux memstore --- pkg/aux_/store/memstore/dss.go | 69 ++++++++++++++- pkg/aux_/store/memstore/dss_test.go | 125 ++++++++++++++++++++++++++++ pkg/aux_/store/memstore/store.go | 28 ++++++- 3 files changed, 216 insertions(+), 6 deletions(-) create mode 100644 pkg/aux_/store/memstore/dss_test.go diff --git a/pkg/aux_/store/memstore/dss.go b/pkg/aux_/store/memstore/dss.go index 38fa4d5bf..7198b3936 100644 --- a/pkg/aux_/store/memstore/dss.go +++ b/pkg/aux_/store/memstore/dss.go @@ -2,6 +2,8 @@ package memstore import ( "context" + "database/sql" + "time" auxmodels "github.com/interuss/dss/pkg/aux_/models" dsserr "github.com/interuss/dss/pkg/errors" @@ -9,17 +11,76 @@ import ( ) func (r *repo) SaveOwnMetadata(_ context.Context, locality string, publicEndpoint string) error { - return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SaveOwnMetadata not implemented for memstore") + if locality == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Locality not set") + } + if publicEndpoint == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Public endpoint not set") + } + + r.participants[locality] = &participant{ + publicEndpoint: publicEndpoint, + updatedAt: time.Now(), + } + return nil } func (r *repo) GetDSSMetadata(_ context.Context) ([]*auxmodels.DSSMetadata, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDSSMetadata not implemented for memstore") + metadata := make([]*auxmodels.DSSMetadata, 0, len(r.participants)) + for locality, p := range r.participants { + updatedAt := p.updatedAt + m := &auxmodels.DSSMetadata{ + Locality: locality, + PublicEndpoint: p.publicEndpoint, + UpdatedAt: &updatedAt, + } + + // Find the latest heartbeat across all sources for this locality. + var latest auxmodels.Heartbeat + found := false + for key, hb := range r.heartbeats { + if key.locality != locality { + continue + } + if !found || hb.Timestamp.After(*latest.Timestamp) { + latest = hb + found = true + } + } + + if found { + m.LatestTimestamp.Source = sql.NullString{String: latest.Source, Valid: true} + m.LatestTimestamp.Timestamp = latest.Timestamp + m.LatestTimestamp.NextHeartbeatExpectedBefore = latest.NextHeartbeatExpectedBefore + m.LatestTimestamp.Reporter = sql.NullString{String: latest.Reporter, Valid: true} + } + + metadata = append(metadata, m) + } + return metadata, nil } func (r *repo) RecordHeartbeat(_ context.Context, heartbeat auxmodels.Heartbeat) error { - return stacktrace.NewErrorWithCode(dsserr.NotImplemented, "RecordHeartbeat not implemented for memstore") + if heartbeat.Locality == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Locality not set") + } + if heartbeat.Source == "" { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Source not set") + } + + if heartbeat.Timestamp == nil { + now := time.Now() + heartbeat.Timestamp = &now + } + + if heartbeat.NextHeartbeatExpectedBefore != nil && heartbeat.NextHeartbeatExpectedBefore.Before(*heartbeat.Timestamp) { + return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Cannot expect the timestamp of the next heartbeat before the timestamp of the new heartbeat") + } + + r.heartbeats[heartbeatKey{locality: heartbeat.Locality, source: heartbeat.Source}] = heartbeat + return nil } func (r *repo) GetDSSAirspaceRepresentationID(_ context.Context) (string, error) { - return "", stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDSSAirspaceRepresentationID not implemented for memstore") + return "", stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetDSSAirspaceRepresentationID not implementable for memstore") } diff --git a/pkg/aux_/store/memstore/dss_test.go b/pkg/aux_/store/memstore/dss_test.go new file mode 100644 index 000000000..43563d27b --- /dev/null +++ b/pkg/aux_/store/memstore/dss_test.go @@ -0,0 +1,125 @@ +package memstore + +import ( + "context" + "testing" + "time" + + auxmodels "github.com/interuss/dss/pkg/aux_/models" + dsserr "github.com/interuss/dss/pkg/errors" + "github.com/interuss/stacktrace" + "github.com/stretchr/testify/require" +) + +func TestSaveOwnMetadataValidation(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.SaveOwnMetadata(ctx, "", "https://example.com"))) + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.SaveOwnMetadata(ctx, "dss-1", ""))) +} + +func TestSaveOwnMetadataRoundTrip(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.Equal(t, "dss-1", md[0].Locality) + require.Equal(t, "https://example.com", md[0].PublicEndpoint) + require.NotNil(t, md[0].UpdatedAt) + + // No heartbeat recorded yet. + require.False(t, md[0].LatestTimestamp.Source.Valid) + require.Nil(t, md[0].LatestTimestamp.Timestamp) +} + +func TestSaveOwnMetadataUpsert(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://old.example.com")) + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://new.example.com")) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.Equal(t, "https://new.example.com", md[0].PublicEndpoint) +} + +func TestRecordHeartbeatValidation(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Source: "source1"}))) + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1"}))) + + ts := time.Now() + before := ts.Add(-time.Minute) + err := r.RecordHeartbeat(ctx, auxmodels.Heartbeat{ + Locality: "dss-1", + Source: "source1", + Timestamp: &ts, + NextHeartbeatExpectedBefore: &before, + }) + + require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(err)) +} + +func TestRecordHeartbeatDefaultsTimestamp(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1"})) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.True(t, md[0].LatestTimestamp.Source.Valid) + require.NotNil(t, md[0].LatestTimestamp.Timestamp) +} + +func TestGetDSSMetadataPicksLatestHeartbeat(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + older := time.Now().Add(-time.Hour) + newer := time.Now() + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1", Timestamp: &older, Reporter: "uss1"})) + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source2", Timestamp: &newer, Reporter: "uss2"})) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.True(t, md[0].LatestTimestamp.Timestamp.Equal(newer)) + require.Equal(t, "source2", md[0].LatestTimestamp.Source.String) + require.Equal(t, "uss2", md[0].LatestTimestamp.Reporter.String) +} + +func TestGetDSSMetadataUpdatesHeartbeatPerSource(t *testing.T) { + ctx := context.Background() + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + first := time.Now().Add(-time.Hour) + second := time.Now() + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1", Timestamp: &first})) + require.NoError(t, r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source1", Timestamp: &second})) + + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + + require.Len(t, md, 1) + require.True(t, md[0].LatestTimestamp.Timestamp.Equal(second)) +} diff --git a/pkg/aux_/store/memstore/store.go b/pkg/aux_/store/memstore/store.go index c2b875f9b..f94a1ffc2 100644 --- a/pkg/aux_/store/memstore/store.go +++ b/pkg/aux_/store/memstore/store.go @@ -2,17 +2,41 @@ package memstore import ( "context" + "time" + auxmodels "github.com/interuss/dss/pkg/aux_/models" "github.com/interuss/dss/pkg/aux_/repos" "github.com/interuss/dss/pkg/memstore" "go.uber.org/zap" ) // repo is a full implementation of aux_.repos.Repository for memory-based storage. -type repo struct{} +type repo struct { + // participants holds pool participants metadata, keyed by locality. + participants map[string]*participant + // heartbeats holds the latest heartbeat per (locality, source). + heartbeats map[heartbeatKey]auxmodels.Heartbeat +} + +type participant struct { + publicEndpoint string + updatedAt time.Time +} + +type heartbeatKey struct { + locality string + source string +} + +func newRepo() *repo { + return &repo{ + participants: map[string]*participant{}, + heartbeats: map[heartbeatKey]auxmodels.Heartbeat{}, + } +} func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Repository], error) { - return memstore.Init(ctx, logger, "aux_", &repo{}) + return memstore.Init(ctx, logger, "aux_", newRepo()) } func (r *repo) GetRepo() repos.Repository { return r } From 49f28517b751b39a1e400fa22577e5735f72193c Mon Sep 17 00:00:00 2001 From: Maximilien Cuony Date: Tue, 16 Jun 2026 11:05:14 +0200 Subject: [PATCH 2/5] [raft/memstore] Add snapshop capability to aux memstore --- pkg/aux_/store/memstore/dss.go | 22 ++++----- pkg/aux_/store/memstore/snapshot.go | 35 ++++++++++++++ pkg/aux_/store/memstore/snapshot_test.go | 59 ++++++++++++++++++++++++ pkg/aux_/store/memstore/store.go | 24 +++++++--- pkg/memstore/store.go | 2 + pkg/rid/store/memstore/snapshot.go | 13 ++++++ pkg/scd/store/memstore/snapshot.go | 13 ++++++ 7 files changed, 150 insertions(+), 18 deletions(-) create mode 100644 pkg/aux_/store/memstore/snapshot.go create mode 100644 pkg/aux_/store/memstore/snapshot_test.go create mode 100644 pkg/rid/store/memstore/snapshot.go create mode 100644 pkg/scd/store/memstore/snapshot.go diff --git a/pkg/aux_/store/memstore/dss.go b/pkg/aux_/store/memstore/dss.go index 7198b3936..79aeeee3a 100644 --- a/pkg/aux_/store/memstore/dss.go +++ b/pkg/aux_/store/memstore/dss.go @@ -18,28 +18,28 @@ func (r *repo) SaveOwnMetadata(_ context.Context, locality string, publicEndpoin return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Public endpoint not set") } - r.participants[locality] = &participant{ - publicEndpoint: publicEndpoint, - updatedAt: time.Now(), + r.state.Participants[locality] = &participant{ + PublicEndpoint: publicEndpoint, + UpdatedAt: time.Now().UTC(), } return nil } func (r *repo) GetDSSMetadata(_ context.Context) ([]*auxmodels.DSSMetadata, error) { - metadata := make([]*auxmodels.DSSMetadata, 0, len(r.participants)) - for locality, p := range r.participants { - updatedAt := p.updatedAt + metadata := make([]*auxmodels.DSSMetadata, 0, len(r.state.Participants)) + for locality, p := range r.state.Participants { + updatedAt := p.UpdatedAt m := &auxmodels.DSSMetadata{ Locality: locality, - PublicEndpoint: p.publicEndpoint, + PublicEndpoint: p.PublicEndpoint, UpdatedAt: &updatedAt, } // Find the latest heartbeat across all sources for this locality. var latest auxmodels.Heartbeat found := false - for key, hb := range r.heartbeats { - if key.locality != locality { + for key, hb := range r.state.Heartbeats { + if key.Locality != locality { continue } if !found || hb.Timestamp.After(*latest.Timestamp) { @@ -69,7 +69,7 @@ func (r *repo) RecordHeartbeat(_ context.Context, heartbeat auxmodels.Heartbeat) } if heartbeat.Timestamp == nil { - now := time.Now() + now := time.Now().UTC() heartbeat.Timestamp = &now } @@ -77,7 +77,7 @@ func (r *repo) RecordHeartbeat(_ context.Context, heartbeat auxmodels.Heartbeat) return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Cannot expect the timestamp of the next heartbeat before the timestamp of the new heartbeat") } - r.heartbeats[heartbeatKey{locality: heartbeat.Locality, source: heartbeat.Source}] = heartbeat + r.state.Heartbeats[heartbeatKey{Locality: heartbeat.Locality, Source: heartbeat.Source}] = heartbeat return nil } diff --git a/pkg/aux_/store/memstore/snapshot.go b/pkg/aux_/store/memstore/snapshot.go new file mode 100644 index 000000000..15deb274b --- /dev/null +++ b/pkg/aux_/store/memstore/snapshot.go @@ -0,0 +1,35 @@ +package memstore + +import ( + "bytes" + "encoding/gob" + + "github.com/interuss/stacktrace" +) + +const snapshotVersion = 1 + +type snapshotEnvelope struct { + Version int + State state +} + +func (r *repo) GetSnapshot() ([]byte, error) { + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion, State: r.state}); err != nil { + return nil, stacktrace.Propagate(err, "Failed to encode memstore snapshot") + } + return buf.Bytes(), nil +} + +func (r *repo) RestoreFromSnapshot(data []byte) error { + var env snapshotEnvelope + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&env); err != nil { + return stacktrace.Propagate(err, "Failed to decode memstore snapshot") + } + if env.Version != snapshotVersion { + return stacktrace.NewError("Unsupported memstore snapshot version %d, expected %d", env.Version, snapshotVersion) + } + r.state = env.State + return nil +} diff --git a/pkg/aux_/store/memstore/snapshot_test.go b/pkg/aux_/store/memstore/snapshot_test.go new file mode 100644 index 000000000..2085e2488 --- /dev/null +++ b/pkg/aux_/store/memstore/snapshot_test.go @@ -0,0 +1,59 @@ +package memstore + +import ( + "bytes" + "context" + "encoding/gob" + "testing" + "time" + + auxmodels "github.com/interuss/dss/pkg/aux_/models" + "github.com/stretchr/testify/require" +) + +func TestSnapshotRoundTrip(t *testing.T) { + ctx := context.Background() + src := newRepo() + require.NoError(t, src.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + ts := time.Now().UTC() + require.NoError(t, src.RecordHeartbeat(ctx, auxmodels.Heartbeat{Locality: "dss-1", Source: "source-1", Timestamp: &ts, Reporter: "uss-1"})) + + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := newRepo() + require.NoError(t, dst.RestoreFromSnapshot(data)) + + want, err := src.GetDSSMetadata(ctx) + require.NoError(t, err) + got, err := dst.GetDSSMetadata(ctx) + require.NoError(t, err) + require.Equal(t, want, got) +} + +func TestRestoreFromSnapshotReplacesState(t *testing.T) { + ctx := context.Background() + src := newRepo() + require.NoError(t, src.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := newRepo() + require.NoError(t, dst.SaveOwnMetadata(ctx, "dss-2", "https://other.example.com")) + require.NoError(t, dst.RestoreFromSnapshot(data)) + + md, err := dst.GetDSSMetadata(ctx) + require.NoError(t, err) + require.Len(t, md, 1) + require.Equal(t, "dss-1", md[0].Locality) +} + +func TestRestoreFromSnapshotInvalidData(t *testing.T) { + require.Error(t, newRepo().RestoreFromSnapshot([]byte("random value that is definitely not valid"))) +} + +func TestRestoreFromSnapshotVersionMismatch(t *testing.T) { + var buf bytes.Buffer + require.NoError(t, gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion + 1})) + require.Error(t, newRepo().RestoreFromSnapshot(buf.Bytes())) +} diff --git a/pkg/aux_/store/memstore/store.go b/pkg/aux_/store/memstore/store.go index f94a1ffc2..52ed59d12 100644 --- a/pkg/aux_/store/memstore/store.go +++ b/pkg/aux_/store/memstore/store.go @@ -12,6 +12,15 @@ import ( // repo is a full implementation of aux_.repos.Repository for memory-based storage. type repo struct { + state state +} + +// state is the serializable in-memory state. +type state struct { + // Participants holds pool participants metadata, keyed by locality. + Participants map[string]*participant + // Heartbeats holds the latest heartbeat per (locality, source). + Heartbeats map[heartbeatKey]auxmodels.Heartbeat // participants holds pool participants metadata, keyed by locality. participants map[string]*participant // heartbeats holds the latest heartbeat per (locality, source). @@ -19,20 +28,21 @@ type repo struct { } type participant struct { - publicEndpoint string - updatedAt time.Time + PublicEndpoint string + UpdatedAt time.Time } type heartbeatKey struct { - locality string - source string + Locality string + Source string } func newRepo() *repo { return &repo{ - participants: map[string]*participant{}, - heartbeats: map[heartbeatKey]auxmodels.Heartbeat{}, - } + state: state{ + Participants: map[string]*participant{}, + Heartbeats: map[heartbeatKey]auxmodels.Heartbeat{}, + }} } func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Repository], error) { diff --git a/pkg/memstore/store.go b/pkg/memstore/store.go index 9c850f2c3..2fe7a12e6 100644 --- a/pkg/memstore/store.go +++ b/pkg/memstore/store.go @@ -18,6 +18,8 @@ import ( type MemRepo[R any] interface { GetRepo() R + GetSnapshot() ([]byte, error) + RestoreFromSnapshot([]byte) error } type Store[R any] struct { diff --git a/pkg/rid/store/memstore/snapshot.go b/pkg/rid/store/memstore/snapshot.go new file mode 100644 index 000000000..dea64e6a4 --- /dev/null +++ b/pkg/rid/store/memstore/snapshot.go @@ -0,0 +1,13 @@ +package memstore + +import ( + "github.com/interuss/stacktrace" +) + +func (r *repo) GetSnapshot() ([]byte, error) { + return nil, stacktrace.NewError("GetSnapshot not yet implemented for rid") +} + +func (r *repo) RestoreFromSnapshot(data []byte) error { + return stacktrace.NewError("RestoreFromSnapshot not yet implemented for rid") +} diff --git a/pkg/scd/store/memstore/snapshot.go b/pkg/scd/store/memstore/snapshot.go new file mode 100644 index 000000000..dea64e6a4 --- /dev/null +++ b/pkg/scd/store/memstore/snapshot.go @@ -0,0 +1,13 @@ +package memstore + +import ( + "github.com/interuss/stacktrace" +) + +func (r *repo) GetSnapshot() ([]byte, error) { + return nil, stacktrace.NewError("GetSnapshot not yet implemented for rid") +} + +func (r *repo) RestoreFromSnapshot(data []byte) error { + return stacktrace.NewError("RestoreFromSnapshot not yet implemented for rid") +} From 3d305b75e55ce7761b16e4d47a137ab02fbdad78 Mon Sep 17 00:00:00 2001 From: Maximilien Cuony Date: Tue, 16 Jun 2026 11:06:01 +0200 Subject: [PATCH 3/5] [raft/memstore] Add rid memstore --- go.mod | 5 +- .../memstore/identification_service_area.go | 129 ++++++- .../identification_service_area_test.go | 321 ++++++++++++++++ pkg/rid/store/memstore/snapshot.go | 34 +- pkg/rid/store/memstore/snapshot_test.go | 82 ++++ pkg/rid/store/memstore/store.go | 124 +++++- pkg/rid/store/memstore/store_test.go | 49 +++ pkg/rid/store/memstore/subscriptions.go | 187 ++++++++- pkg/rid/store/memstore/subscriptions_test.go | 360 ++++++++++++++++++ 9 files changed, 1263 insertions(+), 28 deletions(-) create mode 100644 pkg/rid/store/memstore/identification_service_area_test.go create mode 100644 pkg/rid/store/memstore/snapshot_test.go create mode 100644 pkg/rid/store/memstore/store_test.go create mode 100644 pkg/rid/store/memstore/subscriptions_test.go diff --git a/go.mod b/go.mod index a8fa1da50..d0722521b 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/go-jose/go-jose/v4 v4.1.4 github.com/golang-jwt/jwt/v4 v4.5.2 github.com/golang/geo v0.0.0-20230421003525-6adc56603217 + github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/interuss/stacktrace v1.0.0 github.com/jackc/pgx/v5 v5.9.2 @@ -27,11 +28,13 @@ require ( go.opentelemetry.io/otel v1.43.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0 go.opentelemetry.io/otel/exporters/prometheus v0.65.0 + go.opentelemetry.io/otel/metric v1.43.0 go.opentelemetry.io/otel/sdk v1.43.0 go.opentelemetry.io/otel/sdk/metric v1.43.0 go.opentelemetry.io/otel/trace v1.43.0 go.uber.org/multierr v1.11.0 go.uber.org/zap v1.27.0 + golang.org/x/sync v0.20.0 ) require ( @@ -71,13 +74,11 @@ require ( go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 // indirect - go.opentelemetry.io/otel/metric v1.43.0 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect go.yaml.in/yaml/v2 v2.4.4 // indirect golang.org/x/crypto v0.49.0 // indirect golang.org/x/net v0.52.0 // indirect golang.org/x/oauth2 v0.34.0 // indirect - golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.36.0 // indirect golang.org/x/time v0.14.0 // indirect diff --git a/pkg/rid/store/memstore/identification_service_area.go b/pkg/rid/store/memstore/identification_service_area.go index 3bd315fb7..dde19d9ff 100644 --- a/pkg/rid/store/memstore/identification_service_area.go +++ b/pkg/rid/store/memstore/identification_service_area.go @@ -11,30 +11,141 @@ import ( "github.com/interuss/stacktrace" ) -func (r *repo) GetISA(_ context.Context, id dssmodels.ID, forUpdate bool) (*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetISA not implemented for memstore") +func isaRecordFromModel(isa *ridmodels.IdentificationServiceArea, updatedAt time.Time) *isaRecord { + return &isaRecord{ + ID: isa.ID, + URL: isa.URL, + Owner: isa.Owner, + Cells: cloneCells(isa.Cells), + StartTime: cloneTime(isa.StartTime), + EndTime: cloneTime(isa.EndTime), + AltitudeHi: cloneFloat32(isa.AltitudeHi), + AltitudeLo: cloneFloat32(isa.AltitudeLo), + Writer: isa.Writer, + UpdatedAt: updatedAt, + } } -func (r *repo) DeleteISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteISA not implemented for memstore") +// toModel rebuilds the ISA model +func (rec *isaRecord) toModel() *ridmodels.IdentificationServiceArea { + return &ridmodels.IdentificationServiceArea{ + ID: rec.ID, + URL: rec.URL, + Owner: rec.Owner, + Cells: cloneCells(rec.Cells), + StartTime: cloneTime(rec.StartTime), + EndTime: cloneTime(rec.EndTime), + Version: dssmodels.VersionFromTime(rec.UpdatedAt), + AltitudeHi: cloneFloat32(rec.AltitudeHi), + AltitudeLo: cloneFloat32(rec.AltitudeLo), + Writer: rec.Writer, + } +} + +func (r *repo) GetISA(_ context.Context, id dssmodels.ID, _ bool) (*ridmodels.IdentificationServiceArea, error) { + rec, ok := r.state.ISAs[id] + if !ok { + return nil, nil + } + return rec.toModel(), nil } func (r *repo) InsertISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "InsertISA not implemented for memstore") + if err := validateWriteData(isa.Cells, isa.StartTime, isa.EndTime); err != nil { + return nil, err + } + if _, ok := r.state.ISAs[isa.ID]; ok { + return nil, stacktrace.NewError("ISA with id %s already exists", isa.ID) + } + rec := isaRecordFromModel(isa, r.clock.Now()) + r.state.ISAs[isa.ID] = rec + return rec.toModel(), nil } func (r *repo) UpdateISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpdateISA not implemented for memstore") + if err := validateWriteData(isa.Cells, isa.StartTime, isa.EndTime); err != nil { + return nil, err + } + prev, ok := r.state.ISAs[isa.ID] + if !ok { + return nil, nil + } + if !dssmodels.VersionFromTime(prev.UpdatedAt).Matches(isa.Version) { + return nil, nil + } + rec := isaRecordFromModel(isa, r.clock.Now()) + rec.Owner = prev.Owner + r.state.ISAs[isa.ID] = rec + return rec.toModel(), nil +} + +func (r *repo) DeleteISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { + rec, ok := r.state.ISAs[isa.ID] + if !ok { + return nil, nil + } + if !dssmodels.VersionFromTime(rec.UpdatedAt).Matches(isa.Version) { + return nil, nil + } + out := rec.toModel() + delete(r.state.ISAs, isa.ID) + return out, nil } func (r *repo) SearchISAs(_ context.Context, cells s2.CellUnion, earliest *time.Time, latest *time.Time) ([]*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchISAs not implemented for memstore") + if len(cells) == 0 { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "Missing cell IDs for query") + } + if earliest == nil { + return nil, stacktrace.NewError("Earliest start time is missing") + } + + want := cellSet(cells) + var out []*ridmodels.IdentificationServiceArea + for _, rec := range r.state.ISAs { + // ends_at >= earliest + if rec.EndTime == nil || rec.EndTime.Before(*earliest) { + continue + } + // COALESCE(starts_at <= latest, true) + if latest != nil && rec.StartTime != nil && rec.StartTime.After(*latest) { + continue + } + if !overlaps(rec.Cells, want) { + continue + } + out = append(out, rec.toModel()) + + if len(out) > dssmodels.MaxResultLimit { // This miminc sqlstore behaviour, but it's not very good. + break + } + } + return out, nil } func (r *repo) ListExpiredISAs(_ context.Context, writer string, threshold time.Time) ([]*ridmodels.IdentificationServiceArea, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "ListExpiredISAs not implemented for memstore") + var out []*ridmodels.IdentificationServiceArea + for _, rec := range r.state.ISAs { + // ends_at <= threshold + if rec.EndTime == nil || rec.EndTime.After(threshold) { + continue + } + if writer == "" { + if rec.Writer != "" { + continue + } + } else if rec.Writer != writer { + continue + } + out = append(out, rec.toModel()) + + if len(out) > dssmodels.MaxResultLimit { // This miminc sqlstore behaviour, but it's not very good. + break + } + } + return out, nil } func (r *repo) CountISAs(_ context.Context) (int64, error) { - return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountISAs not implemented for memstore") + return int64(len(r.state.ISAs)), nil } diff --git a/pkg/rid/store/memstore/identification_service_area_test.go b/pkg/rid/store/memstore/identification_service_area_test.go new file mode 100644 index 000000000..70699e415 --- /dev/null +++ b/pkg/rid/store/memstore/identification_service_area_test.go @@ -0,0 +1,321 @@ +package memstore + +import ( + "context" + "testing" + "time" + + "github.com/golang/geo/s2" + "github.com/google/uuid" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/rid/repos" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +var ( + // Ensure the struct conforms to the interface + _ repos.ISA = &repo{} + overflow = uint64(17106221850767130624) // face 5 L13 overflows + serviceArea = &ridmodels.IdentificationServiceArea{ + ID: dssmodels.ID(uuid.New().String()), + Owner: dssmodels.Owner(uuid.New().String()), + URL: "https://no/place/like/home/for/flights", + StartTime: &startTime, + EndTime: &endTime, + Writer: writer, + Cells: s2.CellUnion{ + s2.CellID(uint64(overflow)), + s2.CellID(17106221850767130624), + }, + } +) + +func TestStoreSearchISAs(t *testing.T) { + ctx := context.Background() + cells := s2.CellUnion{ + s2.CellID(17106221850767130624), + s2.CellID(17106221885126868992), + s2.CellID(17106221919486607360), + s2.CellID(uint64(overflow)), + } + repo := setUpStore(t) + + isa := *serviceArea + isa.Cells = cells + saOut, err := repo.InsertISA(ctx, &isa) + require.NoError(t, err) + require.NotNil(t, saOut) + require.Equal(t, isa.ID, saOut.ID) + + for _, r := range []struct { + name string + cells s2.CellUnion + timestampMutator func(time.Time, time.Time) (*time.Time, *time.Time) + expectedLen int + }{ + { + name: "search for empty cell", + cells: s2.CellUnion{s2.CellID(17106221953846345728)}, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, nil + }, + expectedLen: 0, + }, + { + name: "search for only one cell", + cells: s2.CellUnion{s2.CellID(17106221850767130624)}, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, nil + }, + expectedLen: 1, + }, + { + name: "search for only one cell with high bit set", + cells: s2.CellUnion{s2.CellID(uint64(overflow))}, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, nil + }, + expectedLen: 1, + }, + { + name: "search with nil ends_at", + cells: cells, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, nil + }, + expectedLen: 1, + }, + { + name: "search with exact timestamps", + cells: cells, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + return &start, &end + }, + expectedLen: 1, + }, + { + name: "search with non-matching time span", + cells: cells, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + var ( + offset = time.Duration(100 * time.Second) + earliest = end.Add(offset) + latest = end.Add(offset * 2) + ) + + return &earliest, &latest + }, + expectedLen: 0, + }, + { + name: "search with expanded time span", + cells: cells, + timestampMutator: func(start time.Time, end time.Time) (*time.Time, *time.Time) { + var ( + offset = time.Duration(100 * time.Second) + earliest = start.Add(-offset) + latest = end.Add(offset) + ) + + return &earliest, &latest + }, + expectedLen: 1, + }, + } { + t.Run(r.name, func(t *testing.T) { + earliest, latest := r.timestampMutator(*saOut.StartTime, *saOut.EndTime) + + serviceAreas, err := repo.SearchISAs(ctx, r.cells, earliest, latest) + require.NoError(t, err) + require.Len(t, serviceAreas, r.expectedLen) + }) + } +} + +func TestBadVersion(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + saOut1, err := repo.InsertISA(ctx, serviceArea) + require.NoError(t, err) + require.NotNil(t, saOut1) + + // Rewriting service area should fail + saOut2, err := repo.UpdateISA(ctx, serviceArea) + require.NoError(t, err) + require.Nil(t, saOut2) + + // Rewriting, but with the correct version should work. + newEndTime := saOut1.EndTime.Add(time.Minute) + saOut1.EndTime = &newEndTime + saOut3, err := repo.UpdateISA(ctx, saOut1) + require.NoError(t, err) + require.NotNil(t, saOut3) +} + +func TestStoreExpiredISA(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + saOut, err := repo.InsertISA(ctx, serviceArea) + require.NoError(t, err) + require.NotNil(t, saOut) + + // The ISA's endTime is one hour from now. + fakeClock.Advance(59 * time.Minute) + + // We should still be able to find the ISA by searching and by ID. + now := fakeClock.Now() + serviceAreas, err := repo.SearchISAs(ctx, serviceArea.Cells, &now, nil) + require.NoError(t, err) + require.Len(t, serviceAreas, 1) + + ret, err := repo.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + require.NotNil(t, ret) + + // But now the ISA has expired. + fakeClock.Advance(2 * time.Minute) + now = fakeClock.Now() + + serviceAreas, err = repo.SearchISAs(ctx, serviceArea.Cells, &now, nil) + require.NoError(t, err) + require.Len(t, serviceAreas, 0) + + // A get should work even if it is expired. + ret, err = repo.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + require.NotNil(t, ret) +} + +func TestStoreDeleteISAs(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + // Insert the ISA. + copy := *serviceArea + isa, err := repo.InsertISA(ctx, ©) + require.NoError(t, err) + require.NotNil(t, isa) + + // Delete the ISA. + // Ensure a fresh Get, then delete still updates the sub indexes + isa, err = repo.GetISA(ctx, isa.ID, false) + require.NoError(t, err) + + serviceAreaOut, err := repo.DeleteISA(ctx, isa) + require.NoError(t, err) + require.Equal(t, isa, serviceAreaOut) +} + +func TestStoreISAWithNoGeoData(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + endTime := fakeClock.Now().Add(24 * time.Hour) + sub := &ridmodels.IdentificationServiceArea{ + ID: dssmodels.ID(uuid.New().String()), + Owner: dssmodels.Owner("original owner"), + EndTime: &endTime, + } + _, err := repo.InsertISA(ctx, sub) + require.Error(t, err) +} + +func TestListExpiredISAs(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // Insert ISA with endtime 1 day from now + isa1 := *serviceArea + startTime := fakeClock.Now() + isa1.StartTime = &startTime + endTime := fakeClock.Now().Add(24 * time.Hour) + isa1.EndTime = &endTime + saOut1, err := repo.InsertISA(ctx, &isa1) + require.NoError(t, err) + require.NotNil(t, saOut1) + + // Insert ISA with endtime to 30 minutes ago + isa2 := *serviceArea + startTime = fakeClock.Now().Add(-1 * time.Hour) + isa2.StartTime = &startTime + endTime = fakeClock.Now().Add(-30 * time.Minute) + isa2.EndTime = &endTime + isa2.ID = dssmodels.ID(uuid.New().String()) + saOut2, err := repo.InsertISA(ctx, &isa2) + require.NoError(t, err) + require.NotNil(t, saOut2) + + serviceAreas, err := repo.ListExpiredISAs(ctx, writer, fakeClock.Now().Add(-30*time.Minute)) + require.NoError(t, err) + require.Len(t, serviceAreas, 1) +} + +func TestListExpiredISAsWithEmptyWriter(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // Insert ISA with endtime 1 day from now + isa1 := *serviceArea + startTime := fakeClock.Now() + isa1.StartTime = &startTime + endTime := fakeClock.Now().Add(24 * time.Hour) + isa1.EndTime = &endTime + isa1.Writer = "" + saOut1, err := repo.InsertISA(ctx, &isa1) + require.NoError(t, err) + require.NotNil(t, saOut1) + + // Insert ISA with endtime to 30 minutes ago + isa2 := *serviceArea + startTime = fakeClock.Now().Add(-1 * time.Hour) + isa2.StartTime = &startTime + endTime = fakeClock.Now().Add(-30 * time.Minute) + isa2.EndTime = &endTime + isa2.ID = dssmodels.ID(uuid.New().String()) + isa2.Writer = "" + saOut2, err := repo.InsertISA(ctx, &isa2) + require.NoError(t, err) + require.NotNil(t, saOut2) + + serviceAreas, err := repo.ListExpiredISAs(ctx, "", fakeClock.Now().Add(-30*time.Minute)) + require.NoError(t, err) + require.Len(t, serviceAreas, 1) +} + +func TestStoreCountISAs(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + // Insert the ISA. + copy := *serviceArea + isa, err := repo.InsertISA(ctx, ©) + require.NoError(t, err) + require.NotNil(t, isa) + + //Cound should be one + count, err := repo.CountISAs(ctx) + require.NoError(t, err) + require.Equal(t, count, int64(1)) + + // Delete the ISA. + // Ensure a fresh Get, then delete still updates the sub indexes + isa, err = repo.GetISA(ctx, isa.ID, false) + require.NoError(t, err) + + serviceAreaOut, err := repo.DeleteISA(ctx, isa) + require.NoError(t, err) + require.Equal(t, isa, serviceAreaOut) + + //Cound should be zero + count, err = repo.CountISAs(ctx) + require.NoError(t, err) + require.Equal(t, count, int64(0)) +} diff --git a/pkg/rid/store/memstore/snapshot.go b/pkg/rid/store/memstore/snapshot.go index dea64e6a4..fe2aa2c70 100644 --- a/pkg/rid/store/memstore/snapshot.go +++ b/pkg/rid/store/memstore/snapshot.go @@ -1,13 +1,43 @@ package memstore import ( + "bytes" + "encoding/gob" + + dssmodels "github.com/interuss/dss/pkg/models" "github.com/interuss/stacktrace" ) +const snapshotVersion = 1 + +type snapshotEnvelope struct { + Version int + State state +} + func (r *repo) GetSnapshot() ([]byte, error) { - return nil, stacktrace.NewError("GetSnapshot not yet implemented for rid") + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion, State: r.state}); err != nil { + return nil, stacktrace.Propagate(err, "Failed to encode memstore snapshot") + } + return buf.Bytes(), nil } func (r *repo) RestoreFromSnapshot(data []byte) error { - return stacktrace.NewError("RestoreFromSnapshot not yet implemented for rid") + var env snapshotEnvelope + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&env); err != nil { + return stacktrace.Propagate(err, "Failed to decode memstore snapshot") + } + if env.Version != snapshotVersion { + return stacktrace.NewError("Unsupported memstore snapshot version %d, expected %d", env.Version, snapshotVersion) + } + r.state = env.State + // gob decodes an empty map as nil; re-initialize to keep the repo writable. + if r.state.ISAs == nil { + r.state.ISAs = map[dssmodels.ID]*isaRecord{} + } + if r.state.Subscriptions == nil { + r.state.Subscriptions = map[dssmodels.ID]*subscriptionRecord{} + } + return nil } diff --git a/pkg/rid/store/memstore/snapshot_test.go b/pkg/rid/store/memstore/snapshot_test.go new file mode 100644 index 000000000..a9764f812 --- /dev/null +++ b/pkg/rid/store/memstore/snapshot_test.go @@ -0,0 +1,82 @@ +package memstore + +import ( + "bytes" + "context" + "encoding/gob" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/interuss/dss/pkg/models" + "github.com/stretchr/testify/require" +) + +func TestSnapshotRoundTrip(t *testing.T) { + ctx := context.Background() + src := setUpStore(t) + _, err := src.InsertISA(ctx, serviceArea) + require.NoError(t, err) + _, err = src.InsertSubscription(ctx, subscriptionsPool[0].input) + require.NoError(t, err) + + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := setUpStore(t) + require.NoError(t, dst.RestoreFromSnapshot(data)) + + wantISA, err := src.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + gotISA, err := dst.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + + if diff := cmp.Diff(wantISA, gotISA, cmpopts.EquateApproxTime(0), cmp.AllowUnexported(models.Version{})); diff != "" { + t.Errorf("IdentificationServiceArea mismatch (-want +got):\n%s", diff) + } + + wantSub, err := src.GetSubscription(ctx, subscriptionsPool[0].input.ID) + require.NoError(t, err) + gotSub, err := dst.GetSubscription(ctx, subscriptionsPool[0].input.ID) + require.NoError(t, err) + + if diff := cmp.Diff(wantSub, gotSub, cmpopts.EquateApproxTime(0), cmp.AllowUnexported(models.Version{})); diff != "" { + t.Errorf("Subscription mismatch (-want +got):\n%s", diff) + } +} + +func TestRestoreFromSnapshotReplacesState(t *testing.T) { + ctx := context.Background() + src := setUpStore(t) + _, err := src.InsertISA(ctx, serviceArea) + require.NoError(t, err) + data, err := src.GetSnapshot() + require.NoError(t, err) + + dst := setUpStore(t) + other := *serviceArea + other.ID = "00000000-0000-4000-8000-000000000002" + _, err = dst.InsertISA(ctx, &other) + require.NoError(t, err) + require.NoError(t, dst.RestoreFromSnapshot(data)) + + count, err := dst.CountISAs(ctx) + require.NoError(t, err) + require.Equal(t, int64(1), count) + got, err := dst.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + require.NotNil(t, got) + gone, err := dst.GetISA(ctx, other.ID, false) + require.NoError(t, err) + require.Nil(t, gone) +} + +func TestRestoreFromSnapshotInvalidData(t *testing.T) { + require.Error(t, setUpStore(t).RestoreFromSnapshot([]byte("random value that is definitely not valid"))) +} + +func TestRestoreFromSnapshotVersionMismatch(t *testing.T) { + var buf bytes.Buffer + require.NoError(t, gob.NewEncoder(&buf).Encode(snapshotEnvelope{Version: snapshotVersion + 1})) + require.Error(t, setUpStore(t).RestoreFromSnapshot(buf.Bytes())) +} diff --git a/pkg/rid/store/memstore/store.go b/pkg/rid/store/memstore/store.go index b50c2b169..824ea4901 100644 --- a/pkg/rid/store/memstore/store.go +++ b/pkg/rid/store/memstore/store.go @@ -2,17 +2,137 @@ package memstore import ( "context" + "time" + "github.com/golang/geo/s2" + "github.com/interuss/dss/pkg/geo" "github.com/interuss/dss/pkg/memstore" + dssmodels "github.com/interuss/dss/pkg/models" "github.com/interuss/dss/pkg/rid/repos" + "github.com/interuss/stacktrace" + "github.com/jonboulle/clockwork" "go.uber.org/zap" ) // repo is a full implementation of rid.repos.Repository for memory-based storage. -type repo struct{} +type repo struct { + state state + clock clockwork.Clock +} + +// state is the serializable in-memory state. +type state struct { + // ISAs holds the stored ISAs keyed by ID. + ISAs map[dssmodels.ID]*isaRecord + // Subscriptions holds the stored subscriptions keyed by ID. + Subscriptions map[dssmodels.ID]*subscriptionRecord +} + +// isaRecord is the gob-serializable representation of an ISA. It intentionally +// stores only primitive fields: the model's Version is never persisted, it is +// derived from UpdatedAt on read. +type isaRecord struct { + ID dssmodels.ID + URL string + Owner dssmodels.Owner + Cells s2.CellUnion + StartTime *time.Time + EndTime *time.Time + AltitudeHi *float32 + AltitudeLo *float32 + Writer string + UpdatedAt time.Time +} + +// subscriptionRecord is the gob-serializable representation of a Subscription. +type subscriptionRecord struct { + ID dssmodels.ID + URL string + NotificationIndex int + Owner dssmodels.Owner + Cells s2.CellUnion + StartTime *time.Time + EndTime *time.Time + AltitudeHi *float32 + AltitudeLo *float32 + Writer string + UpdatedAt time.Time +} + +func newRepo() *repo { + r := &repo{clock: clockwork.NewRealClock()} + r.resetState() + return r +} + +func (r *repo) resetState() { + r.state = state{ + ISAs: map[dssmodels.ID]*isaRecord{}, + Subscriptions: map[dssmodels.ID]*subscriptionRecord{}, + } +} func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Repository], error) { - return memstore.Init(ctx, logger, "rid", &repo{}) + return memstore.Init(ctx, logger, "rid", newRepo()) } func (r *repo) GetRepo() repos.Repository { return r } + +// validateWriteData validate constraints on an ISA +func validateWriteData(cells s2.CellUnion, start, end *time.Time) error { + if len(cells) == 0 { + return stacktrace.NewError("At least one cell must be provided") + } + for _, c := range cells { + if err := geo.ValidateCell(c); err != nil { + return stacktrace.Propagate(err, "Error validating cell") + } + } + if start != nil && end != nil && !start.Before(*end) { + return stacktrace.NewError("Start time must be strictly before end time") + } + return nil +} + +// cellSet builds a lookup set from a cell union. +func cellSet(cells s2.CellUnion) map[s2.CellID]struct{} { + set := make(map[s2.CellID]struct{}, len(cells)) + for _, c := range cells { + set[c] = struct{}{} + } + return set +} + +// overlaps reports whether any cell is present in set (equivalent to the SQL +// "cells && $x" array-overlap operator). +func overlaps(cells s2.CellUnion, set map[s2.CellID]struct{}) bool { + for _, c := range cells { + if _, ok := set[c]; ok { + return true + } + } + return false +} + +func cloneCells(cells s2.CellUnion) s2.CellUnion { + if cells == nil { + return nil + } + return append(s2.CellUnion(nil), cells...) +} + +func cloneTime(t *time.Time) *time.Time { + if t == nil { + return nil + } + v := *t + return &v +} + +func cloneFloat32(f *float32) *float32 { + if f == nil { + return nil + } + v := *f + return &v +} diff --git a/pkg/rid/store/memstore/store_test.go b/pkg/rid/store/memstore/store_test.go new file mode 100644 index 000000000..05a78d0c6 --- /dev/null +++ b/pkg/rid/store/memstore/store_test.go @@ -0,0 +1,49 @@ +package memstore + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +var ( + fakeClock = clockwork.NewFakeClock() + startTime = fakeClock.Now().UTC().Add(-time.Minute) + endTime = fakeClock.Now().UTC().Add(time.Hour) + writer = "writer" +) + +// setUpStore returns a fresh in-memory repo whose clock is the (reset) package +// fakeClock, so tests can advance time deterministically. +func setUpStore(t *testing.T) *repo { + t.Helper() + fakeClock = clockwork.NewFakeClock() + r := newRepo() + r.clock = fakeClock + return r +} + +func TestDatabaseEnsuresBeginsBeforeExpires(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + var ( + begins = time.Now().UTC() + expires = begins.Add(-5 * time.Minute) + ) + _, err := repo.InsertSubscription(ctx, &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: "me-myself-and-i", + URL: "https://no/place/like/home", + NotificationIndex: 42, + StartTime: &begins, + EndTime: &expires, + }) + require.Error(t, err) +} diff --git a/pkg/rid/store/memstore/subscriptions.go b/pkg/rid/store/memstore/subscriptions.go index ac744ab10..addf17e9a 100644 --- a/pkg/rid/store/memstore/subscriptions.go +++ b/pkg/rid/store/memstore/subscriptions.go @@ -11,42 +11,203 @@ import ( "github.com/interuss/stacktrace" ) +func subRecordFromModel(s *ridmodels.Subscription, updatedAt time.Time) *subscriptionRecord { + return &subscriptionRecord{ + ID: s.ID, + URL: s.URL, + NotificationIndex: s.NotificationIndex, + Owner: s.Owner, + Cells: cloneCells(s.Cells), + StartTime: cloneTime(s.StartTime), + EndTime: cloneTime(s.EndTime), + AltitudeHi: cloneFloat32(s.AltitudeHi), + AltitudeLo: cloneFloat32(s.AltitudeLo), + Writer: s.Writer, + UpdatedAt: updatedAt, + } +} + +func (rec *subscriptionRecord) toModel() *ridmodels.Subscription { + return &ridmodels.Subscription{ + ID: rec.ID, + URL: rec.URL, + NotificationIndex: rec.NotificationIndex, + Owner: rec.Owner, + Cells: cloneCells(rec.Cells), + StartTime: cloneTime(rec.StartTime), + EndTime: cloneTime(rec.EndTime), + Version: dssmodels.VersionFromTime(rec.UpdatedAt), + AltitudeHi: cloneFloat32(rec.AltitudeHi), + AltitudeLo: cloneFloat32(rec.AltitudeLo), + Writer: rec.Writer, + } +} + func (r *repo) GetSubscription(_ context.Context, id dssmodels.ID) (*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "GetSubscription not implemented for memstore") + rec, ok := r.state.Subscriptions[id] + if !ok { + return nil, nil + } + return rec.toModel(), nil } -func (r *repo) DeleteSubscription(_ context.Context, sub *ridmodels.Subscription) (*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "DeleteSubscription not implemented for memstore") +func (r *repo) InsertSubscription(_ context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { + if err := validateWriteData(s.Cells, s.StartTime, s.EndTime); err != nil { + return nil, err + } + if _, ok := r.state.Subscriptions[s.ID]; ok { + return nil, stacktrace.NewError("Subscription with id %s already exists", s.ID) + } + rec := subRecordFromModel(s, r.clock.Now()) + r.state.Subscriptions[s.ID] = rec + return rec.toModel(), nil } -func (r *repo) InsertSubscription(_ context.Context, sub *ridmodels.Subscription) (*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "InsertSubscription not implemented for memstore") +func (r *repo) UpdateSubscription(_ context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { + if err := validateWriteData(s.Cells, s.StartTime, s.EndTime); err != nil { + return nil, err + } + prev, ok := r.state.Subscriptions[s.ID] + if !ok { + return nil, nil + } + if !dssmodels.VersionFromTime(prev.UpdatedAt).Matches(s.Version) { + return nil, nil + } + rec := subRecordFromModel(s, r.clock.Now()) + rec.Owner = prev.Owner + r.state.Subscriptions[s.ID] = rec + return rec.toModel(), nil } -func (r *repo) UpdateSubscription(_ context.Context, sub *ridmodels.Subscription) (*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpdateSubscription not implemented for memstore") +func (r *repo) DeleteSubscription(_ context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { + rec, ok := r.state.Subscriptions[s.ID] + if !ok { + return nil, nil + } + if !dssmodels.VersionFromTime(rec.UpdatedAt).Matches(s.Version) { + return nil, nil + } + out := rec.toModel() + delete(r.state.Subscriptions, s.ID) + return out, nil } func (r *repo) SearchSubscriptions(_ context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchSubscriptions not implemented for memstore") + if len(cells) == 0 { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") + } + now := r.clock.Now() + want := cellSet(cells) + var out []*ridmodels.Subscription + for _, rec := range r.state.Subscriptions { + if rec.EndTime == nil || rec.EndTime.Before(now) { + continue + } + if !overlaps(rec.Cells, want) { + continue + } + out = append(out, rec.toModel()) + + if len(out) > dssmodels.MaxResultLimit { // This miminc sqlstore behaviour, but it's not very good. + break + } + } + return out, nil } func (r *repo) SearchSubscriptionsByOwner(_ context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "SearchSubscriptionsByOwner not implemented for memstore") + if len(cells) == 0 { + return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") + } + now := r.clock.Now() + want := cellSet(cells) + var out []*ridmodels.Subscription + for _, rec := range r.state.Subscriptions { + if rec.Owner != owner { + continue + } + if rec.EndTime == nil || rec.EndTime.Before(now) { + continue + } + if !overlaps(rec.Cells, want) { + continue + } + out = append(out, rec.toModel()) + + if len(out) > dssmodels.MaxResultLimit { // This miminc sqlstore behaviour, but it's not very good. + break + } + } + return out, nil } +// UpdateNotificationIdxsInCells increments the notification index for each +// subscription in the given cells. func (r *repo) UpdateNotificationIdxsInCells(_ context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "UpdateNotificationIdxsInCells not implemented for memstore") + now := r.clock.Now() + want := cellSet(cells) + var out []*ridmodels.Subscription + for _, rec := range r.state.Subscriptions { + if rec.EndTime == nil || rec.EndTime.Before(now) { + continue + } + if !overlaps(rec.Cells, want) { + continue + } + rec.NotificationIndex++ + out = append(out, rec.toModel()) + } + return out, nil } func (r *repo) MaxSubscriptionCountInCellsByOwner(_ context.Context, cells s2.CellUnion, owner dssmodels.Owner) (int, error) { - return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "MaxSubscriptionCountInCellsByOwner not implemented for memstore") + now := r.clock.Now() + want := cellSet(cells) + counts := make(map[s2.CellID]int, len(cells)) + for _, rec := range r.state.Subscriptions { + if rec.Owner != owner { + continue + } + if rec.EndTime == nil || rec.EndTime.Before(now) { + continue + } + for _, c := range rec.Cells { + if _, ok := want[c]; ok { + counts[c]++ + } + } + } + best := 0 + for _, n := range counts { + if n > best { + best = n + } + } + return best, nil } func (r *repo) ListExpiredSubscriptions(_ context.Context, writer string, threshold time.Time) ([]*ridmodels.Subscription, error) { - return nil, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "ListExpiredSubscriptions not implemented for memstore") + var out []*ridmodels.Subscription + for _, rec := range r.state.Subscriptions { + // ends_at <= threshold + if rec.EndTime == nil || rec.EndTime.After(threshold) { + continue + } + if writer == "" { + if rec.Writer != "" { + continue + } + } else if rec.Writer != writer { + continue + } + out = append(out, rec.toModel()) + + // TODO: This miminc sqlstore inconsistency of not limiting results there, comparted to ISAs. Should it be normalized? + } + return out, nil } func (r *repo) CountSubscriptions(_ context.Context) (int64, error) { - return 0, stacktrace.NewErrorWithCode(dsserr.NotImplemented, "CountSubscriptions not implemented for memstore") + return int64(len(r.state.Subscriptions)), nil } diff --git a/pkg/rid/store/memstore/subscriptions_test.go b/pkg/rid/store/memstore/subscriptions_test.go new file mode 100644 index 000000000..6f7832379 --- /dev/null +++ b/pkg/rid/store/memstore/subscriptions_test.go @@ -0,0 +1,360 @@ +package memstore + +import ( + "context" + "testing" + "time" + + "github.com/golang/geo/s2" + "github.com/google/uuid" + dssmodels "github.com/interuss/dss/pkg/models" + ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/rid/repos" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +var ( + // Ensure the struct conforms to the interface + _ repos.Subscription = &repo{} + subscriptionsPool = []struct { + name string + input *ridmodels.Subscription + }{ + { + name: "a subscription with startTime and endTime", + input: &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: "myself", + URL: "https://no/place/like/home", + StartTime: &startTime, + EndTime: &endTime, + NotificationIndex: 42, + Writer: writer, + Cells: s2.CellUnion{ + s2.CellID(uint64(overflow)), + 12494535935418957824, + }, + }, + }, + { + name: "a subscription without startTime and with endTime", + input: &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: "myself", + URL: "https://no/place/like/home", + EndTime: &endTime, + NotificationIndex: 42, + Cells: s2.CellUnion{ + 12494535935418957824, + }, + }, + }, + { + name: "a subscription without startTime and with endTime", + input: &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: "me", + URL: "https://no/place/like/home", + StartTime: &startTime, + EndTime: &endTime, + NotificationIndex: 42, + Cells: s2.CellUnion{ + 12494535935418957824, + }, + }, + }, + } +) + +func TestStoreGetSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, r := range subscriptionsPool { + t.Run(r.name, func(t *testing.T) { + sub1, err := repo.InsertSubscription(ctx, r.input) + require.NoError(t, err) + require.NotNil(t, sub1) + + sub2, err := repo.GetSubscription(ctx, sub1.ID) + require.NoError(t, err) + require.NotNil(t, sub2) + + require.Equal(t, *sub1, *sub2) + }) + } +} + +func TestStoreInsertSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, r := range subscriptionsPool { + t.Run(r.name, func(t *testing.T) { + sub1, err := repo.InsertSubscription(ctx, r.input) + require.NoError(t, err) + require.NotNil(t, sub1) + + // Test changes without the version differing. + r2 := *sub1 + r2.URL = "new url" + sub2, err := repo.UpdateSubscription(ctx, &r2) + require.NoError(t, err) + require.NotNil(t, sub2) + require.Equal(t, "new url", sub2.URL) + + // Test it doesn't work when Version is nil. + r3 := *sub2 + r3.URL = "new url 2" + r3.Version = nil + sub3, err := repo.UpdateSubscription(ctx, &r3) + require.NoError(t, err) + require.Nil(t, sub3) + + // Bad version doesn't work. + r4 := *sub2 + r4.URL = "new url 3" + r4.Version = dssmodels.VersionFromTime(time.Now()) + sub4, err := repo.UpdateSubscription(ctx, &r4) + require.NoError(t, err) + require.Nil(t, sub4) + + sub5, err := repo.GetSubscription(ctx, sub1.ID) + require.NoError(t, err) + require.NotNil(t, sub5) + + require.Equal(t, *sub2, *sub5) + }) + } +} + +func TestStoreDeleteSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, r := range subscriptionsPool { + t.Run(r.name, func(t *testing.T) { + sub1, err := repo.InsertSubscription(ctx, r.input) + require.NoError(t, err) + require.NotNil(t, sub1) + + // Ensure mismatched versions returns nothing + sub1BadVersion := *sub1 + sub1BadVersion.Version, err = dssmodels.VersionFromString("a3cg3tcuhk00") + require.NoError(t, err) + sub2, err := repo.DeleteSubscription(ctx, &sub1BadVersion) + require.NoError(t, err) + require.Nil(t, sub2) + + sub4, err := repo.DeleteSubscription(ctx, sub1) + require.NoError(t, err) + require.NotNil(t, sub4) + + require.Equal(t, *sub1, *sub4) + }) + } +} + +func TestStoreSearchSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + var ( + // pick an L13 value that overflows. + overflow = uint64(17106221850767130624) + + cells = s2.CellUnion{ + s2.CellID(12494535935418957824), + s2.CellID(12494535866699481088), + s2.CellID(12494535901059219456), + s2.CellID(12494535866699481088), + s2.CellID(overflow), + } + owners = []dssmodels.Owner{ + "me", + "my", + "self", + } + ) + + for i, r := range subscriptionsPool { + subscription := *r.input + subscription.Owner = owners[i] + subscription.Cells = cells[:i+1] + sub1, err := repo.InsertSubscription(ctx, &subscription) + require.NoError(t, err) + require.NotNil(t, sub1) + } + // Test normal search + found, err := repo.SearchSubscriptions(ctx, cells) + require.NoError(t, err) + require.Len(t, found, 3) + for _, owner := range owners { + found, err := repo.SearchSubscriptionsByOwner(ctx, cells, owner) + require.NoError(t, err) + require.NotNil(t, found) + // We insert one subscription per owner. Hence, no matter how many cells are touched by the subscription, + // the result should always be 1. + require.Len(t, found, 1) + } +} + +func TestStoreExpiredSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + endTime := fakeClock.Now().Add(24 * time.Hour) + sub := &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: dssmodels.Owner("original owner"), + Cells: s2.CellUnion{s2.CellID(12494535866699481088)}, + EndTime: &endTime, + } + _, err := repo.InsertSubscription(ctx, sub) + require.NoError(t, err) + + // The subscription's endTime is 24 hours from now. + fakeClock.Advance(23 * time.Hour) + + // We should still be able to find the subscription by searching and by ID. + subs, err := repo.SearchSubscriptionsByOwner(ctx, sub.Cells, "original owner") + require.NoError(t, err) + require.Len(t, subs, 1) + + ret, err := repo.GetSubscription(ctx, sub.ID) + require.NoError(t, err) + require.NotNil(t, &ret) + + // But now the subscription has expired. + fakeClock.Advance(2 * time.Hour) + + subs, err = repo.SearchSubscriptionsByOwner(ctx, sub.Cells, "original owner") + require.NoError(t, err) + require.Len(t, subs, 0) + + ret, err = repo.GetSubscription(ctx, sub.ID) + require.NotNil(t, ret) + require.NoError(t, err) +} + +func TestStoreSubscriptionWithNoGeoData(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + endTime := fakeClock.Now().Add(24 * time.Hour) + sub := &ridmodels.Subscription{ + ID: dssmodels.ID(uuid.New().String()), + Owner: dssmodels.Owner("original owner"), + EndTime: &endTime, + } + _, err := repo.InsertSubscription(ctx, sub) + require.Error(t, err) +} + +func TestMaxSubscriptionCountInCellsByOwner(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, s := range subscriptionsPool { + _, err := repo.InsertSubscription(ctx, s.input) + require.NoError(t, err) + } + + count, err := repo.MaxSubscriptionCountInCellsByOwner(ctx, s2.CellUnion{12494535935418957824}, "myself") + require.NoError(t, err) + require.Equal(t, 2, count) +} + +func TestListExpiredSubscriptions(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // Insert Subscription with endtime 1 day from now + subscripiton1 := *subscriptionsPool[0].input + startTime := fakeClock.Now() + subscripiton1.StartTime = &startTime + endTime := fakeClock.Now().Add(24 * time.Hour) + subscripiton1.EndTime = &endTime + subOut1, err := repo.InsertSubscription(ctx, &subscripiton1) + require.NoError(t, err) + require.NotNil(t, subOut1) + + // Insert Subscription with endtime to 30 minutes ago + subscripiton2 := *subscriptionsPool[0].input + startTime = fakeClock.Now().Add(-1 * time.Hour) + subscripiton2.StartTime = &startTime + endTime = fakeClock.Now().Add(-30 * time.Minute) + subscripiton2.EndTime = &endTime + subscripiton2.ID = dssmodels.ID(uuid.New().String()) + subOut2, err := repo.InsertSubscription(ctx, &subscripiton2) + require.NoError(t, err) + require.NotNil(t, subOut2) + + subscriptions, err := repo.ListExpiredSubscriptions(ctx, writer, fakeClock.Now().Add(-30*time.Minute)) + require.NoError(t, err) + require.Len(t, subscriptions, 1) +} + +func TestListExpiredSubscriptionsWithEmptyWriter(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + fakeClock := clockwork.NewFakeClockAt(time.Now()) + + // Insert Subscription with endtime 1 day from now + subscripiton1 := *subscriptionsPool[0].input + startTime := fakeClock.Now() + subscripiton1.StartTime = &startTime + endTime := fakeClock.Now().Add(24 * time.Hour) + subscripiton1.EndTime = &endTime + subscripiton1.Writer = "" + subOut1, err := repo.InsertSubscription(ctx, &subscripiton1) + require.NoError(t, err) + require.NotNil(t, subOut1) + + // Insert Subscription with endtime to 30 minutes ago + subscripiton2 := *subscriptionsPool[0].input + startTime = fakeClock.Now().Add(-1 * time.Hour) + subscripiton2.StartTime = &startTime + endTime = fakeClock.Now().Add(-30 * time.Minute) + subscripiton2.EndTime = &endTime + subscripiton2.ID = dssmodels.ID(uuid.New().String()) + subscripiton2.Writer = "" + subOut2, err := repo.InsertSubscription(ctx, &subscripiton2) + require.NoError(t, err) + require.NotNil(t, subOut2) + + subscriptions, err := repo.ListExpiredSubscriptions(ctx, "", fakeClock.Now().Add(-30*time.Minute)) + require.NoError(t, err) + require.Len(t, subscriptions, 1) +} + +func TestStoreCountSubscription(t *testing.T) { + ctx := context.Background() + repo := setUpStore(t) + + for _, r := range subscriptionsPool { + t.Run(r.name, func(t *testing.T) { + sub1, err := repo.InsertSubscription(ctx, r.input) + require.NoError(t, err) + require.NotNil(t, sub1) + + //Cound should be one + count, err := repo.CountSubscriptions(ctx) + require.NoError(t, err) + require.Equal(t, count, int64(1)) + + sub4, err := repo.DeleteSubscription(ctx, sub1) + require.NoError(t, err) + require.NotNil(t, sub4) + + //Cound should be zero + count, err = repo.CountSubscriptions(ctx) + require.NoError(t, err) + require.Equal(t, count, int64(0)) + }) + } +} From 0269e1862ece9b826fb683d851b633c95ac03dab Mon Sep 17 00:00:00 2001 From: Maximilien Cuony Date: Wed, 1 Jul 2026 11:19:59 +0200 Subject: [PATCH 4/5] [raft/memstore] Use now from context --- pkg/aux_/store/memstore/dss.go | 21 +++++-- pkg/aux_/store/memstore/dss_test.go | 11 ++++ pkg/aux_/store/memstore/snapshot_test.go | 10 +++- .../memstore/identification_service_area.go | 23 ++++++-- .../identification_service_area_test.go | 21 ++++--- pkg/rid/store/memstore/snapshot_test.go | 3 + pkg/rid/store/memstore/store.go | 4 +- pkg/rid/store/memstore/store_test.go | 2 - pkg/rid/store/memstore/subscriptions.go | 58 +++++++++++++++---- pkg/rid/store/memstore/subscriptions_test.go | 17 ++++-- 10 files changed, 130 insertions(+), 40 deletions(-) diff --git a/pkg/aux_/store/memstore/dss.go b/pkg/aux_/store/memstore/dss.go index 79aeeee3a..dc86c8e90 100644 --- a/pkg/aux_/store/memstore/dss.go +++ b/pkg/aux_/store/memstore/dss.go @@ -3,14 +3,14 @@ package memstore import ( "context" "database/sql" - "time" auxmodels "github.com/interuss/dss/pkg/aux_/models" dsserr "github.com/interuss/dss/pkg/errors" + "github.com/interuss/dss/pkg/timestamp" "github.com/interuss/stacktrace" ) -func (r *repo) SaveOwnMetadata(_ context.Context, locality string, publicEndpoint string) error { +func (r *repo) SaveOwnMetadata(ctx context.Context, locality string, publicEndpoint string) error { if locality == "" { return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Locality not set") } @@ -18,9 +18,15 @@ func (r *repo) SaveOwnMetadata(_ context.Context, locality string, publicEndpoin return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Public endpoint not set") } + now, err := timestamp.RequestTimestampFromContext(ctx) + + if err != nil { + return err + } + r.state.Participants[locality] = &participant{ PublicEndpoint: publicEndpoint, - UpdatedAt: time.Now().UTC(), + UpdatedAt: now, } return nil } @@ -60,7 +66,7 @@ func (r *repo) GetDSSMetadata(_ context.Context) ([]*auxmodels.DSSMetadata, erro return metadata, nil } -func (r *repo) RecordHeartbeat(_ context.Context, heartbeat auxmodels.Heartbeat) error { +func (r *repo) RecordHeartbeat(ctx context.Context, heartbeat auxmodels.Heartbeat) error { if heartbeat.Locality == "" { return stacktrace.NewErrorWithCode(dsserr.BadRequest, "Locality not set") } @@ -69,7 +75,12 @@ func (r *repo) RecordHeartbeat(_ context.Context, heartbeat auxmodels.Heartbeat) } if heartbeat.Timestamp == nil { - now := time.Now().UTC() + + now, err := timestamp.RequestTimestampFromContext(ctx) + + if err != nil { + return err + } heartbeat.Timestamp = &now } diff --git a/pkg/aux_/store/memstore/dss_test.go b/pkg/aux_/store/memstore/dss_test.go index 43563d27b..332c2b409 100644 --- a/pkg/aux_/store/memstore/dss_test.go +++ b/pkg/aux_/store/memstore/dss_test.go @@ -7,12 +7,17 @@ import ( auxmodels "github.com/interuss/dss/pkg/aux_/models" dsserr "github.com/interuss/dss/pkg/errors" + "github.com/interuss/dss/pkg/timestamp" "github.com/interuss/stacktrace" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" ) +var fakeClock = clockwork.NewFakeClock() + func TestSaveOwnMetadataValidation(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) r := newRepo() require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.SaveOwnMetadata(ctx, "", "https://example.com"))) @@ -21,6 +26,7 @@ func TestSaveOwnMetadataValidation(t *testing.T) { func TestSaveOwnMetadataRoundTrip(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) r := newRepo() require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) @@ -40,6 +46,7 @@ func TestSaveOwnMetadataRoundTrip(t *testing.T) { func TestSaveOwnMetadataUpsert(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) r := newRepo() require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://old.example.com")) @@ -54,6 +61,7 @@ func TestSaveOwnMetadataUpsert(t *testing.T) { func TestRecordHeartbeatValidation(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) r := newRepo() require.Equal(t, dsserr.BadRequest, stacktrace.GetCode(r.RecordHeartbeat(ctx, auxmodels.Heartbeat{Source: "source1"}))) @@ -73,6 +81,7 @@ func TestRecordHeartbeatValidation(t *testing.T) { func TestRecordHeartbeatDefaultsTimestamp(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) r := newRepo() require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) @@ -88,6 +97,7 @@ func TestRecordHeartbeatDefaultsTimestamp(t *testing.T) { func TestGetDSSMetadataPicksLatestHeartbeat(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) r := newRepo() require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) @@ -108,6 +118,7 @@ func TestGetDSSMetadataPicksLatestHeartbeat(t *testing.T) { func TestGetDSSMetadataUpdatesHeartbeatPerSource(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) r := newRepo() require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) diff --git a/pkg/aux_/store/memstore/snapshot_test.go b/pkg/aux_/store/memstore/snapshot_test.go index 2085e2488..cd2f03a9f 100644 --- a/pkg/aux_/store/memstore/snapshot_test.go +++ b/pkg/aux_/store/memstore/snapshot_test.go @@ -7,12 +7,17 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" auxmodels "github.com/interuss/dss/pkg/aux_/models" + "github.com/interuss/dss/pkg/models" + "github.com/interuss/dss/pkg/timestamp" "github.com/stretchr/testify/require" ) func TestSnapshotRoundTrip(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) src := newRepo() require.NoError(t, src.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) ts := time.Now().UTC() @@ -28,11 +33,14 @@ func TestSnapshotRoundTrip(t *testing.T) { require.NoError(t, err) got, err := dst.GetDSSMetadata(ctx) require.NoError(t, err) - require.Equal(t, want, got) + if diff := cmp.Diff(want, got, cmpopts.EquateApproxTime(0), cmp.AllowUnexported(models.Version{})); diff != "" { + t.Errorf("Store mismatch (-want +got):\n%s", diff) + } } func TestRestoreFromSnapshotReplacesState(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) src := newRepo() require.NoError(t, src.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) data, err := src.GetSnapshot() diff --git a/pkg/rid/store/memstore/identification_service_area.go b/pkg/rid/store/memstore/identification_service_area.go index dde19d9ff..4dc666fb1 100644 --- a/pkg/rid/store/memstore/identification_service_area.go +++ b/pkg/rid/store/memstore/identification_service_area.go @@ -8,6 +8,7 @@ import ( dsserr "github.com/interuss/dss/pkg/errors" dssmodels "github.com/interuss/dss/pkg/models" ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/timestamp" "github.com/interuss/stacktrace" ) @@ -50,19 +51,26 @@ func (r *repo) GetISA(_ context.Context, id dssmodels.ID, _ bool) (*ridmodels.Id return rec.toModel(), nil } -func (r *repo) InsertISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { +func (r *repo) InsertISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { if err := validateWriteData(isa.Cells, isa.StartTime, isa.EndTime); err != nil { return nil, err } if _, ok := r.state.ISAs[isa.ID]; ok { return nil, stacktrace.NewError("ISA with id %s already exists", isa.ID) } - rec := isaRecordFromModel(isa, r.clock.Now()) + + now, err := timestamp.RequestTimestampFromContext(ctx) + + if err != nil { + return nil, err + } + + rec := isaRecordFromModel(isa, now) r.state.ISAs[isa.ID] = rec return rec.toModel(), nil } -func (r *repo) UpdateISA(_ context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { +func (r *repo) UpdateISA(ctx context.Context, isa *ridmodels.IdentificationServiceArea) (*ridmodels.IdentificationServiceArea, error) { if err := validateWriteData(isa.Cells, isa.StartTime, isa.EndTime); err != nil { return nil, err } @@ -73,7 +81,14 @@ func (r *repo) UpdateISA(_ context.Context, isa *ridmodels.IdentificationService if !dssmodels.VersionFromTime(prev.UpdatedAt).Matches(isa.Version) { return nil, nil } - rec := isaRecordFromModel(isa, r.clock.Now()) + + now, err := timestamp.RequestTimestampFromContext(ctx) + + if err != nil { + return nil, err + } + + rec := isaRecordFromModel(isa, now) rec.Owner = prev.Owner r.state.ISAs[isa.ID] = rec return rec.toModel(), nil diff --git a/pkg/rid/store/memstore/identification_service_area_test.go b/pkg/rid/store/memstore/identification_service_area_test.go index 70699e415..dc27053a7 100644 --- a/pkg/rid/store/memstore/identification_service_area_test.go +++ b/pkg/rid/store/memstore/identification_service_area_test.go @@ -10,7 +10,7 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" ridmodels "github.com/interuss/dss/pkg/rid/models" "github.com/interuss/dss/pkg/rid/repos" - "github.com/jonboulle/clockwork" + "github.com/interuss/dss/pkg/timestamp" "github.com/stretchr/testify/require" ) @@ -34,6 +34,7 @@ var ( func TestStoreSearchISAs(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) cells := s2.CellUnion{ s2.CellID(17106221850767130624), s2.CellID(17106221885126868992), @@ -136,6 +137,7 @@ func TestStoreSearchISAs(t *testing.T) { func TestBadVersion(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) saOut1, err := repo.InsertISA(ctx, serviceArea) @@ -157,6 +159,7 @@ func TestBadVersion(t *testing.T) { func TestStoreExpiredISA(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) saOut, err := repo.InsertISA(ctx, serviceArea) @@ -164,10 +167,10 @@ func TestStoreExpiredISA(t *testing.T) { require.NotNil(t, saOut) // The ISA's endTime is one hour from now. - fakeClock.Advance(59 * time.Minute) + now := fakeClock.Now() + now = now.Add(59 * time.Minute) // We should still be able to find the ISA by searching and by ID. - now := fakeClock.Now() serviceAreas, err := repo.SearchISAs(ctx, serviceArea.Cells, &now, nil) require.NoError(t, err) require.Len(t, serviceAreas, 1) @@ -177,8 +180,7 @@ func TestStoreExpiredISA(t *testing.T) { require.NotNil(t, ret) // But now the ISA has expired. - fakeClock.Advance(2 * time.Minute) - now = fakeClock.Now() + now = now.Add(2 * time.Minute) serviceAreas, err = repo.SearchISAs(ctx, serviceArea.Cells, &now, nil) require.NoError(t, err) @@ -192,6 +194,7 @@ func TestStoreExpiredISA(t *testing.T) { func TestStoreDeleteISAs(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) // Insert the ISA. @@ -212,6 +215,7 @@ func TestStoreDeleteISAs(t *testing.T) { func TestStoreISAWithNoGeoData(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) endTime := fakeClock.Now().Add(24 * time.Hour) @@ -226,10 +230,9 @@ func TestStoreISAWithNoGeoData(t *testing.T) { func TestListExpiredISAs(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) - fakeClock := clockwork.NewFakeClockAt(time.Now()) - // Insert ISA with endtime 1 day from now isa1 := *serviceArea startTime := fakeClock.Now() @@ -258,10 +261,9 @@ func TestListExpiredISAs(t *testing.T) { func TestListExpiredISAsWithEmptyWriter(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) - fakeClock := clockwork.NewFakeClockAt(time.Now()) - // Insert ISA with endtime 1 day from now isa1 := *serviceArea startTime := fakeClock.Now() @@ -292,6 +294,7 @@ func TestListExpiredISAsWithEmptyWriter(t *testing.T) { func TestStoreCountISAs(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) // Insert the ISA. diff --git a/pkg/rid/store/memstore/snapshot_test.go b/pkg/rid/store/memstore/snapshot_test.go index a9764f812..a795d5f5f 100644 --- a/pkg/rid/store/memstore/snapshot_test.go +++ b/pkg/rid/store/memstore/snapshot_test.go @@ -9,11 +9,13 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/interuss/dss/pkg/models" + "github.com/interuss/dss/pkg/timestamp" "github.com/stretchr/testify/require" ) func TestSnapshotRoundTrip(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) src := setUpStore(t) _, err := src.InsertISA(ctx, serviceArea) require.NoError(t, err) @@ -47,6 +49,7 @@ func TestSnapshotRoundTrip(t *testing.T) { func TestRestoreFromSnapshotReplacesState(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) src := setUpStore(t) _, err := src.InsertISA(ctx, serviceArea) require.NoError(t, err) diff --git a/pkg/rid/store/memstore/store.go b/pkg/rid/store/memstore/store.go index 824ea4901..7cd02e161 100644 --- a/pkg/rid/store/memstore/store.go +++ b/pkg/rid/store/memstore/store.go @@ -10,14 +10,12 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" "github.com/interuss/dss/pkg/rid/repos" "github.com/interuss/stacktrace" - "github.com/jonboulle/clockwork" "go.uber.org/zap" ) // repo is a full implementation of rid.repos.Repository for memory-based storage. type repo struct { state state - clock clockwork.Clock } // state is the serializable in-memory state. @@ -60,7 +58,7 @@ type subscriptionRecord struct { } func newRepo() *repo { - r := &repo{clock: clockwork.NewRealClock()} + r := &repo{} r.resetState() return r } diff --git a/pkg/rid/store/memstore/store_test.go b/pkg/rid/store/memstore/store_test.go index 05a78d0c6..3bf484463 100644 --- a/pkg/rid/store/memstore/store_test.go +++ b/pkg/rid/store/memstore/store_test.go @@ -23,9 +23,7 @@ var ( // fakeClock, so tests can advance time deterministically. func setUpStore(t *testing.T) *repo { t.Helper() - fakeClock = clockwork.NewFakeClock() r := newRepo() - r.clock = fakeClock return r } diff --git a/pkg/rid/store/memstore/subscriptions.go b/pkg/rid/store/memstore/subscriptions.go index addf17e9a..9037b0f38 100644 --- a/pkg/rid/store/memstore/subscriptions.go +++ b/pkg/rid/store/memstore/subscriptions.go @@ -8,6 +8,7 @@ import ( dsserr "github.com/interuss/dss/pkg/errors" dssmodels "github.com/interuss/dss/pkg/models" ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/timestamp" "github.com/interuss/stacktrace" ) @@ -51,19 +52,26 @@ func (r *repo) GetSubscription(_ context.Context, id dssmodels.ID) (*ridmodels.S return rec.toModel(), nil } -func (r *repo) InsertSubscription(_ context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { +func (r *repo) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { if err := validateWriteData(s.Cells, s.StartTime, s.EndTime); err != nil { return nil, err } if _, ok := r.state.Subscriptions[s.ID]; ok { return nil, stacktrace.NewError("Subscription with id %s already exists", s.ID) } - rec := subRecordFromModel(s, r.clock.Now()) + + now, err := timestamp.RequestTimestampFromContext(ctx) + + if err != nil { + return nil, err + } + + rec := subRecordFromModel(s, now) r.state.Subscriptions[s.ID] = rec return rec.toModel(), nil } -func (r *repo) UpdateSubscription(_ context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { +func (r *repo) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) { if err := validateWriteData(s.Cells, s.StartTime, s.EndTime); err != nil { return nil, err } @@ -74,7 +82,14 @@ func (r *repo) UpdateSubscription(_ context.Context, s *ridmodels.Subscription) if !dssmodels.VersionFromTime(prev.UpdatedAt).Matches(s.Version) { return nil, nil } - rec := subRecordFromModel(s, r.clock.Now()) + + now, err := timestamp.RequestTimestampFromContext(ctx) + + if err != nil { + return nil, err + } + + rec := subRecordFromModel(s, now) rec.Owner = prev.Owner r.state.Subscriptions[s.ID] = rec return rec.toModel(), nil @@ -93,11 +108,16 @@ func (r *repo) DeleteSubscription(_ context.Context, s *ridmodels.Subscription) return out, nil } -func (r *repo) SearchSubscriptions(_ context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { +func (r *repo) SearchSubscriptions(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { if len(cells) == 0 { return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") } - now := r.clock.Now() + now, err := timestamp.RequestTimestampFromContext(ctx) + + if err != nil { + return nil, err + } + want := cellSet(cells) var out []*ridmodels.Subscription for _, rec := range r.state.Subscriptions { @@ -116,11 +136,16 @@ func (r *repo) SearchSubscriptions(_ context.Context, cells s2.CellUnion) ([]*ri return out, nil } -func (r *repo) SearchSubscriptionsByOwner(_ context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) { +func (r *repo) SearchSubscriptionsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) { if len(cells) == 0 { return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") } - now := r.clock.Now() + now, err := timestamp.RequestTimestampFromContext(ctx) + + if err != nil { + return nil, err + } + want := cellSet(cells) var out []*ridmodels.Subscription for _, rec := range r.state.Subscriptions { @@ -144,8 +169,12 @@ func (r *repo) SearchSubscriptionsByOwner(_ context.Context, cells s2.CellUnion, // UpdateNotificationIdxsInCells increments the notification index for each // subscription in the given cells. -func (r *repo) UpdateNotificationIdxsInCells(_ context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { - now := r.clock.Now() +func (r *repo) UpdateNotificationIdxsInCells(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) { + now, err := timestamp.RequestTimestampFromContext(ctx) + + if err != nil { + return nil, err + } want := cellSet(cells) var out []*ridmodels.Subscription for _, rec := range r.state.Subscriptions { @@ -161,8 +190,13 @@ func (r *repo) UpdateNotificationIdxsInCells(_ context.Context, cells s2.CellUni return out, nil } -func (r *repo) MaxSubscriptionCountInCellsByOwner(_ context.Context, cells s2.CellUnion, owner dssmodels.Owner) (int, error) { - now := r.clock.Now() +func (r *repo) MaxSubscriptionCountInCellsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) (int, error) { + now, err := timestamp.RequestTimestampFromContext(ctx) + + if err != nil { + return 0, err + } + want := cellSet(cells) counts := make(map[s2.CellID]int, len(cells)) for _, rec := range r.state.Subscriptions { diff --git a/pkg/rid/store/memstore/subscriptions_test.go b/pkg/rid/store/memstore/subscriptions_test.go index 6f7832379..47797f10d 100644 --- a/pkg/rid/store/memstore/subscriptions_test.go +++ b/pkg/rid/store/memstore/subscriptions_test.go @@ -10,6 +10,7 @@ import ( dssmodels "github.com/interuss/dss/pkg/models" ridmodels "github.com/interuss/dss/pkg/rid/models" "github.com/interuss/dss/pkg/rid/repos" + "github.com/interuss/dss/pkg/timestamp" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" ) @@ -69,6 +70,7 @@ var ( func TestStoreGetSubscription(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) for _, r := range subscriptionsPool { @@ -88,6 +90,7 @@ func TestStoreGetSubscription(t *testing.T) { func TestStoreInsertSubscription(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) for _, r := range subscriptionsPool { @@ -131,6 +134,7 @@ func TestStoreInsertSubscription(t *testing.T) { func TestStoreDeleteSubscription(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) for _, r := range subscriptionsPool { @@ -158,6 +162,7 @@ func TestStoreDeleteSubscription(t *testing.T) { func TestStoreSearchSubscription(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now().UTC()) repo := setUpStore(t) var ( @@ -202,6 +207,7 @@ func TestStoreSearchSubscription(t *testing.T) { func TestStoreExpiredSubscription(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) endTime := fakeClock.Now().Add(24 * time.Hour) @@ -215,7 +221,7 @@ func TestStoreExpiredSubscription(t *testing.T) { require.NoError(t, err) // The subscription's endTime is 24 hours from now. - fakeClock.Advance(23 * time.Hour) + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now().Add(23*time.Hour)) // We should still be able to find the subscription by searching and by ID. subs, err := repo.SearchSubscriptionsByOwner(ctx, sub.Cells, "original owner") @@ -227,7 +233,7 @@ func TestStoreExpiredSubscription(t *testing.T) { require.NotNil(t, &ret) // But now the subscription has expired. - fakeClock.Advance(2 * time.Hour) + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now().Add(25*time.Hour)) subs, err = repo.SearchSubscriptionsByOwner(ctx, sub.Cells, "original owner") require.NoError(t, err) @@ -240,6 +246,7 @@ func TestStoreExpiredSubscription(t *testing.T) { func TestStoreSubscriptionWithNoGeoData(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) endTime := fakeClock.Now().Add(24 * time.Hour) @@ -254,6 +261,7 @@ func TestStoreSubscriptionWithNoGeoData(t *testing.T) { func TestMaxSubscriptionCountInCellsByOwner(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) for _, s := range subscriptionsPool { @@ -268,6 +276,7 @@ func TestMaxSubscriptionCountInCellsByOwner(t *testing.T) { func TestListExpiredSubscriptions(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) fakeClock := clockwork.NewFakeClockAt(time.Now()) @@ -300,10 +309,9 @@ func TestListExpiredSubscriptions(t *testing.T) { func TestListExpiredSubscriptionsWithEmptyWriter(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) - fakeClock := clockwork.NewFakeClockAt(time.Now()) - // Insert Subscription with endtime 1 day from now subscripiton1 := *subscriptionsPool[0].input startTime := fakeClock.Now() @@ -334,6 +342,7 @@ func TestListExpiredSubscriptionsWithEmptyWriter(t *testing.T) { func TestStoreCountSubscription(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) for _, r := range subscriptionsPool { From 8a60b2001e3ef536531aee8c3c0ee5173b0e1033 Mon Sep 17 00:00:00 2001 From: Maximilien Cuony Date: Wed, 1 Jul 2026 11:22:56 +0200 Subject: [PATCH 5/5] [raft/memstore] Add Checkpoint/Restore --- pkg/aux_/store/memstore/store.go | 31 +++++++++++++++ pkg/aux_/store/memstore/store_test.go | 55 +++++++++++++++++++++++++++ pkg/memstore/store.go | 12 ++++++ pkg/rid/store/memstore/store.go | 34 +++++++++++++++++ pkg/rid/store/memstore/store_test.go | 54 ++++++++++++++++++++++++++ pkg/scd/store/memstore/store.go | 9 +++++ 6 files changed, 195 insertions(+) create mode 100644 pkg/aux_/store/memstore/store_test.go diff --git a/pkg/aux_/store/memstore/store.go b/pkg/aux_/store/memstore/store.go index 52ed59d12..bdb8d35d3 100644 --- a/pkg/aux_/store/memstore/store.go +++ b/pkg/aux_/store/memstore/store.go @@ -7,6 +7,7 @@ import ( auxmodels "github.com/interuss/dss/pkg/aux_/models" "github.com/interuss/dss/pkg/aux_/repos" "github.com/interuss/dss/pkg/memstore" + "github.com/interuss/stacktrace" "go.uber.org/zap" ) @@ -50,3 +51,33 @@ func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Reposi } func (r *repo) GetRepo() repos.Repository { return r } + +// clone returns a copy of s with independent maps and participant records. +func (s state) clone() state { + ps := make(map[string]*participant, len(s.Participants)) + for k, v := range s.Participants { + cp := *v + ps[k] = &cp + } + hb := make(map[heartbeatKey]auxmodels.Heartbeat, len(s.Heartbeats)) + for k, v := range s.Heartbeats { + hb[k] = v + } + return state{Participants: ps, Heartbeats: hb} +} + +// Checkpoint returns a fast, restorable in-memory copy of the current state. +func (r *repo) Checkpoint() any { + return r.state.clone() +} + +// Restore replaces the current state with a checkpoint previously returned by +// Checkpoint. The checkpoint is copied, so it stays reusable. +func (r *repo) Restore(cp any) error { + s, ok := cp.(state) + if !ok { + return stacktrace.NewError("Invalid checkpoint type %T", cp) + } + r.state = s.clone() + return nil +} diff --git a/pkg/aux_/store/memstore/store_test.go b/pkg/aux_/store/memstore/store_test.go new file mode 100644 index 000000000..c0f41b158 --- /dev/null +++ b/pkg/aux_/store/memstore/store_test.go @@ -0,0 +1,55 @@ +package memstore + +import ( + "context" + "testing" + + "github.com/interuss/dss/pkg/timestamp" + "github.com/stretchr/testify/require" +) + +func TestCheckpointRestore(t *testing.T) { + ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) + + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://example.com")) + + cp := r.Checkpoint() + + // Mutate after the checkpoint. + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-2", "https://other.example.com")) + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + require.Len(t, md, 2) + + // Restore drops dss-2. + require.NoError(t, r.Restore(cp)) + md, err = r.GetDSSMetadata(ctx) + require.NoError(t, err) + require.Len(t, md, 1) + require.Equal(t, "dss-1", md[0].Locality) +} + +func TestCheckpointIsolatesUpsert(t *testing.T) { + ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) + r := newRepo() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://old.example.com")) + + cp := r.Checkpoint() + + require.NoError(t, r.SaveOwnMetadata(ctx, "dss-1", "https://new.example.com")) + + require.NoError(t, r.Restore(cp)) + md, err := r.GetDSSMetadata(ctx) + require.NoError(t, err) + require.Len(t, md, 1) + require.Equal(t, "https://old.example.com", md[0].PublicEndpoint) +} + +func TestRestoreInvalidType(t *testing.T) { + require.Error(t, newRepo().Restore("not a checkpoint")) +} diff --git a/pkg/memstore/store.go b/pkg/memstore/store.go index 2fe7a12e6..20ac4bfe5 100644 --- a/pkg/memstore/store.go +++ b/pkg/memstore/store.go @@ -20,6 +20,8 @@ type MemRepo[R any] interface { GetRepo() R GetSnapshot() ([]byte, error) RestoreFromSnapshot([]byte) error + Checkpoint() any + Restore(any) error } type Store[R any] struct { @@ -60,6 +62,16 @@ func (s *Store[R]) Interact(_ context.Context) (R, error) { return s.memRepo.GetRepo(), nil } +// Checkpoint returns a fast, restorable in-memory copy of the current state. +func (s *Store[R]) Checkpoint() any { + return s.memRepo.Checkpoint() +} + +// Restore replaces the current state with a checkpoint returned by Checkpoint. +func (s *Store[R]) Restore(cp any) error { + return s.memRepo.Restore(cp) +} + func (s *Store[R]) Close() error { return nil } diff --git a/pkg/rid/store/memstore/store.go b/pkg/rid/store/memstore/store.go index 7cd02e161..cefd0bb4c 100644 --- a/pkg/rid/store/memstore/store.go +++ b/pkg/rid/store/memstore/store.go @@ -134,3 +134,37 @@ func cloneFloat32(f *float32) *float32 { v := *f return &v } + +// clone returns a copy of s with independent maps and records. Cell slices and +// time pointers are shared, as they are never mutated in place. +func (s state) clone() state { + isas := make(map[dssmodels.ID]*isaRecord, len(s.ISAs)) + for id, rec := range s.ISAs { + cp := *rec + isas[id] = &cp + } + subs := make(map[dssmodels.ID]*subscriptionRecord, len(s.Subscriptions)) + for id, rec := range s.Subscriptions { + cp := *rec + subs[id] = &cp + } + return state{ISAs: isas, Subscriptions: subs} +} + +// Checkpoint returns a fast, restorable in-memory copy of the current state. +// Unlike GetSnapshot it does not serialize, so it is cheap but only valid +// in-process. +func (r *repo) Checkpoint() any { + return r.state.clone() +} + +// Restore replaces the current state with a checkpoint previously returned by +// Checkpoint. The checkpoint is copied, so it stays reusable. +func (r *repo) Restore(cp any) error { + s, ok := cp.(state) + if !ok { + return stacktrace.NewError("Invalid checkpoint type %T", cp) + } + r.state = s.clone() + return nil +} diff --git a/pkg/rid/store/memstore/store_test.go b/pkg/rid/store/memstore/store_test.go index 3bf484463..8a798b5d8 100644 --- a/pkg/rid/store/memstore/store_test.go +++ b/pkg/rid/store/memstore/store_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" dssmodels "github.com/interuss/dss/pkg/models" ridmodels "github.com/interuss/dss/pkg/rid/models" + "github.com/interuss/dss/pkg/timestamp" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" ) @@ -29,6 +30,7 @@ func setUpStore(t *testing.T) *repo { func TestDatabaseEnsuresBeginsBeforeExpires(t *testing.T) { ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) repo := setUpStore(t) var ( @@ -45,3 +47,55 @@ func TestDatabaseEnsuresBeginsBeforeExpires(t *testing.T) { }) require.Error(t, err) } + +func TestCheckpointRestoreISA(t *testing.T) { + ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) + repo := setUpStore(t) + + _, err := repo.InsertISA(ctx, serviceArea) + require.NoError(t, err) + + cp := repo.Checkpoint() + + // Mutate after the checkpoint. + isa, err := repo.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + _, err = repo.DeleteISA(ctx, isa) + require.NoError(t, err) + gone, err := repo.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + require.Nil(t, gone) + + // Restore brings it back. + require.NoError(t, repo.Restore(cp)) + back, err := repo.GetISA(ctx, serviceArea.ID, false) + require.NoError(t, err) + require.NotNil(t, back) +} + +func TestCheckpointIsolatesNotificationIndex(t *testing.T) { + ctx := context.Background() + ctx = timestamp.WithRequestTimestamp(ctx, fakeClock.Now()) + repo := setUpStore(t) + + sub, err := repo.InsertSubscription(ctx, subscriptionsPool[0].input) + require.NoError(t, err) + + cp := repo.Checkpoint() + + // In-place notification-index bump must not leak into the checkpoint. + updated, err := repo.UpdateNotificationIdxsInCells(ctx, sub.Cells) + require.NoError(t, err) + require.Len(t, updated, 1) + require.Equal(t, sub.NotificationIndex+1, updated[0].NotificationIndex) + + require.NoError(t, repo.Restore(cp)) + restored, err := repo.GetSubscription(ctx, sub.ID) + require.NoError(t, err) + require.Equal(t, sub.NotificationIndex, restored.NotificationIndex) +} + +func TestRestoreInvalidType(t *testing.T) { + require.Error(t, setUpStore(t).Restore("not a checkpoint")) +} diff --git a/pkg/scd/store/memstore/store.go b/pkg/scd/store/memstore/store.go index 45365614a..05d670cfd 100644 --- a/pkg/scd/store/memstore/store.go +++ b/pkg/scd/store/memstore/store.go @@ -5,6 +5,7 @@ import ( "github.com/interuss/dss/pkg/memstore" "github.com/interuss/dss/pkg/scd/repos" + "github.com/interuss/stacktrace" "go.uber.org/zap" ) @@ -16,3 +17,11 @@ func Init(ctx context.Context, logger *zap.Logger) (*memstore.Store[repos.Reposi } func (r *repo) GetRepo() repos.Repository { return r } + +func (r *repo) Checkpoint() any { + return nil +} + +func (r *repo) Restore(any) error { + return stacktrace.NewError("Restore not yet implemented for scd") +}