diff --git a/chart/templates/deployment.yaml b/chart/templates/deployment.yaml index 863d4ab..525a633 100644 --- a/chart/templates/deployment.yaml +++ b/chart/templates/deployment.yaml @@ -56,6 +56,11 @@ spec: - "-discovery-period-seconds={{ .Values.discoveryPeriodSeconds }}" - "-auth-enabled={{ .Values.auth.enabled }}" - "-auth-cookie-name={{ .Values.auth.cookieName }}" + - "-auth-cache-enabled={{ .Values.auth.cache.enabled }}" + - "-auth-cache-ttl-seconds={{ .Values.auth.cache.ttlSeconds }}" + - "-auth-cache-capacity={{ .Values.auth.cache.capacity }}" + - "-auth-cache-max-concurrent-backend-requests={{ .Values.auth.cache.maxConcurrentBackendRequests }}" + - "-auth-cache-refresh-before-seconds={{ .Values.auth.cache.refreshBeforeSeconds }}" ports: - containerPort: 9090 name: grpc diff --git a/chart/values.yaml b/chart/values.yaml index 0610c20..895f780 100644 --- a/chart/values.yaml +++ b/chart/values.yaml @@ -15,6 +15,12 @@ discoveryPeriodSeconds: 60 auth: enabled: true cookieName: YTCypressCookie + cache: + enabled: false + ttlSeconds: 30 + capacity: 1000 + maxConcurrentBackendRequests: 2 + refreshBeforeSeconds: 5 tls: enabled: false diff --git a/dashboards/auth.json b/dashboards/auth.json new file mode 100644 index 0000000..a72dc07 --- /dev/null +++ b/dashboards/auth.json @@ -0,0 +1,399 @@ +{ + "title": "TaskProxy Auth", + "uid": "task-proxy-auth-dashboard", + "version": 1, + "refresh": "30s", + "timezone": "browser", + "editable": true, + "tags": [ + "task-proxy", + "auth", + "auth-cache" + ], + "time": { + "from": "now-1h", + "to": "now" + }, + "templating": { + "list": [ + { + "name": "datasource", + "label": "datasource", + "type": "datasource", + "query": "prometheus", + "current": { + "selected": true + }, + "refresh": 1 + }, + { + "name": "namespace", + "label": "namespace", + "type": "query", + "datasource": { + "type": "prometheus", + "uid": "$datasource" + }, + "query": { + "query": "label_values(yt_task_proxy_auth_success_total, namespace)", + "refId": "NamespaceVariableQuery" + }, + "includeAll": true, + "allValue": ".*", + "multi": true, + "refresh": 2 + }, + { + "name": "pod", + "label": "pod", + "type": "query", + "datasource": { + "type": "prometheus", + "uid": "$datasource" + }, + "query": { + "query": "label_values(yt_task_proxy_auth_success_total{namespace=~\"$namespace\"}, pod)", + "refId": "PodVariableQuery" + }, + "includeAll": true, + "allValue": ".*", + "multi": true, + "refresh": 2 + }, + { + "name": "rate_interval", + "label": "rate interval", + "type": "custom", + "query": "1m,5m,15m", + "current": { + "selected": true, + "text": "5m", + "value": "5m" + } + } + ] + }, + "panels": [ + { + "id": 1, + "type": "row", + "title": "Auth Overview", + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 0 + } + }, + { + "id": 2, + "type": "timeseries", + "title": "Auth Success Rate", + "datasource": { + "type": "prometheus", + "uid": "$datasource" + }, + "gridPos": { + "h": 8, + "w": 8, + "x": 0, + "y": 1 + }, + "targets": [ + { + "expr": "sum(rate(yt_task_proxy_auth_success_total{namespace=~\"$namespace\",pod=~\"$pod\"}[$rate_interval]))", + "legendFormat": "success/s", + "refId": "A" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ops" + }, + "overrides": [] + }, + "options": { + "legend": { + "displayMode": "table", + "placement": "bottom" + } + } + }, + { + "id": 3, + "type": "timeseries", + "title": "Auth Failed Rate by Reason", + "datasource": { + "type": "prometheus", + "uid": "$datasource" + }, + "gridPos": { + "h": 8, + "w": 16, + "x": 8, + "y": 1 + }, + "targets": [ + { + "expr": "sum by (reason) (rate(yt_task_proxy_auth_failed_total{namespace=~\"$namespace\",pod=~\"$pod\"}[$rate_interval]))", + "legendFormat": "{{reason}}", + "refId": "A" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ops" + }, + "overrides": [] + }, + "options": { + "legend": { + "displayMode": "table", + "placement": "right" + } + } + }, + { + "id": 4, + "type": "timeseries", + "title": "Auth Errors by Stage", + "datasource": { + "type": "prometheus", + "uid": "$datasource" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 9 + }, + "targets": [ + { + "expr": "sum by (stage) (rate(yt_task_proxy_auth_errors_total{namespace=~\"$namespace\",pod=~\"$pod\"}[$rate_interval]))", + "legendFormat": "{{stage}}", + "refId": "A" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ops" + }, + "overrides": [] + }, + "options": { + "legend": { + "displayMode": "table", + "placement": "right" + } + } + }, + { + "id": 5, + "type": "timeseries", + "title": "Auth Infra Errors by Kind/Stage", + "datasource": { + "type": "prometheus", + "uid": "$datasource" + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 9 + }, + "targets": [ + { + "expr": "sum by (kind, stage) (rate(yt_task_proxy_auth_infra_errors_total{namespace=~\"$namespace\",pod=~\"$pod\"}[$rate_interval]))", + "legendFormat": "{{kind}} / {{stage}}", + "refId": "A" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ops" + }, + "overrides": [] + }, + "options": { + "legend": { + "displayMode": "table", + "placement": "right" + } + } + }, + { + "id": 6, + "type": "row", + "title": "Auth Cache Core", + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 17 + } + }, + { + "id": 7, + "type": "stat", + "title": "Auth Cache Entries", + "datasource": { + "type": "prometheus", + "uid": "$datasource" + }, + "gridPos": { + "h": 6, + "w": 6, + "x": 0, + "y": 18 + }, + "targets": [ + { + "expr": "sum(yt_task_proxy_auth_cache_entries{namespace=~\"$namespace\",pod=~\"$pod\"})", + "refId": "A" + } + ], + "fieldConfig": { + "defaults": { + "unit": "none" + }, + "overrides": [] + } + }, + { + "id": 8, + "type": "timeseries", + "title": "Auth Cache Hit/Miss Rate", + "datasource": { + "type": "prometheus", + "uid": "$datasource" + }, + "gridPos": { + "h": 6, + "w": 12, + "x": 6, + "y": 18 + }, + "targets": [ + { + "expr": "sum(rate(yt_task_proxy_auth_cache_hits_total{namespace=~\"$namespace\",pod=~\"$pod\"}[$rate_interval]))", + "legendFormat": "hits/s", + "refId": "A" + }, + { + "expr": "sum(rate(yt_task_proxy_auth_cache_misses_total{namespace=~\"$namespace\",pod=~\"$pod\"}[$rate_interval]))", + "legendFormat": "misses/s", + "refId": "B" + } + ], + "fieldConfig": { + "defaults": { + "unit": "ops" + }, + "overrides": [] + }, + "options": { + "legend": { + "displayMode": "table", + "placement": "bottom" + } + } + }, + { + "id": 9, + "type": "stat", + "title": "Auth Cache Hit Ratio", + "datasource": { + "type": "prometheus", + "uid": "$datasource" + }, + "gridPos": { + "h": 6, + "w": 6, + "x": 18, + "y": 18 + }, + "targets": [ + { + "expr": "sum(rate(yt_task_proxy_auth_cache_hits_total{namespace=~\"$namespace\",pod=~\"$pod\"}[$rate_interval])) / clamp_min(sum(rate(yt_task_proxy_auth_cache_hits_total{namespace=~\"$namespace\",pod=~\"$pod\"}[$rate_interval])) + sum(rate(yt_task_proxy_auth_cache_misses_total{namespace=~\"$namespace\",pod=~\"$pod\"}[$rate_interval])), 1e-9)", + "refId": "A" + } + ], + "fieldConfig": { + "defaults": { + "unit": "percentunit", + "min": 0, + "max": 1 + }, + "overrides": [] + } + }, + { + "id": 10, + "type": "timeseries", + "title": "Auth Cache Inflight Backend Requests", + "datasource": { + "type": "prometheus", + "uid": "$datasource" + }, + "gridPos": { + "h": 6, + "w": 12, + "x": 0, + "y": 24 + }, + "targets": [ + { + "expr": "sum(yt_task_proxy_auth_cache_inflight_backend_requests{namespace=~\"$namespace\",pod=~\"$pod\"})", + "legendFormat": "inflight", + "refId": "A" + } + ], + "fieldConfig": { + "defaults": { + "unit": "none" + }, + "overrides": [] + }, + "options": { + "legend": { + "displayMode": "list", + "placement": "bottom" + } + } + }, + { + "id": 11, + "type": "timeseries", + "title": "Auth Cache Waiting Requests", + "datasource": { + "type": "prometheus", + "uid": "$datasource" + }, + "gridPos": { + "h": 6, + "w": 12, + "x": 12, + "y": 24 + }, + "targets": [ + { + "expr": "sum(yt_task_proxy_auth_cache_waiting_requests{namespace=~\"$namespace\",pod=~\"$pod\"})", + "legendFormat": "waiting", + "refId": "A" + } + ], + "fieldConfig": { + "defaults": { + "unit": "none" + }, + "overrides": [] + }, + "options": { + "legend": { + "displayMode": "list", + "placement": "bottom" + } + } + } + ] +} \ No newline at end of file diff --git a/server/go.mod b/server/go.mod index 4c9bbd3..34f495c 100644 --- a/server/go.mod +++ b/server/go.mod @@ -29,6 +29,7 @@ require ( github.com/golang/snappy v1.0.0 // indirect github.com/google/tink/go v1.7.0 // indirect github.com/klauspost/compress v1.18.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect diff --git a/server/main.go b/server/main.go index fde04a3..a4aa5e0 100644 --- a/server/main.go +++ b/server/main.go @@ -21,13 +21,18 @@ func main() { ctx := context.Background() var args struct { - ytProxy string - ytTokenPath string - baseDomain string - dirPath string - discoveryPeriodSeconds uint - authEnabled bool - authCookieName string + ytProxy string + ytTokenPath string + baseDomain string + dirPath string + discoveryPeriodSeconds uint + authEnabled bool + authCookieName string + authCacheEnabled bool + authCacheTTLSeconds int + authCacheCapacity int + authCacheMaxConcurrency int + authCacheRefreshBefore int } flag.StringVar(&args.ytProxy, "yt-proxy", "", "YT proxy host") flag.StringVar(&args.ytTokenPath, "yt-token-path", "", "YT token path") @@ -36,6 +41,11 @@ func main() { flag.UintVar(&args.discoveryPeriodSeconds, "discovery-period-seconds", 60, "services discovery period in seconds") flag.BoolVar(&args.authEnabled, "auth-enabled", true, "operation auth enabled") flag.StringVar(&args.authCookieName, "auth-cookie-name", "", "auth cookie name") + flag.BoolVar(&args.authCacheEnabled, "auth-cache-enabled", false, "enable auth cache") + flag.IntVar(&args.authCacheTTLSeconds, "auth-cache-ttl-seconds", 0, "auth cache entry TTL in seconds (0 means no expiration)") + flag.IntVar(&args.authCacheCapacity, "auth-cache-capacity", 0, "auth cache maximum number of entries (0 means unlimited)") + flag.IntVar(&args.authCacheMaxConcurrency, "auth-cache-max-concurrent-backend-requests", 0, "auth cache max concurrent backend requests per key on misses (0 means unlimited)") + flag.IntVar(&args.authCacheRefreshBefore, "auth-cache-refresh-before-seconds", 0, "auth cache proactive refresh threshold in seconds before TTL deadline (0 disables proactive refresh)") flag.Parse() if args.ytProxy == "" { @@ -53,6 +63,18 @@ func main() { if args.discoveryPeriodSeconds < 1 || args.discoveryPeriodSeconds > 24*60*60 { log.Fatal("'discovery-period-seconds' argument must be positive and not greater than 24 hours") } + if args.authCacheTTLSeconds < 0 { + log.Fatal("'auth-cache-ttl-seconds' argument must be non-negative") + } + if args.authCacheCapacity < 0 { + log.Fatal("'auth-cache-capacity' argument must be non-negative") + } + if args.authCacheMaxConcurrency < 0 { + log.Fatal("'auth-cache-max-concurrent-backend-requests' argument must be non-negative") + } + if args.authCacheRefreshBefore < 0 { + log.Fatal("'auth-cache-refresh-before-seconds' argument must be non-negative") + } ytTokenBytes, err := os.ReadFile(args.ytTokenPath) if err != nil { @@ -79,7 +101,13 @@ func main() { taskDiscovery := pkg.CreateTaskDiscovery(args.baseDomain, args.dirPath, ytClient, &logger) - authServer := pkg.CreateAuthServer(ytClient, &logger, args.authCookieName) + authServer := pkg.CreateAuthServer(ytClient, args.ytProxy, &logger, args.authCookieName, pkg.AuthCacheConfig{ + Enabled: args.authCacheEnabled, + TTLSeconds: args.authCacheTTLSeconds, + Capacity: args.authCacheCapacity, + MaxConcurrentBackendRequests: args.authCacheMaxConcurrency, + RefreshBeforeSeconds: args.authCacheRefreshBefore, + }) taskUpdater := pkg.CreateTaskUpdater(args.baseDomain, tls, args.authEnabled, authServer, taskDiscovery, cache) diff --git a/server/pkg/auth.go b/server/pkg/auth.go index c7b5c40..70089d2 100644 --- a/server/pkg/auth.go +++ b/server/pkg/auth.go @@ -17,6 +17,14 @@ import ( "google.golang.org/grpc/codes" ) +type AuthCacheConfig struct { + Enabled bool + TTLSeconds int + Capacity int + MaxConcurrentBackendRequests int + RefreshBeforeSeconds int +} + type authServer struct { authv3.UnimplementedAuthorizationServer @@ -24,17 +32,22 @@ type authServer struct { hashToTasks map[string]Task operationAliasToID map[string]string yt ytsdk.Client + ytProxy string logger *SimpleLogger authCookieName string + cache *authPermissionCache // cache is nil when caching is disabled. } -func CreateAuthServer(yt ytsdk.Client, logger *SimpleLogger, authCookieName string) *authServer { +func CreateAuthServer(yt ytsdk.Client, ytProxy string, logger *SimpleLogger, authCookieName string, cacheCfg AuthCacheConfig) *authServer { + cache := newAuthPermissionCache(cacheCfg, logger) return &authServer{ hashToTasks: make(map[string]Task), mx: sync.RWMutex{}, yt: yt, + ytProxy: ytProxy, logger: logger, authCookieName: authCookieName, + cache: cache, } } @@ -125,47 +138,66 @@ func (s *authServer) checkOperationPermission(ctx context.Context, operationID s return false, nil } - whoAmIStarted := time.Now() - userResp, err := s.yt.WhoAmI(ytsdk.WithCredentials(ctx, userCredentials), nil) - defaultMetrics.ObserveYTDuration("whoami", time.Since(whoAmIStarted)) - if err != nil { - s.logger.Errorf("whoami failed: dur=%s, err=%v", time.Since(whoAmIStarted), err) - defaultMetrics.ObserveAuthYTError("whoami", err) - return false, err + cacheKey := authCacheKey{ + credentials: credentialsKey(userCredentials), + operationID: operationID, } - user := userResp.Login - if user == "" { - s.logger.Errorf("user not identified by provided credentials: %v", userResp) - defaultMetrics.ObserveAuthFailure(authReasonUserNotIdentified, nil) - return false, nil - } - s.logger.Debugf("auth user is %q", user) + allowed, _, err := s.cache.GetOrLoad(ctx, cacheKey, func(checkCtx context.Context) (bool, string, error) { + userYT, err := CreateYTClient(s.ytProxy, userCredentials, s.logger) + if err != nil { + defaultMetrics.ObserveAuthYTError("create_client", err) + return false, "", err + } - operationIDg, err := guid.ParseString(operationID) - if err != nil { - s.logger.Warnf("invalid operation ID %s", operationID) - defaultMetrics.ObserveAuthFailure(authReasonInvalidOperation, nil) - return false, nil - } + whoAmIStarted := time.Now() + userResp, err := userYT.WhoAmI(checkCtx, nil) + defaultMetrics.ObserveYTDuration("whoami", time.Since(whoAmIStarted)) + if err != nil { + s.logger.Errorf("whoami failed: dur=%s, err=%v", time.Since(whoAmIStarted), err) + defaultMetrics.ObserveAuthYTError("whoami", err) + return false, "", err + } + + user := userResp.Login + if user == "" { + s.logger.Errorf("user not identified by provided credentials: %v", userResp) + defaultMetrics.ObserveAuthFailure(authReasonUserNotIdentified, nil) + return false, "", nil + } + s.logger.Debugf("auth user is %q", user) - permissionCheckStarted := time.Now() - resp, err := s.yt.CheckOperationPermission( - ctx, - yt.OperationID(operationIDg), - user, - yt.PermissionRead, - nil, - ) - defaultMetrics.ObserveYTDuration("check_operation_permission", time.Since(permissionCheckStarted)) + operationIDg, err := guid.ParseString(operationID) + if err != nil { + s.logger.Warnf("invalid operation ID %s", operationID) + defaultMetrics.ObserveAuthFailure(authReasonInvalidOperation, nil) + return false, "", nil + } + + permissionCheckStarted := time.Now() + resp, err := s.yt.CheckOperationPermission( + checkCtx, + yt.OperationID(operationIDg), + user, + yt.PermissionRead, + nil, + ) + defaultMetrics.ObserveYTDuration("check_operation_permission", time.Since(permissionCheckStarted)) + if err != nil { + s.logger.Infof("permission check failed: dur=%s, err=%v", time.Since(permissionCheckStarted), err) + defaultMetrics.ObserveAuthYTError("permission_check", err) + return false, "", err + } + + allowed := resp.Action == "allow" + s.logger.Debugf("check operation permission result is %q for user %q and operation %q", resp.Action, user, operationID) + return allowed, user, nil + }) if err != nil { - s.logger.Infof("permission check failed: dur=%s, err=%v", time.Since(permissionCheckStarted), err) - defaultMetrics.ObserveAuthYTError("permission_check", err) return false, err } - s.logger.Debugf("check operation permission result is %q for user %q and operation %q", resp.Action, user, operationID) - if resp.Action != "allow" { + if !allowed { defaultMetrics.ObserveAuthFailure(authReasonPermissionDenied, nil) return false, nil } diff --git a/server/pkg/auth_cache.go b/server/pkg/auth_cache.go new file mode 100644 index 0000000..6ca8dee --- /dev/null +++ b/server/pkg/auth_cache.go @@ -0,0 +1,328 @@ +package pkg + +import ( + "container/list" + "context" + "fmt" + "sync" + "time" + + ytsdk "go.ytsaurus.tech/yt/go/yt" +) + +type authCacheKey struct { + credentials string + operationID string +} + +func credentialsKey(creds ytsdk.Credentials) string { + if creds == nil { + return "" + } + switch v := creds.(type) { + case *ytsdk.TokenCredentials: + return "token:" + v.Token + case *ytsdk.BearerCredentials: + return "bearer:" + v.Token + case *ytsdk.CookieCredentials: + return "cookie:" + v.Cookie.Value + default: + return fmt.Sprintf("%T:%v", v, v) + } +} + +type authCacheEntry struct { + allowed bool + expiresAt time.Time + login string +} + +type authCacheItem struct { + key authCacheKey + entry authCacheEntry +} + +type authCacheLoadState struct { + inFlight int + waitCh chan struct{} +} + +type authPermissionCache struct { + logger *SimpleLogger + metrics *Metrics + + ttl time.Duration + refreshBefore time.Duration + capacity int + maxConcurrentLoadsPerKeyMiss int + nowFn func() time.Time + + mx sync.Mutex + lru *list.List + entries map[authCacheKey]*list.Element + loadStates map[authCacheKey]*authCacheLoadState +} + +func newAuthPermissionCache(cfg AuthCacheConfig, logger *SimpleLogger) *authPermissionCache { + if !cfg.Enabled { + return nil + } + + cache := &authPermissionCache{ + logger: logger, + metrics: DefaultMetrics(), + ttl: time.Duration(cfg.TTLSeconds) * time.Second, + refreshBefore: time.Duration(cfg.RefreshBeforeSeconds) * time.Second, + capacity: cfg.Capacity, + maxConcurrentLoadsPerKeyMiss: cfg.MaxConcurrentBackendRequests, + nowFn: time.Now, + lru: list.New(), + entries: make(map[authCacheKey]*list.Element), + loadStates: make(map[authCacheKey]*authCacheLoadState), + } + return cache +} + +func (c *authPermissionCache) GetOrLoad( + ctx context.Context, + key authCacheKey, + loadFn func(context.Context) (bool, string, error), +) (bool, string, error) { + if c == nil { + return loadFn(ctx) + } + + if allowed, ok, needsRefresh, login := c.get(key); ok { + c.metrics.ObserveAuthCacheHit() + c.logger.Debugf("auth cache hit: operation_id=%q user=%q allowed=%v", key.operationID, login, allowed) + if needsRefresh { + c.logger.Debugf("auth cache preventive refresh scheduled: operation_id=%q, user=%q", key.operationID, login) + c.triggerRefresh(key, login, loadFn) + } + return allowed, login, nil + } + c.metrics.ObserveAuthCacheMiss() + c.logger.Debugf("auth cache miss: operation_id=%q user=%q", key.operationID, "unknown") + + return c.loadOnMiss(ctx, key, loadFn) +} + +func (c *authPermissionCache) get(key authCacheKey) (allowed bool, ok bool, needsRefresh bool, login string) { + c.lock() + defer c.unlock() + + elem, exists := c.entries[key] + if !exists { + return false, false, false, "" + } + + item := elem.Value.(*authCacheItem) + now := c.nowFn() + if c.isExpired(item.entry, now) { + c.removeElement(elem) + return false, false, false, "" + } + + c.lru.MoveToFront(elem) + + if c.refreshBefore > 0 && !item.entry.expiresAt.IsZero() { + remaining := item.entry.expiresAt.Sub(now) + if remaining < c.refreshBefore { + if !c.hasLoadInFlightLocked(key) { + needsRefresh = true + } + } + } + + return item.entry.allowed, true, needsRefresh, item.entry.login +} + +func (c *authPermissionCache) triggerRefresh( + key authCacheKey, + login string, + loadFn func(context.Context) (bool, string, error), +) { + started, _ := c.tryStartLoad(key, 1) + if !started { + c.logger.Debugf( + "auth cache preventive refresh skipped: operation_id=%q user=%q in-flight request already exists", + key.operationID, + login, + ) + return + } + + c.logger.Debugf("auth cache preventive refresh started: operation_id=%q user=%q", key.operationID, login) + + go func() { + _, _, _ = c.executeLoad(context.Background(), key, loadFn) + }() +} + +func (c *authPermissionCache) loadOnMiss( + ctx context.Context, + key authCacheKey, + loadFn func(context.Context) (bool, string, error), +) (bool, string, error) { + for { + if allowed, ok, needsRefresh, login := c.get(key); ok { + if needsRefresh { + c.triggerRefresh(key, login, loadFn) + } + return allowed, login, nil + } + + started, waitCh := c.tryStartLoad(key, c.maxConcurrentLoadsPerKeyMiss) + if started { + return c.executeLoad(ctx, key, loadFn) + } + c.logger.Debugf( + "auth cache waiting for in-flight backend request: operation_id=%q user=%q max_concurrent_per_key=%d", + key.operationID, + "unknown", + c.maxConcurrentLoadsPerKeyMiss, + ) + c.metrics.IncAuthCacheWaitingRequests() + + select { + case <-waitCh: + // Some in-flight load has completed, retry from cache. + case <-ctx.Done(): + c.metrics.DecAuthCacheWaitingRequests() + return false, "", ctx.Err() + } + c.metrics.DecAuthCacheWaitingRequests() + } +} + +func (c *authPermissionCache) tryStartLoad( + key authCacheKey, + maxInFlight int, +) (bool, chan struct{}) { + c.lock() + defer c.unlock() + + state := c.loadStates[key] + if state == nil { + state = &authCacheLoadState{ + waitCh: make(chan struct{}), + } + c.loadStates[key] = state + } + + if maxInFlight <= 0 || state.inFlight < maxInFlight { + state.inFlight++ + c.metrics.IncAuthCacheInflightBackendRequests() + return true, nil + } + + return false, state.waitCh +} + +// hasLoadInFlightLocked expects c.mx to be held by the caller. +func (c *authPermissionCache) hasLoadInFlightLocked(key authCacheKey) bool { + state := c.loadStates[key] + return state != nil && state.inFlight > 0 +} + +func (c *authPermissionCache) executeLoad( + ctx context.Context, + key authCacheKey, + loadFn func(context.Context) (bool, string, error), +) (bool, string, error) { + allowed, login, err := loadFn(ctx) + if err == nil { + c.logger.Debugf("auth cache backend load succeeded: operation_id=%q user=%q allowed=%v", key.operationID, login, allowed) + c.set(key, authCacheEntry{ + allowed: allowed, + expiresAt: c.expiration(), + login: login, + }) + } else { + c.logger.Debugf("auth cache backend load failed: operation_id=%q user=%q err=%v", key.operationID, login, err) + } + c.finishLoad(key) + return allowed, login, err +} + +func (c *authPermissionCache) finishLoad(key authCacheKey) { + c.lock() + defer c.unlock() + + state := c.loadStates[key] + if state == nil { + return + } + if state.inFlight > 0 { + state.inFlight-- + c.metrics.DecAuthCacheInflightBackendRequests() + } + close(state.waitCh) + if state.inFlight == 0 { + delete(c.loadStates, key) + return + } + state.waitCh = make(chan struct{}) +} + +func (c *authPermissionCache) set(key authCacheKey, entry authCacheEntry) { + c.lock() + defer c.unlock() + + if elem, ok := c.entries[key]; ok { + item := elem.Value.(*authCacheItem) + item.entry = entry + c.lru.MoveToFront(elem) + return + } + + elem := c.lru.PushFront(&authCacheItem{ + key: key, + entry: entry, + }) + c.entries[key] = elem + c.metrics.IncAuthCacheEntries() + + if c.capacity > 0 && c.lru.Len() > c.capacity { + c.removeOldest() + } +} + +// removeOldest expects c.mx to be held by the caller. +func (c *authPermissionCache) removeOldest() { + elem := c.lru.Back() + if elem == nil { + return + } + c.removeElement(elem) +} + +// removeElement expects c.mx to be held by the caller. +func (c *authPermissionCache) removeElement(elem *list.Element) { + item := elem.Value.(*authCacheItem) + delete(c.entries, item.key) + c.lru.Remove(elem) + c.metrics.DecAuthCacheEntries() +} + +func (c *authPermissionCache) expiration() time.Time { + if c.ttl <= 0 { + return time.Time{} + } + return c.nowFn().Add(c.ttl) +} + +func (c *authPermissionCache) isExpired(entry authCacheEntry, now time.Time) bool { + if entry.expiresAt.IsZero() { + return false + } + return !now.Before(entry.expiresAt) +} + +func (c *authPermissionCache) lock() { + c.mx.Lock() +} + +func (c *authPermissionCache) unlock() { + c.mx.Unlock() +} diff --git a/server/pkg/auth_cache_metrics_test.go b/server/pkg/auth_cache_metrics_test.go new file mode 100644 index 0000000..99b839b --- /dev/null +++ b/server/pkg/auth_cache_metrics_test.go @@ -0,0 +1,114 @@ +package pkg + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/require" +) + +func TestAuthCacheMetricsHitMissAndSize(t *testing.T) { + cache := newAuthPermissionCache(AuthCacheConfig{ + Enabled: true, + TTLSeconds: 60, + Capacity: 100, + MaxConcurrentBackendRequests: 1, + }, &SimpleLogger{}) + require.NotNil(t, cache) + + cache.metrics = NewMetrics(prometheus.NewRegistry()) + + beforeHits := testutil.ToFloat64(cache.metrics.authCacheHits) + beforeMisses := testutil.ToFloat64(cache.metrics.authCacheMisses) + beforeSize := testutil.ToFloat64(cache.metrics.authCacheEntries) + + key := authCacheKey{credentials: "token:metrics-user", operationID: "metrics-op"} + loadFn := func(context.Context) (bool, string, error) { + return true, "user", nil + } + + allowed, login, err := cache.GetOrLoad(context.Background(), key, loadFn) + require.NoError(t, err) + require.True(t, allowed) + require.Equal(t, "user", login) + + allowed, login, err = cache.GetOrLoad(context.Background(), key, loadFn) + require.NoError(t, err) + require.True(t, allowed) + require.Equal(t, "user", login) + + require.Equal(t, beforeHits+1, testutil.ToFloat64(cache.metrics.authCacheHits)) + require.Equal(t, beforeMisses+1, testutil.ToFloat64(cache.metrics.authCacheMisses)) + require.Equal(t, beforeSize+1, testutil.ToFloat64(cache.metrics.authCacheEntries)) +} + +func TestAuthCacheMetricsInFlightAndWaitingRequests(t *testing.T) { + cache := newAuthPermissionCache(AuthCacheConfig{ + Enabled: true, + TTLSeconds: 60, + Capacity: 100, + MaxConcurrentBackendRequests: 1, + }, &SimpleLogger{}) + require.NotNil(t, cache) + + cache.metrics = NewMetrics(prometheus.NewRegistry()) + + beforeInflight := testutil.ToFloat64(cache.metrics.authCacheInflightBackend) + beforeWaiting := testutil.ToFloat64(cache.metrics.authCacheWaitingRequests) + + key := authCacheKey{credentials: "token:wait-user", operationID: "wait-op"} + release := make(chan struct{}) + loadFn := func(context.Context) (bool, string, error) { + <-release + return true, "user", nil + } + + var wg sync.WaitGroup + errs := make(chan error, 2) + wg.Add(2) + go func() { + defer wg.Done() + allowed, login, err := cache.GetOrLoad(context.Background(), key, loadFn) + if err == nil && (!allowed || login != "user") { + err = errors.New("unexpected auth cache result for first request") + } + errs <- err + }() + + require.Eventually(t, func() bool { + return testutil.ToFloat64(cache.metrics.authCacheInflightBackend) == beforeInflight+1 + }, time.Second, 10*time.Millisecond) + + go func() { + defer wg.Done() + allowed, login, err := cache.GetOrLoad(context.Background(), key, loadFn) + if err == nil && (!allowed || login != "user") { + err = errors.New("unexpected auth cache result for second request") + } + errs <- err + }() + + require.Eventually(t, func() bool { + return testutil.ToFloat64(cache.metrics.authCacheWaitingRequests) == beforeWaiting+1 + }, time.Second, 10*time.Millisecond) + + close(release) + wg.Wait() + close(errs) + + for err := range errs { + require.NoError(t, err) + } + + require.Eventually(t, func() bool { + return testutil.ToFloat64(cache.metrics.authCacheInflightBackend) == beforeInflight + }, time.Second, 10*time.Millisecond) + require.Eventually(t, func() bool { + return testutil.ToFloat64(cache.metrics.authCacheWaitingRequests) == beforeWaiting + }, time.Second, 10*time.Millisecond) +} diff --git a/server/pkg/auth_cache_test.go b/server/pkg/auth_cache_test.go new file mode 100644 index 0000000..4773119 --- /dev/null +++ b/server/pkg/auth_cache_test.go @@ -0,0 +1,246 @@ +package pkg + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestAuthPermissionCacheSingleflightByKey(t *testing.T) { + cache := newAuthPermissionCache(AuthCacheConfig{ + Enabled: true, + TTLSeconds: 60, + Capacity: 100, + MaxConcurrentBackendRequests: 1, + }, &SimpleLogger{}) + require.NotNil(t, cache) + + key := authCacheKey{credentials: "token:u", operationID: "op1"} + + releaseLoad := make(chan struct{}) + var loadCalls atomic.Int32 + loadFn := func(ctx context.Context) (bool, string, error) { + loadCalls.Add(1) + <-releaseLoad + return true, "user", nil + } + + const goroutines = 20 + var wg sync.WaitGroup + wg.Add(goroutines) + results := make(chan bool, goroutines) + logins := make(chan string, goroutines) + errs := make(chan error, goroutines) + for range goroutines { + go func() { + defer wg.Done() + allowed, login, err := cache.GetOrLoad(context.Background(), key, loadFn) + results <- allowed + logins <- login + errs <- err + }() + } + + time.Sleep(50 * time.Millisecond) + close(releaseLoad) + wg.Wait() + close(results) + close(logins) + close(errs) + + require.Equal(t, int32(1), loadCalls.Load()) + for err := range errs { + require.NoError(t, err) + } + for allowed := range results { + require.True(t, allowed) + } + for login := range logins { + require.Equal(t, "user", login) + } +} + +func TestAuthPermissionCacheRespectsPerKeyConcurrentMissLimit(t *testing.T) { + cache := newAuthPermissionCache(AuthCacheConfig{ + Enabled: true, + TTLSeconds: 60, + Capacity: 100, + MaxConcurrentBackendRequests: 2, + }, &SimpleLogger{}) + require.NotNil(t, cache) + + var inFlight atomic.Int32 + var maxInFlight atomic.Int32 + var loadCalls atomic.Int32 + releaseLoad := make(chan struct{}) + loadFn := func(ctx context.Context) (bool, string, error) { + cur := inFlight.Add(1) + for { + prev := maxInFlight.Load() + if cur <= prev || maxInFlight.CompareAndSwap(prev, cur) { + break + } + } + loadCalls.Add(1) + <-releaseLoad + inFlight.Add(-1) + return true, "user", nil + } + + key := authCacheKey{ + credentials: "token:u", + operationID: "same-op", + } + + const requests = 20 + var wg sync.WaitGroup + wg.Add(requests) + errs := make(chan error, requests) + for i := 0; i < requests; i++ { + go func() { + defer wg.Done() + allowed, login, err := cache.GetOrLoad(context.Background(), key, loadFn) + if err != nil { + errs <- err + return + } + if !allowed { + errs <- errors.New("permission should be allowed") + return + } + if login != "user" { + errs <- errors.New("login should be user") + } + }() + } + + require.Eventually(t, func() bool { + return loadCalls.Load() >= 2 + }, time.Second, 10*time.Millisecond) + close(releaseLoad) + + wg.Wait() + close(errs) + for err := range errs { + require.NoError(t, err) + } + + require.GreaterOrEqual(t, loadCalls.Load(), int32(2)) + require.LessOrEqual(t, maxInFlight.Load(), int32(2)) +} + +func TestAuthPermissionCacheLimitIsNotGlobal(t *testing.T) { + cache := newAuthPermissionCache(AuthCacheConfig{ + Enabled: true, + TTLSeconds: 60, + Capacity: 100, + MaxConcurrentBackendRequests: 1, + }, &SimpleLogger{}) + require.NotNil(t, cache) + + var inFlight atomic.Int32 + var maxInFlight atomic.Int32 + releaseLoad := make(chan struct{}) + loadFn := func(ctx context.Context) (bool, string, error) { + cur := inFlight.Add(1) + for { + prev := maxInFlight.Load() + if cur <= prev || maxInFlight.CompareAndSwap(prev, cur) { + break + } + } + <-releaseLoad + inFlight.Add(-1) + return true, "user", nil + } + + key1 := authCacheKey{credentials: "token:u1", operationID: "op1"} + key2 := authCacheKey{credentials: "token:u2", operationID: "op2"} + + var wg sync.WaitGroup + wg.Add(2) + errs := make(chan error, 2) + go func() { + defer wg.Done() + allowed, login, err := cache.GetOrLoad(context.Background(), key1, loadFn) + if err == nil && (!allowed || login != "user") { + err = errors.New("unexpected result for key1") + } + errs <- err + }() + go func() { + defer wg.Done() + allowed, login, err := cache.GetOrLoad(context.Background(), key2, loadFn) + if err == nil && (!allowed || login != "user") { + err = errors.New("unexpected result for key2") + } + errs <- err + }() + + require.Eventually(t, func() bool { + return maxInFlight.Load() >= 2 + }, time.Second, 10*time.Millisecond) + close(releaseLoad) + + wg.Wait() + close(errs) + for err := range errs { + require.NoError(t, err) + } + require.GreaterOrEqual(t, maxInFlight.Load(), int32(2)) +} + +func TestAuthPermissionCacheProactiveRefresh(t *testing.T) { + cache := newAuthPermissionCache(AuthCacheConfig{ + Enabled: true, + TTLSeconds: 60, + Capacity: 100, + MaxConcurrentBackendRequests: 1, + RefreshBeforeSeconds: 30, + }, &SimpleLogger{}) + require.NotNil(t, cache) + + // Keep the test fast and deterministic. + cache.ttl = 100 * time.Millisecond + cache.refreshBefore = 80 * time.Millisecond + + key := authCacheKey{credentials: "token:u", operationID: "op-proactive"} + + var loadCalls atomic.Int32 + loadFn := func(ctx context.Context) (bool, string, error) { + call := loadCalls.Add(1) + // first load: true, proactive refresh load: false + if call == 1 { + return true, "user-first", nil + } + return false, "user-second", nil + } + + allowed, login, err := cache.GetOrLoad(context.Background(), key, loadFn) + require.NoError(t, err) + require.True(t, allowed) + require.Equal(t, "user-first", login) + require.Equal(t, int32(1), loadCalls.Load()) + + time.Sleep(35 * time.Millisecond) // remaining TTL is now below refresh threshold. + + allowed, login, err = cache.GetOrLoad(context.Background(), key, loadFn) + require.NoError(t, err) + require.True(t, allowed) // stale value while refresh is happening + require.Equal(t, "user-first", login) + + require.Eventually(t, func() bool { + return loadCalls.Load() >= 2 + }, time.Second, 10*time.Millisecond) + + require.Eventually(t, func() bool { + allowed, login, err = cache.GetOrLoad(context.Background(), key, loadFn) + require.NoError(t, err) + return !allowed && login == "user-second" + }, time.Second, 10*time.Millisecond) +} diff --git a/server/pkg/auth_test.go b/server/pkg/auth_test.go index c633520..8668b7e 100644 --- a/server/pkg/auth_test.go +++ b/server/pkg/auth_test.go @@ -42,7 +42,7 @@ func TestFindTaskByRequest(t *testing.T) { "anotheralias": "op-999", } - server := CreateAuthServer(nil, &SimpleLogger{}, "") + server := CreateAuthServer(nil, "", &SimpleLogger{}, "", AuthCacheConfig{}) server.SetTasksData(hashToTasks, operationAliasToID) tests := []struct { diff --git a/server/pkg/metrics.go b/server/pkg/metrics.go index 77e9edc..4d85933 100644 --- a/server/pkg/metrics.go +++ b/server/pkg/metrics.go @@ -22,6 +22,11 @@ type Metrics struct { authFailures *prometheus.CounterVec authErrors *prometheus.CounterVec authInfrastructureErrors *prometheus.CounterVec + authCacheEntries prometheus.Gauge + authCacheHits prometheus.Counter + authCacheMisses prometheus.Counter + authCacheInflightBackend prometheus.Gauge + authCacheWaitingRequests prometheus.Gauge discoverySuccesses *prometheus.CounterVec discoveryFailures *prometheus.CounterVec discoveryErrors *prometheus.CounterVec @@ -83,6 +88,36 @@ func NewMetrics(registerer prometheus.Registerer) *Metrics { }, []string{"stage", "kind", "grpc_code"}, ), + authCacheEntries: prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "yt_task_proxy_auth_cache_entries", + Help: "Current number of entries in auth cache.", + }, + ), + authCacheHits: prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "yt_task_proxy_auth_cache_hits_total", + Help: "Total number of auth cache hits.", + }, + ), + authCacheMisses: prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "yt_task_proxy_auth_cache_misses_total", + Help: "Total number of auth cache misses.", + }, + ), + authCacheInflightBackend: prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "yt_task_proxy_auth_cache_inflight_backend_requests", + Help: "Current number of in-flight backend requests for auth cache.", + }, + ), + authCacheWaitingRequests: prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "yt_task_proxy_auth_cache_waiting_requests", + Help: "Current number of requests waiting on in-flight auth cache backend requests.", + }, + ), discoverySuccesses: prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "yt_task_proxy_discovery_success_total", @@ -133,6 +168,11 @@ func NewMetrics(registerer prometheus.Registerer) *Metrics { m.authFailures, m.authErrors, m.authInfrastructureErrors, + m.authCacheEntries, + m.authCacheHits, + m.authCacheMisses, + m.authCacheInflightBackend, + m.authCacheWaitingRequests, m.discoverySuccesses, m.discoveryFailures, m.discoveryErrors, @@ -179,6 +219,38 @@ func (m *Metrics) ObserveAuthYTError(stage string, err error) string { return m.ObserveAuthFailure(stage, err) } +func (m *Metrics) ObserveAuthCacheHit() { + m.authCacheHits.Inc() +} + +func (m *Metrics) ObserveAuthCacheMiss() { + m.authCacheMisses.Inc() +} + +func (m *Metrics) IncAuthCacheEntries() { + m.authCacheEntries.Inc() +} + +func (m *Metrics) DecAuthCacheEntries() { + m.authCacheEntries.Dec() +} + +func (m *Metrics) IncAuthCacheInflightBackendRequests() { + m.authCacheInflightBackend.Inc() +} + +func (m *Metrics) DecAuthCacheInflightBackendRequests() { + m.authCacheInflightBackend.Dec() +} + +func (m *Metrics) IncAuthCacheWaitingRequests() { + m.authCacheWaitingRequests.Inc() +} + +func (m *Metrics) DecAuthCacheWaitingRequests() { + m.authCacheWaitingRequests.Dec() +} + func (m *Metrics) ObserveDiscoverySuccess(reason string) { m.discoverySuccesses.WithLabelValues(reason).Inc() } diff --git a/server/pkg/metrics_test.go b/server/pkg/metrics_test.go index 43ac993..074f017 100644 --- a/server/pkg/metrics_test.go +++ b/server/pkg/metrics_test.go @@ -24,6 +24,11 @@ func TestMetricsHandler(t *testing.T) { metrics.ObserveAuthFailure(authReasonPermissionDenied, nil) metrics.ObserveAuthYTError("permission_check", context.DeadlineExceeded) metrics.ObserveAuthFailure(authReasonTaskLookup, nil) + metrics.ObserveAuthCacheHit() + metrics.ObserveAuthCacheMiss() + metrics.IncAuthCacheEntries() + metrics.IncAuthCacheInflightBackendRequests() + metrics.IncAuthCacheWaitingRequests() metrics.ObserveDiscoverySuccess("updated") metrics.ObserveDiscoverySuccess("no_changes") metrics.ObserveDiscoveryFailure("discovery", nil) @@ -46,6 +51,11 @@ func TestMetricsHandler(t *testing.T) { require.True(t, strings.Contains(body, `yt_task_proxy_auth_errors_total{stage="task_lookup"} 1`)) require.True(t, strings.Contains(body, `yt_task_proxy_auth_errors_total{stage="permission_check"} 1`)) require.True(t, strings.Contains(body, `yt_task_proxy_auth_infra_errors_total{grpc_code="none",kind="context_deadline_exceeded",stage="permission_check"} 1`)) + require.True(t, strings.Contains(body, `yt_task_proxy_auth_cache_hits_total 1`)) + require.True(t, strings.Contains(body, `yt_task_proxy_auth_cache_misses_total 1`)) + require.True(t, strings.Contains(body, `yt_task_proxy_auth_cache_entries 1`)) + require.True(t, strings.Contains(body, `yt_task_proxy_auth_cache_inflight_backend_requests 1`)) + require.True(t, strings.Contains(body, `yt_task_proxy_auth_cache_waiting_requests 1`)) require.True(t, strings.Contains(body, `yt_task_proxy_discovery_success_total{reason="updated"} 1`)) require.True(t, strings.Contains(body, `yt_task_proxy_discovery_success_total{reason="no_changes"} 1`)) require.True(t, strings.Contains(body, `yt_task_proxy_discovery_failed_total{reason="discovery"} 1`)) diff --git a/server/pkg/updater_test.go b/server/pkg/updater_test.go index 47ffd27..67f7c83 100644 --- a/server/pkg/updater_test.go +++ b/server/pkg/updater_test.go @@ -20,7 +20,7 @@ func (s *failingSnapshotSetter) SetSnapshot(_ context.Context, _ string, _ cache } func TestUpdateDoesNotChangeAuthDataIfSetSnapshotFails(t *testing.T) { - authServer := CreateAuthServer(nil, &SimpleLogger{}, "") + authServer := CreateAuthServer(nil, "", &SimpleLogger{}, "", AuthCacheConfig{}) oldTask := Task{ operationID: "op-old",