diff --git a/plugin/objectsigner/gpg/example_test.go b/plugin/objectsigner/gpg/example_test.go index 9e86ccf..fb40da6 100644 --- a/plugin/objectsigner/gpg/example_test.go +++ b/plugin/objectsigner/gpg/example_test.go @@ -1,6 +1,7 @@ package gpg_test import ( + "context" "fmt" "strings" @@ -21,7 +22,7 @@ func ExampleFromKey() { panic(err) } - sig, err := signer.Sign(strings.NewReader("signed commit message\n")) + sig, err := signer.Sign(context.Background(), strings.NewReader("signed commit message\n")) if err != nil { panic(err) } diff --git a/plugin/objectsigner/gpg/gpg.go b/plugin/objectsigner/gpg/gpg.go index b77676e..88790af 100644 --- a/plugin/objectsigner/gpg/gpg.go +++ b/plugin/objectsigner/gpg/gpg.go @@ -4,6 +4,7 @@ package gpg import ( "bytes" + "context" "errors" "fmt" "io" @@ -36,8 +37,10 @@ type signer struct { } // Sign reads message and returns an ASCII-armored detached GPG -// signature created with the signer's OpenPGP key. -func (s *signer) Sign(message io.Reader) ([]byte, error) { +// signature created with the signer's OpenPGP key. The context is accepted +// for interface uniformity across signers; native OpenPGP signing is purely +// local and does not consult it. +func (s *signer) Sign(_ context.Context, message io.Reader) ([]byte, error) { if message == nil { return nil, ErrNilMessage } diff --git a/plugin/objectsigner/gpg/gpg_test.go b/plugin/objectsigner/gpg/gpg_test.go index f1f26b5..4375932 100644 --- a/plugin/objectsigner/gpg/gpg_test.go +++ b/plugin/objectsigner/gpg/gpg_test.go @@ -62,7 +62,7 @@ func TestSign(t *testing.T) { signer, err := gpg.FromKey(generateTestKey(t)) require.NoError(t, err) - sig, err := signer.Sign(test.message) + sig, err := signer.Sign(t.Context(), test.message) if test.wantErr != "" { require.ErrorContains(t, err, test.wantErr) require.Nil(t, sig) @@ -85,7 +85,7 @@ func TestSignVerifyRoundTrip(t *testing.T) { message := "signed commit message\n" - sig, err := signer.Sign(strings.NewReader(message)) + sig, err := signer.Sign(t.Context(), strings.NewReader(message)) require.NoError(t, err) keyring := openpgp.EntityList{key} @@ -106,10 +106,10 @@ func TestSignDifferentMessagesProduceDifferentSignatures(t *testing.T) { signer, err := gpg.FromKey(key) require.NoError(t, err) - sig1, err := signer.Sign(strings.NewReader("message one")) + sig1, err := signer.Sign(t.Context(), strings.NewReader("message one")) require.NoError(t, err) - sig2, err := signer.Sign(strings.NewReader("message two")) + sig2, err := signer.Sign(t.Context(), strings.NewReader("message two")) require.NoError(t, err) assert.NotEqual(t, sig1, sig2, "different messages produced identical signatures") diff --git a/plugin/objectsigner/program/program.go b/plugin/objectsigner/program/program.go index e264438..0b1ab11 100644 --- a/plugin/objectsigner/program/program.go +++ b/plugin/objectsigner/program/program.go @@ -150,17 +150,17 @@ func resolveProgram(program string, lookPath func(string) (string, error)) (stri } // Sign reads message and returns the signature produced by the external -// binary. -func (s *signer) Sign(message io.Reader) ([]byte, error) { +// binary. The context cancels the external program invocation. +func (s *signer) Sign(ctx context.Context, message io.Reader) ([]byte, error) { if message == nil { return nil, ErrNilMessage } switch s.format { case FormatOpenPGP, FormatX509: - return s.signStdio(context.Background(), message) + return s.signStdio(ctx, message) case FormatSSH: - return s.signSSH(context.Background(), message) + return s.signSSH(ctx, message) default: return nil, fmt.Errorf("%w: %q", ErrUnsupportedFormat, s.format) } diff --git a/plugin/objectsigner/program/program_test.go b/plugin/objectsigner/program/program_test.go index 2a6c2af..7cf8c9c 100644 --- a/plugin/objectsigner/program/program_test.go +++ b/plugin/objectsigner/program/program_test.go @@ -135,11 +135,61 @@ func TestSign_NilMessage(t *testing.T) { signer, _ := newTestSigner(t, FormatOpenPGP, "ABC", nil) - sig, err := signer.Sign(nil) + sig, err := signer.Sign(t.Context(), nil) require.ErrorIs(t, err, ErrNilMessage) require.Nil(t, sig) } +// TestSign_ThreadsContext asserts Sign threads the caller's context into the +// command invocation for every format, rather than building the command with a +// fresh context.Background(). The proof is propagation: cancelling the context +// passed to Sign is observable through the context the command was created +// with, which is only possible if it is the same context. +func TestSign_ThreadsContext(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + format Format + signingKey string + }{ + {name: "openpgp", format: FormatOpenPGP, signingKey: "KEYID"}, + {name: "x509", format: FormatX509, signingKey: "KEYID"}, + {name: "ssh", format: FormatSSH, signingKey: "/path/to/key"}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + signer, calls := newTestSigner(t, test.format, test.signingKey, func(cmd *mockCommand) error { + if test.format == FormatSSH { + return writeSignatureFile(cmd.args[len(cmd.args)-1]) + } + + _, err := io.WriteString(cmd.stdout, "SIG\n") + require.NoError(t, err) + + return nil + }) + + ctx, cancel := context.WithCancel(t.Context()) + t.Cleanup(cancel) + + _, err := signer.Sign(ctx, strings.NewReader("body")) + require.NoError(t, err) + + require.Len(t, calls(), 1) + got := calls()[0].ctx + require.NotNil(t, got) + + require.NoError(t, got.Err()) + cancel() + assert.ErrorIs(t, got.Err(), context.Canceled) + }) + } +} + func TestSign_StdioFormats(t *testing.T) { t.Parallel() @@ -163,7 +213,7 @@ func TestSign_StdioFormats(t *testing.T) { return nil }) - sig, err := signer.Sign(strings.NewReader("commit body\n")) + sig, err := signer.Sign(t.Context(), strings.NewReader("commit body\n")) require.NoError(t, err) assert.Equal(t, "STDIO-SIG\n", string(sig)) assert.Equal(t, "commit body\n", stdin) @@ -184,7 +234,7 @@ func TestSign_StdioFailure(t *testing.T) { return errStdioExit }) - sig, err := signer.Sign(strings.NewReader("body")) + sig, err := signer.Sign(t.Context(), strings.NewReader("body")) require.Error(t, err) assert.Contains(t, err.Error(), "stdio failed") require.Nil(t, sig) @@ -206,7 +256,7 @@ func TestSign_SSH(t *testing.T) { return writeSignatureFile(bufferFile) }) - sig, err := signer.Sign(strings.NewReader("commit body\n")) + sig, err := signer.Sign(t.Context(), strings.NewReader("commit body\n")) require.NoError(t, err) assert.Equal(t, "SSH-SIG\n", string(sig)) assert.Equal(t, "commit body\n", buffer) @@ -234,7 +284,7 @@ func TestSign_SSHExpandsHomePath(t *testing.T) { return writeSignatureFile(bufferFile) }) - sig, err := signer.Sign(strings.NewReader("commit body\n")) + sig, err := signer.Sign(t.Context(), strings.NewReader("commit body\n")) require.NoError(t, err) assert.Equal(t, "SSH-SIG\n", string(sig)) @@ -290,7 +340,7 @@ func TestSign_SSHLiteralKey(t *testing.T) { return writeSignatureFile(bufferFile) }) - sig, err := signer.Sign(strings.NewReader("commit body\n")) + sig, err := signer.Sign(t.Context(), strings.NewReader("commit body\n")) require.NoError(t, err) assert.Equal(t, "SSH-SIG\n", string(sig)) assert.Equal(t, "commit body\n", buffer) @@ -320,7 +370,7 @@ func TestSign_SSHFailure(t *testing.T) { return errSSHExit }) - sig, err := signer.Sign(strings.NewReader("body")) + sig, err := signer.Sign(t.Context(), strings.NewReader("body")) require.Error(t, err) assert.Contains(t, err.Error(), "ssh failed") require.Nil(t, sig) @@ -335,7 +385,7 @@ func TestSign_SSHPathPrefixedSshDash(t *testing.T) { return writeSignatureFile(bufferFile) }) - sig, err := signer.Sign(strings.NewReader("body")) + sig, err := signer.Sign(t.Context(), strings.NewReader("body")) require.NoError(t, err) require.NotNil(t, sig) @@ -359,7 +409,7 @@ func TestSign_StdioOutputTooLarge(t *testing.T) { return nil }) - sig, err := signer.Sign(strings.NewReader("body")) + sig, err := signer.Sign(t.Context(), strings.NewReader("body")) require.ErrorIs(t, err, ErrOutputLimitExceeded) assert.Contains(t, err.Error(), "stdout") require.Nil(t, sig) @@ -379,7 +429,7 @@ func TestSign_StderrTooLarge(t *testing.T) { return nil }) - sig, err := signer.Sign(strings.NewReader("body")) + sig, err := signer.Sign(t.Context(), strings.NewReader("body")) require.ErrorIs(t, err, ErrOutputLimitExceeded) assert.Contains(t, err.Error(), "stderr") require.Nil(t, sig) @@ -401,12 +451,13 @@ func TestSign_SSHSignatureTooLarge(t *testing.T) { return nil }) - sig, err := signer.Sign(strings.NewReader("body")) + sig, err := signer.Sign(t.Context(), strings.NewReader("body")) require.ErrorIs(t, err, ErrSignatureTooLarge) require.Nil(t, sig) } type mockCommand struct { + ctx context.Context //nolint:containedctx // captured to assert Sign threads its context into the command. run func(*mockCommand) error stdin io.Reader stdout io.Writer @@ -457,8 +508,9 @@ func stubCommand(run func(*mockCommand) error) ( ) { calls := make([]*mockCommand, 0, 1) - commandContext := func(_ context.Context, binary string, args ...string) command { + commandContext := func(ctx context.Context, binary string, args ...string) command { cmd := &mockCommand{ + ctx: ctx, run: run, stdin: nil, stdout: nil, diff --git a/plugin/objectsigner/ssh/example_test.go b/plugin/objectsigner/ssh/example_test.go index 67df0f5..04d3bcd 100644 --- a/plugin/objectsigner/ssh/example_test.go +++ b/plugin/objectsigner/ssh/example_test.go @@ -1,6 +1,7 @@ package ssh_test import ( + "context" "crypto/ed25519" "crypto/rand" "fmt" @@ -29,7 +30,7 @@ func ExampleFromKey() { panic(err) } - sig, err := signer.Sign(strings.NewReader("signed commit message\n")) + sig, err := signer.Sign(context.Background(), strings.NewReader("signed commit message\n")) if err != nil { panic(err) } diff --git a/plugin/objectsigner/ssh/ssh.go b/plugin/objectsigner/ssh/ssh.go index f858dc4..7af11aa 100644 --- a/plugin/objectsigner/ssh/ssh.go +++ b/plugin/objectsigner/ssh/ssh.go @@ -4,6 +4,7 @@ package ssh import ( + "context" "errors" "fmt" "io" @@ -69,8 +70,10 @@ type signer struct { // Sign reads the message from r and returns an armored SSH signature created // with the signer's SSH key and hash algorithm. -// The signature uses the "git" namespace. -func (s *signer) Sign(message io.Reader) ([]byte, error) { +// The signature uses the "git" namespace. The context is accepted for +// interface uniformity across signers; SSH signing is purely local and does +// not consult it. +func (s *signer) Sign(_ context.Context, message io.Reader) ([]byte, error) { if message == nil { return nil, ErrNilMessage } @@ -82,3 +85,9 @@ func (s *signer) Sign(message io.Reader) ([]byte, error) { return sshsig.Armor(sig), nil } + +// KeyID returns the SHA256 fingerprint of the signer's SSH public key, +// in the form "SHA256:...". +func (s *signer) KeyID() string { + return gossh.FingerprintSHA256(s.signer.PublicKey()) +} diff --git a/plugin/objectsigner/ssh/ssh_test.go b/plugin/objectsigner/ssh/ssh_test.go index e41d199..95d4092 100644 --- a/plugin/objectsigner/ssh/ssh_test.go +++ b/plugin/objectsigner/ssh/ssh_test.go @@ -71,7 +71,7 @@ func TestSign(t *testing.T) { signer, err := ssh.FromKey(sshSigner, ssh.WithHashAlgorithm(test.algo)) require.NoError(t, err) - sig, err := signer.Sign(test.message) + sig, err := signer.Sign(t.Context(), test.message) if test.wantErr == "" { require.NoError(t, err) assert.NotEmpty(t, sig) @@ -95,7 +95,7 @@ func TestSignVerifyRoundTrip(t *testing.T) { message := "signed commit message\n" - sig, err := signer.Sign(strings.NewReader(message)) + sig, err := signer.Sign(t.Context(), strings.NewReader(message)) require.NoError(t, err) ssig, err := sshsig.Unarmor(sig) @@ -113,15 +113,26 @@ func TestSignDifferentMessagesProduceDifferentSignatures(t *testing.T) { signer, err := ssh.FromKey(sshSigner) require.NoError(t, err) - sig1, err := signer.Sign(strings.NewReader("message one")) + sig1, err := signer.Sign(t.Context(), strings.NewReader("message one")) require.NoError(t, err) - sig2, err := signer.Sign(strings.NewReader("message two")) + sig2, err := signer.Sign(t.Context(), strings.NewReader("message two")) require.NoError(t, err) assert.NotEqual(t, sig1, sig2, "different messages produced identical signatures") } +func TestKeyID(t *testing.T) { + t.Parallel() + + key := generateTestSigner(t) + signer, err := ssh.FromKey(key) + require.NoError(t, err) + + assert.Equal(t, gossh.FingerprintSHA256(key.PublicKey()), signer.KeyID()) + assert.Contains(t, signer.KeyID(), "SHA256:") +} + //nolint:ireturn // gossh.NewSignerFromKey returns gossh.Signer (interface); no concrete type is accessible func generateTestSigner(t *testing.T) gossh.Signer { t.Helper()