From 2a1980703d285a893d522819576569ac42154916 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Edstr=C3=B6m?= Date: Tue, 27 Jan 2026 19:37:36 +0100 Subject: [PATCH 1/8] Add WithFuncMinParallelism --- dao.go | 141 ++++++++++++++++++++++++++++++++++++------------ docket.go | 15 ++++-- option.go | 7 +++ scheduler.go | 27 ++++++---- want/counter.go | 73 +++++++++++++++++++++++++ worker.go | 2 +- 6 files changed, 215 insertions(+), 50 deletions(-) create mode 100644 want/counter.go diff --git a/dao.go b/dao.go index aef18e8..75d1deb 100644 --- a/dao.go +++ b/dao.go @@ -4,12 +4,14 @@ import ( "database/sql" "errors" "fmt" - "github.com/guregu/null/v6" - "github.com/lib/pq" "log/slog" "math" "strings" "time" + + "github.com/guregu/null/v6" + "github.com/lib/pq" + "github.com/modfin/pqdocket/want" ) const taskTableName = "pqdocket_task" @@ -36,6 +38,10 @@ func (d *docket) initTables() error { q2 := ` CREATE INDEX CONCURRENTLY IF NOT EXISTS pqdocket_task_scheduled_at_idx ON pqdocket_task (scheduled_at ASC) WHERE completed_at IS NULL + ` + q3 := ` + CREATE INDEX CONCURRENTLY IF NOT EXISTS pqdocket_task_ready_at_idx + ON pqdocket_task (greatest(scheduled_at, claimed_until)) WHERE completed_at IS NULL; ` _, err := d.db.Exec(strings.Replace(q, taskTableName, d.tableName(), 1)) if err != nil { @@ -45,6 +51,21 @@ func (d *docket) initTables() error { if err != nil { return err } + _, err = d.db.Exec(strings.Replace(q3, taskTableName, d.tableName(), 2)) + if err != nil { + return err + } + if len(d.parallelismMinByFunc) > 0 { + q4 := ` + CREATE INDEX CONCURRENTLY IF NOT EXISTS pqdocket_task_func_scheduled_at_idx + ON pqdocket_task (func, scheduled_at ASC) WHERE completed_at IS NULL + ` + _, err = d.db.Exec(strings.Replace(q4, taskTableName, d.tableName(), 2)) + if err != nil { + return err + } + } + if d.afterTableInitCallback == nil { return nil } @@ -154,46 +175,96 @@ func (d *docket) insertTasks(tx *sql.Tx, skipDuplicates bool, tcs ...TaskCreator return tasks, rows.Close() } -func (d *docket) claimTasks(wantNum int) ([]task, error) { - q := ` - WITH tasks_to_claim AS ( - SELECT task_id - FROM pqdocket_task t1 - WHERE completed_at IS NULL - AND (scheduled_at < now()) - AND (claimed_until IS NULL OR now() > claimed_until) - AND claim_count < $1 - ORDER BY scheduled_at ASC - LIMIT $2 - FOR UPDATE SKIP LOCKED - ) - UPDATE pqdocket_task - SET claimed_until = now() + make_interval(secs := claim_time_seconds), - claim_count = claim_count + 1 - WHERE task_id IN (select * from tasks_to_claim) - RETURNING * - ` - q = strings.Replace(q, taskTableName, d.tableName(), 2) - rows, err := d.db.Query(q, d.maxClaimCount, wantNum) - if err != nil { - return nil, err +func (d *docket) claimTasks(wantNum *want.Counter) ([]task, error) { + + claim := func(q string, args ...any) ([]task, error) { + q = strings.Replace(q, taskTableName, d.tableName(), 2) + rows, err := d.db.Query(q, args...) + if err != nil { + return nil, err + } + defer rows.Close() + var claimedTasks []task + for rows.Next() { + t, err := scanTaskFrom(rows) + if err != nil { + return nil, err + } + t.docket = d + t.claimedByThisProcess = true + claimedTasks = append(claimedTasks, t) + } + err = rows.Err() + if err != nil { + return nil, err + } + return claimedTasks, rows.Close() } - defer rows.Close() + var claimedTasks []task - for rows.Next() { - t, err := scanTaskFrom(rows) + + wantByFunc := wantNum.WantByFuncNameTotal() + + if wantByFunc > 0 { + q := ` + WITH tasks_to_claim AS ( + SELECT task_id + FROM unnest($2 :: TEXT[], $3 :: INT[]) f(func_name, func_limit) + JOIN LATERAL ( + SELECT task_id + FROM pqdocket_task + WHERE completed_at IS NULL + AND func = f.func_name + AND (scheduled_at < now()) + AND (claimed_until IS NULL OR now() > claimed_until) + AND claim_count < $1 + ORDER BY scheduled_at ASC + LIMIT f.func_limit + ) a ON true + FOR UPDATE SKIP LOCKED + ) + UPDATE pqdocket_task + SET claimed_until = now() + make_interval(secs := claim_time_seconds), + claim_count = claim_count + 1 + WHERE task_id IN (select * from tasks_to_claim) + RETURNING * + ` + funcNames, funcLimits := wantNum.WantByFuncName() + tasks, err := claim(q, d.maxClaimCount, pq.StringArray(funcNames), pq.Int64Array(funcLimits)) if err != nil { return nil, err } - t.docket = d - t.claimedByThisProcess = true - claimedTasks = append(claimedTasks, t) + claimedTasks = append(claimedTasks, tasks...) } - err = rows.Err() - if err != nil { - return nil, err + + wantGeneral := wantNum.General() + wantByFunc - len(claimedTasks) + + if wantGeneral > 0 { + q := ` + WITH tasks_to_claim AS ( + SELECT task_id + FROM pqdocket_task t1 + WHERE completed_at IS NULL + AND (scheduled_at < now()) + AND (claimed_until IS NULL OR now() > claimed_until) + AND claim_count < $1 + ORDER BY scheduled_at ASC + LIMIT $2 + FOR UPDATE SKIP LOCKED + ) + UPDATE pqdocket_task + SET claimed_until = now() + make_interval(secs := claim_time_seconds), + claim_count = claim_count + 1 + WHERE task_id IN (select * from tasks_to_claim) + RETURNING * + ` + tasks, err := claim(q, d.maxClaimCount, wantGeneral) + if err != nil { + return nil, err + } + claimedTasks = append(claimedTasks, tasks...) } - return claimedTasks, rows.Close() + return claimedTasks, nil } func (d *docket) saveTaskResult(l *slog.Logger, t task, taskErr error) { diff --git a/docket.go b/docket.go index a773417..8b08b01 100644 --- a/docket.go +++ b/docket.go @@ -4,7 +4,6 @@ import ( "database/sql" "errors" "fmt" - "github.com/lib/pq" "io" "log/slog" "math" @@ -15,6 +14,9 @@ import ( "sync" "sync/atomic" "time" + + "github.com/lib/pq" + "github.com/modfin/pqdocket/want" ) type Docket interface { @@ -71,6 +73,7 @@ type docket struct { defaultClaimTime int pollInterval time.Duration parallelism int + parallelismMinByFunc map[string]int maxRetryBackoff int minRetryBackoff int maxClaimCount int @@ -79,7 +82,7 @@ type docket struct { useManuallyCreatedTable string claimedTasks chan task - taskCompleted chan bool + taskCompleted chan string functions map[string]TaskFunction closed bool @@ -117,6 +120,7 @@ func Init(dbUrl string, options ...Option) (Docket, error) { d.maxRetryBackoff = 128 d.minRetryBackoff = 4 d.maxClaimCount = math.MaxInt32 + d.parallelismMinByFunc = make(map[string]int) for _, opt := range options { opt(d) } @@ -126,7 +130,10 @@ func Init(dbUrl string, options ...Option) (Docket, error) { if d.useManuallyCreatedTable != "" && d.afterTableInitCallback != nil { return nil, errors.New("AfterTableInit will never be called when UseManuallyCreatedTable is enabled") } - + _, err := want.NewCounter(d.parallelism, d.parallelismMinByFunc) + if err != nil { + return nil, err + } d.close = make(chan bool) d.closeFinished = make(chan bool) @@ -139,7 +146,7 @@ func Init(dbUrl string, options ...Option) (Docket, error) { d.db = db d.claimedTasks = make(chan task, d.parallelism) - d.taskCompleted = make(chan bool, d.parallelism*2) + d.taskCompleted = make(chan string, d.parallelism*2) d.functions = make(map[string]TaskFunction) if err = d.initTables(); err != nil { return nil, err diff --git a/option.go b/option.go index 159a5c8..6456800 100644 --- a/option.go +++ b/option.go @@ -24,6 +24,13 @@ func Parallelism(parallelism int) Option { } } +// WithFuncMinParallelism sets the minimum number of parallel tasks for a given function. +func WithFuncMinParallelism(funcName string, minParallelism int) Option { + return func(docket *docket) { + docket.parallelismMinByFunc[funcName] = minParallelism + } +} + // DefaultClaimTime sets the default claim time for created tasks. Can be overridden on Task creation. func DefaultClaimTime(defaultClaimTimeSeconds int) Option { return func(docket *docket) { diff --git a/scheduler.go b/scheduler.go index 097b570..d96fac8 100644 --- a/scheduler.go +++ b/scheduler.go @@ -2,11 +2,13 @@ package pqdocket import ( "errors" - "github.com/lib/pq" "log/slog" "math" "math/rand" "time" + + "github.com/lib/pq" + "github.com/modfin/pqdocket/want" ) func (d *docket) reinitTablesIfError(l *slog.Logger, err error) { @@ -32,10 +34,15 @@ func (d *docket) startScheduler() { defer pollTicker.Stop() d.logger.Load().With("poll_interval", pollInterval, "parallelism", d.parallelism).Info("scheduler started") - wantNum := d.parallelism + wantNum, err := want.NewCounter(d.parallelism, d.parallelismMinByFunc) + if err != nil { + d.logger.Load().With("error", err).Error("failed to create want counter") + return + } + for { l := d.logger.Load() - if wantNum > 0 { + if wantNum.Total() > 0 { tasks, err := d.claimTasks(wantNum) if err != nil { l.With("error", err).Error("error in claimTasks") @@ -43,17 +50,17 @@ func (d *docket) startScheduler() { time.Sleep(20 * time.Second) continue } - l.With("want", wantNum, "got", len(tasks)).Info("claimTasks") + l.With("want", wantNum.Total(), "got", len(tasks)).Info("claimTasks") for _, t := range tasks { d.claimedTasks <- t - wantNum-- + wantNum.Decrement(t.function) } } timeout := d.timeToSleep() // protect against polling if our workers are full - if wantNum == 0 && timeout < 2*time.Second { + if wantNum.Total() == 0 && timeout < 2*time.Second { timeout = 2 * time.Second } if timeout < time.Duration(math.MaxInt64) { @@ -64,8 +71,8 @@ func (d *docket) startScheduler() { select { case <-taskScheduled: l = l.With("reason", "task_scheduled") - case <-d.taskCompleted: - wantNum++ + case funcName := <-d.taskCompleted: + wantNum.Increment(funcName) l = l.With("reason", "task_completed") case <-d.close: l = l.With("reason", "closed") @@ -87,8 +94,8 @@ func (d *docket) startScheduler() { // consume extra buffered taskScheduled/taskCompleted messages for { select { - case <-d.taskCompleted: - wantNum++ + case funcName := <-d.taskCompleted: + wantNum.Increment(funcName) continue case <-taskScheduled: continue diff --git a/want/counter.go b/want/counter.go new file mode 100644 index 0000000..82bb81a --- /dev/null +++ b/want/counter.go @@ -0,0 +1,73 @@ +package want + +import ( + "fmt" +) + +type Counter struct { + parallelism int + parallelismMinByFunc map[string]int + wantNumGeneral int + wantNumByFunc map[string]int +} + +func NewCounter(parallelism int, parallelismMinByFunc map[string]int) (*Counter, error) { + + wantNumByFunc := make(map[string]int) + wantNumGeneral := parallelism + for funcName, minParallelism := range parallelismMinByFunc { + wantNumByFunc[funcName] = minParallelism + wantNumGeneral-- + } + if wantNumGeneral < 1 { + return nil, fmt.Errorf("not enough parallelism for all functions, due to bad WithFuncMinParallelism config") + } + return &Counter{ + parallelism: parallelism, + parallelismMinByFunc: parallelismMinByFunc, + wantNumGeneral: wantNumGeneral, + wantNumByFunc: wantNumByFunc, + }, nil +} + +func (c *Counter) Total() int { + return c.wantNumGeneral + c.WantByFuncNameTotal() +} + +func (c *Counter) WantByFuncNameTotal() int { + num := 0 + for _, v := range c.wantNumByFunc { + num += v + } + return num +} + +func (c *Counter) General() int { + return c.wantNumGeneral +} + +func (c *Counter) WantByFuncName() ([]string, []int64) { + var funcNames []string + var counts []int64 + for funcName, count := range c.wantNumByFunc { + funcNames = append(funcNames, funcName) + counts = append(counts, int64(count)) + } + return funcNames, counts +} + +func (c *Counter) Increment(funcName string) { + if current, ok := c.wantNumByFunc[funcName]; ok && current < c.parallelismMinByFunc[funcName] { + c.wantNumByFunc[funcName]++ + return + } + c.wantNumGeneral++ +} + +func (c *Counter) Decrement(funcName string) { + if current, ok := c.wantNumByFunc[funcName]; ok && current > 0 { + c.wantNumByFunc[funcName]-- + return + } + c.wantNumGeneral-- +} diff --git a/worker.go b/worker.go index 9fec466..a6ed3cc 100644 --- a/worker.go +++ b/worker.go @@ -11,7 +11,7 @@ func (d *docket) worker(workerId int) { l.With("worker_id", workerId).Info("running task") err := d.workerBody(t) d.saveTaskResult(l, t, err) - d.taskCompleted <- true + d.taskCompleted <- t.function } d.logger.Load().With("worker_id", workerId).Info("worker terminated") d.mu.Lock() From f99873bae9f0054ef48a3d4266a5132ad112461c Mon Sep 17 00:00:00 2001 From: agaton Date: Wed, 28 Jan 2026 11:04:50 +0100 Subject: [PATCH 2/8] add basic test --- tests/func_parallelism_test.go | 148 +++++++++++++++++++++++++++++++++ tests/main_test.go | 4 +- 2 files changed, 151 insertions(+), 1 deletion(-) create mode 100644 tests/func_parallelism_test.go diff --git a/tests/func_parallelism_test.go b/tests/func_parallelism_test.go new file mode 100644 index 0000000..6137a6f --- /dev/null +++ b/tests/func_parallelism_test.go @@ -0,0 +1,148 @@ +package tests + +import ( + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/modfin/pqdocket" +) + +var ( + pfpD pqdocket.Docket + pfpMaxSlowCon atomic.Int32 + pfpMaxMediumCon atomic.Int32 + pfpMaxFastCon atomic.Int32 + pfpMaxGeneralCon atomic.Int32 + pfpWg sync.WaitGroup +) + +const ( + slowParallelism = 3 + mediumParallelism = 5 + fastParallelism = 2 + generalParallelism = 50 +) + +func PerFunctionParallelismInit() { + var err error + pfpD, err = pqdocket.Init("postgres://postgres:qwerty@localhost:9300/postgres?sslmode=disable", + pqdocket.Namespace("pfp_test"), + pqdocket.Parallelism(generalParallelism), + pqdocket.WithFuncMinParallelism("slowFunc", slowParallelism), + pqdocket.WithFuncMinParallelism("mediumFunc", mediumParallelism), + pqdocket.WithFuncMinParallelism("fastFunc", fastParallelism), + ) + if err != nil { + panic(err) + } + + pfpMaxSlowCon.Store(0) + pfpMaxMediumCon.Store(0) + pfpMaxFastCon.Store(0) + pfpMaxGeneralCon.Store(0) + +} + +func slowFunc(t *testing.T) func(task pqdocket.RunningTask) error { + return func(task pqdocket.RunningTask) error { + current := pfpMaxSlowCon.Add(1) + defer pfpMaxSlowCon.Add(-1) + defer pfpWg.Done() + + if current > slowParallelism { + t.Error("slowFunc exceeded its parallelism limit of 3") + return fmt.Errorf("slowFunc exceeded its parallelism limit of 3") + } + + time.Sleep(250 * time.Millisecond) + return nil + } +} +func mediumFunc(t *testing.T) func(task pqdocket.RunningTask) error { + return func(task pqdocket.RunningTask) error { + current := pfpMaxMediumCon.Add(1) + defer pfpMaxMediumCon.Add(-1) + defer pfpWg.Done() + + if current > mediumParallelism { + t.Error("mediumFunc exceeded parallelism limit of 5") + return fmt.Errorf("mediumFunc exceeded parallelism limit of 5") + } + + time.Sleep(50 * time.Millisecond) + return nil + } +} +func fastFunc(t *testing.T) func(task pqdocket.RunningTask) error { + return func(task pqdocket.RunningTask) error { + current := pfpMaxFastCon.Add(1) + defer pfpMaxFastCon.Add(-1) + defer pfpWg.Done() + + if current > fastParallelism { + t.Error("fastFunc exceeded parallelism limit of 2") + return fmt.Errorf("fastFunc exceeded parallelism limit of 2") + } + time.Sleep(10 * time.Millisecond) + return nil + } +} +func generalFunc(t *testing.T) func(task pqdocket.RunningTask) error { + return func(task pqdocket.RunningTask) error { + current := pfpMaxGeneralCon.Add(1) + defer pfpMaxGeneralCon.Add(-1) + defer pfpWg.Done() + + if current > generalParallelism { + t.Error("generalFunc exceeded default parallelism limit of 10") + return fmt.Errorf("generalFunc exceeded default parallelism limit of 50") + } + time.Sleep(20 * time.Millisecond) + return nil + } +} + +// TestMultipleFunctionSchedulers verifies each function respects its individual limit +func TestMultipleFunctionSchedulers(t *testing.T) { + var tcs []pqdocket.TaskCreator + + pfpD.RegisterFunctionWithFuncName("slowFunc", slowFunc(t)) + pfpD.RegisterFunctionWithFuncName("mediumFunc", mediumFunc(t)) + pfpD.RegisterFunctionWithFuncName("fastFunc", fastFunc(t)) + pfpD.RegisterFunctionWithFuncName("generalFunc", generalFunc(t)) + + for i := 0; i < slowParallelism*10; i++ { + pfpWg.Add(1) + tc := pfpD.CreateTaskWithFuncName("slowFunc") + tcs = append(tcs, tc) + } + + for i := 0; i < generalParallelism*10; i++ { + pfpWg.Add(1) + tc := pfpD.CreateTaskWithFuncName("generalFunc") + tcs = append(tcs, tc) + } + + for i := 0; i < mediumParallelism*10; i++ { + pfpWg.Add(1) + tc := pfpD.CreateTaskWithFuncName("mediumFunc") + tcs = append(tcs, tc) + } + + for i := 0; i < fastParallelism*10; i++ { + pfpWg.Add(1) + tc := pfpD.CreateTaskWithFuncName("fastFunc") + tcs = append(tcs, tc) + } + + _, err := pfpD.InsertTasks(nil, tcs...) + if err != nil { + t.Error(err) + t.Fail() + } + + pfpWg.Wait() +} diff --git a/tests/main_test.go b/tests/main_test.go index 547aac5..87b2314 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -4,12 +4,13 @@ import ( "database/sql" "errors" "fmt" - "github.com/modfin/pqdocket" "log/slog" "math/rand" "os" "testing" "time" + + "github.com/modfin/pqdocket" ) var d pqdocket.Docket @@ -67,6 +68,7 @@ func TestMain(m *testing.M) { FindInit() ChainInit() ExtendedClaimInit() + PerFunctionParallelismInit() db, err = sql.Open("postgres", dbUrl) if err != nil { From 6a3be542f7281ff4feff13b833cab542147b8175 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Edstr=C3=B6m?= Date: Wed, 28 Jan 2026 12:15:50 +0100 Subject: [PATCH 3/8] Fixes --- dao.go | 45 ++++++++------- docket.go | 5 +- option.go | 16 +++++- scheduler.go | 2 +- tests/func_parallelism_test.go | 12 ++-- tests/main_test.go | 3 +- want/counter.go | 100 +++++++++++++++++++++++---------- 7 files changed, 119 insertions(+), 64 deletions(-) diff --git a/dao.go b/dao.go index 75d1deb..fb72403 100644 --- a/dao.go +++ b/dao.go @@ -5,7 +5,9 @@ import ( "errors" "fmt" "log/slog" + "maps" "math" + "slices" "strings" "time" @@ -55,17 +57,6 @@ func (d *docket) initTables() error { if err != nil { return err } - if len(d.parallelismMinByFunc) > 0 { - q4 := ` - CREATE INDEX CONCURRENTLY IF NOT EXISTS pqdocket_task_func_scheduled_at_idx - ON pqdocket_task (func, scheduled_at ASC) WHERE completed_at IS NULL - ` - _, err = d.db.Exec(strings.Replace(q4, taskTableName, d.tableName(), 2)) - if err != nil { - return err - } - } - if d.afterTableInitCallback == nil { return nil } @@ -203,23 +194,23 @@ func (d *docket) claimTasks(wantNum *want.Counter) ([]task, error) { var claimedTasks []task - wantByFunc := wantNum.WantByFuncNameTotal() + wantByGroups := wantNum.WantByGroupsTotal() - if wantByFunc > 0 { + if wantByGroups > 0 { q := ` WITH tasks_to_claim AS ( SELECT task_id - FROM unnest($2 :: TEXT[], $3 :: INT[]) f(func_name, func_limit) + FROM unnest($2 :: TEXT[], $3 :: INT[]) f(func_names, group_limit) JOIN LATERAL ( SELECT task_id FROM pqdocket_task WHERE completed_at IS NULL - AND func = f.func_name + AND func = ANY(string_to_array(f.func_names, ',')) AND (scheduled_at < now()) AND (claimed_until IS NULL OR now() > claimed_until) AND claim_count < $1 ORDER BY scheduled_at ASC - LIMIT f.func_limit + LIMIT f.group_limit ) a ON true FOR UPDATE SKIP LOCKED ) @@ -229,7 +220,7 @@ func (d *docket) claimTasks(wantNum *want.Counter) ([]task, error) { WHERE task_id IN (select * from tasks_to_claim) RETURNING * ` - funcNames, funcLimits := wantNum.WantByFuncName() + funcNames, funcLimits := wantNum.WantByGroup() tasks, err := claim(q, d.maxClaimCount, pq.StringArray(funcNames), pq.Int64Array(funcLimits)) if err != nil { return nil, err @@ -237,9 +228,22 @@ func (d *docket) claimTasks(wantNum *want.Counter) ([]task, error) { claimedTasks = append(claimedTasks, tasks...) } - wantGeneral := wantNum.General() + wantByFunc - len(claimedTasks) + wantGeneral := wantNum.General() if wantGeneral > 0 { + d.mu.RLock() + registeredFuncs := slices.Collect(maps.Keys(d.functions)) + d.mu.RUnlock() + handledByGroups := wantNum.FunctionsHandledByGroups() + + var generalFuncs []string + for _, rf := range registeredFuncs { + if _, ok := handledByGroups[rf]; ok { + continue + } + generalFuncs = append(generalFuncs, rf) + } + q := ` WITH tasks_to_claim AS ( SELECT task_id @@ -248,8 +252,9 @@ func (d *docket) claimTasks(wantNum *want.Counter) ([]task, error) { AND (scheduled_at < now()) AND (claimed_until IS NULL OR now() > claimed_until) AND claim_count < $1 + AND func = ANY($2) ORDER BY scheduled_at ASC - LIMIT $2 + LIMIT $3 FOR UPDATE SKIP LOCKED ) UPDATE pqdocket_task @@ -258,7 +263,7 @@ func (d *docket) claimTasks(wantNum *want.Counter) ([]task, error) { WHERE task_id IN (select * from tasks_to_claim) RETURNING * ` - tasks, err := claim(q, d.maxClaimCount, wantGeneral) + tasks, err := claim(q, d.maxClaimCount, pq.StringArray(generalFuncs), wantGeneral) if err != nil { return nil, err } diff --git a/docket.go b/docket.go index 8b08b01..c91127a 100644 --- a/docket.go +++ b/docket.go @@ -73,7 +73,7 @@ type docket struct { defaultClaimTime int pollInterval time.Duration parallelism int - parallelismMinByFunc map[string]int + parallelismGroups []want.ParallelismGroup maxRetryBackoff int minRetryBackoff int maxClaimCount int @@ -120,7 +120,6 @@ func Init(dbUrl string, options ...Option) (Docket, error) { d.maxRetryBackoff = 128 d.minRetryBackoff = 4 d.maxClaimCount = math.MaxInt32 - d.parallelismMinByFunc = make(map[string]int) for _, opt := range options { opt(d) } @@ -130,7 +129,7 @@ func Init(dbUrl string, options ...Option) (Docket, error) { if d.useManuallyCreatedTable != "" && d.afterTableInitCallback != nil { return nil, errors.New("AfterTableInit will never be called when UseManuallyCreatedTable is enabled") } - _, err := want.NewCounter(d.parallelism, d.parallelismMinByFunc) + _, err := want.NewCounter(d.parallelism, d.parallelismGroups) if err != nil { return nil, err } diff --git a/option.go b/option.go index 6456800..b0770a9 100644 --- a/option.go +++ b/option.go @@ -4,6 +4,8 @@ import ( "database/sql" "log/slog" "time" + + "github.com/modfin/pqdocket/want" ) type Option func(docket *docket) @@ -24,10 +26,18 @@ func Parallelism(parallelism int) Option { } } -// WithFuncMinParallelism sets the minimum number of parallel tasks for a given function. -func WithFuncMinParallelism(funcName string, minParallelism int) Option { +// WithDedicatedParallelismGroup dedicates N parallelism slots to a group of functions. +// This means that N slots will be dedicated to tasks of these functions only. +// And that no more than N tasks from this group will run in parallel. +// +// For example, if you specify total parallelism to 5 and make a dedicated group of size 4, +// then all other tasks will have to fight for the the remaining single slot. +func WithDedicatedParallelismGroup(size int, funcNames ...string) Option { return func(docket *docket) { - docket.parallelismMinByFunc[funcName] = minParallelism + if len(funcNames) == 0 { + return + } + docket.parallelismGroups = append(docket.parallelismGroups, want.ParallelismGroup{Functions: funcNames, Count: size}) } } diff --git a/scheduler.go b/scheduler.go index d96fac8..2fcf6b2 100644 --- a/scheduler.go +++ b/scheduler.go @@ -34,7 +34,7 @@ func (d *docket) startScheduler() { defer pollTicker.Stop() d.logger.Load().With("poll_interval", pollInterval, "parallelism", d.parallelism).Info("scheduler started") - wantNum, err := want.NewCounter(d.parallelism, d.parallelismMinByFunc) + wantNum, err := want.NewCounter(d.parallelism, d.parallelismGroups) if err != nil { d.logger.Load().With("error", err).Error("failed to create want counter") return diff --git a/tests/func_parallelism_test.go b/tests/func_parallelism_test.go index 6137a6f..b73712c 100644 --- a/tests/func_parallelism_test.go +++ b/tests/func_parallelism_test.go @@ -26,14 +26,14 @@ const ( generalParallelism = 50 ) -func PerFunctionParallelismInit() { +func ParallelismGroupsInit() { var err error pfpD, err = pqdocket.Init("postgres://postgres:qwerty@localhost:9300/postgres?sslmode=disable", pqdocket.Namespace("pfp_test"), pqdocket.Parallelism(generalParallelism), - pqdocket.WithFuncMinParallelism("slowFunc", slowParallelism), - pqdocket.WithFuncMinParallelism("mediumFunc", mediumParallelism), - pqdocket.WithFuncMinParallelism("fastFunc", fastParallelism), + pqdocket.WithDedicatedParallelismGroup(slowParallelism, "slowFunc"), + pqdocket.WithDedicatedParallelismGroup(mediumParallelism, "mediumFunc"), + pqdocket.WithDedicatedParallelismGroup(fastParallelism, "fastFunc"), ) if err != nil { panic(err) @@ -105,8 +105,8 @@ func generalFunc(t *testing.T) func(task pqdocket.RunningTask) error { } } -// TestMultipleFunctionSchedulers verifies each function respects its individual limit -func TestMultipleFunctionSchedulers(t *testing.T) { +// TestParallelismGroups verifies each function respects its individual limit +func TestParallelismGroups(t *testing.T) { var tcs []pqdocket.TaskCreator pfpD.RegisterFunctionWithFuncName("slowFunc", slowFunc(t)) diff --git a/tests/main_test.go b/tests/main_test.go index 87b2314..6d4b294 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -32,6 +32,7 @@ func TestMain(m *testing.M) { pqdocket.DefaultClaimTime(20), pqdocket.WithLogger(l.With("instance", "d1")), pqdocket.EnableTaskCleaner(pqdocket.TaskCleanerSettings{PollInterval: time.Second}), + pqdocket.WithDedicatedParallelismGroup(30, "Task1", "Task3", "StressTask"), ) if err != nil { fmt.Println("d1 init err", err) @@ -68,7 +69,7 @@ func TestMain(m *testing.M) { FindInit() ChainInit() ExtendedClaimInit() - PerFunctionParallelismInit() + ParallelismGroupsInit() db, err = sql.Open("postgres", dbUrl) if err != nil { diff --git a/want/counter.go b/want/counter.go index 82bb81a..c195d25 100644 --- a/want/counter.go +++ b/want/counter.go @@ -2,72 +2,112 @@ package want import ( "fmt" + "slices" + "strings" ) type Counter struct { - parallelism int - parallelismMinByFunc map[string]int - wantNumGeneral int - wantNumByFunc map[string]int + parallelism int + parallelismGroups []ParallelismGroup + wantNumGeneral int + wantNumByGroup []ParallelismGroup } -func NewCounter(parallelism int, parallelismMinByFunc map[string]int) (*Counter, error) { +type ParallelismGroup struct { + Functions []string + Count int +} + +func NewCounter(parallelism int, parallelismGroups []ParallelismGroup) (*Counter, error) { - wantNumByFunc := make(map[string]int) + var wantNumByFunc []ParallelismGroup wantNumGeneral := parallelism - for funcName, minParallelism := range parallelismMinByFunc { - wantNumByFunc[funcName] = minParallelism - wantNumGeneral-- + for _, g := range parallelismGroups { + if slices.ContainsFunc(g.Functions, func(s string) bool { + return strings.Contains(s, ",") + }) { + return nil, fmt.Errorf("functions names cannot contain commas") + } + wantNumByFunc = append(wantNumByFunc, g) + wantNumGeneral -= g.Count } if wantNumGeneral < 1 { - return nil, fmt.Errorf("not enough parallelism for all functions, due to bad WithFuncMinParallelism config") + return nil, fmt.Errorf("not enough parallelism for all functions, due to bad WithDedicatedParallelismGroup config") } return &Counter{ - parallelism: parallelism, - parallelismMinByFunc: parallelismMinByFunc, - wantNumGeneral: wantNumGeneral, - wantNumByFunc: wantNumByFunc, + parallelism: parallelism, + parallelismGroups: parallelismGroups, + wantNumGeneral: wantNumGeneral, + wantNumByGroup: wantNumByFunc, }, nil } func (c *Counter) Total() int { - return c.wantNumGeneral + c.WantByFuncNameTotal() + return c.wantNumGeneral + c.WantByGroupsTotal() } -func (c *Counter) WantByFuncNameTotal() int { +func (c *Counter) WantByGroupsTotal() int { num := 0 - for _, v := range c.wantNumByFunc { - num += v + for _, v := range c.wantNumByGroup { + num += v.Count } return num } +func (c *Counter) FunctionsHandledByGroups() map[string]bool { + funcs := make(map[string]bool) + for _, v := range c.wantNumByGroup { + for _, f := range v.Functions { + funcs[f] = true + } + } + return funcs +} + func (c *Counter) General() int { return c.wantNumGeneral } -func (c *Counter) WantByFuncName() ([]string, []int64) { - var funcNames []string +func (c *Counter) WantByGroup() ([]string, []int64) { + var groupFuncNames []string var counts []int64 - for funcName, count := range c.wantNumByFunc { - funcNames = append(funcNames, funcName) - counts = append(counts, int64(count)) + for _, g := range c.wantNumByGroup { + groupFuncNames = append(groupFuncNames, strings.Join(g.Functions, ",")) + counts = append(counts, int64(g.Count)) } - return funcNames, counts + return groupFuncNames, counts } func (c *Counter) Increment(funcName string) { - if current, ok := c.wantNumByFunc[funcName]; ok && current < c.parallelismMinByFunc[funcName] { - c.wantNumByFunc[funcName]++ - return + for i := range c.wantNumByGroup { + if slices.Contains(c.wantNumByGroup[i].Functions, funcName) { + c.wantNumByGroup[i].Count++ + if c.wantNumByGroup[i].Count > c.parallelismGroups[i].Count { // TODO remove + panic("shouldn't happen, count should never exceed group parallelism") + } + return + } } + c.wantNumGeneral++ + if c.wantNumGeneral > c.parallelism { // TODO remove + panic("general: shouldn't happen, count should never exceed parallelism") + } } func (c *Counter) Decrement(funcName string) { - if current, ok := c.wantNumByFunc[funcName]; ok && current > 0 { - c.wantNumByFunc[funcName]-- - return + for i := range c.wantNumByGroup { + if slices.Contains(c.wantNumByGroup[i].Functions, funcName) { + c.wantNumByGroup[i].Count-- + if c.wantNumByGroup[i].Count < 0 { // TODO remove + fmt.Println(c.wantNumByGroup[i].Count, c.parallelismGroups[i].Count) + panic("group: shouldn't happen, count should never be lower than 0") + } + return + } } c.wantNumGeneral-- + if c.wantNumGeneral < 0 { // TODO remove + panic("general: shouldn't happen, count should never go below 0") + } } From 7f48f2a46ecd6982dc8005db269cb31b96e1253d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Edstr=C3=B6m?= Date: Wed, 28 Jan 2026 12:43:33 +0100 Subject: [PATCH 4/8] Fix --- scheduler.go | 4 ++-- want/counter.go | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/scheduler.go b/scheduler.go index 2fcf6b2..4601b7f 100644 --- a/scheduler.go +++ b/scheduler.go @@ -59,8 +59,8 @@ func (d *docket) startScheduler() { timeout := d.timeToSleep() - // protect against polling if our workers are full - if wantNum.Total() == 0 && timeout < 2*time.Second { + // protect against polling if any of our worker groups are full + if wantNum.HasZero() && timeout < 2*time.Second { timeout = 2 * time.Second } if timeout < time.Duration(math.MaxInt64) { diff --git a/want/counter.go b/want/counter.go index c195d25..2f0affb 100644 --- a/want/counter.go +++ b/want/counter.go @@ -46,6 +46,18 @@ func (c *Counter) Total() int { return c.wantNumGeneral + c.WantByGroupsTotal() } +func (c *Counter) HasZero() bool { + if c.wantNumGeneral == 0 { + return true + } + for _, v := range c.wantNumByGroup { + if v.Count == 0 { + return true + } + } + return false +} + func (c *Counter) WantByGroupsTotal() int { num := 0 for _, v := range c.wantNumByGroup { From ecdb0702eb6615966691f788c05f85a5c97b78bb Mon Sep 17 00:00:00 2001 From: Alexander Olsson Date: Tue, 3 Feb 2026 13:56:30 +0100 Subject: [PATCH 5/8] Added get last error to task --- task.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/task.go b/task.go index b3bde8b..60bde80 100644 --- a/task.go +++ b/task.go @@ -4,10 +4,11 @@ import ( "database/sql" "encoding/json" "errors" - "github.com/guregu/null/v6" "log/slog" "sync" "time" + + "github.com/guregu/null/v6" ) type Task interface { @@ -16,6 +17,7 @@ type Task interface { BindMetadata(dest interface{}) error CompletedAt() null.Time ClaimCount() int + LastError() null.String CreatedAt() time.Time Func() string @@ -188,6 +190,10 @@ func (t task) ClaimCount() int { return t.claimCount } +func (t task) LastError() null.String { + return t.lastError +} + func (t task) CreatedAt() time.Time { return t.createdAt } From 83ff6bcd636a9945f29c8133beec905fa1bb07c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joel=20Edstr=C3=B6m?= Date: Sat, 7 Feb 2026 21:07:50 +0100 Subject: [PATCH 6/8] Try using enable_seqscan = off --- dao.go | 115 ++++++++++++++++++++++----------------------------------- 1 file changed, 44 insertions(+), 71 deletions(-) diff --git a/dao.go b/dao.go index fb72403..692cf26 100644 --- a/dao.go +++ b/dao.go @@ -1,6 +1,7 @@ package pqdocket import ( + "context" "database/sql" "errors" "fmt" @@ -167,37 +168,36 @@ func (d *docket) insertTasks(tx *sql.Tx, skipDuplicates bool, tcs ...TaskCreator } func (d *docket) claimTasks(wantNum *want.Counter) ([]task, error) { + d.mu.RLock() + registeredFuncs := slices.Collect(maps.Keys(d.functions)) + d.mu.RUnlock() - claim := func(q string, args ...any) ([]task, error) { - q = strings.Replace(q, taskTableName, d.tableName(), 2) - rows, err := d.db.Query(q, args...) - if err != nil { - return nil, err - } - defer rows.Close() - var claimedTasks []task - for rows.Next() { - t, err := scanTaskFrom(rows) - if err != nil { - return nil, err - } - t.docket = d - t.claimedByThisProcess = true - claimedTasks = append(claimedTasks, t) - } - err = rows.Err() - if err != nil { - return nil, err + wantGeneral := wantNum.General() + handledByGroups := wantNum.FunctionsHandledByGroups() + + var generalFuncs []string + for _, rf := range registeredFuncs { + if _, ok := handledByGroups[rf]; ok { + continue } - return claimedTasks, rows.Close() + generalFuncs = append(generalFuncs, rf) } - var claimedTasks []task - - wantByGroups := wantNum.WantByGroupsTotal() + funcNames, funcLimits := wantNum.WantByGroup() + funcNames = append(funcNames, strings.Join(generalFuncs, ",")) + funcLimits = append(funcLimits, int64(wantGeneral)) - if wantByGroups > 0 { - q := ` + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tx, err := d.db.BeginTx(ctx, nil) // tx rollbacks are handled by context cancel + if err != nil { + return nil, err + } + _, err = tx.Exec(`SET LOCAL enable_seqscan = off`) + if err != nil { + return nil, err + } + q := ` WITH tasks_to_claim AS ( SELECT task_id FROM unnest($2 :: TEXT[], $3 :: INT[]) f(func_names, group_limit) @@ -220,56 +220,29 @@ func (d *docket) claimTasks(wantNum *want.Counter) ([]task, error) { WHERE task_id IN (select * from tasks_to_claim) RETURNING * ` - funcNames, funcLimits := wantNum.WantByGroup() - tasks, err := claim(q, d.maxClaimCount, pq.StringArray(funcNames), pq.Int64Array(funcLimits)) - if err != nil { - return nil, err - } - claimedTasks = append(claimedTasks, tasks...) + q = strings.Replace(q, taskTableName, d.tableName(), 2) + rows, err := tx.Query(q, d.maxClaimCount, pq.StringArray(funcNames), pq.Int64Array(funcLimits)) + if err != nil { + return nil, err } - - wantGeneral := wantNum.General() - - if wantGeneral > 0 { - d.mu.RLock() - registeredFuncs := slices.Collect(maps.Keys(d.functions)) - d.mu.RUnlock() - handledByGroups := wantNum.FunctionsHandledByGroups() - - var generalFuncs []string - for _, rf := range registeredFuncs { - if _, ok := handledByGroups[rf]; ok { - continue - } - generalFuncs = append(generalFuncs, rf) - } - - q := ` - WITH tasks_to_claim AS ( - SELECT task_id - FROM pqdocket_task t1 - WHERE completed_at IS NULL - AND (scheduled_at < now()) - AND (claimed_until IS NULL OR now() > claimed_until) - AND claim_count < $1 - AND func = ANY($2) - ORDER BY scheduled_at ASC - LIMIT $3 - FOR UPDATE SKIP LOCKED - ) - UPDATE pqdocket_task - SET claimed_until = now() + make_interval(secs := claim_time_seconds), - claim_count = claim_count + 1 - WHERE task_id IN (select * from tasks_to_claim) - RETURNING * - ` - tasks, err := claim(q, d.maxClaimCount, pq.StringArray(generalFuncs), wantGeneral) + defer rows.Close() + var claimedTasks []task + for rows.Next() { + t, err := scanTaskFrom(rows) if err != nil { return nil, err } - claimedTasks = append(claimedTasks, tasks...) + t.docket = d + t.claimedByThisProcess = true + claimedTasks = append(claimedTasks, t) + } + if err = rows.Err(); err != nil { + return nil, err + } + if err = rows.Close(); err != nil { + return nil, err } - return claimedTasks, nil + return claimedTasks, tx.Commit() } func (d *docket) saveTaskResult(l *slog.Logger, t task, taskErr error) { From cb4396f8150e2daa7819b620e704451bbddc7f6d Mon Sep 17 00:00:00 2001 From: agaton Date: Tue, 10 Feb 2026 11:44:11 +0100 Subject: [PATCH 7/8] force index --- dao.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/dao.go b/dao.go index 692cf26..feadf17 100644 --- a/dao.go +++ b/dao.go @@ -193,7 +193,15 @@ func (d *docket) claimTasks(wantNum *want.Counter) ([]task, error) { if err != nil { return nil, err } - _, err = tx.Exec(`SET LOCAL enable_seqscan = off`) + _, err = tx.Exec(`SET LOCAL enable_hashjoin = off;`) + if err != nil { + return nil, err + } + _, err = tx.Exec(`SET LOCAL enable_mergejoin = off;`) + if err != nil { + return nil, err + } + _, err = tx.Exec(`SET LOCAL jit = off;`) if err != nil { return nil, err } From e4f2d31a03d493ece198c68821168feba466a9ee Mon Sep 17 00:00:00 2001 From: agaton Date: Sun, 15 Feb 2026 13:44:31 +0100 Subject: [PATCH 8/8] improving query --- dao.go | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/dao.go b/dao.go index feadf17..3329cc6 100644 --- a/dao.go +++ b/dao.go @@ -210,14 +210,23 @@ func (d *docket) claimTasks(wantNum *want.Counter) ([]task, error) { SELECT task_id FROM unnest($2 :: TEXT[], $3 :: INT[]) f(func_names, group_limit) JOIN LATERAL ( - SELECT task_id - FROM pqdocket_task - WHERE completed_at IS NULL - AND func = ANY(string_to_array(f.func_names, ',')) - AND (scheduled_at < now()) - AND (claimed_until IS NULL OR now() > claimed_until) - AND claim_count < $1 - ORDER BY scheduled_at ASC + SELECT combined_tasks.task_id + FROM ( + SELECT single_func + FROM unnest(string_to_array(f.func_names, ',')) AS single_func + ) f2 + CROSS JOIN LATERAL ( + SELECT task_id, scheduled_at + FROM pqdocket_task + WHERE completed_at IS NULL + AND claim_count < $1 + AND func = f2.single_func + AND scheduled_at < now() + AND (claimed_until IS NULL OR now() > claimed_until) + ORDER BY scheduled_at ASC + LIMIT f.group_limit + ) combined_tasks + ORDER BY combined_tasks.scheduled_at ASC LIMIT f.group_limit ) a ON true FOR UPDATE SKIP LOCKED