package middleware import ( "encoding/json" "errors" "fmt" "net/http" "strconv" "time" "github.com/dgrijalva/jwt-go" "github.com/gofrs/uuid" "github.com/joho/godotenv" "strings" "github.com/jinzhu/gorm" "os" "golang.org/x/crypto/bcrypt" ) var ( accessExpires int refreshExpiration int audience string subject string ) func init() { e := godotenv.Load() //Load .env file if e != nil { fmt.Print(e) } var err error accessExpires, err = strconv.Atoi(os.Getenv("jwt_access_expiration")) if err != nil { panic(err) } refreshExpiration, err = strconv.Atoi(os.Getenv("jwt_access_expiration")) if err != nil { panic(err) } audience = os.Getenv("jwt_audience") subject = os.Getenv("jwt_subject") } // TokenDetails struct type TokenDetails struct { UserID uuid.UUID `json:"user_id"` jwt.StandardClaims } //Account struct to rep user account type Account struct { ID uuid.UUID `gorm:"type:uuid;primary_key;"` Email string `json:"email"` Password string `json:"password,omitempty"` Token string `json:"access_token,omitempty" sql:"-"` RefreshToken string `json:"refresh_token,omitempty" sql:"-"` TokenExpiresAt string `json:"-"` } //Validate incoming user details... func (account *Account) Validate() (map[string]interface{}, bool) { if !strings.Contains(account.Email, "@") { return message(false, "Email address is required"), false } if len(account.Password) < 1 { return message(false, "Password is required"), false } //Email must be unique temp := &Account{} //check for errors and duplicate emails err := GetDB().Table("accounts").Where("email = ?", account.Email).First(temp).Error if err != nil && err != gorm.ErrRecordNotFound { return message(false, "Connection error. Please retry"), false } if temp.Email != "" { return message(false, "Email address already in use by another user."), false } return message(false, "Requirement passed"), true } // Create user account func (account *Account) Create() map[string]interface{} { if resp, ok := account.Validate(); !ok { return resp } hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(account.Password), bcrypt.DefaultCost) account.Password = string(hashedPassword) GetDB().Create(account) if account.ID == uuid.Nil { return message(false, "Failed to create account, connection error.") } account.Password = "" //delete password response := message(true, "Account has been created") response["account"] = account return response } // Login and provides user token func Login(email, password string) map[string]interface{} { account := &Account{} err := GetDB().Table("accounts").Where("email = ?", email).First(account).Error if err != nil { if err == gorm.ErrRecordNotFound { return message(false, "Email address not found") } return message(false, "Connection error. Please retry") } err = bcrypt.CompareHashAndPassword([]byte(account.Password), []byte(password)) if err != nil && err == bcrypt.ErrMismatchedHashAndPassword { //Password does not match! return message(false, "Invalid login credentials. Please try again") } //Worked! Logged In account.Password = "" account.GenerateTokenPair() //account.CreateAuth(td) resp := make(map[string]interface{}) resp["access_token"] = account.Token resp["refresh_token"] = account.RefreshToken return resp } // GenerateTokenPair return access_token and refresh_token pair func (account *Account) GenerateTokenPair() (*TokenDetails, error) { //Create JWT token tk := &TokenDetails{ account.ID, jwt.StandardClaims{ IssuedAt: time.Now().Unix(), ExpiresAt: time.Now().Add(time.Second * time.Duration(accessExpires)).Unix(), Audience: audience, Subject: subject, }, } token := jwt.NewWithClaims(jwt.GetSigningMethod("HS256"), tk) tokenString, err := token.SignedString([]byte(os.Getenv("jwt_access_token_password"))) if err != nil { return nil, err } account.Token = tokenString //Store the token in the response reftk := &TokenDetails{ account.ID, jwt.StandardClaims{ ExpiresAt: time.Now().Add(time.Minute * time.Duration(refreshExpiration)).Unix(), }, } refreshtoken := jwt.NewWithClaims(jwt.GetSigningMethod("HS256"), reftk) refreshtokenString, err := refreshtoken.SignedString([]byte(os.Getenv("jwt_refresh_token_password"))) if err != nil { return nil, err } account.RefreshToken = refreshtokenString //Store the token in the response return reftk, nil } // CreateAuth stores refresh_token expires // func (account *Account) CreateAuth(td *TokenDetails) { // rt := time.Unix(td.ExpiresAt, 0) // db.Model(&account).Update(account.TokenExpiresAt, rt.String()) // } // GetUser in the database func GetUser(u uuid.UUID) *Account { acc := &Account{} GetDB().Table("accounts").Where("id = ?", u).First(acc) if acc.Email == "" { //User not found! return nil } acc.Password = "" return acc } // CreateAccount middleware to create user account var CreateAccount = func(w http.ResponseWriter, r *http.Request) { account := &Account{} err := json.NewDecoder(r.Body).Decode(account) //decode the request body into struct and failed if any error occur if err != nil { respond(w, message(false, "Invalid request")) return } resp := account.Create() //Create account respond(w, resp) } // Authenticate middleware to authenticate the user var Authenticate = func(w http.ResponseWriter, r *http.Request) { account := &Account{} err := json.NewDecoder(r.Body).Decode(account) //decode the request body into struct and failed if any error occur if err != nil { respond(w, message(false, "Invalid request")) return } resp := Login(account.Email, account.Password) respond(w, resp) } // Refresh middleware to authenticate the user var Refresh = func(w http.ResponseWriter, r *http.Request) { tk, err := ValidateToken(w, r) if err != nil { w.WriteHeader(http.StatusForbidden) w.Header().Add("Content-Type", "application/json") response := message(true, "Refreshed expired") respond(w, response) return } // check if user exists in database acc := GetUser(tk.UserID) if acc == nil { respond(w, message(false, "User not found")) return } acc.Password = "" acc.GenerateTokenPair() //account.CreateAuth(td) resp := make(map[string]interface{}) resp["access_token"] = acc.Token resp["refresh_token"] = acc.RefreshToken //update refresh expirity // rt := time.Now().Add(time.Minute * time.Duration(refreshExpiration)).Unix() // log.Println(rt) // db.Model(&acc).Update(acc.TokenExpiresAt, rt) respond(w, resp) } // ValidateToken to check if token is valid func ValidateToken(w http.ResponseWriter, r *http.Request) (*TokenDetails, error) { var pwd string urltomatch := r.URL.String() if strings.Contains(urltomatch, "/api/user/refresh") { pwd = "refresh_token_password" } else { pwd = "access_token_password" } response := make(map[string]interface{}) tk := &TokenDetails{} bearToken := r.Header.Get("Authorization") //Grab the token from the header if bearToken == "" { //Token is missing, returns with error code 403 Unauthorized response = message(false, "Missing auth token") w.WriteHeader(http.StatusForbidden) w.Header().Add("Content-Type", "application/json") respond(w, response) return tk, errors.New("Missing auth token") } splitted := strings.Split(bearToken, " ") //The token normally comes in format `Bearer {token-body}`, we check if the retrieved token matched this requirement if len(splitted) != 2 { response = message(false, "Invalid/Malformed auth token") w.WriteHeader(http.StatusForbidden) w.Header().Add("Content-Type", "application/json") respond(w, response) return tk, errors.New("Invalid/Malformed auth token") } tokenPart := splitted[1] //Grab the token part, what we are truly interested in token, err := jwt.ParseWithClaims(tokenPart, tk, func(token *jwt.Token) (interface{}, error) { return []byte(os.Getenv(pwd)), nil }) if err != nil { //Malformed token, returns with http code 403 as usual response = message(false, "Malformed authentication token") w.WriteHeader(http.StatusForbidden) w.Header().Add("Content-Type", "application/json") respond(w, response) return tk, errors.New("Malformed authentication token") } if !token.Valid { //Token is invalid, maybe not signed on this server response = message(false, "Token is not valid.") w.WriteHeader(http.StatusForbidden) w.Header().Add("Content-Type", "application/json") respond(w, response) return tk, errors.New("Token is not valid") } return tk, nil } func message(status bool, message string) map[string]interface{} { return map[string]interface{}{"status": status, "message": message} } func respond(w http.ResponseWriter, data map[string]interface{}) { w.Header().Add("Content-Type", "application/json") json.NewEncoder(w).Encode(data) }