Skip to content
122 changes: 112 additions & 10 deletions cmd/ateapi/internal/controlapi/functional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package controlapi
import (
"context"
"fmt"
"log"
"net"
"os"
"os/exec"
Expand Down Expand Up @@ -60,9 +61,11 @@ var (

func TestMain(m *testing.M) {
cmd := exec.Command("bash", "../../../../hack/run-tool.sh", "setup-envtest", "use", "--print", "path")
var stderr strings.Builder
cmd.Stderr = &stderr
out, err := cmd.Output()
if err != nil {
os.Exit(1)
log.Fatalf("setup-envtest: %v (stderr: %s)", err, stderr.String())
}
binaryAssetsDirectory := strings.TrimSpace(string(out))

Expand All @@ -73,19 +76,19 @@ func TestMain(m *testing.M) {

cfg, err = testEnv.Start()
if err != nil {
os.Exit(1)
log.Fatalf("testEnv.Start: %v", err)
}

// Create ate-system namespace
k8sClient, err := kubernetes.NewForConfig(cfg)
if err != nil {
os.Exit(1)
log.Fatalf("kubernetes.NewForConfig: %v", err)
}
_, err = k8sClient.CoreV1().Namespaces().Create(context.Background(), &corev1.Namespace{
ObjectMeta: metav1.ObjectMeta{Name: "ate-system"},
}, metav1.CreateOptions{})
if err != nil && !strings.Contains(err.Error(), "already exists") {
os.Exit(1)
log.Fatalf("create ate-system namespace: %v", err)
}

// Create shared Atelet Pod
Expand All @@ -106,14 +109,14 @@ func TestMain(m *testing.M) {
}
createdAtelet, err := k8sClient.CoreV1().Pods("ate-system").Create(context.Background(), ateletPod, metav1.CreateOptions{})
if err != nil && !strings.Contains(err.Error(), "already exists") {
os.Exit(1)
log.Fatalf("create atelet pod: %v", err)
}
if err == nil {
createdAtelet.Status.PodIPs = []corev1.PodIP{{IP: "127.0.0.1"}}
createdAtelet.Status.Phase = corev1.PodRunning
_, err = k8sClient.CoreV1().Pods("ate-system").UpdateStatus(context.Background(), createdAtelet, metav1.UpdateOptions{})
if err != nil {
os.Exit(1)
log.Fatalf("update atelet pod status: %v", err)
}
}

Expand All @@ -122,7 +125,7 @@ func TestMain(m *testing.M) {
ateletpb.RegisterAteomHerderServer(ateletGrpcServer, fakeAtelet)
ateletLis, err := net.Listen("tcp", "127.0.0.1:8085")
if err != nil {
os.Exit(1)
log.Fatalf("listen on 127.0.0.1:8085: %v", err)
}
go func() {
if err := ateletGrpcServer.Serve(ateletLis); err != nil {
Expand All @@ -136,7 +139,7 @@ func TestMain(m *testing.M) {

err = testEnv.Stop()
if err != nil {
os.Exit(1)
log.Fatalf("testEnv.Stop: %v", err)
}

os.Exit(code)
Expand Down Expand Up @@ -868,6 +871,7 @@ func TestResumeActor(t *testing.T) {
ActorTemplate: "tmpl1",
ActorId: id,
Ip: "127.0.0.1",
NodeName: "node1",
}

if diff := cmp.Diff(wantWorker, actorWorker, protocmp.Transform(), protocmp.IgnoreFields(&ateapipb.Worker{}, "version"), protocmp.IgnoreFields(&ateapipb.Worker{}, "worker_pod_uid")); diff != "" {
Expand Down Expand Up @@ -1123,14 +1127,112 @@ func TestSuspendActor(t *testing.T) {
ActorTemplateNamespace: ns,
ActorTemplateName: "tmpl1",
Status: ateapipb.Actor_STATUS_SUSPENDED,
LastSnapshot: fmt.Sprintf("gs://my-bucket/%s/tmpl1/%s/", ns, id),
LatestSnapshotInfo: &ateapipb.SnapshotInfo{
Type: ateapipb.SnapshotType_SNAPSHOT_TYPE_EXTERNAL,
Data: &ateapipb.SnapshotInfo_External{
External: &ateapipb.ExternalSnapshotInfo{
SnapshotUriPrefix: fmt.Sprintf("gs://fake-fake-fake/%s/", id),
},
},
},
},
}

if diff := cmp.Diff(want, getResp, protocmp.Transform(), protocmp.IgnoreFields(&ateapipb.Actor{}, "version", "last_snapshot", "ateom_pod_uid")); diff != "" {
if diff := cmp.Diff(want, getResp,
protocmp.Transform(),
protocmp.IgnoreFields(&ateapipb.Actor{}, "version"),
protocmp.IgnoreFields(&ateapipb.Actor{}, "ateom_pod_uid"),
protocmp.FilterField(&ateapipb.ExternalSnapshotInfo{}, "snapshot_uri_prefix", cmp.Comparer(func(x, y string) bool {
return strings.HasPrefix(y, x)
})),
); diff != "" {
t.Errorf("GetActor response mismatch (-want +got):\n%s", diff)
}
}

// TestPauseActor tests the full workflow of pausing a running actor.
// Workflow:
// 1. Creates a mock ActorTemplate.
// 2. Creates a mock Atelet Pod on 'node1'.
// 3. Creates a mock worker Pod on 'node1'.
// 4. Waits for the WorkerPoolSyncer to mirror the worker to Redis.
// 5. Creates an actor.
// 6. Calls ResumeActor to transition it to RUNNING.
// 7. Calls PauseActor RPC.
// 8. Verifies that the fake Atelet received the Pause call.
func TestPauseActor(t *testing.T) {
ns := namespaceForTest("ns-pause")
tc := setupTest(t, ns)
defer tc.cleanup()

createTemplate(t, tc, ns)

createWorkerPod(t, tc, ns, "worker-1", "node1")

_, err := tc.client.CreateActor(context.Background(), &ateapipb.CreateActorRequest{
ActorTemplateNamespace: ns,
ActorTemplateName: "tmpl1",
ActorId: "id1",
})
if err != nil {
t.Fatalf("CreateActor failed: %v", err)
}
id := "id1"

// Resume first to make it running
_, err = tc.client.ResumeActor(context.Background(), &ateapipb.ResumeActorRequest{
ActorId: id,
})
if err != nil {
t.Fatalf("ResumeActor failed: %v", err)
}

// Pause
_, err = tc.client.PauseActor(context.Background(), &ateapipb.PauseActorRequest{
ActorId: id,
})
if err != nil {
t.Fatalf("PauseActor failed: %v", err)
}

if !tc.fakeAtelet.CheckpointCalled {
t.Errorf("expected atelet Checkpoint to be called")
}

getResp, err := tc.client.GetActor(context.Background(), &ateapipb.GetActorRequest{
ActorId: id,
})
if err != nil {
t.Fatalf("GetActor failed: %v", err)
}
want := &ateapipb.GetActorResponse{
Actor: &ateapipb.Actor{
ActorId: id,
ActorTemplateNamespace: ns,
ActorTemplateName: "tmpl1",
Status: ateapipb.Actor_STATUS_PAUSED,
LatestSnapshotInfo: &ateapipb.SnapshotInfo{
Type: ateapipb.SnapshotType_SNAPSHOT_TYPE_LOCAL,
Data: &ateapipb.SnapshotInfo_Local{
Local: &ateapipb.LocalSnapshotInfo{
SnapshotPrefix: "id1",
NodeVmsWithLocalSnapshots: []string{"node1"},
},
},
},
},
}

if diff := cmp.Diff(want, getResp,
protocmp.Transform(),
protocmp.IgnoreFields(&ateapipb.Actor{}, "version"),
protocmp.IgnoreFields(&ateapipb.Actor{}, "ateom_pod_uid"),
protocmp.FilterField(&ateapipb.LocalSnapshotInfo{}, "snapshot_prefix", cmp.Comparer(func(x, y string) bool {
return strings.HasPrefix(y, x)
})),
); diff != "" {
t.Errorf("GetActor response mismatch (-want +got):\n%s", diff)
}
}

// TestValidation tests the negative validation cases for all gRPC methods.
Expand Down
51 changes: 51 additions & 0 deletions cmd/ateapi/internal/controlapi/pause_actor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package controlapi

import (
"context"
"errors"

"github.com/agent-substrate/substrate/cmd/ateapi/internal/store"
"github.com/agent-substrate/substrate/pkg/proto/ateapipb"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

func (s *Service) PauseActor(ctx context.Context, req *ateapipb.PauseActorRequest) (*ateapipb.PauseActorResponse, error) {
if err := validatePauseActorRequest(req); err != nil {
return nil, err
}

actor, err := s.actorWorkflow.PauseActor(ctx, req.GetActorId())
if err != nil {
if errors.Is(err, store.ErrPersistenceRetry) {
return nil, status.Error(codes.Aborted, "concurrent update conflict, please retry")
}
if errors.Is(err, store.ErrNotFound) {
return nil, status.Errorf(codes.NotFound, "Actor %s not found", req.GetActorId())
}
return nil, err
}

return &ateapipb.PauseActorResponse{Actor: actor}, nil
}

func validatePauseActorRequest(req *ateapipb.PauseActorRequest) error {
if req.GetActorId() == "" {
return status.Error(codes.InvalidArgument, "id is required")
}
return nil
}
1 change: 1 addition & 0 deletions cmd/ateapi/internal/controlapi/syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ func (s *WorkerPoolSyncer) syncWorkerToStore(ctx context.Context, pod *corev1.Po
WorkerPod: pod.Name,
Ip: pod.Status.PodIP,
WorkerPodUid: string(pod.UID),
NodeName: pod.Spec.NodeName,
})
if err != nil && !errors.Is(err, store.ErrAlreadyExists) {
slog.ErrorContext(ctx, "Failed to create worker in store", slog.Any("err", err))
Expand Down
14 changes: 11 additions & 3 deletions cmd/ateapi/internal/controlapi/syncer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,15 @@ func TestSyncer_DeleteBoundWorker_ClearsActor(t *testing.T) {
ActorId: actorID, ActorTemplateNamespace: ns, ActorTemplateName: "tmpl",
Status: ateapipb.Actor_STATUS_RUNNING,
AteomPodNamespace: ns, AteomPodName: pod, AteomPodIp: ip,
LastSnapshot: "gs://snapshots/last", InProgressSnapshot: "gs://snapshots/partial",
InProgressSnapshot: "gs://snapshots/partial",
LatestSnapshotInfo: &ateapipb.SnapshotInfo{
Type: ateapipb.SnapshotType_SNAPSHOT_TYPE_EXTERNAL,
Data: &ateapipb.SnapshotInfo_External{
External: &ateapipb.ExternalSnapshotInfo{
SnapshotUriPrefix: "gs://snapshots/last",
},
},
},
}); err != nil {
t.Fatalf("create actor: %v", err)
}
Expand All @@ -210,7 +218,7 @@ func TestSyncer_DeleteBoundWorker_ClearsActor(t *testing.T) {
if got.AteomPodName != "" || got.AteomPodNamespace != "" || got.AteomPodIp != "" || got.InProgressSnapshot != "" {
t.Errorf("bind fields not cleared: %+v", got)
}
if got.LastSnapshot == "" {
t.Errorf("LastSnapshot must be preserved")
if got.GetLatestSnapshotInfo().GetExternal().SnapshotUriPrefix == "" {
t.Errorf("External SnapshotUriPrefix must be preserved")
}
}
31 changes: 30 additions & 1 deletion cmd/ateapi/internal/controlapi/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func (w *ActorWorkflow) SuspendActor(ctx context.Context, id string) (*ateapipb.
state := &SuspendState{}

// Acquire lock and get the timeout context for the workflow
// Lock TTL is 7 seconds, with 2 seconds padding for workflow timeout
// Lock TTL is 30 seconds, with 2 seconds padding for workflow timeout
ctx, releaseLock, err := w.acquireActorLock(ctx, id, 30*time.Second, 2*time.Second)
if err != nil {
return nil, err
Expand All @@ -191,6 +191,35 @@ func (w *ActorWorkflow) SuspendActor(ctx context.Context, id string) (*ateapipb.
return state.Actor, nil
}

// PauseActor executes the workflow to pause a running actor. Idempotent.
func (w *ActorWorkflow) PauseActor(ctx context.Context, id string) (*ateapipb.Actor, error) {
input := &PauseInput{
ActorID: id,
}
state := &PauseState{}

// Acquire lock and get the timeout context for the workflow
// Lock TTL is 30 seconds, with 2 seconds padding for workflow timeout
ctx, releaseLock, err := w.acquireActorLock(ctx, id, 30*time.Second, 2*time.Second)
if err != nil {
return nil, err
}
defer releaseLock()

steps := []WorkflowStep[*PauseInput, *PauseState]{
&LoadActorForPauseStep{store: w.store, actorTemplateLister: w.actorTemplateLister},
&MarkPausingStep{store: w.store},
&CallAteletPauseStep{dialer: w.dialer},
&FinalizePausedStep{store: w.store},
}

if err := RunWorkflow(ctx, input, state, steps); err != nil {
return nil, err
}

return state.Actor, nil
}

func (w *ActorWorkflow) acquireActorLock(ctx context.Context, id string, ttl time.Duration, padding time.Duration) (context.Context, func(), error) {
lockKey := "lock:actor:" + id
lockValue := uuid.New().String()
Expand Down
Loading
Loading