package jwk import ( "crypto/rand" "crypto/rsa" "encoding/json" "os" "github.com/btcsuite/btcd/btcutil/base58" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" "github.com/oklog/ulid/v2" "github.com/pkg/errors" ) const DefaultKeySize = 2048 type ( Key = jwk.Key Set = jwk.Set ParseOption = jwk.ParseOption ) var ( FromRaw = jwk.FromRaw NewSet = jwk.NewSet ParseKey = jwk.ParseKey ) const AlgorithmKey = jwk.AlgorithmKey func Parse(src []byte, options ...jwk.ParseOption) (Set, error) { return jwk.Parse(src, options...) } func PublicKeySet(keys ...jwk.Key) (jwk.Set, error) { set := jwk.NewSet() for _, k := range keys { pubkey, err := k.PublicKey() if err != nil { return nil, errors.WithStack(err) } if err := pubkey.Set(jwk.AlgorithmKey, jwa.RS256); err != nil { return nil, errors.WithStack(err) } if err := set.AddKey(pubkey); err != nil { return nil, errors.WithStack(err) } } return set, nil } func LoadOrGenerate(path string, size int) (jwk.Key, error) { data, err := os.ReadFile(path) if err != nil && !errors.Is(err, os.ErrNotExist) { return nil, errors.WithStack(err) } if errors.Is(err, os.ErrNotExist) { key, err := Generate(size) if err != nil { return nil, errors.WithStack(err) } data, err = json.Marshal(key) if err != nil { return nil, errors.WithStack(err) } if err := os.WriteFile(path, data, 0o640); err != nil { return nil, errors.WithStack(err) } } key, err := jwk.ParseKey(data) if err != nil { return nil, errors.WithStack(err) } return key, nil } func Generate(size int) (jwk.Key, error) { privKey, err := rsa.GenerateKey(rand.Reader, size) if err != nil { return nil, errors.WithStack(err) } key, err := jwk.FromRaw(privKey) if err != nil { return nil, errors.WithStack(err) } keyID := ulid.Make().String() if err := key.Set(jwk.KeyIDKey, keyID); err != nil { return nil, errors.WithStack(err) } return key, nil } func Sign(key jwk.Key, payload ...any) (string, error) { json, err := json.Marshal(payload) if err != nil { return "", errors.WithStack(err) } rawSignature, err := jws.Sign( nil, jws.WithKey(jwa.RS256, key), jws.WithDetachedPayload(json), ) if err != nil { return "", errors.WithStack(err) } signature := base58.Encode(rawSignature) return signature, nil } func Verify(jwks jwk.Set, signature string, payload ...any) (bool, error) { json, err := json.Marshal(payload) if err != nil { return false, errors.WithStack(err) } decoded := base58.Decode(signature) _, err = jws.Verify( decoded, jws.WithKeySet(jwks, jws.WithRequireKid(false)), jws.WithDetachedPayload(json), ) if err != nil { return false, errors.WithStack(err) } return true, nil }