2019-02-18 14:57:54 +01:00
|
|
|
/*
|
2019-05-24 15:13:15 +02:00
|
|
|
Copyright (c) JSC iCore.
|
2019-02-18 14:57:54 +01:00
|
|
|
|
2019-05-24 15:13:15 +02:00
|
|
|
This source code is licensed under the MIT license found in the
|
|
|
|
LICENSE file in the root directory of this source tree.
|
2019-02-18 14:57:54 +01:00
|
|
|
*/
|
|
|
|
|
|
|
|
package hydra
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"encoding/json"
|
2019-05-15 14:03:05 +02:00
|
|
|
"errors"
|
2019-02-18 14:57:54 +01:00
|
|
|
"fmt"
|
|
|
|
"io/ioutil"
|
|
|
|
"net/http"
|
|
|
|
"net/url"
|
2019-05-15 14:03:05 +02:00
|
|
|
)
|
2019-02-18 14:57:54 +01:00
|
|
|
|
2019-05-15 14:03:05 +02:00
|
|
|
var (
|
2019-05-24 15:13:15 +02:00
|
|
|
// ErrChallengeMissed is an error that happens when a challenge is missed.
|
|
|
|
ErrChallengeMissed = errors.New("challenge missed")
|
2019-05-15 14:03:05 +02:00
|
|
|
// ErrUnauthenticated is an error that happens when authentication is failed.
|
|
|
|
ErrUnauthenticated = errors.New("unauthenticated")
|
|
|
|
// ErrChallengeNotFound is an error that happens when an unknown challenge is used.
|
|
|
|
ErrChallengeNotFound = errors.New("challenge not found")
|
|
|
|
// ErrChallengeExpired is an error that happens when a challenge is already used.
|
|
|
|
ErrChallengeExpired = errors.New("challenge expired")
|
2023-07-26 10:33:54 +02:00
|
|
|
//ErrServiceUnavailable is an error that happens when the hydra admin service is unavailable
|
|
|
|
ErrServiceUnavailable = errors.New("hydra service unavailable")
|
2019-02-18 14:57:54 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
type reqType string
|
|
|
|
|
|
|
|
const (
|
|
|
|
login reqType = "login"
|
|
|
|
consent reqType = "consent"
|
2019-05-15 14:03:05 +02:00
|
|
|
logout reqType = "logout"
|
2019-02-18 14:57:54 +01:00
|
|
|
)
|
|
|
|
|
2019-05-15 14:03:05 +02:00
|
|
|
// ReqInfo contains information on an ongoing login or consent request.
|
|
|
|
type ReqInfo struct {
|
|
|
|
Challenge string `json:"challenge"`
|
|
|
|
RequestedScopes []string `json:"requested_scope"`
|
|
|
|
Skip bool `json:"skip"`
|
|
|
|
Subject string `json:"subject"`
|
|
|
|
}
|
|
|
|
|
2021-08-06 14:30:48 +02:00
|
|
|
func initiateRequest(typ reqType, hydraURL string, fakeTLSTermination bool, challenge string) (*ReqInfo, error) {
|
2019-05-24 15:13:15 +02:00
|
|
|
if challenge == "" {
|
|
|
|
return nil, ErrChallengeMissed
|
|
|
|
}
|
2019-05-15 14:03:05 +02:00
|
|
|
ref, err := url.Parse(fmt.Sprintf("oauth2/auth/requests/%[1]s?%[1]s_challenge=%s", string(typ), challenge))
|
2019-02-18 14:57:54 +01:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2023-07-26 10:33:54 +02:00
|
|
|
|
2019-02-18 14:57:54 +01:00
|
|
|
u, err := parseURL(hydraURL)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
u = u.ResolveReference(ref)
|
|
|
|
|
2021-05-13 07:40:27 +02:00
|
|
|
req, err := http.NewRequest("GET", u.String(), nil)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2021-08-06 14:30:48 +02:00
|
|
|
if fakeTLSTermination {
|
2021-05-13 07:40:27 +02:00
|
|
|
req.Header.Add("X-Forwarded-Proto", "https")
|
|
|
|
}
|
|
|
|
|
|
|
|
client := &http.Client{}
|
|
|
|
resp, err := client.Do(req)
|
2019-02-18 14:57:54 +01:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
if err = checkResponse(resp); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
data, err := ioutil.ReadAll(resp.Body)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2019-05-15 14:03:05 +02:00
|
|
|
var ri ReqInfo
|
2019-02-18 14:57:54 +01:00
|
|
|
if err := json.Unmarshal(data, &ri); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return &ri, nil
|
|
|
|
}
|
|
|
|
|
2021-08-06 14:30:48 +02:00
|
|
|
func acceptRequest(typ reqType, hydraURL string, fakeTLSTermination bool, challenge string, data interface{}) (string, error) {
|
2019-05-24 15:13:15 +02:00
|
|
|
if challenge == "" {
|
|
|
|
return "", ErrChallengeMissed
|
|
|
|
}
|
2019-05-15 14:03:05 +02:00
|
|
|
ref, err := url.Parse(fmt.Sprintf("oauth2/auth/requests/%[1]s/accept?%[1]s_challenge=%s", string(typ), challenge))
|
2019-02-18 14:57:54 +01:00
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
u, err := parseURL(hydraURL)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
u = u.ResolveReference(ref)
|
|
|
|
|
2019-05-15 14:03:05 +02:00
|
|
|
var body []byte
|
|
|
|
if data != nil {
|
|
|
|
if body, err = json.Marshal(data); err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
2019-02-18 14:57:54 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
r, err := http.NewRequest(http.MethodPut, u.String(), bytes.NewBuffer(body))
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
2021-08-06 14:30:48 +02:00
|
|
|
if fakeTLSTermination {
|
2021-05-13 07:45:37 +02:00
|
|
|
r.Header.Add("X-Forwarded-Proto", "https")
|
|
|
|
}
|
|
|
|
|
2019-02-18 14:57:54 +01:00
|
|
|
r.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(r)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
if err := checkResponse(resp); err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
2019-05-15 14:03:05 +02:00
|
|
|
var rs struct {
|
2019-02-18 14:57:54 +01:00
|
|
|
RedirectTo string `json:"redirect_to"`
|
|
|
|
}
|
|
|
|
dec := json.NewDecoder(resp.Body)
|
|
|
|
if err := dec.Decode(&rs); err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
return rs.RedirectTo, nil
|
|
|
|
}
|
|
|
|
|
2019-05-15 14:03:05 +02:00
|
|
|
func checkResponse(resp *http.Response) error {
|
|
|
|
if resp.StatusCode >= 200 && resp.StatusCode <= 302 {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
switch resp.StatusCode {
|
|
|
|
case 401:
|
|
|
|
return ErrUnauthenticated
|
|
|
|
case 404:
|
|
|
|
return ErrChallengeNotFound
|
|
|
|
case 409:
|
|
|
|
return ErrChallengeExpired
|
2023-07-26 10:33:54 +02:00
|
|
|
case 503:
|
|
|
|
return ErrServiceUnavailable
|
2019-05-15 14:03:05 +02:00
|
|
|
default:
|
|
|
|
var rs struct {
|
|
|
|
Message string `json:"error"`
|
|
|
|
}
|
|
|
|
data, err := ioutil.ReadAll(resp.Body)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if err := json.Unmarshal(data, &rs); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return fmt.Errorf("bad HTTP status code %d with message %q", resp.StatusCode, rs.Message)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-02-18 14:57:54 +01:00
|
|
|
func parseURL(s string) (*url.URL, error) {
|
|
|
|
if len(s) > 0 && s[len(s)-1] != '/' {
|
|
|
|
s += "/"
|
|
|
|
}
|
|
|
|
u, err := url.Parse(s)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return u, nil
|
|
|
|
}
|