package hydra import ( "bytes" "encoding/json" "net/http" "net/url" "time" "github.com/pkg/errors" ) type Client struct { baseURL *url.URL http *http.Client } func (c *Client) LoginRequest(challenge string) (*LoginResponse, error) { u := fromURL(*c.baseURL, "/oauth2/auth/requests/login", url.Values{ "login_challenge": []string{challenge}, }) res, err := c.http.Get(u) if err != nil { return nil, errors.Wrap(err, "could not retrieve login response") } if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { return nil, errors.Wrapf(ErrUnexpectedHydraResponse, "hydra responded with status code '%d'", res.StatusCode) } defer res.Body.Close() decoder := json.NewDecoder(res.Body) loginRes := &LoginResponse{} if err := decoder.Decode(loginRes); err != nil { return nil, errors.Wrap(err, "could not decode json response") } return loginRes, nil } func (c *Client) AcceptLoginRequest(challenge string, req *AcceptLoginRequest) (*AcceptResponse, error) { u := fromURL(*c.baseURL, "/oauth2/auth/requests/login/accept", url.Values{ "login_challenge": []string{challenge}, }) res := &AcceptResponse{} if err := c.putJSON(u, req, res); err != nil { return nil, err } return res, nil } func (c *Client) RejectLoginRequest(challenge string, req *RejectRequest) (*RejectResponse, error) { u := fromURL(*c.baseURL, "/oauth2/auth/requests/login/reject", url.Values{ "login_challenge": []string{challenge}, }) res := &RejectResponse{} if err := c.putJSON(u, req, res); err != nil { return nil, err } return res, nil } func (c *Client) LogoutRequest(challenge string) (*LogoutResponse, error) { u := fromURL(*c.baseURL, "/oauth2/auth/requests/logout", url.Values{ "logout_challenge": []string{challenge}, }) res, err := c.http.Get(u) if err != nil { return nil, errors.Wrap(err, "could not retrieve logout response") } if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { return nil, errors.Wrapf(ErrUnexpectedHydraResponse, "hydra responded with status code '%d'", res.StatusCode) } defer res.Body.Close() decoder := json.NewDecoder(res.Body) logoutRes := &LogoutResponse{} if err := decoder.Decode(logoutRes); err != nil { return nil, errors.Wrap(err, "could not decode json response") } return logoutRes, nil } func (c *Client) AcceptLogoutRequest(challenge string) (*AcceptResponse, error) { u := fromURL(*c.baseURL, "/oauth2/auth/requests/logout/accept", url.Values{ "logout_challenge": []string{challenge}, }) res := &AcceptResponse{} if err := c.putJSON(u, nil, res); err != nil { return nil, err } return res, nil } func (c *Client) RejectLogoutRequest(challenge string, req *RejectRequest) (*RejectResponse, error) { u := fromURL(*c.baseURL, "/oauth2/auth/requests/logout/reject", url.Values{ "logout_challenge": []string{challenge}, }) res := &RejectResponse{} if err := c.putJSON(u, req, res); err != nil { return nil, err } return res, nil } func (c *Client) ConsentRequest(challenge string) (*ConsentResponse, error) { u := fromURL(*c.baseURL, "/oauth2/auth/requests/consent", url.Values{ "consent_challenge": []string{challenge}, }) res, err := c.http.Get(u) if err != nil { return nil, errors.Wrap(err, "could not retrieve login response") } if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { return nil, errors.Wrapf(ErrUnexpectedHydraResponse, "hydra responded with status code '%d'", res.StatusCode) } defer res.Body.Close() decoder := json.NewDecoder(res.Body) consentRes := &ConsentResponse{} if err := decoder.Decode(consentRes); err != nil { return nil, errors.Wrap(err, "could not decode json response") } return consentRes, nil } func (c *Client) AcceptConsentRequest(challenge string, req *AcceptConsentRequest) (*AcceptResponse, error) { u := fromURL(*c.baseURL, "/oauth2/auth/requests/consent/accept", url.Values{ "consent_challenge": []string{challenge}, }) res := &AcceptResponse{} if err := c.putJSON(u, req, res); err != nil { return nil, err } return res, nil } func (c *Client) RejectConsentRequest(challenge string, req *RejectRequest) (*RejectResponse, error) { u := fromURL(*c.baseURL, "/oauth2/auth/requests/consent/reject", url.Values{ "consent_challenge": []string{challenge}, }) res := &RejectResponse{} if err := c.putJSON(u, req, res); err != nil { return nil, err } return res, nil } func (c *Client) LoginChallenge(r *http.Request) (string, error) { return c.challenge(r, "login_challenge") } func (c *Client) ConsentChallenge(r *http.Request) (string, error) { return c.challenge(r, "consent_challenge") } func (c *Client) LogoutChallenge(r *http.Request) (string, error) { return c.challenge(r, "logout_challenge") } func (c *Client) challenge(r *http.Request, name string) (string, error) { challenge := r.URL.Query().Get(name) if challenge == "" { return "", ErrChallengeNotFound } return challenge, nil } func (c *Client) putJSON(u string, payload interface{}, result interface{}) error { var buf bytes.Buffer encoder := json.NewEncoder(&buf) if err := encoder.Encode(payload); err != nil { return errors.Wrap(err, "could not encode request body") } req, err := http.NewRequest("PUT", u, &buf) if err != nil { return errors.Wrap(err, "could not create request") } res, err := c.http.Do(req) if err != nil { return errors.Wrap(err, "could not retrieve login response") } if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { return errors.Wrapf(ErrUnexpectedHydraResponse, "hydra responded with status code '%d'", res.StatusCode) } defer res.Body.Close() decoder := json.NewDecoder(res.Body) if err := decoder.Decode(result); err != nil { return errors.Wrap(err, "could not decode json response") } return nil } func fromURL(url url.URL, path string, query url.Values) string { url.Path = path url.RawQuery = query.Encode() return url.String() } type fakeSSLTerminationTransport struct { T http.RoundTripper } func (t *fakeSSLTerminationTransport) RoundTrip(req *http.Request) (*http.Response, error) { req.Header.Add("X-Forwarded-Proto", "https") return t.T.RoundTrip(req) } func NewClient(baseURL *url.URL, fakeSSLTermination bool, httpTimeout time.Duration) *Client { httpClient := &http.Client{ Timeout: httpTimeout, } if fakeSSLTermination { httpClient.Transport = &fakeSSLTerminationTransport{http.DefaultTransport} } return &Client{ baseURL: baseURL, http: httpClient, } }