package main import ( "crypto/rand" "log" "net/http" "forge.cadoles.com/Cadoles/ldap-profile/cmd/server/config" "forge.cadoles.com/Cadoles/ldap-profile/ldap" "forge.cadoles.com/wpetit/goweb/middleware/container" "forge.cadoles.com/wpetit/goweb/service/session" "forge.cadoles.com/wpetit/goweb/service/template" "forge.cadoles.com/wpetit/goweb/static" "forge.cadoles.com/wpetit/goweb/template/html" "github.com/go-chi/chi" "github.com/gorilla/csrf" "github.com/pkg/errors" ldapv3 "gopkg.in/ldap.v3" ) func mountRoutes(r *chi.Mux, config *config.Config) { csrfSecret, err := generateRandomBytes(32) if err != nil { panic(errors.Wrap(err, "error while generating CSRF secret")) } csrfMiddleware := csrf.Protect( csrfSecret, csrf.Secure(false), ) r.Use(csrfMiddleware) r.Get("/login", serveLoginPage) r.Post("/login", handleLoginForm) r.Get("/logout", handleLogout) r.Group(func(r chi.Router) { r.Use(authMiddleware) r.Get("/", serveHomePage) r.Get("/profile", serveProfilePage) r.Post("/profile/password", handlePasswordChange) }) r.Get("/*", static.Dir(config.HTTP.PublicDir, "", r.NotFoundHandler())) } func serveHomePage(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/profile", http.StatusTemporaryRedirect) } func serveLoginPage(w http.ResponseWriter, r *http.Request) { ctn := container.Must(r.Context()) tmpl := template.Must(ctn) data := template.Data{ csrf.TemplateTag: csrf.TemplateField(r), } if err := tmpl.RenderPage(w, "login.html.tmpl", data); err != nil { panic(errors.Wrap(err, "error while rendering page")) } } func handleLogout(w http.ResponseWriter, r *http.Request) { ctn := container.Must(r.Context()) sess, err := session.Must(ctn).Get(w, r) if err != nil { panic(errors.Wrap(err, "error while retrieving session")) } if err := sess.Delete(w, r); err != nil { panic(errors.Wrap(err, "error while deleting session")) } http.Redirect(w, r, "/", http.StatusTemporaryRedirect) } func handleLoginForm(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { panic(errors.Wrap(err, "error while parsing form")) } username := r.Form.Get("username") password := r.Form.Get("password") ctn := container.Must(r.Context()) ldapSrv := ldap.Must(ctn) tmplSrv := template.Must(ctn) conf := config.Must(ctn) sess, err := session.Must(ctn).Get(w, r) if err != nil { panic(errors.Wrap(err, "error while retrieving session")) } renderInvalidCredentials := func() { sess.AddFlash(session.FlashError, "Identifiants invalides.") data := extendTemplateData(w, r, template.Data{ "Username": username, csrf.TemplateTag: csrf.TemplateField(r), }) if err := tmplSrv.RenderPage(w, "login.html.tmpl", data); err != nil { panic(errors.Wrap(err, "error while rendering page")) } } if username == "" || password == "" { renderInvalidCredentials() return } results, err := ldapSrv.Search( ldap.EscapeFilter(conf.LDAP.UserSearchFilterPattern, username), ldap.WithBaseDN(conf.LDAP.BaseDN), ldap.WithScope(ldapv3.ScopeWholeSubtree), ldap.WithSizeLimit(1), ldap.WithAttributes("dn"), ) if err != nil { panic(errors.Wrap(err, "error while searching ldap entry")) } if len(results.Entries) == 0 { renderInvalidCredentials() return } userDN := results.Entries[0].DN log.Printf("authenticating user '%s' with DN '%s'", username, userDN) if err := ldapSrv.Bind(userDN, password); err != nil { // If the provided credentials are invalid, add flash message and rerender // the page if ldapv3.IsErrorWithCode(errors.Cause(err), ldapv3.LDAPResultInvalidCredentials) { renderInvalidCredentials() return } if ldapv3.IsErrorWithCode(errors.Cause(err), ldapv3.LDAPResultNoSuchObject) { renderInvalidCredentials() return } panic(errors.Wrap(err, "error while binding ldap connection")) } log.Printf("successful authentication for user with DN '%s'", userDN) sess.Set("password", password) sess.Set("dn", userDN) sess.AddFlash(session.FlashSuccess, "Bienvenue !") if err := sess.Save(w, r); err != nil { panic(errors.Wrap(err, "error while saving session")) } http.Redirect(w, r, "/", http.StatusSeeOther) } func serveProfilePage(w http.ResponseWriter, r *http.Request) { ctn := container.Must(r.Context()) sess, err := session.Must(ctn).Get(w, r) if err != nil { panic(errors.Wrap(err, "error while retrieving session")) } ldapSrv := ldap.Must(ctn) conn, err := ldapSrv.Connect() if err != nil { panic(errors.Wrap(err, "error while connecting to ldap server")) } defer conn.Close() userDN := sess.Get("dn").(string) password := sess.Get("password").(string) if err := ldapSrv.BindConn(conn, userDN, password); err != nil { panic(errors.Wrap(err, "error while binding ldap connection")) } results, err := ldapSrv.Search( "(&)", ldap.WithBaseDN(userDN), ldap.WithScope(ldapv3.ScopeBaseObject), ldap.WithSizeLimit(1), ) if err != nil { panic(errors.Wrap(err, "error while searching ldap entry")) } if len(results.Entries) == 0 { panic(errors.Errorf("could not retrieve ldap entry '%s'", userDN)) } tmpl := template.Must(ctn) data := extendTemplateData(w, r, template.Data{ "EntryAttributes": results.Entries[0].Attributes, csrf.TemplateTag: csrf.TemplateField(r), }) if err := tmpl.RenderPage(w, "profile.html.tmpl", data); err != nil { panic(errors.Wrap(err, "error while rendering page")) } } func handlePasswordChange(w http.ResponseWriter, r *http.Request) { ctn := container.Must(r.Context()) tmpl := template.Must(ctn) sess, err := session.Must(ctn).Get(w, r) if err != nil { panic(errors.Wrap(err, "error while retrieving session")) } renderError := func(message string) { sess.AddFlash(session.FlashError, message) data := extendTemplateData(w, r, template.Data{ csrf.TemplateTag: csrf.TemplateField(r), }) if err := tmpl.RenderPage(w, "profile.html.tmpl", data); err != nil { panic(errors.Wrap(err, "error while rendering page")) } } if err := r.ParseForm(); err != nil { panic(errors.Wrap(err, "error while parsing form")) } currentPassword := r.Form.Get("currentPassword") if currentPassword == "" { renderError("Vous devez renseigner votre mot de passe actuel.") return } if currentPassword == "" { renderError("Vous devez renseigner votre mot de passe actuel.") return } password := sess.Get("password").(string) if currentPassword != password { renderError("Votre mot de passe est invalide.") return } newPassword := r.Form.Get("newPassword") newPasswordConfirm := r.Form.Get("newPasswordConfirm") if newPassword == "" { renderError("Vous devez renseigner un nouveau mot de passe.") return } if newPassword != newPasswordConfirm { renderError("La confirmation de votre mot de passe n'est pas identique à votre nouveau mot de passe.") return } ldapSrv := ldap.Must(ctn) conn, err := ldapSrv.Connect() if err != nil { panic(errors.Wrap(err, "error while connecting to ldap server")) } defer conn.Close() userDN := sess.Get("dn").(string) if err := ldapSrv.BindConn(conn, userDN, password); err != nil { panic(errors.Wrap(err, "error while binding ldap connection")) } if err := ldapSrv.ModifyPasswordConn(conn, userDN, password, newPassword); err != nil { panic(errors.Wrap(err, "error while modifying password")) } sess.Set("password", newPassword) sess.AddFlash(session.FlashSuccess, "Votre mot de passe a été modifié.") if err := sess.Save(w, r); err != nil { panic(errors.Wrap(err, "error while saving session")) } http.Redirect(w, r, "/profile", http.StatusSeeOther) } func authMiddleware(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ctn := container.Must(r.Context()) sess, err := session.Must(ctn).Get(w, r) if err != nil { panic(errors.Wrap(err, "error while retrieving session")) } dn, ok := sess.Get("dn").(string) if !ok || dn == "" { http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) return } next.ServeHTTP(w, r) } return http.HandlerFunc(fn) } func extendTemplateData(w http.ResponseWriter, r *http.Request, data template.Data) template.Data { ctn := container.Must(r.Context()) data, err := template.Extend(data, html.WithFlashes(w, r, ctn), ) if err != nil { panic(errors.Wrap(err, "error while extending template data")) } return data } func generateRandomBytes(n int) ([]byte, error) { b := make([]byte, n) _, err := rand.Read(b) if err != nil { return nil, err } return b, nil }