hydra-werther/internal/hydra/hydra.go

160 lines
3.6 KiB
Go
Raw Permalink Normal View History

2019-02-18 14:57:54 +01:00
/*
2019-05-24 15:13:15 +02:00
Copyright (c) JSC iCore.
2019-02-18 14:57:54 +01:00
2019-05-24 15:13:15 +02:00
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
2019-02-18 14:57:54 +01:00
*/
package hydra
import (
"bytes"
"encoding/json"
2019-05-15 14:03:05 +02:00
"errors"
2019-02-18 14:57:54 +01:00
"fmt"
"io/ioutil"
"net/http"
"net/url"
2019-05-15 14:03:05 +02:00
)
2019-02-18 14:57:54 +01:00
2019-05-15 14:03:05 +02:00
var (
2019-05-24 15:13:15 +02:00
// ErrChallengeMissed is an error that happens when a challenge is missed.
ErrChallengeMissed = errors.New("challenge missed")
2019-05-15 14:03:05 +02:00
// ErrUnauthenticated is an error that happens when authentication is failed.
ErrUnauthenticated = errors.New("unauthenticated")
// ErrChallengeNotFound is an error that happens when an unknown challenge is used.
ErrChallengeNotFound = errors.New("challenge not found")
// ErrChallengeExpired is an error that happens when a challenge is already used.
ErrChallengeExpired = errors.New("challenge expired")
2019-02-18 14:57:54 +01:00
)
type reqType string
const (
login reqType = "login"
consent reqType = "consent"
2019-05-15 14:03:05 +02:00
logout reqType = "logout"
2019-02-18 14:57:54 +01:00
)
2019-05-15 14:03:05 +02:00
// ReqInfo contains information on an ongoing login or consent request.
type ReqInfo struct {
Challenge string `json:"challenge"`
RequestedScopes []string `json:"requested_scope"`
Skip bool `json:"skip"`
Subject string `json:"subject"`
}
func initiateRequest(typ reqType, hydraURL, challenge string) (*ReqInfo, error) {
2019-05-24 15:13:15 +02:00
if challenge == "" {
return nil, ErrChallengeMissed
}
2019-05-15 14:03:05 +02:00
ref, err := url.Parse(fmt.Sprintf("oauth2/auth/requests/%[1]s?%[1]s_challenge=%s", string(typ), challenge))
2019-02-18 14:57:54 +01:00
if err != nil {
return nil, err
}
u, err := parseURL(hydraURL)
if err != nil {
return nil, err
}
u = u.ResolveReference(ref)
resp, err := http.Get(u.String())
if err != nil {
return nil, err
}
if err = checkResponse(resp); err != nil {
return nil, err
}
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
2019-05-15 14:03:05 +02:00
var ri ReqInfo
2019-02-18 14:57:54 +01:00
if err := json.Unmarshal(data, &ri); err != nil {
return nil, err
}
return &ri, nil
}
func acceptRequest(typ reqType, hydraURL, challenge string, data interface{}) (string, error) {
2019-05-24 15:13:15 +02:00
if challenge == "" {
return "", ErrChallengeMissed
}
2019-05-15 14:03:05 +02:00
ref, err := url.Parse(fmt.Sprintf("oauth2/auth/requests/%[1]s/accept?%[1]s_challenge=%s", string(typ), challenge))
2019-02-18 14:57:54 +01:00
if err != nil {
return "", err
}
u, err := parseURL(hydraURL)
if err != nil {
return "", err
}
u = u.ResolveReference(ref)
2019-05-15 14:03:05 +02:00
var body []byte
if data != nil {
if body, err = json.Marshal(data); err != nil {
return "", err
}
2019-02-18 14:57:54 +01:00
}
r, err := http.NewRequest(http.MethodPut, u.String(), bytes.NewBuffer(body))
if err != nil {
return "", err
}
r.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(r)
if err != nil {
return "", err
}
defer resp.Body.Close()
if err := checkResponse(resp); err != nil {
return "", err
}
2019-05-15 14:03:05 +02:00
var rs struct {
2019-02-18 14:57:54 +01:00
RedirectTo string `json:"redirect_to"`
}
dec := json.NewDecoder(resp.Body)
if err := dec.Decode(&rs); err != nil {
return "", err
}
return rs.RedirectTo, nil
}
2019-05-15 14:03:05 +02:00
func checkResponse(resp *http.Response) error {
if resp.StatusCode >= 200 && resp.StatusCode <= 302 {
return nil
}
switch resp.StatusCode {
case 401:
return ErrUnauthenticated
case 404:
return ErrChallengeNotFound
case 409:
return ErrChallengeExpired
default:
var rs struct {
Message string `json:"error"`
}
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
if err := json.Unmarshal(data, &rs); err != nil {
return err
}
return fmt.Errorf("bad HTTP status code %d with message %q", resp.StatusCode, rs.Message)
}
}
2019-02-18 14:57:54 +01:00
func parseURL(s string) (*url.URL, error) {
if len(s) > 0 && s[len(s)-1] != '/' {
s += "/"
}
u, err := url.Parse(s)
if err != nil {
return nil, err
}
return u, nil
}