Initial commit
This commit is contained in:
101
oidc/client.go
Normal file
101
oidc/client.go
Normal file
@ -0,0 +1,101 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/dchest/uniuri"
|
||||
"github.com/pkg/errors"
|
||||
"gitlab.com/wpetit/goweb/middleware/container"
|
||||
"gitlab.com/wpetit/goweb/service/session"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
oauth2 *oauth2.Config
|
||||
provider *oidc.Provider
|
||||
verifier *oidc.IDTokenVerifier
|
||||
}
|
||||
|
||||
func (c *Client) Verifier() *oidc.IDTokenVerifier {
|
||||
return c.verifier
|
||||
}
|
||||
|
||||
func (c *Client) Provider() *oidc.Provider {
|
||||
return c.provider
|
||||
}
|
||||
|
||||
func (c *Client) Redirect(w http.ResponseWriter, r *http.Request) {
|
||||
ctn := container.Must(r.Context())
|
||||
|
||||
sess, err := session.Must(ctn).Get(w, r)
|
||||
if err != nil {
|
||||
panic(errors.Wrap(err, "could not retrieve session"))
|
||||
}
|
||||
|
||||
state := uniuri.New()
|
||||
|
||||
sess.Set(SessionOIDCStateKey, state)
|
||||
|
||||
if err := sess.Save(w, r); err != nil {
|
||||
panic(errors.Wrap(err, "could not save session"))
|
||||
}
|
||||
|
||||
http.Redirect(w, r, c.oauth2.AuthCodeURL(state), http.StatusFound)
|
||||
}
|
||||
|
||||
func (c *Client) Validate(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, error) {
|
||||
ctx := r.Context()
|
||||
ctn := container.Must(ctx)
|
||||
|
||||
sess, err := session.Must(ctn).Get(w, r)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not retrieve session")
|
||||
}
|
||||
|
||||
state, ok := sess.Get(SessionOIDCStateKey).(string)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid state")
|
||||
}
|
||||
|
||||
if r.URL.Query().Get("state") != state {
|
||||
return nil, errors.New("state mismatch")
|
||||
}
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
|
||||
token, err := c.oauth2.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not exchange token")
|
||||
}
|
||||
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, errors.New("could not find id token")
|
||||
}
|
||||
|
||||
idToken, err := c.verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not verify id token")
|
||||
}
|
||||
|
||||
return idToken, nil
|
||||
}
|
||||
|
||||
func NewClient(opts ...OptionFunc) *Client {
|
||||
opt := fromDefault(opts...)
|
||||
|
||||
oauth2 := &oauth2.Config{
|
||||
ClientID: opt.ClientID,
|
||||
ClientSecret: opt.ClientSecret,
|
||||
Endpoint: opt.Provider.Endpoint(),
|
||||
RedirectURL: opt.RedirectURL,
|
||||
Scopes: opt.Scopes,
|
||||
}
|
||||
|
||||
verifier := opt.Provider.Verifier(&oidc.Config{
|
||||
ClientID: opt.ClientID,
|
||||
})
|
||||
|
||||
return &Client{oauth2, opt.Provider, verifier}
|
||||
}
|
52
oidc/middleware.go
Normal file
52
oidc/middleware.go
Normal file
@ -0,0 +1,52 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/pkg/errors"
|
||||
"gitlab.com/wpetit/goweb/middleware/container"
|
||||
"gitlab.com/wpetit/goweb/service/session"
|
||||
)
|
||||
|
||||
const (
|
||||
SessionOIDCTokenKey = "oidc-token"
|
||||
SessionOIDCStateKey = "oidc-state"
|
||||
)
|
||||
|
||||
func Middleware(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
if _, err := IDToken(w, r); err != nil {
|
||||
ctn := container.Must(r.Context())
|
||||
|
||||
log.Println("retrieving oidc client")
|
||||
|
||||
client := Must(ctn)
|
||||
|
||||
client.Redirect(w, r)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
|
||||
func IDToken(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, error) {
|
||||
ctn := container.Must(r.Context())
|
||||
|
||||
sess, err := session.Must(ctn).Get(w, r)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not retrieve session")
|
||||
}
|
||||
|
||||
idToken, ok := sess.Get(SessionOIDCTokenKey).(*oidc.IDToken)
|
||||
if !ok || idToken == nil {
|
||||
return nil, errors.New("invalid id token")
|
||||
}
|
||||
|
||||
return idToken, nil
|
||||
}
|
52
oidc/option.go
Normal file
52
oidc/option.go
Normal file
@ -0,0 +1,52 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
)
|
||||
|
||||
type OptionFunc func(*Option)
|
||||
|
||||
type Option struct {
|
||||
Provider *oidc.Provider
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
RedirectURL string
|
||||
Scopes []string
|
||||
}
|
||||
|
||||
func WithCredentials(clientID, clientSecret string) OptionFunc {
|
||||
return func(opt *Option) {
|
||||
opt.ClientID = clientID
|
||||
opt.ClientSecret = clientSecret
|
||||
}
|
||||
}
|
||||
|
||||
func WithScopes(scopes ...string) OptionFunc {
|
||||
return func(opt *Option) {
|
||||
opt.Scopes = scopes
|
||||
}
|
||||
}
|
||||
|
||||
func NewProvider(ctx context.Context, issuer string) (*oidc.Provider, error) {
|
||||
return oidc.NewProvider(ctx, issuer)
|
||||
}
|
||||
|
||||
func WithProvider(provider *oidc.Provider) OptionFunc {
|
||||
return func(opt *Option) {
|
||||
opt.Provider = provider
|
||||
}
|
||||
}
|
||||
|
||||
func fromDefault(funcs ...OptionFunc) *Option {
|
||||
opt := &Option{
|
||||
Scopes: []string{oidc.ScopeOpenID},
|
||||
}
|
||||
|
||||
for _, f := range funcs {
|
||||
f(opt)
|
||||
}
|
||||
|
||||
return opt
|
||||
}
|
11
oidc/provider.go
Normal file
11
oidc/provider.go
Normal file
@ -0,0 +1,11 @@
|
||||
package oidc
|
||||
|
||||
import "gitlab.com/wpetit/goweb/service"
|
||||
|
||||
func ServiceProvider(opts ...OptionFunc) service.Provider {
|
||||
client := NewClient(opts...)
|
||||
|
||||
return func(ctn *service.Container) (interface{}, error) {
|
||||
return client, nil
|
||||
}
|
||||
}
|
33
oidc/service.go
Normal file
33
oidc/service.go
Normal file
@ -0,0 +1,33 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
"gitlab.com/wpetit/goweb/service"
|
||||
)
|
||||
|
||||
const ServiceName service.Name = "oidc"
|
||||
|
||||
// From retrieves the oidc service in the given container
|
||||
func From(container *service.Container) (*Client, error) {
|
||||
service, err := container.Service(ServiceName)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error while retrieving '%s' service", ServiceName)
|
||||
}
|
||||
|
||||
srv, ok := service.(*Client)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("retrieved service is not a valid '%s' service", ServiceName)
|
||||
}
|
||||
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
// Must retrieves the oidc service in the given container or panic otherwise
|
||||
func Must(container *service.Container) *Client {
|
||||
srv, err := From(container)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return srv
|
||||
}
|
Reference in New Issue
Block a user