package jwtutil import ( "crypto/rand" "crypto/rsa" "os" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/pkg/errors" ) func LoadOrGenerateKey(path string, defaultKeySize int) (jwk.Key, error) { key, err := LoadKey(path) if err != nil { if !errors.Is(err, os.ErrNotExist) { return nil, errors.WithStack(err) } key, err = GenerateKey(defaultKeySize) if err != nil { return nil, errors.WithStack(err) } if err := SaveKey(path, key); err != nil { return nil, errors.WithStack(err) } } return key, nil } func LoadKey(path string) (jwk.Key, error) { data, err := os.ReadFile(path) if err != nil { return nil, errors.WithStack(err) } key, err := jwk.ParseKey(data, jwk.WithPEM(true)) if err != nil { return nil, errors.WithStack(err) } return key, nil } func SaveKey(path string, key jwk.Key) error { data, err := jwk.Pem(key) if err != nil { return errors.WithStack(err) } if err := os.WriteFile(path, data, os.FileMode(0600)); err != nil { return errors.WithStack(err) } return nil } func GenerateKey(keySize int) (jwk.Key, error) { rsaKey, err := rsa.GenerateKey(rand.Reader, keySize) if err != nil { return nil, errors.WithStack(err) } key, err := jwk.FromRaw(rsaKey) if err != nil { return nil, errors.WithStack(err) } return key, nil }