go-jwtserver/middleware/accounts.go

314 lines
8.5 KiB
Go
Raw Permalink Normal View History

2020-07-16 13:56:57 +02:00
package middleware
2020-07-16 10:51:50 +02:00
import (
"encoding/json"
2020-07-21 11:08:01 +02:00
"errors"
"fmt"
2020-07-16 10:51:50 +02:00
"net/http"
2020-07-21 11:08:01 +02:00
"strconv"
"time"
2020-07-16 10:51:50 +02:00
"github.com/dgrijalva/jwt-go"
2020-07-21 11:08:01 +02:00
"github.com/gofrs/uuid"
"github.com/joho/godotenv"
2020-07-16 10:51:50 +02:00
"strings"
"github.com/jinzhu/gorm"
"os"
"golang.org/x/crypto/bcrypt"
)
2020-07-21 11:08:01 +02:00
var (
accessExpires int
refreshExpiration int
audience string
subject string
)
func init() {
e := godotenv.Load() //Load .env file
if e != nil {
fmt.Print(e)
}
var err error
2020-07-22 10:14:13 +02:00
accessExpires, err = strconv.Atoi(os.Getenv("jwt_access_expiration"))
2020-07-21 11:08:01 +02:00
if err != nil {
panic(err)
}
2020-07-27 09:34:24 +02:00
refreshExpiration, err = strconv.Atoi(os.Getenv("jwt_refresh_expiration"))
2020-07-21 11:08:01 +02:00
if err != nil {
panic(err)
}
2020-07-22 10:14:13 +02:00
audience = os.Getenv("jwt_audience")
subject = os.Getenv("jwt_subject")
2020-07-21 11:08:01 +02:00
}
// TokenDetails struct
type TokenDetails struct {
UserID uuid.UUID `json:"user_id"`
2020-07-16 10:51:50 +02:00
jwt.StandardClaims
}
2020-07-21 11:08:01 +02:00
//Account struct to rep user account
2020-07-16 10:51:50 +02:00
type Account struct {
2020-07-21 11:08:01 +02:00
ID uuid.UUID `gorm:"type:uuid;primary_key;"`
Email string `json:"email"`
Password string `json:"password,omitempty"`
Token string `json:"access_token,omitempty" sql:"-"`
RefreshToken string `json:"refresh_token,omitempty" sql:"-"`
TokenExpiresAt string `json:"-"`
2020-07-16 10:51:50 +02:00
}
//Validate incoming user details...
func (account *Account) Validate() (map[string]interface{}, bool) {
if !strings.Contains(account.Email, "@") {
2020-07-21 11:08:01 +02:00
return message(false, "Email address is required"), false
2020-07-16 10:51:50 +02:00
}
if len(account.Password) < 1 {
2020-07-21 11:08:01 +02:00
return message(false, "Password is required"), false
2020-07-16 10:51:50 +02:00
}
//Email must be unique
temp := &Account{}
//check for errors and duplicate emails
err := GetDB().Table("accounts").Where("email = ?", account.Email).First(temp).Error
if err != nil && err != gorm.ErrRecordNotFound {
2020-07-21 11:08:01 +02:00
return message(false, "Connection error. Please retry"), false
2020-07-16 10:51:50 +02:00
}
if temp.Email != "" {
2020-07-21 11:08:01 +02:00
return message(false, "Email address already in use by another user."), false
2020-07-16 10:51:50 +02:00
}
2020-07-21 11:08:01 +02:00
return message(false, "Requirement passed"), true
2020-07-16 10:51:50 +02:00
}
2020-07-21 11:08:01 +02:00
// Create user account
2020-07-16 10:51:50 +02:00
func (account *Account) Create() map[string]interface{} {
if resp, ok := account.Validate(); !ok {
return resp
}
hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(account.Password), bcrypt.DefaultCost)
account.Password = string(hashedPassword)
GetDB().Create(account)
2020-07-21 11:08:01 +02:00
if account.ID == uuid.Nil {
return message(false, "Failed to create account, connection error.")
2020-07-16 10:51:50 +02:00
}
account.Password = "" //delete password
2020-07-21 11:08:01 +02:00
response := message(true, "Account has been created")
2020-07-16 10:51:50 +02:00
response["account"] = account
return response
}
2020-07-21 11:08:01 +02:00
// Login and provides user token
2020-07-16 10:51:50 +02:00
func Login(email, password string) map[string]interface{} {
account := &Account{}
err := GetDB().Table("accounts").Where("email = ?", email).First(account).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
2020-07-21 11:08:01 +02:00
return message(false, "Email address not found")
2020-07-16 10:51:50 +02:00
}
2020-07-21 11:08:01 +02:00
return message(false, "Connection error. Please retry")
2020-07-16 10:51:50 +02:00
}
err = bcrypt.CompareHashAndPassword([]byte(account.Password), []byte(password))
if err != nil && err == bcrypt.ErrMismatchedHashAndPassword { //Password does not match!
2020-07-21 11:08:01 +02:00
return message(false, "Invalid login credentials. Please try again")
2020-07-16 10:51:50 +02:00
}
//Worked! Logged In
account.Password = ""
2020-07-21 11:08:01 +02:00
account.GenerateTokenPair()
//account.CreateAuth(td)
2020-07-16 10:51:50 +02:00
2020-07-21 11:08:01 +02:00
resp := make(map[string]interface{})
resp["access_token"] = account.Token
resp["refresh_token"] = account.RefreshToken
return resp
}
// GenerateTokenPair return access_token and refresh_token pair
func (account *Account) GenerateTokenPair() (*TokenDetails, error) {
2020-07-16 10:51:50 +02:00
//Create JWT token
2020-07-21 11:08:01 +02:00
tk := &TokenDetails{
account.ID,
jwt.StandardClaims{
IssuedAt: time.Now().Unix(),
ExpiresAt: time.Now().Add(time.Second * time.Duration(accessExpires)).Unix(),
Audience: audience,
Subject: subject,
},
}
2020-07-16 10:51:50 +02:00
token := jwt.NewWithClaims(jwt.GetSigningMethod("HS256"), tk)
2020-07-22 10:14:13 +02:00
tokenString, err := token.SignedString([]byte(os.Getenv("jwt_access_token_password")))
2020-07-21 11:08:01 +02:00
if err != nil {
return nil, err
}
2020-07-16 10:51:50 +02:00
account.Token = tokenString //Store the token in the response
2020-07-21 11:08:01 +02:00
reftk := &TokenDetails{
account.ID,
jwt.StandardClaims{
ExpiresAt: time.Now().Add(time.Minute * time.Duration(refreshExpiration)).Unix(),
},
}
refreshtoken := jwt.NewWithClaims(jwt.GetSigningMethod("HS256"), reftk)
2020-07-22 10:14:13 +02:00
refreshtokenString, err := refreshtoken.SignedString([]byte(os.Getenv("jwt_refresh_token_password")))
2020-07-21 11:08:01 +02:00
if err != nil {
return nil, err
}
account.RefreshToken = refreshtokenString //Store the token in the response
return reftk, nil
2020-07-16 10:51:50 +02:00
}
2020-07-21 11:08:01 +02:00
// CreateAuth stores refresh_token expires
// func (account *Account) CreateAuth(td *TokenDetails) {
// rt := time.Unix(td.ExpiresAt, 0)
// db.Model(&account).Update(account.TokenExpiresAt, rt.String())
// }
// GetUser in the database
func GetUser(u uuid.UUID) *Account {
2020-07-16 10:51:50 +02:00
acc := &Account{}
GetDB().Table("accounts").Where("id = ?", u).First(acc)
if acc.Email == "" { //User not found!
return nil
}
acc.Password = ""
return acc
}
2020-07-21 11:08:01 +02:00
// CreateAccount middleware to create user account
2020-07-16 10:51:50 +02:00
var CreateAccount = func(w http.ResponseWriter, r *http.Request) {
account := &Account{}
err := json.NewDecoder(r.Body).Decode(account) //decode the request body into struct and failed if any error occur
if err != nil {
2020-07-21 11:08:01 +02:00
respond(w, message(false, "Invalid request"))
2020-07-16 10:51:50 +02:00
return
}
resp := account.Create() //Create account
2020-07-21 11:08:01 +02:00
respond(w, resp)
2020-07-16 10:51:50 +02:00
}
2020-07-21 11:08:01 +02:00
// Authenticate middleware to authenticate the user
2020-07-16 10:51:50 +02:00
var Authenticate = func(w http.ResponseWriter, r *http.Request) {
account := &Account{}
err := json.NewDecoder(r.Body).Decode(account) //decode the request body into struct and failed if any error occur
if err != nil {
2020-07-21 11:08:01 +02:00
respond(w, message(false, "Invalid request"))
2020-07-16 10:51:50 +02:00
return
}
resp := Login(account.Email, account.Password)
2020-07-21 11:08:01 +02:00
respond(w, resp)
2020-07-16 10:51:50 +02:00
}
2020-07-21 11:08:01 +02:00
// Refresh middleware to authenticate the user
var Refresh = func(w http.ResponseWriter, r *http.Request) {
tk, err := ValidateToken(w, r)
if err != nil {
w.WriteHeader(http.StatusForbidden)
w.Header().Add("Content-Type", "application/json")
response := message(true, "Refreshed expired")
respond(w, response)
return
}
// check if user exists in database
acc := GetUser(tk.UserID)
if acc == nil {
respond(w, message(false, "User not found"))
return
}
acc.Password = ""
acc.GenerateTokenPair()
//account.CreateAuth(td)
resp := make(map[string]interface{})
resp["access_token"] = acc.Token
resp["refresh_token"] = acc.RefreshToken
respond(w, resp)
}
// ValidateToken to check if token is valid
func ValidateToken(w http.ResponseWriter, r *http.Request) (*TokenDetails, error) {
var pwd string
urltomatch := r.URL.String()
2020-07-21 11:20:06 +02:00
if strings.Contains(urltomatch, "/api/user/refresh") {
2020-07-27 09:34:24 +02:00
pwd = "jwt_refresh_token_password"
2020-07-21 11:08:01 +02:00
} else {
2020-07-27 09:34:24 +02:00
pwd = "jwt_access_token_password"
2020-07-21 11:08:01 +02:00
}
response := make(map[string]interface{})
tk := &TokenDetails{}
bearToken := r.Header.Get("Authorization") //Grab the token from the header
if bearToken == "" { //Token is missing, returns with error code 403 Unauthorized
response = message(false, "Missing auth token")
w.WriteHeader(http.StatusForbidden)
w.Header().Add("Content-Type", "application/json")
respond(w, response)
2020-07-21 11:21:39 +02:00
return tk, errors.New("Missing auth token")
2020-07-21 11:08:01 +02:00
}
splitted := strings.Split(bearToken, " ") //The token normally comes in format `Bearer {token-body}`, we check if the retrieved token matched this requirement
if len(splitted) != 2 {
response = message(false, "Invalid/Malformed auth token")
w.WriteHeader(http.StatusForbidden)
w.Header().Add("Content-Type", "application/json")
respond(w, response)
return tk, errors.New("Invalid/Malformed auth token")
}
tokenPart := splitted[1] //Grab the token part, what we are truly interested in
token, err := jwt.ParseWithClaims(tokenPart, tk, func(token *jwt.Token) (interface{}, error) {
return []byte(os.Getenv(pwd)), nil
})
if err != nil { //Malformed token, returns with http code 403 as usual
response = message(false, "Malformed authentication token")
w.WriteHeader(http.StatusForbidden)
w.Header().Add("Content-Type", "application/json")
respond(w, response)
return tk, errors.New("Malformed authentication token")
}
if !token.Valid { //Token is invalid, maybe not signed on this server
response = message(false, "Token is not valid.")
w.WriteHeader(http.StatusForbidden)
w.Header().Add("Content-Type", "application/json")
respond(w, response)
return tk, errors.New("Token is not valid")
}
return tk, nil
}
func message(status bool, message string) map[string]interface{} {
2020-07-16 10:51:50 +02:00
return map[string]interface{}{"status": status, "message": message}
}
2020-07-21 11:08:01 +02:00
func respond(w http.ResponseWriter, data map[string]interface{}) {
2020-07-16 10:51:50 +02:00
w.Header().Add("Content-Type", "application/json")
json.NewEncoder(w).Encode(data)
}