Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion plugin/objectsigner/gpg/example_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gpg_test

import (
"context"
"fmt"
"strings"

Expand All @@ -21,7 +22,7 @@ func ExampleFromKey() {
panic(err)
}

sig, err := signer.Sign(strings.NewReader("signed commit message\n"))
sig, err := signer.Sign(context.Background(), strings.NewReader("signed commit message\n"))
if err != nil {
panic(err)
}
Expand Down
7 changes: 5 additions & 2 deletions plugin/objectsigner/gpg/gpg.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package gpg

import (
"bytes"
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -36,8 +37,10 @@ type signer struct {
}

// Sign reads message and returns an ASCII-armored detached GPG
// signature created with the signer's OpenPGP key.
func (s *signer) Sign(message io.Reader) ([]byte, error) {
// signature created with the signer's OpenPGP key. The context is accepted
// for interface uniformity across signers; native OpenPGP signing is purely
// local and does not consult it.
func (s *signer) Sign(_ context.Context, message io.Reader) ([]byte, error) {
if message == nil {
return nil, ErrNilMessage
}
Expand Down
8 changes: 4 additions & 4 deletions plugin/objectsigner/gpg/gpg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestSign(t *testing.T) {
signer, err := gpg.FromKey(generateTestKey(t))
require.NoError(t, err)

sig, err := signer.Sign(test.message)
sig, err := signer.Sign(t.Context(), test.message)
if test.wantErr != "" {
require.ErrorContains(t, err, test.wantErr)
require.Nil(t, sig)
Expand All @@ -85,7 +85,7 @@ func TestSignVerifyRoundTrip(t *testing.T) {

message := "signed commit message\n"

sig, err := signer.Sign(strings.NewReader(message))
sig, err := signer.Sign(t.Context(), strings.NewReader(message))
require.NoError(t, err)

keyring := openpgp.EntityList{key}
Expand All @@ -106,10 +106,10 @@ func TestSignDifferentMessagesProduceDifferentSignatures(t *testing.T) {
signer, err := gpg.FromKey(key)
require.NoError(t, err)

sig1, err := signer.Sign(strings.NewReader("message one"))
sig1, err := signer.Sign(t.Context(), strings.NewReader("message one"))
require.NoError(t, err)

sig2, err := signer.Sign(strings.NewReader("message two"))
sig2, err := signer.Sign(t.Context(), strings.NewReader("message two"))
require.NoError(t, err)

assert.NotEqual(t, sig1, sig2, "different messages produced identical signatures")
Expand Down
8 changes: 4 additions & 4 deletions plugin/objectsigner/program/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,17 +150,17 @@ func resolveProgram(program string, lookPath func(string) (string, error)) (stri
}

// Sign reads message and returns the signature produced by the external
// binary.
func (s *signer) Sign(message io.Reader) ([]byte, error) {
// binary. The context cancels the external program invocation.
func (s *signer) Sign(ctx context.Context, message io.Reader) ([]byte, error) {
if message == nil {
Comment thread
pjbgf marked this conversation as resolved.
return nil, ErrNilMessage
}

switch s.format {
case FormatOpenPGP, FormatX509:
return s.signStdio(context.Background(), message)
return s.signStdio(ctx, message)
case FormatSSH:
return s.signSSH(context.Background(), message)
return s.signSSH(ctx, message)
default:
return nil, fmt.Errorf("%w: %q", ErrUnsupportedFormat, s.format)
}
Expand Down
76 changes: 64 additions & 12 deletions plugin/objectsigner/program/program_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,61 @@ func TestSign_NilMessage(t *testing.T) {

signer, _ := newTestSigner(t, FormatOpenPGP, "ABC", nil)

sig, err := signer.Sign(nil)
sig, err := signer.Sign(t.Context(), nil)
require.ErrorIs(t, err, ErrNilMessage)
require.Nil(t, sig)
}

// TestSign_ThreadsContext asserts Sign threads the caller's context into the
// command invocation for every format, rather than building the command with a
// fresh context.Background(). The proof is propagation: cancelling the context
// passed to Sign is observable through the context the command was created
// with, which is only possible if it is the same context.
func TestSign_ThreadsContext(t *testing.T) {
t.Parallel()

tests := []struct {
name string
format Format
signingKey string
}{
{name: "openpgp", format: FormatOpenPGP, signingKey: "KEYID"},
{name: "x509", format: FormatX509, signingKey: "KEYID"},
{name: "ssh", format: FormatSSH, signingKey: "/path/to/key"},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()

signer, calls := newTestSigner(t, test.format, test.signingKey, func(cmd *mockCommand) error {
if test.format == FormatSSH {
return writeSignatureFile(cmd.args[len(cmd.args)-1])
}

_, err := io.WriteString(cmd.stdout, "SIG\n")
require.NoError(t, err)

return nil
})

ctx, cancel := context.WithCancel(t.Context())
t.Cleanup(cancel)

_, err := signer.Sign(ctx, strings.NewReader("body"))
require.NoError(t, err)

require.Len(t, calls(), 1)
got := calls()[0].ctx
require.NotNil(t, got)

require.NoError(t, got.Err())
cancel()
assert.ErrorIs(t, got.Err(), context.Canceled)
})
}
}

func TestSign_StdioFormats(t *testing.T) {
t.Parallel()

Expand All @@ -163,7 +213,7 @@ func TestSign_StdioFormats(t *testing.T) {
return nil
})

sig, err := signer.Sign(strings.NewReader("commit body\n"))
sig, err := signer.Sign(t.Context(), strings.NewReader("commit body\n"))
require.NoError(t, err)
assert.Equal(t, "STDIO-SIG\n", string(sig))
assert.Equal(t, "commit body\n", stdin)
Expand All @@ -184,7 +234,7 @@ func TestSign_StdioFailure(t *testing.T) {
return errStdioExit
})

sig, err := signer.Sign(strings.NewReader("body"))
sig, err := signer.Sign(t.Context(), strings.NewReader("body"))
require.Error(t, err)
assert.Contains(t, err.Error(), "stdio failed")
require.Nil(t, sig)
Expand All @@ -206,7 +256,7 @@ func TestSign_SSH(t *testing.T) {
return writeSignatureFile(bufferFile)
})

sig, err := signer.Sign(strings.NewReader("commit body\n"))
sig, err := signer.Sign(t.Context(), strings.NewReader("commit body\n"))
require.NoError(t, err)
assert.Equal(t, "SSH-SIG\n", string(sig))
assert.Equal(t, "commit body\n", buffer)
Expand Down Expand Up @@ -234,7 +284,7 @@ func TestSign_SSHExpandsHomePath(t *testing.T) {
return writeSignatureFile(bufferFile)
})

sig, err := signer.Sign(strings.NewReader("commit body\n"))
sig, err := signer.Sign(t.Context(), strings.NewReader("commit body\n"))
require.NoError(t, err)
assert.Equal(t, "SSH-SIG\n", string(sig))

Expand Down Expand Up @@ -290,7 +340,7 @@ func TestSign_SSHLiteralKey(t *testing.T) {
return writeSignatureFile(bufferFile)
})

sig, err := signer.Sign(strings.NewReader("commit body\n"))
sig, err := signer.Sign(t.Context(), strings.NewReader("commit body\n"))
require.NoError(t, err)
assert.Equal(t, "SSH-SIG\n", string(sig))
assert.Equal(t, "commit body\n", buffer)
Expand Down Expand Up @@ -320,7 +370,7 @@ func TestSign_SSHFailure(t *testing.T) {
return errSSHExit
})

sig, err := signer.Sign(strings.NewReader("body"))
sig, err := signer.Sign(t.Context(), strings.NewReader("body"))
require.Error(t, err)
assert.Contains(t, err.Error(), "ssh failed")
require.Nil(t, sig)
Expand All @@ -335,7 +385,7 @@ func TestSign_SSHPathPrefixedSshDash(t *testing.T) {
return writeSignatureFile(bufferFile)
})

sig, err := signer.Sign(strings.NewReader("body"))
sig, err := signer.Sign(t.Context(), strings.NewReader("body"))
require.NoError(t, err)
require.NotNil(t, sig)

Expand All @@ -359,7 +409,7 @@ func TestSign_StdioOutputTooLarge(t *testing.T) {
return nil
})

sig, err := signer.Sign(strings.NewReader("body"))
sig, err := signer.Sign(t.Context(), strings.NewReader("body"))
require.ErrorIs(t, err, ErrOutputLimitExceeded)
assert.Contains(t, err.Error(), "stdout")
require.Nil(t, sig)
Expand All @@ -379,7 +429,7 @@ func TestSign_StderrTooLarge(t *testing.T) {
return nil
})

