diff --git a/internal/api/generated.go b/internal/api/generated.go index c98a9da..a8bb65d 100644 --- a/internal/api/generated.go +++ b/internal/api/generated.go @@ -36,12 +36,13 @@ func (e CleanupScopeRequestScope) Valid() bool { // Defines values for FlowStatus. const ( - FlowStatusCancelled FlowStatus = "cancelled" - FlowStatusDone FlowStatus = "done" - FlowStatusFailed FlowStatus = "failed" - FlowStatusPending FlowStatus = "pending" - FlowStatusRunning FlowStatus = "running" - FlowStatusWaiting FlowStatus = "waiting" + FlowStatusCancelled FlowStatus = "cancelled" + FlowStatusDone FlowStatus = "done" + FlowStatusFailed FlowStatus = "failed" + FlowStatusPending FlowStatus = "pending" + FlowStatusRescheduled FlowStatus = "rescheduled" + FlowStatusRunning FlowStatus = "running" + FlowStatusWaiting FlowStatus = "waiting" ) // Valid indicates whether the value is a known member of the FlowStatus enum. @@ -55,6 +56,8 @@ func (e FlowStatus) Valid() bool { return true case FlowStatusPending: return true + case FlowStatusRescheduled: + return true case FlowStatusRunning: return true case FlowStatusWaiting: diff --git a/internal/application/tickets/orchestrator.go b/internal/application/tickets/orchestrator.go index 5db1a4b..6cac50f 100644 --- a/internal/application/tickets/orchestrator.go +++ b/internal/application/tickets/orchestrator.go @@ -61,35 +61,35 @@ func NewWithStore(cfg config.Config, repoRoot string, store stateStore, provider // StartFlow begins or re-runs the workflow for a ticket. Creates a worktree on // first call; re-runs the current state if the ticket is already waiting or failed. -func (o *Orchestrator) StartFlow(ctx context.Context, ticketNumber string) error { +func (o *Orchestrator) StartFlow(ctx context.Context, ticketNumber string) (RunOutcome, error) { wflow, err := workflow.Load(o.RepoRoot) if err != nil { - return fmt.Errorf("load workflow: %w", err) + return RunOutcome{}, fmt.Errorf("load workflow: %w", err) } state, loadErr := o.Store.LoadState(ticketNumber) if errors.Is(loadErr, os.ErrNotExist) { state = workflowstate.New(ticketNumber) } else if loadErr != nil { - return fmt.Errorf("load ticket state: %w", loadErr) + return RunOutcome{}, fmt.Errorf("load ticket state: %w", loadErr) } if state.FlowStatus == workflowstate.FlowStatusDone || state.FlowStatus == workflowstate.FlowStatusCancelled { slog.Info("skipping ticket", "ticket", ticketNumber, "status", state.FlowStatus) - return nil + return RunOutcome{}, nil } if state.FlowStatus == workflowstate.FlowStatusRunning { - return fmt.Errorf("ticket %s: %w", ticketNumber, ErrTicketRunning) + return RunOutcome{}, fmt.Errorf("ticket %s: %w", ticketNumber, ErrTicketRunning) } err = o.ensureWorktreeAndContext(ctx, &state) if err != nil { - return err + return RunOutcome{}, err } // Determine which state to run. stateCfg, err := resolveStateForStart(state, wflow) if err != nil { - return err + return RunOutcome{}, err } slog.Info("starting flow", "ticket", ticketNumber, "state", stateCfg.Name) @@ -98,23 +98,23 @@ func (o *Orchestrator) StartFlow(ctx context.Context, ticketNumber string) error } // ApplyAction applies the named action to a ticket that is waiting for input. -func (o *Orchestrator) ApplyAction(ctx context.Context, ticketNumber, actionLabel, message string) error { +func (o *Orchestrator) ApplyAction(ctx context.Context, ticketNumber, actionLabel, message string) (RunOutcome, error) { wflow, err := workflow.Load(o.RepoRoot) if err != nil { - return fmt.Errorf("load workflow: %w", err) + return RunOutcome{}, fmt.Errorf("load workflow: %w", err) } state, err := o.Store.LoadState(ticketNumber) if err != nil { - return fmt.Errorf("load ticket state: %w", err) + return RunOutcome{}, fmt.Errorf("load ticket state: %w", err) } if state.FlowStatus != workflowstate.FlowStatusWaiting { - return fmt.Errorf("ticket %s (status: %s): %w", ticketNumber, state.FlowStatus, ErrTicketNotWaiting) + return RunOutcome{}, fmt.Errorf("ticket %s (status: %s): %w", ticketNumber, state.FlowStatus, ErrTicketNotWaiting) } stateCfg, ok := wflow.StateByName(state.CurrentState) if !ok { - return fmt.Errorf("state %q: %w", state.CurrentState, ErrStateNotFound) + return RunOutcome{}, fmt.Errorf("state %q: %w", state.CurrentState, ErrStateNotFound) } var action *workflow.ActionConfig @@ -131,7 +131,7 @@ func (o *Orchestrator) ApplyAction(ctx context.Context, ticketNumber, actionLabe labels[i] = a.Label } - return fmt.Errorf("action %q in state %q (available: %s): %w", actionLabel, state.CurrentState, strings.Join(labels, ", "), ErrActionNotFound) + return RunOutcome{}, fmt.Errorf("action %q in state %q (available: %s): %w", actionLabel, state.CurrentState, strings.Join(labels, ", "), ErrActionNotFound) } slog.Info("applying action", "ticket", ticketNumber, "action", actionLabel, "state", state.CurrentState) @@ -140,27 +140,27 @@ func (o *Orchestrator) ApplyAction(ctx context.Context, ticketNumber, actionLabe } // MoveToState force-transitions the ticket to target, creating a worktree if needed. -func (o *Orchestrator) MoveToState(ctx context.Context, ticketNumber, target string) error { +func (o *Orchestrator) MoveToState(ctx context.Context, ticketNumber, target string) (RunOutcome, error) { wflow, err := workflow.Load(o.RepoRoot) if err != nil { - return fmt.Errorf("load workflow: %w", err) + return RunOutcome{}, fmt.Errorf("load workflow: %w", err) } if strings.TrimSpace(target) == "" { - return ErrTargetStateRequired + return RunOutcome{}, ErrTargetStateRequired } state, loadErr := o.Store.LoadState(ticketNumber) if errors.Is(loadErr, os.ErrNotExist) { state = workflowstate.New(ticketNumber) } else if loadErr != nil { - return fmt.Errorf("load ticket state: %w", loadErr) + return RunOutcome{}, fmt.Errorf("load ticket state: %w", loadErr) } if state.FlowStatus == workflowstate.FlowStatusRunning { - return fmt.Errorf("ticket %s: %w", ticketNumber, ErrTicketRunning) + return RunOutcome{}, fmt.Errorf("ticket %s: %w", ticketNumber, ErrTicketRunning) } err = o.ensureWorktreeAndContext(ctx, &state) if err != nil { - return err + return RunOutcome{}, err } slog.Info("force moving to state", "ticket", ticketNumber, "target", target) @@ -316,6 +316,41 @@ func (o *Orchestrator) DiscoverTickets(ctx context.Context) ([]DiscoveredTicket, return tickets, nil } +// ProbeProvider checks whether the AI provider is reachable and the usage quota has reset. +func (o *Orchestrator) ProbeProvider(ctx context.Context) error { + // create temp work dir (WorkDir must be a real existing path) + workDir, err := os.MkdirTemp("", "autopr-probe-work-*") + if err != nil { + return fmt.Errorf("probe: create work dir: %w", err) + } + defer func() { _ = os.RemoveAll(workDir) }() + // create temp runtime dir (RuntimeDir must be a real existing path) + runtimeDir, err := os.MkdirTemp("", "autopr-probe-runtime-*") + if err != nil { + return fmt.Errorf("probe: create runtime dir: %w", err) + } + defer func() { _ = os.RemoveAll(runtimeDir) }() + + // write minimal prompt file (PromptPath must exist on disk — CLIProvider reads it) + promptPath := filepath.Join(workDir, "probe.md") + err = os.WriteFile(promptPath, []byte("ping"), 0o644) + if err != nil { + return fmt.Errorf("probe: write prompt: %w", err) + } + // call provider — discard output, only care about ErrTokensExhausted + _, err = o.Provider.Execute(ctx, providers.ExecuteRequest{ + PromptPath: promptPath, + WorkDir: workDir, + RuntimeDir: runtimeDir, + SessionData: "", // fresh call, no session + }) + if errors.Is(err, providers.ErrTokensExhausted) { + return providers.ErrTokensExhausted // quota still hit + } + + return nil // any other result = quota not the issue +} + func (o *Orchestrator) ensureWorktreeAndContext(ctx context.Context, state *workflowstate.State) error { if state.WorktreePath == "" { branchName := "auto-pr/" + state.TicketNumber @@ -362,11 +397,11 @@ func (o *Orchestrator) ensureWorktreeAndContext(ctx context.Context, state *work // --- internal helpers --- -func (o *Orchestrator) runState(ctx context.Context, state *workflowstate.State, stateCfg workflow.StateConfig) error { +func (o *Orchestrator) runState(ctx context.Context, state *workflowstate.State, stateCfg workflow.StateConfig) (RunOutcome, error) { slog.Info("running state", "ticket", state.TicketNumber, "state", stateCfg.Name) run, err := startStateRun(state, stateCfg) if err != nil { - return err + return RunOutcome{}, err } logPath := state.ResolveRef(run.LogRef) @@ -375,33 +410,33 @@ func (o *Orchestrator) runState(ctx context.Context, state *workflowstate.State, state.LastError = "" err = o.Store.SaveState(state.TicketNumber, *state) if err != nil { - return fmt.Errorf("save ticket state: %w", err) + return RunOutcome{}, fmt.Errorf("save ticket state: %w", err) } err = o.prepareRunContext(*state, stateCfg, run) if err != nil { - return o.failState(state, err) + return RunOutcome{}, o.failState(state, err) } err = o.runCommands(ctx, state.WorktreePath, stateCfg.PrePromptCommands, logPath, "Pre-prompt") if err != nil { - return o.failState(state, err) + return RunOutcome{}, o.failState(state, err) } promptContent, err := workflow.ReadPrompt(o.RepoRoot, stateCfg.Prompt) if err != nil { - return o.failState(state, fmt.Errorf("read prompt %s: %w", stateCfg.Prompt, err)) + return RunOutcome{}, o.failState(state, fmt.Errorf("read prompt %s: %w", stateCfg.Prompt, err)) } promptPath := state.RunPath(run.ID, "prompt.md") err = os.WriteFile(promptPath, promptContent, 0o644) if err != nil { - return o.failState(state, err) + return RunOutcome{}, o.failState(state, err) } runtimeDir := state.RunPath(run.ID, "provider") err = os.MkdirAll(runtimeDir, 0o755) if err != nil { - return o.failState(state, err) + return RunOutcome{}, o.failState(state, err) } slog.Info("executing provider", "ticket", state.TicketNumber, "state", stateCfg.Name) @@ -413,14 +448,21 @@ func (o *Orchestrator) runState(ctx context.Context, state *workflowstate.State, }) rawLogPath := state.RunPath(run.ID, "raw-provider.log") _ = os.WriteFile(rawLogPath, []byte(result.RawOutput+"\n\n[stderr]\n"+result.Stderr), 0o644) + outcome := RunOutcome{ + Provider: o.Provider.Name(), + QuotaReached: result.QuotaReached, + } + if result.QuotaReached { + _ = markdown.AppendSection(logPath, stateCfg.Name+" Reschedule", err.Error()) + + return outcome, o.rescheduledState(state, err) + } if err != nil { - if errors.Is(err, providers.ErrTokensExhausted) { - err = fmt.Errorf("token usage limit reached — wait for your quota to reset, then rerun this ticket to continue: %w", err) - } _ = markdown.AppendSection(logPath, stateCfg.Name+" Failed", err.Error()) - return o.failState(state, err) + return outcome, o.failState(state, err) } + if result.SessionData != "" { state.ProviderSessionData = result.SessionData } @@ -429,7 +471,7 @@ func (o *Orchestrator) runState(ctx context.Context, state *workflowstate.State, err = o.runCommands(ctx, state.WorktreePath, stateCfg.PostPromptCommands, logPath, "Post-prompt") if err != nil { - return o.failState(state, err) + return RunOutcome{}, o.failState(state, err) } if run.ArtifactRef != "" { @@ -449,10 +491,10 @@ func (o *Orchestrator) runState(ctx context.Context, state *workflowstate.State, state.FlowStatus = workflowstate.FlowStatusWaiting saveErr := o.Store.SaveState(state.TicketNumber, *state) if saveErr != nil { - return fmt.Errorf("save ticket state: %w", saveErr) + return RunOutcome{}, fmt.Errorf("save ticket state: %w", saveErr) } - return nil + return RunOutcome{}, nil } func (o *Orchestrator) failState(st *workflowstate.State, cause error) error { @@ -464,7 +506,16 @@ func (o *Orchestrator) failState(st *workflowstate.State, cause error) error { return cause } -func (o *Orchestrator) dispatchAction(ctx context.Context, state *workflowstate.State, wflow workflow.Config, action workflow.ActionConfig, message string) error { +func (o *Orchestrator) rescheduledState(st *workflowstate.State, cause error) error { + slog.Warn("token usage limit reached", "ticket", st.TicketNumber, "state", st.CurrentState, "err", cause) + st.FlowStatus = workflowstate.FlowStatusRescheduled + st.LastError = cause.Error() + _ = o.Store.SaveState(st.TicketNumber, *st) + + return cause +} + +func (o *Orchestrator) dispatchAction(ctx context.Context, state *workflowstate.State, wflow workflow.Config, action workflow.ActionConfig, message string) (RunOutcome, error) { logPath := state.CurrentRunLogPath() _ = markdown.AppendSection(logPath, "Human Action: "+action.Label, "") @@ -476,11 +527,11 @@ func (o *Orchestrator) dispatchAction(ctx context.Context, state *workflowstate. case workflow.ActionRunScript: return o.executeScript(ctx, state, wflow, action) default: - return fmt.Errorf("action type %q: %w", action.Type, ErrUnknownActionType) + return RunOutcome{}, fmt.Errorf("action type %q: %w", action.Type, ErrUnknownActionType) } } -func (o *Orchestrator) transitionTo(ctx context.Context, state *workflowstate.State, wflow workflow.Config, target string) error { +func (o *Orchestrator) transitionTo(ctx context.Context, state *workflowstate.State, wflow workflow.Config, target string) (RunOutcome, error) { if workflow.IsTerminal(target) { slog.Info("reached terminal state", "ticket", state.TicketNumber, "state", target) switch target { @@ -494,43 +545,43 @@ func (o *Orchestrator) transitionTo(ctx context.Context, state *workflowstate.St saveErr := o.Store.SaveState(state.TicketNumber, *state) if saveErr != nil { - return fmt.Errorf("save ticket state: %w", saveErr) + return RunOutcome{}, fmt.Errorf("save ticket state: %w", saveErr) } - return nil + return RunOutcome{}, nil } slog.Info("transitioning to state", "ticket", state.TicketNumber, "target", target) stateCfg, ok := wflow.StateByName(target) if !ok { - return fmt.Errorf("state %q: %w", target, ErrTargetNotFound) + return RunOutcome{}, fmt.Errorf("state %q: %w", target, ErrTargetNotFound) } return o.runState(ctx, state, stateCfg) } -func (o *Orchestrator) writeFeedbackAndRerun(ctx context.Context, state *workflowstate.State, wflow workflow.Config, message string) error { +func (o *Orchestrator) writeFeedbackAndRerun(ctx context.Context, state *workflowstate.State, wflow workflow.Config, message string) (RunOutcome, error) { if strings.TrimSpace(message) == "" { - return ErrFeedbackRequired + return RunOutcome{}, ErrFeedbackRequired } slog.Info("applying feedback", "ticket", state.TicketNumber, "state", state.CurrentState) if state.CurrentRunID == "" { - return ErrNoCurrentRunID + return RunOutcome{}, ErrNoCurrentRunID } content := []byte(strings.TrimSpace(message)) runFeedbackPath := state.RunPath(state.CurrentRunID, "feedback.md") writeErr := os.WriteFile(runFeedbackPath, content, 0o644) if writeErr != nil { - return fmt.Errorf("write feedback file: %w", writeErr) + return RunOutcome{}, fmt.Errorf("write feedback file: %w", writeErr) } stateCfg, ok := wflow.StateByName(state.CurrentState) if !ok { - return fmt.Errorf("state %q: %w", state.CurrentState, ErrStateNotFound) + return RunOutcome{}, fmt.Errorf("state %q: %w", state.CurrentState, ErrStateNotFound) } return o.runState(ctx, state, stateCfg) } -func (o *Orchestrator) executeScript(ctx context.Context, state *workflowstate.State, wflow workflow.Config, action workflow.ActionConfig) error { +func (o *Orchestrator) executeScript(ctx context.Context, state *workflowstate.State, wflow workflow.Config, action workflow.ActionConfig) (RunOutcome, error) { logPath := state.CurrentRunLogPath() var out strings.Builder @@ -552,44 +603,48 @@ func (o *Orchestrator) executeScript(ctx context.Context, state *workflowstate.S captured := strings.TrimSpace(out.String()) + var outcome RunOutcome if scriptErr == nil && action.OnSuccess != nil { - err := o.dispatchSubAction(ctx, state, wflow, *action.OnSuccess, captured) + sub, err := o.dispatchSubAction(ctx, state, wflow, *action.OnSuccess, captured) if err != nil { - return err + return RunOutcome{}, err } + outcome = sub } else if scriptErr != nil && action.OnFailure != nil { - err := o.dispatchSubAction(ctx, state, wflow, *action.OnFailure, captured) + sub, err := o.dispatchSubAction(ctx, state, wflow, *action.OnFailure, captured) if err != nil { - return err + return RunOutcome{}, err } + outcome = sub } if action.Always != nil { - err := o.dispatchSubAction(ctx, state, wflow, *action.Always, captured) + sub, err := o.dispatchSubAction(ctx, state, wflow, *action.Always, captured) if err != nil { - return err + return RunOutcome{}, err } + outcome = sub } - return nil + return outcome, nil } func (o *Orchestrator) dispatchSubAction( ctx context.Context, state *workflowstate.State, wflow workflow.Config, action workflow.ActionConfig, message string, -) error { +) (RunOutcome, error) { switch action.Type { case workflow.ActionProvideFeedback: if strings.TrimSpace(message) == "" { - return nil // no script output to feed back + return RunOutcome{}, nil // no script output to feed back } return o.writeFeedbackAndRerun(ctx, state, wflow, message) case workflow.ActionMoveToState: return o.transitionTo(ctx, state, wflow, action.Target) case workflow.ActionRunScript: - return ErrScriptSubAction + return RunOutcome{}, ErrScriptSubAction default: - return fmt.Errorf("action type %q: %w", action.Type, ErrUnsupportedSubAction) + return RunOutcome{}, fmt.Errorf("action type %q: %w", action.Type, ErrUnsupportedSubAction) } } @@ -661,6 +716,8 @@ func buildNextSteps(state workflowstate.State, wflow workflow.Config) string { return fmt.Sprintf("Ticket failed: %s\n\nRetry: auto-pr run %s", state.LastError, state.TicketNumber) case workflowstate.FlowStatusCancelled: return "Ticket was cancelled." + case workflowstate.FlowStatusRescheduled: + return "Ticket was rescheduled." } return "" diff --git a/internal/application/tickets/orchestrator_test.go b/internal/application/tickets/orchestrator_test.go index 7fd2270..abb4b05 100644 --- a/internal/application/tickets/orchestrator_test.go +++ b/internal/application/tickets/orchestrator_test.go @@ -188,7 +188,7 @@ func TestStartFlow_newTicket_endsWaiting(t *testing.T) { prov := &mockProvider{result: providers.ExecuteResult{RawOutput: "analysis done"}} prepareWorktree(t, store, "42") - err := newOrchestrator(root, store, prov).StartFlow(context.Background(), "42") + _, err := newOrchestrator(root, store, prov).StartFlow(context.Background(), "42") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -212,7 +212,7 @@ func TestStartFlow_doneTicket_isNoop(t *testing.T) { st.FlowStatus = workflowstate.FlowStatusDone _ = store.SaveState("10", st) - err := newOrchestrator(root, store, prov).StartFlow(context.Background(), "10") + _, err := newOrchestrator(root, store, prov).StartFlow(context.Background(), "10") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -232,7 +232,7 @@ func TestStartFlow_runningTicket_returnsError(t *testing.T) { st.FlowStatus = workflowstate.FlowStatusRunning _ = store.SaveState("7", st) - err := newOrchestrator(root, store, &mockProvider{}).StartFlow(context.Background(), "7") + _, err := newOrchestrator(root, store, &mockProvider{}).StartFlow(context.Background(), "7") if !errors.Is(err, tickets.ErrTicketRunning) { t.Errorf("expected ErrTicketRunning, got %v", err) } @@ -246,7 +246,7 @@ func TestStartFlow_providerError_setsFailedStatus(t *testing.T) { prov := &mockProvider{err: provErr} prepareWorktree(t, store, "5") - err := newOrchestrator(root, store, prov).StartFlow(context.Background(), "5") + _, err := newOrchestrator(root, store, prov).StartFlow(context.Background(), "5") if err == nil { t.Fatal("expected error from provider") } @@ -278,7 +278,7 @@ func TestStartFlow_createsWorktreeFromTicketBaseBranch(t *testing.T) { t.Fatal(err) } - err = orch.StartFlow(context.Background(), "42") + _, err = orch.StartFlow(context.Background(), "42") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -317,7 +317,7 @@ func TestStartFlow_writesBaseBranchIntoContextFiles(t *testing.T) { prov, ) - err = orch.StartFlow(context.Background(), "77") + _, err = orch.StartFlow(context.Background(), "77") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -351,7 +351,7 @@ func TestStartFlow_newTicketPersistsStateInWorktreeOnly(t *testing.T) { prov, ) - err := orch.StartFlow(context.Background(), "42") + _, err := orch.StartFlow(context.Background(), "42") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -440,7 +440,7 @@ func TestApplyAction_notWaiting_returnsError(t *testing.T) { st.FlowStatus = workflowstate.FlowStatusRunning _ = store.SaveState("3", st) - err := newOrchestrator(root, store, &mockProvider{}).ApplyAction(context.Background(), "3", "Approve", "") + _, err := newOrchestrator(root, store, &mockProvider{}).ApplyAction(context.Background(), "3", "Approve", "") if !errors.Is(err, tickets.ErrTicketNotWaiting) { t.Errorf("expected ErrTicketNotWaiting, got %v", err) } @@ -461,7 +461,7 @@ func TestApplyAction_unknownLabel_returnsError(t *testing.T) { st.FlowStatus = workflowstate.FlowStatusWaiting _ = store.SaveState("8", st) - err := newOrchestrator(root, store, &mockProvider{}).ApplyAction(context.Background(), "8", "NoSuchAction", "") + _, err := newOrchestrator(root, store, &mockProvider{}).ApplyAction(context.Background(), "8", "NoSuchAction", "") if !errors.Is(err, tickets.ErrActionNotFound) { t.Errorf("expected ErrActionNotFound, got %v", err) } @@ -482,7 +482,7 @@ func TestApplyAction_moveToStateDone_setsDone(t *testing.T) { st.FlowStatus = workflowstate.FlowStatusWaiting _ = store.SaveState("99", st) - err := newOrchestrator(root, store, &mockProvider{}).ApplyAction(context.Background(), "99", "Approve", "") + _, err := newOrchestrator(root, store, &mockProvider{}).ApplyAction(context.Background(), "99", "Approve", "") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -507,7 +507,7 @@ func TestApplyAction_provideFeedback_emptyMessage_returnsError(t *testing.T) { st.FlowStatus = workflowstate.FlowStatusWaiting _ = store.SaveState("11", st) - err := newOrchestrator(root, store, &mockProvider{}).ApplyAction(context.Background(), "11", "Feedback", "") + _, err := newOrchestrator(root, store, &mockProvider{}).ApplyAction(context.Background(), "11", "Feedback", "") if !errors.Is(err, tickets.ErrFeedbackRequired) { t.Errorf("expected ErrFeedbackRequired, got %v", err) } @@ -535,7 +535,7 @@ func TestApplyAction_provideFeedback_reruns(t *testing.T) { _ = store.SaveState("12", st) prov := &mockProvider{result: providers.ExecuteResult{RawOutput: "re-investigated"}} - err := newOrchestrator(root, store, prov).ApplyAction(context.Background(), "12", "Feedback", "please dig deeper") + _, err := newOrchestrator(root, store, prov).ApplyAction(context.Background(), "12", "Feedback", "please dig deeper") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -640,7 +640,7 @@ func TestApplyAction_moveToNextState_runsNextState(t *testing.T) { _ = store.SaveState("50", st) prov := &mockProvider{result: providers.ExecuteResult{RawOutput: "implemented"}} - err := newOrchestrator(root, store, prov).ApplyAction(context.Background(), "50", "Continue", "") + _, err := newOrchestrator(root, store, prov).ApplyAction(context.Background(), "50", "Continue", "") if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/internal/application/tickets/outcome.go b/internal/application/tickets/outcome.go new file mode 100644 index 0000000..44bc4a5 --- /dev/null +++ b/internal/application/tickets/outcome.go @@ -0,0 +1,9 @@ +package tickets + +// RunOutcome carries the decision-relevant result of running a workflow state. +// It is returned alongside (not inside) the error so callers can inspect it even +// when the error is non-nil. The zero value means "no signal to report". +type RunOutcome struct { + Provider string + QuotaReached bool +} diff --git a/internal/domain/workflowstate/types.go b/internal/domain/workflowstate/types.go index 47b4f28..a6a4a28 100644 --- a/internal/domain/workflowstate/types.go +++ b/internal/domain/workflowstate/types.go @@ -18,6 +18,7 @@ const ( FlowStatusDone FlowStatus = "done" FlowStatusFailed FlowStatus = "failed" FlowStatusCancelled FlowStatus = "cancelled" + FlowStatusRescheduled FlowStatus = "rescheduled" ) // StateRun records a single execution of a workflow state for a ticket. diff --git a/internal/providers/cli_provider.go b/internal/providers/cli_provider.go index 2db2340..7ab9c97 100644 --- a/internal/providers/cli_provider.go +++ b/internal/providers/cli_provider.go @@ -3,6 +3,7 @@ package providers import ( "context" + "errors" "fmt" "os" "path/filepath" @@ -58,6 +59,10 @@ func (p *CLIProvider) Execute(ctx context.Context, req ExecuteRequest) (ExecuteR SessionData: sessionData, } if err != nil { + if errors.Is(err, ErrTokensExhausted) { + result.QuotaReached = true + } + return result, err } diff --git a/internal/providers/provider.go b/internal/providers/provider.go index 7bfef8e..67a81b0 100644 --- a/internal/providers/provider.go +++ b/internal/providers/provider.go @@ -10,7 +10,8 @@ type ExecuteRequest struct { // ExecuteResult holds the output produced by an AI provider after executing a prompt. type ExecuteResult struct { - RawOutput string // text produced by the AI (extracted from structured output when applicable) - Stderr string - SessionData string // opaque session token to persist for the next run; empty if unsupported + RawOutput string // text produced by the AI (extracted from structured output when applicable) + Stderr string + SessionData string // opaque session token to persist for the next run; empty if unsupported + QuotaReached bool } diff --git a/internal/server/jobs.go b/internal/server/jobs.go index 836202e..ed99dd0 100644 --- a/internal/server/jobs.go +++ b/internal/server/jobs.go @@ -8,14 +8,28 @@ import ( "sync" "github.com/Neokil/AutoPR/internal/api" + "github.com/Neokil/AutoPR/internal/application/tickets" workflowstate "github.com/Neokil/AutoPR/internal/domain/workflowstate" "github.com/Neokil/AutoPR/internal/serverstate" ) func (s *server) workerLoop() { for job := range s.jobs { + s.waitIfQuotaReached() s.setJobStatus(job.record, serverstate.JobStatusRunning, "") - err := s.executeJob(job) + outcome, err := s.executeJob(job) + if outcome.QuotaReached { + slog.Warn("LLM quota reached during job execution. Marking quota as reached and pausing further jobs.") + s.setJobStatus(job.record, "queued", "") + + s.setQuotaReached(true) + err := s.reQueueJob(job) + if err != nil { + slog.Error("quota re-queue failed", "job", job.record.ID, "err", err) + } + + continue + } if err != nil { s.setJobStatus(job.record, serverstate.JobStatusFailed, err.Error()) @@ -43,7 +57,7 @@ func (s *server) setJobStatus(job serverstate.JobRecord, status, errMsg string) }) } -func (s *server) executeJob(job queuedJob) error { +func (s *server) executeJob(job queuedJob) (tickets.RunOutcome, error) { repoRoot, repoID := job.record.RepoPath, job.record.RepoID ticket := job.record.TicketNumber @@ -64,22 +78,23 @@ func (s *server) executeJob(job queuedJob) error { repoRt, err := s.runtimeForRepo(repoRoot) if err != nil { - return err + return tickets.RunOutcome{}, err } + var outcome tickets.RunOutcome switch job.record.Action { case jobRun: - err = repoRt.svc.StartFlow(context.Background(), ticket) + outcome, err = repoRt.svc.StartFlow(context.Background(), ticket) if err == nil { err = s.syncTicketFromRepo(repoID, repoRoot, ticket, repoRt, true) } case jobAction: - err = repoRt.svc.ApplyAction(context.Background(), ticket, job.actionLabel, job.message) + outcome, err = repoRt.svc.ApplyAction(context.Background(), ticket, job.actionLabel, job.message) if err == nil { err = s.syncTicketFromRepo(repoID, repoRoot, ticket, repoRt, true) } case jobMoveToState: - err = repoRt.svc.MoveToState(context.Background(), ticket, job.targetState) + outcome, err = repoRt.svc.MoveToState(context.Background(), ticket, job.targetState) if err == nil { err = s.syncTicketFromRepo(repoID, repoRoot, ticket, repoRt, true) } @@ -109,14 +124,14 @@ func (s *server) executeJob(job queuedJob) error { default: err = fmt.Errorf("%w: %s", errUnsupportedJobAction, job.record.Action) } - if err != nil && ticket != "" { + if err != nil && ticket != "" && !outcome.QuotaReached { persistErr := s.persistTicketFailure(repoID, repoRoot, ticket, repoRt, job, err) if persistErr != nil { - return fmt.Errorf("%w (also failed to persist ticket failure: %w)", err, persistErr) + return outcome, fmt.Errorf("%w (also failed to persist ticket failure: %w)", err, persistErr) } } - return err + return outcome, err } func (s *server) getRepoLock(repoID string) *sync.RWMutex { @@ -183,3 +198,29 @@ func (s *server) recoverStuckTickets() { } } } + +func (s *server) waitIfQuotaReached() { + s.quotaMu.RLock() + quotaReached := s.quotaReached + resetCh := s.quotaResetCh + s.quotaMu.RUnlock() + + if !quotaReached { + return + } + + <-resetCh + slog.Info("LLM quota reset detected. Resuming job execution.") +} + +func (s *server) reQueueJob(job queuedJob) error { + select { + case s.jobs <- job: + return nil + // here we could also listen for a shutdown signal if we had one, to avoid trying to re-queue when the server is shutting down. For now, we'll just return an error if the job queue is full. + // case <-context.Background().Done(): + // return fmt.Errorf("re-queue aborted: server shutting down") + default: + return errJobQueueFull + } +} diff --git a/internal/server/quota_monitor.go b/internal/server/quota_monitor.go new file mode 100644 index 0000000..93425c1 --- /dev/null +++ b/internal/server/quota_monitor.go @@ -0,0 +1,75 @@ +package server + +import ( + "context" + "errors" + "log/slog" + "time" + + "github.com/Neokil/AutoPR/internal/providers" +) + +const ( + quotaMonitorInterval = 20 * time.Minute +) + +func (s *server) quotaMonitorLoop() { + ticker := time.NewTicker(quotaMonitorInterval) + + defer ticker.Stop() + for range ticker.C { + s.checkQuotaStatus() + } +} + +func (s *server) isQuotaReached() bool { + s.quotaMu.RLock() + defer s.quotaMu.RUnlock() + + return s.quotaReached +} + +func (s *server) setQuotaReached(reached bool) { + s.quotaMu.Lock() + defer s.quotaMu.Unlock() + if reached && !s.quotaReached { + // Create a fresh channel that workers will block on + s.quotaResetCh = make(chan struct{}) + } + if !reached && s.quotaReached { + // Signal all waiting workers to wake up + close(s.quotaResetCh) + } + s.quotaReached = reached +} + +func (s *server) checkQuotaStatus() { + if !s.isQuotaReached() { + return + } + + slog.Info("quota monitor: probing provider to check if quota has reset") + + repos := s.meta.ListRepos() + if len(repos) == 0 { + slog.Warn("quota monitor: no repos available for probe, skipping") + + return + } + repoRt, err := s.runtimeForRepo(repos[0].Path) + if err != nil { + slog.Error("quota monitor: failed to get runtime for probe", "err", err) + + return + } + + probeErr := repoRt.svc.ProbeProvider(context.Background()) + if errors.Is(probeErr, providers.ErrTokensExhausted) { + slog.Info("quota monitor: quota still reached, will check again later") + + return + } + + s.setQuotaReached(false) + slog.Info("quota monitor: quota has reset, resuming operations") +} diff --git a/internal/server/server.go b/internal/server/server.go index be575ea..d0a2ade 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -70,6 +70,10 @@ type server struct { ticketLockMu sync.Mutex ticketLocks map[string]*sync.Mutex + + quotaReached bool + quotaMu sync.RWMutex + quotaResetCh chan struct{} } var sectionHeaderRE = regexp.MustCompile(`^## (.+) \(([^)]+)\)$`) @@ -96,20 +100,23 @@ func Run(portOverride int) error { } daemon := &server{ - cfg: cfg, - meta: meta, - runtimes: map[string]*repoRuntime{}, - jobs: make(chan queuedJob, jobQueueSize), - repoLocks: map[string]*sync.RWMutex{}, - ticketLocks: map[string]*sync.Mutex{}, - webFS: distFS, - subscribers: map[string]chan api.ServerEvent{}, + cfg: cfg, + meta: meta, + runtimes: map[string]*repoRuntime{}, + jobs: make(chan queuedJob, jobQueueSize), + repoLocks: map[string]*sync.RWMutex{}, + ticketLocks: map[string]*sync.Mutex{}, + webFS: distFS, + subscribers: map[string]chan api.ServerEvent{}, + quotaMu: sync.RWMutex{}, + quotaResetCh: make(chan struct{}), } daemon.recoverStuckTickets() for range cfg.ServerWorkers { go daemon.workerLoop() } go daemon.prMonitorLoop() + go daemon.quotaMonitorLoop() port := cfg.ServerPort if portOverride > 0 { diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index e1e10af..3a5c5ec 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -305,7 +305,7 @@ components: schemas: FlowStatus: type: string - enum: [pending, running, waiting, done, failed, cancelled] + enum: [pending, running, waiting, done, failed, cancelled, rescheduled] JobStatus: type: string enum: [queued, running, done, failed] diff --git a/web/src/generated/api.ts b/web/src/generated/api.ts index deaed66..5c83e91 100644 --- a/web/src/generated/api.ts +++ b/web/src/generated/api.ts @@ -233,7 +233,7 @@ export type webhooks = Record; export interface components { schemas: { /** @enum {string} */ - FlowStatus: "pending" | "running" | "waiting" | "done" | "failed" | "cancelled"; + FlowStatus: "pending" | "running" | "waiting" | "done" | "failed" | "cancelled" | "rescheduled"; /** @enum {string} */ JobStatus: "queued" | "running" | "done" | "failed"; HealthResponse: {