Skip to content

Commit

Permalink
Merge pull request #403 from smallstep/mariano/yubikey-cache
Browse files Browse the repository at this point in the history
Implement yubikey connection cache
  • Loading branch information
maraino authored Jan 11, 2024
2 parents 2f7051e + 4817292 commit efd90c5
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 15 deletions.
33 changes: 29 additions & 4 deletions kms/yubikey/yubikey.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var oidYubicoSerialNumber = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 41482, 3, 7}
type YubiKey struct {
yk pivKey
pin string
card string
managementKey [24]byte
}

Expand All @@ -45,11 +46,28 @@ type pivKey interface {
}

var pivCards = piv.Cards
var pivMap sync.Map

// pivOpen calls piv.Open. It can be replaced by a custom functions for testing
// purposes.
var pivOpen = func(card string) (pivKey, error) {
return piv.Open(card)
}

// openCard wraps pivOpen with a cache. It loads a card connection from the
// cache if present.
func openCard(card string) (pivKey, error) {
if v, ok := pivMap.Load(card); ok {
return v.(pivKey), nil
}
yk, err := pivOpen(card)
if err != nil {
return nil, err
}
pivMap.Store(card, yk)
return yk, nil
}

// New initializes a new YubiKey KMS.
//
// The most common way to open a YubiKey is to add a URI in the options:
Expand Down Expand Up @@ -116,30 +134,33 @@ func New(_ context.Context, opts apiv1.Options) (*YubiKey, error) {
if len(cards) == 0 {
return nil, errors.New("error detecting yubikey: try removing and reconnecting the device")
}
card := cards[0]

var yk pivKey
if serial != "" {
// Attempt to locate the yubikey with the given serial.
for _, name := range cards {
if k, err := pivOpen(name); err == nil {
if k, err := openCard(name); err == nil {
if cert, err := k.Attest(piv.SlotAuthentication); err == nil {
if serial == getSerialNumber(cert) {
yk = k
card = name
break
}
}
}
}
if yk == nil {
return nil, errors.Errorf("failed to find key with serial number %s", serial)
return nil, errors.Errorf("failed to find key with serial number %s, slot 0x9a might be empty", serial)
}
} else if yk, err = pivOpen(cards[0]); err != nil {
} else if yk, err = openCard(cards[0]); err != nil {
return nil, errors.Wrap(err, "error opening yubikey")
}

return &YubiKey{
yk: yk,
pin: pin,
card: card,
managementKey: managementKey,
}, nil
}
Expand Down Expand Up @@ -338,7 +359,11 @@ func (k *YubiKey) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1

// Close releases the connection to the YubiKey.
func (k *YubiKey) Close() error {
return errors.Wrap(k.yk.Close(), "error closing yubikey")
if err := k.yk.Close(); err != nil {
return errors.Wrap(err, "error closing yubikey")
}
pivMap.Delete(k.card)
return nil
}

// getPublicKey returns the public key on a slot. First it attempts to do
Expand Down
42 changes: 31 additions & 11 deletions kms/yubikey/yubikey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"encoding/asn1"
"errors"
"reflect"
"sync"
"testing"

"github.com/go-piv/piv-go/piv"
Expand All @@ -35,6 +36,7 @@ type stubPivKey struct {
certMap map[piv.Slot]*x509.Certificate
signerMap map[piv.Slot]interface{}
keyOptionsMap map[piv.Slot]piv.Key
closeErr error
}

type symmetricAlgorithm int
Expand Down Expand Up @@ -215,7 +217,7 @@ func (s *stubPivKey) Attest(slot piv.Slot) (*x509.Certificate, error) {
}

func (s *stubPivKey) Close() error {
return nil
return s.closeErr
}

func TestRegister(t *testing.T) {
Expand All @@ -242,6 +244,7 @@ func TestNew(t *testing.T) {
pOpen := pivOpen
pCards := pivCards
t.Cleanup(func() {
pivMap = sync.Map{}
pivOpen = pOpen
pivCards = pCards
})
Expand Down Expand Up @@ -285,57 +288,71 @@ func TestNew(t *testing.T) {
{"ok", args{ctx, apiv1.Options{}}, func() {
pivCards = okPivCards
pivOpen = okPivOpen
}, &YubiKey{yk: yk, pin: "123456", managementKey: piv.DefaultManagementKey}, false},
}, &YubiKey{yk: yk, pin: "123456", card: "Yubico YubiKey OTP+FIDO+CCID", managementKey: piv.DefaultManagementKey}, false},
{"ok with uri", args{ctx, apiv1.Options{
URI: "yubikey:pin-value=111111;management-key=001122334455667788990011223344556677889900112233",
}}, func() {
pivMap = sync.Map{}
pivCards = okMultiplePivCards
pivOpen = okPivOpen
}, &YubiKey{yk: yk, pin: "111111", managementKey: [24]byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, 0x11, 0x22, 0x33}}, false},
}, &YubiKey{yk: yk, pin: "111111", card: "Yubico YubiKey OTP+FIDO+CCID", managementKey: [24]byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, 0x11, 0x22, 0x33}}, false},
{"ok with uri and serial", args{ctx, apiv1.Options{
URI: "yubikey:serial=112233?pin-value=123456",
}}, func() {
pivMap = sync.Map{}
pivCards = okPivCards
pivOpen = okPivOpen
}, &YubiKey{yk: yk, pin: "123456", card: "Yubico YubiKey OTP+FIDO+CCID", managementKey: piv.DefaultManagementKey}, false},
{"ok with uri and serial from cache", args{ctx, apiv1.Options{
URI: "yubikey:serial=112233?pin-value=123456",
}}, func() {
pivCards = okPivCards
pivOpen = okPivOpen
}, &YubiKey{yk: yk, pin: "123456", managementKey: piv.DefaultManagementKey}, false},
}, &YubiKey{yk: yk, pin: "123456", card: "Yubico YubiKey OTP+FIDO+CCID", managementKey: piv.DefaultManagementKey}, false},
{"ok with Pin", args{ctx, apiv1.Options{Pin: "222222"}}, func() {
pivMap = sync.Map{}
pivCards = okPivCards
pivOpen = okPivOpen
}, &YubiKey{yk: yk, pin: "222222", managementKey: piv.DefaultManagementKey}, false},
}, &YubiKey{yk: yk, pin: "222222", card: "Yubico YubiKey OTP+FIDO+CCID", managementKey: piv.DefaultManagementKey}, false},
{"ok with ManagementKey", args{ctx, apiv1.Options{ManagementKey: "001122334455667788990011223344556677889900112233"}}, func() {
pivMap = sync.Map{}
pivCards = okPivCards
pivOpen = okPivOpen
}, &YubiKey{yk: yk, pin: "123456", managementKey: [24]byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, 0x11, 0x22, 0x33}}, false},
}, &YubiKey{yk: yk, pin: "123456", card: "Yubico YubiKey OTP+FIDO+CCID", managementKey: [24]byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, 0x11, 0x22, 0x33}}, false},
{"fail uri", args{ctx, apiv1.Options{URI: "badschema:"}}, func() {
pivMap = sync.Map{}
pivCards = okPivCards
pivOpen = okPivOpen
}, nil, true},
{"fail management key", args{ctx, apiv1.Options{URI: "yubikey:management-key=xxyyzz"}}, func() {
pivMap = sync.Map{}
pivCards = okPivCards
pivOpen = okPivOpen
}, nil, true},
{"fail management key size", args{ctx, apiv1.Options{URI: "yubikey:management-key=00112233"}}, func() {
pivMap = sync.Map{}
pivCards = okPivCards
pivOpen = okPivOpen
}, nil, true},
{"fail pivCards", args{ctx, apiv1.Options{}}, func() {
pivMap = sync.Map{}
pivCards = failPivCards
pivOpen = okPivOpen

}, nil, true},
{"fail no pivCards", args{ctx, apiv1.Options{}}, func() {
pivMap = sync.Map{}
pivCards = failNoPivCards
pivOpen = okPivOpen

}, nil, true},
{"fail no pivCards with serial", args{ctx, apiv1.Options{
URI: "yubikey:pin-value=111111;serial=332211?pin-value=123456",
}}, func() {
pivMap = sync.Map{}
pivCards = okPivCards
pivOpen = okPivOpen

}, nil, true},
{"fail pivOpen", args{ctx, apiv1.Options{}}, func() {
pivMap = sync.Map{}
pivCards = okPivCards
pivOpen = failPivOpen
}, nil, true},
Expand Down Expand Up @@ -1013,7 +1030,9 @@ func TestYubiKey_CreateAttestation(t *testing.T) {
}

func TestYubiKey_Close(t *testing.T) {
yk := newStubPivKey(t, ECDSA)
yk1 := newStubPivKey(t, ECDSA)
yk2 := newStubPivKey(t, RSA)
yk2.closeErr = errors.New("some error")

type fields struct {
yk pivKey
Expand All @@ -1025,7 +1044,8 @@ func TestYubiKey_Close(t *testing.T) {
fields fields
wantErr bool
}{
{"ok", fields{yk, "123456", piv.DefaultManagementKey}, false},
{"ok", fields{yk1, "123456", piv.DefaultManagementKey}, false},
{"fail", fields{yk2, "123456", piv.DefaultManagementKey}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down

0 comments on commit efd90c5

Please sign in to comment.