sig, err := signer.Sign(strings.NewReader("body"))
sig, err := signer.Sign(t.Context(), strings.NewReader("body"))
require.ErrorIs(t, err, ErrOutputLimitExceeded)
assert.Contains(t, err.Error(), "stderr")
require.Nil(t, sig)
Expand All @@ -401,12 +451,13 @@ func TestSign_SSHSignatureTooLarge(t *testing.T) {
return nil
})

sig, err := signer.Sign(strings.NewReader("body"))
sig, err := signer.Sign(t.Context(), strings.NewReader("body"))
require.ErrorIs(t, err, ErrSignatureTooLarge)
require.Nil(t, sig)
}

type mockCommand struct {
ctx context.Context //nolint:containedctx // captured to assert Sign threads its context into the command.
run func(*mockCommand) error
stdin io.Reader
stdout io.Writer
Expand Down Expand Up @@ -457,8 +508,9 @@ func stubCommand(run func(*mockCommand) error) (
) {
calls := make([]*mockCommand, 0, 1)

commandContext := func(_ context.Context, binary string, args ...string) command {
commandContext := func(ctx context.Context, binary string, args ...string) command {
cmd := &mockCommand{
ctx: ctx,
run: run,
stdin: nil,
stdout: nil,
Expand Down
3 changes: 2 additions & 1 deletion plugin/objectsigner/ssh/example_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ssh_test

import (
"context"
"crypto/ed25519"
"crypto/rand"
"fmt"
Expand Down Expand Up @@ -29,7 +30,7 @@ func ExampleFromKey() {
panic(err)
}

sig, err := signer.Sign(strings.NewReader("signed commit message\n"))
sig, err := signer.Sign(context.Background(), strings.NewReader("signed commit message\n"))
if err != nil {
panic(err)
}
Expand Down
13 changes: 11 additions & 2 deletions plugin/objectsigner/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package ssh

import (
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -69,8 +70,10 @@ type signer struct {

// Sign reads the message from r and returns an armored SSH signature created
// with the signer's SSH key and hash algorithm.
// The signature uses the "git" namespace.
func (s *signer) Sign(message io.Reader) ([]byte, error) {
// The signature uses the "git" namespace. The context is accepted for
// interface uniformity across signers; SSH signing is purely local and does
// not consult it.
func (s *signer) Sign(_ context.Context, message io.Reader) ([]byte, error) {
if message == nil {
return nil, ErrNilMessage
}
Expand All @@ -82,3 +85,9 @@ func (s *signer) Sign(message io.Reader) ([]byte, error) {

return sshsig.Armor(sig), nil
}

// KeyID returns the SHA256 fingerprint of the signer's SSH public key,
// in the form "SHA256:...".
func (s *signer) KeyID() string {
return gossh.FingerprintSHA256(s.signer.PublicKey())
}
19 changes: 15 additions & 4 deletions plugin/objectsigner/ssh/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func TestSign(t *testing.T) {
signer, err := ssh.FromKey(sshSigner, ssh.WithHashAlgorithm(test.algo))
require.NoError(t, err)

sig, err := signer.Sign(test.message)
sig, err := signer.Sign(t.Context(), test.message)
if test.wantErr == "" {
require.NoError(t, err)
assert.NotEmpty(t, sig)
Expand All @@ -95,7 +95,7 @@ func TestSignVerifyRoundTrip(t *testing.T) {

message := "signed commit message\n"

sig, err := signer.Sign(strings.NewReader(message))
sig, err := signer.Sign(t.Context(), strings.NewReader(message))
require.NoError(t, err)

ssig, err := sshsig.Unarmor(sig)
Expand All @@ -113,15 +113,26 @@ func TestSignDifferentMessagesProduceDifferentSignatures(t *testing.T) {
signer, err := ssh.FromKey(sshSigner)
require.NoError(t, err)

sig1, err := signer.Sign(strings.NewReader("message one"))
sig1, err := signer.Sign(t.Context(), strings.NewReader("message one"))
require.NoError(t, err)

sig2, err := signer.Sign(strings.NewReader("message two"))
sig2, err := signer.Sign(t.Context(), strings.NewReader("message two"))
require.NoError(t, err)

assert.NotEqual(t, sig1, sig2, "different messages produced identical signatures")
}

func TestKeyID(t *testing.T) {
t.Parallel()

key := generateTestSigner(t)
signer, err := ssh.FromKey(key)
require.NoError(t, err)

assert.Equal(t, gossh.FingerprintSHA256(key.PublicKey()), signer.KeyID())
assert.Contains(t, signer.KeyID(), "SHA256:")
}

//nolint:ireturn // gossh.NewSignerFromKey returns gossh.Signer (interface); no concrete type is accessible
func generateTestSigner(t *testing.T) gossh.Signer {
t.Helper()
Expand Down
Loading