hydra-webauthn/internal/hydra/client.go

264 lines
6.4 KiB
Go
Raw Normal View History

2023-11-15 20:38:25 +01:00
package hydra
import (
"bytes"
"encoding/json"
"net/http"
"net/url"
"time"
"github.com/pkg/errors"
)
type Client struct {
baseURL *url.URL
http *http.Client
}
func (c *Client) LoginRequest(challenge string) (*LoginResponse, error) {
u := fromURL(*c.baseURL, "/oauth2/auth/requests/login", url.Values{
"login_challenge": []string{challenge},
})
res, err := c.http.Get(u)
if err != nil {
return nil, errors.Wrap(err, "could not retrieve login response")
}
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
return nil, errors.Wrapf(ErrUnexpectedHydraResponse, "hydra responded with status code '%d'", res.StatusCode)
}
defer res.Body.Close()
decoder := json.NewDecoder(res.Body)
loginRes := &LoginResponse{}
if err := decoder.Decode(loginRes); err != nil {
return nil, errors.Wrap(err, "could not decode json response")
}
return loginRes, nil
}
func (c *Client) AcceptLoginRequest(challenge string, req *AcceptLoginRequest) (*AcceptResponse, error) {
u := fromURL(*c.baseURL, "/oauth2/auth/requests/login/accept", url.Values{
"login_challenge": []string{challenge},
})
res := &AcceptResponse{}
if err := c.putJSON(u, req, res); err != nil {
return nil, err
}
return res, nil
}
func (c *Client) RejectLoginRequest(challenge string, req *RejectRequest) (*RejectResponse, error) {
u := fromURL(*c.baseURL, "/oauth2/auth/requests/login/reject", url.Values{
"login_challenge": []string{challenge},
})
res := &RejectResponse{}
if err := c.putJSON(u, req, res); err != nil {
return nil, err
}
return res, nil
}
func (c *Client) LogoutRequest(challenge string) (*LogoutResponse, error) {
u := fromURL(*c.baseURL, "/oauth2/auth/requests/logout", url.Values{
"logout_challenge": []string{challenge},
})
res, err := c.http.Get(u)
if err != nil {
return nil, errors.Wrap(err, "could not retrieve logout response")
}
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
return nil, errors.Wrapf(ErrUnexpectedHydraResponse, "hydra responded with status code '%d'", res.StatusCode)
}
defer res.Body.Close()
decoder := json.NewDecoder(res.Body)
logoutRes := &LogoutResponse{}
if err := decoder.Decode(logoutRes); err != nil {
return nil, errors.Wrap(err, "could not decode json response")
}
return logoutRes, nil
}
func (c *Client) AcceptLogoutRequest(challenge string) (*AcceptResponse, error) {
u := fromURL(*c.baseURL, "/oauth2/auth/requests/logout/accept", url.Values{
"logout_challenge": []string{challenge},
})
res := &AcceptResponse{}
if err := c.putJSON(u, nil, res); err != nil {
return nil, err
}
return res, nil
}
func (c *Client) RejectLogoutRequest(challenge string, req *RejectRequest) (*RejectResponse, error) {
u := fromURL(*c.baseURL, "/oauth2/auth/requests/logout/reject", url.Values{
"logout_challenge": []string{challenge},
})
res := &RejectResponse{}
if err := c.putJSON(u, req, res); err != nil {
return nil, err
}
return res, nil
}
func (c *Client) ConsentRequest(challenge string) (*ConsentResponse, error) {
u := fromURL(*c.baseURL, "/oauth2/auth/requests/consent", url.Values{
"consent_challenge": []string{challenge},
})
res, err := c.http.Get(u)
if err != nil {
return nil, errors.Wrap(err, "could not retrieve login response")
}
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
return nil, errors.Wrapf(ErrUnexpectedHydraResponse, "hydra responded with status code '%d'", res.StatusCode)
}
defer res.Body.Close()
decoder := json.NewDecoder(res.Body)
consentRes := &ConsentResponse{}
if err := decoder.Decode(consentRes); err != nil {
return nil, errors.Wrap(err, "could not decode json response")
}
return consentRes, nil
}
func (c *Client) AcceptConsentRequest(challenge string, req *AcceptConsentRequest) (*AcceptResponse, error) {
u := fromURL(*c.baseURL, "/oauth2/auth/requests/consent/accept", url.Values{
"consent_challenge": []string{challenge},
})
res := &AcceptResponse{}
if err := c.putJSON(u, req, res); err != nil {
return nil, err
}
return res, nil
}
func (c *Client) RejectConsentRequest(challenge string, req *RejectRequest) (*RejectResponse, error) {
u := fromURL(*c.baseURL, "/oauth2/auth/requests/consent/reject", url.Values{
"consent_challenge": []string{challenge},
})
res := &RejectResponse{}
if err := c.putJSON(u, req, res); err != nil {
return nil, err
}
return res, nil
}
func (c *Client) LoginChallenge(r *http.Request) (string, error) {
return c.challenge(r, "login_challenge")
}
func (c *Client) ConsentChallenge(r *http.Request) (string, error) {
return c.challenge(r, "consent_challenge")
}
func (c *Client) LogoutChallenge(r *http.Request) (string, error) {
return c.challenge(r, "logout_challenge")
}
func (c *Client) challenge(r *http.Request, name string) (string, error) {
challenge := r.URL.Query().Get(name)
if challenge == "" {
return "", ErrChallengeNotFound
}
return challenge, nil
}
func (c *Client) putJSON(u string, payload interface{}, result interface{}) error {
var buf bytes.Buffer
encoder := json.NewEncoder(&buf)
if err := encoder.Encode(payload); err != nil {
return errors.Wrap(err, "could not encode request body")
}
req, err := http.NewRequest("PUT", u, &buf)
if err != nil {
return errors.Wrap(err, "could not create request")
}
res, err := c.http.Do(req)
if err != nil {
return errors.Wrap(err, "could not retrieve login response")
}
if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
return errors.Wrapf(ErrUnexpectedHydraResponse, "hydra responded with status code '%d'", res.StatusCode)
}
defer res.Body.Close()
decoder := json.NewDecoder(res.Body)
if err := decoder.Decode(result); err != nil {
return errors.Wrap(err, "could not decode json response")
}
return nil
}
func fromURL(url url.URL, path string, query url.Values) string {
url.Path = path
url.RawQuery = query.Encode()
return url.String()
}
type fakeSSLTerminationTransport struct {
T http.RoundTripper
}
func (t *fakeSSLTerminationTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Add("X-Forwarded-Proto", "https")
return t.T.RoundTrip(req)
}
func NewClient(baseURL *url.URL, fakeSSLTermination bool, httpTimeout time.Duration) *Client {
httpClient := &http.Client{
Timeout: httpTimeout,
}
if fakeSSLTermination {
httpClient.Transport = &fakeSSLTerminationTransport{http.DefaultTransport}
}
return &Client{
baseURL: baseURL,
http: httpClient,
}
}