set uuid and refresh token path
This commit is contained in:
@ -2,10 +2,15 @@ package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/joho/godotenv"
|
||||
|
||||
"strings"
|
||||
|
||||
@ -16,31 +21,56 @@ import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
/*
|
||||
JWT claims struct
|
||||
*/
|
||||
type Token struct {
|
||||
UserId uint
|
||||
var (
|
||||
accessExpires int
|
||||
refreshExpiration int
|
||||
audience string
|
||||
subject string
|
||||
)
|
||||
|
||||
func init() {
|
||||
e := godotenv.Load() //Load .env file
|
||||
if e != nil {
|
||||
fmt.Print(e)
|
||||
}
|
||||
var err error
|
||||
accessExpires, err = strconv.Atoi(os.Getenv("access_expiration"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
refreshExpiration, err = strconv.Atoi(os.Getenv("access_expiration"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
audience = os.Getenv("audience")
|
||||
subject = os.Getenv("subject")
|
||||
}
|
||||
|
||||
// TokenDetails struct
|
||||
type TokenDetails struct {
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
jwt.StandardClaims
|
||||
}
|
||||
|
||||
//a struct to rep user account
|
||||
//Account struct to rep user account
|
||||
type Account struct {
|
||||
gorm.Model
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
Token string `json:"token";sql:"-"`
|
||||
ID uuid.UUID `gorm:"type:uuid;primary_key;"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Token string `json:"access_token,omitempty" sql:"-"`
|
||||
RefreshToken string `json:"refresh_token,omitempty" sql:"-"`
|
||||
TokenExpiresAt string `json:"-"`
|
||||
}
|
||||
|
||||
//Validate incoming user details...
|
||||
func (account *Account) Validate() (map[string]interface{}, bool) {
|
||||
|
||||
if !strings.Contains(account.Email, "@") {
|
||||
return Message(false, "Email address is required"), false
|
||||
return message(false, "Email address is required"), false
|
||||
}
|
||||
|
||||
if len(account.Password) < 1 {
|
||||
return Message(false, "Password is required"), false
|
||||
return message(false, "Password is required"), false
|
||||
}
|
||||
|
||||
//Email must be unique
|
||||
@ -49,15 +79,16 @@ func (account *Account) Validate() (map[string]interface{}, bool) {
|
||||
//check for errors and duplicate emails
|
||||
err := GetDB().Table("accounts").Where("email = ?", account.Email).First(temp).Error
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
return Message(false, "Connection error. Please retry"), false
|
||||
return message(false, "Connection error. Please retry"), false
|
||||
}
|
||||
if temp.Email != "" {
|
||||
return Message(false, "Email address already in use by another user."), false
|
||||
return message(false, "Email address already in use by another user."), false
|
||||
}
|
||||
|
||||
return Message(false, "Requirement passed"), true
|
||||
return message(false, "Requirement passed"), true
|
||||
}
|
||||
|
||||
// Create user account
|
||||
func (account *Account) Create() map[string]interface{} {
|
||||
|
||||
if resp, ok := account.Validate(); !ok {
|
||||
@ -69,53 +100,89 @@ func (account *Account) Create() map[string]interface{} {
|
||||
|
||||
GetDB().Create(account)
|
||||
|
||||
if account.ID <= 0 {
|
||||
return Message(false, "Failed to create account, connection error.")
|
||||
if account.ID == uuid.Nil {
|
||||
return message(false, "Failed to create account, connection error.")
|
||||
}
|
||||
|
||||
//Create new JWT token for the newly registered account
|
||||
tk := &Token{UserId: account.ID}
|
||||
token := jwt.NewWithClaims(jwt.GetSigningMethod("HS256"), tk)
|
||||
tokenString, _ := token.SignedString([]byte(os.Getenv("token_password")))
|
||||
account.Token = tokenString
|
||||
|
||||
account.Password = "" //delete password
|
||||
|
||||
response := Message(true, "Account has been created")
|
||||
response := message(true, "Account has been created")
|
||||
response["account"] = account
|
||||
return response
|
||||
}
|
||||
|
||||
// Login and provides user token
|
||||
func Login(email, password string) map[string]interface{} {
|
||||
|
||||
account := &Account{}
|
||||
err := GetDB().Table("accounts").Where("email = ?", email).First(account).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return Message(false, "Email address not found")
|
||||
return message(false, "Email address not found")
|
||||
}
|
||||
return Message(false, "Connection error. Please retry")
|
||||
return message(false, "Connection error. Please retry")
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(account.Password), []byte(password))
|
||||
if err != nil && err == bcrypt.ErrMismatchedHashAndPassword { //Password does not match!
|
||||
return Message(false, "Invalid login credentials. Please try again")
|
||||
return message(false, "Invalid login credentials. Please try again")
|
||||
}
|
||||
//Worked! Logged In
|
||||
account.Password = ""
|
||||
account.GenerateTokenPair()
|
||||
//account.CreateAuth(td)
|
||||
|
||||
//Create JWT token
|
||||
tk := &Token{UserId: account.ID}
|
||||
token := jwt.NewWithClaims(jwt.GetSigningMethod("HS256"), tk)
|
||||
tokenString, _ := token.SignedString([]byte(os.Getenv("token_password")))
|
||||
account.Token = tokenString //Store the token in the response
|
||||
|
||||
resp := Message(true, "Logged In")
|
||||
resp["account"] = account
|
||||
resp := make(map[string]interface{})
|
||||
resp["access_token"] = account.Token
|
||||
resp["refresh_token"] = account.RefreshToken
|
||||
return resp
|
||||
}
|
||||
|
||||
func GetUser(u uint) *Account {
|
||||
// GenerateTokenPair return access_token and refresh_token pair
|
||||
func (account *Account) GenerateTokenPair() (*TokenDetails, error) {
|
||||
//Create JWT token
|
||||
tk := &TokenDetails{
|
||||
account.ID,
|
||||
jwt.StandardClaims{
|
||||
IssuedAt: time.Now().Unix(),
|
||||
ExpiresAt: time.Now().Add(time.Second * time.Duration(accessExpires)).Unix(),
|
||||
Audience: audience,
|
||||
Subject: subject,
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.GetSigningMethod("HS256"), tk)
|
||||
tokenString, err := token.SignedString([]byte(os.Getenv("access_token_password")))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
account.Token = tokenString //Store the token in the response
|
||||
|
||||
reftk := &TokenDetails{
|
||||
account.ID,
|
||||
jwt.StandardClaims{
|
||||
ExpiresAt: time.Now().Add(time.Minute * time.Duration(refreshExpiration)).Unix(),
|
||||
},
|
||||
}
|
||||
refreshtoken := jwt.NewWithClaims(jwt.GetSigningMethod("HS256"), reftk)
|
||||
refreshtokenString, err := refreshtoken.SignedString([]byte(os.Getenv("refresh_token_password")))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
account.RefreshToken = refreshtokenString //Store the token in the response
|
||||
|
||||
return reftk, nil
|
||||
}
|
||||
|
||||
// CreateAuth stores refresh_token expires
|
||||
// func (account *Account) CreateAuth(td *TokenDetails) {
|
||||
// rt := time.Unix(td.ExpiresAt, 0)
|
||||
|
||||
// db.Model(&account).Update(account.TokenExpiresAt, rt.String())
|
||||
|
||||
// }
|
||||
|
||||
// GetUser in the database
|
||||
func GetUser(u uuid.UUID) *Account {
|
||||
|
||||
acc := &Account{}
|
||||
GetDB().Table("accounts").Where("id = ?", u).First(acc)
|
||||
@ -127,41 +194,126 @@ func GetUser(u uint) *Account {
|
||||
return acc
|
||||
}
|
||||
|
||||
// CreateAccount middleware to create user account
|
||||
var CreateAccount = func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
account := &Account{}
|
||||
err := json.NewDecoder(r.Body).Decode(account) //decode the request body into struct and failed if any error occur
|
||||
log.Println(err)
|
||||
if err != nil {
|
||||
Respond(w, Message(false, "Invalid request"))
|
||||
respond(w, message(false, "Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
resp := account.Create() //Create account
|
||||
Respond(w, resp)
|
||||
respond(w, resp)
|
||||
}
|
||||
|
||||
// Authenticate middleware to authenticate the user
|
||||
var Authenticate = func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
account := &Account{}
|
||||
err := json.NewDecoder(r.Body).Decode(account) //decode the request body into struct and failed if any error occur
|
||||
log.Println(err)
|
||||
if err != nil {
|
||||
Respond(w, Message(false, "Invalid request"))
|
||||
respond(w, message(false, "Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
resp := Login(account.Email, account.Password)
|
||||
Respond(w, resp)
|
||||
respond(w, resp)
|
||||
}
|
||||
|
||||
func Message(status bool, message string) map[string]interface{} {
|
||||
log.Println(message)
|
||||
// Refresh middleware to authenticate the user
|
||||
var Refresh = func(w http.ResponseWriter, r *http.Request) {
|
||||
tk, err := ValidateToken(w, r)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
response := message(true, "Refreshed expired")
|
||||
respond(w, response)
|
||||
return
|
||||
}
|
||||
// check if user exists in database
|
||||
acc := GetUser(tk.UserID)
|
||||
if acc == nil {
|
||||
respond(w, message(false, "User not found"))
|
||||
return
|
||||
}
|
||||
|
||||
acc.Password = ""
|
||||
acc.GenerateTokenPair()
|
||||
//account.CreateAuth(td)
|
||||
resp := make(map[string]interface{})
|
||||
resp["access_token"] = acc.Token
|
||||
resp["refresh_token"] = acc.RefreshToken
|
||||
|
||||
//update refresh expirity
|
||||
|
||||
// rt := time.Now().Add(time.Minute * time.Duration(refreshExpiration)).Unix()
|
||||
// log.Println(rt)
|
||||
// db.Model(&acc).Update(acc.TokenExpiresAt, rt)
|
||||
|
||||
respond(w, resp)
|
||||
}
|
||||
|
||||
// ValidateToken to check if token is valid
|
||||
func ValidateToken(w http.ResponseWriter, r *http.Request) (*TokenDetails, error) {
|
||||
var pwd string
|
||||
urltomatch := r.URL.String()
|
||||
if urltomatch == "/api/user/refresh" {
|
||||
pwd = "refresh_token_password"
|
||||
} else {
|
||||
pwd = "access_token_password"
|
||||
}
|
||||
|
||||
response := make(map[string]interface{})
|
||||
tk := &TokenDetails{}
|
||||
bearToken := r.Header.Get("Authorization") //Grab the token from the header
|
||||
|
||||
if bearToken == "" { //Token is missing, returns with error code 403 Unauthorized
|
||||
response = message(false, "Missing auth token")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
respond(w, response)
|
||||
return tk, errors.New("Missing auth token")
|
||||
}
|
||||
|
||||
splitted := strings.Split(bearToken, " ") //The token normally comes in format `Bearer {token-body}`, we check if the retrieved token matched this requirement
|
||||
if len(splitted) != 2 {
|
||||
response = message(false, "Invalid/Malformed auth token")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
respond(w, response)
|
||||
return tk, errors.New("Invalid/Malformed auth token")
|
||||
}
|
||||
|
||||
tokenPart := splitted[1] //Grab the token part, what we are truly interested in
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenPart, tk, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(os.Getenv(pwd)), nil
|
||||
})
|
||||
|
||||
if err != nil { //Malformed token, returns with http code 403 as usual
|
||||
response = message(false, "Malformed authentication token")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
respond(w, response)
|
||||
return tk, errors.New("Malformed authentication token")
|
||||
}
|
||||
|
||||
if !token.Valid { //Token is invalid, maybe not signed on this server
|
||||
response = message(false, "Token is not valid.")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
respond(w, response)
|
||||
return tk, errors.New("Token is not valid")
|
||||
}
|
||||
return tk, nil
|
||||
}
|
||||
func message(status bool, message string) map[string]interface{} {
|
||||
return map[string]interface{}{"status": status, "message": message}
|
||||
}
|
||||
|
||||
func Respond(w http.ResponseWriter, data map[string]interface{}) {
|
||||
log.Println(data)
|
||||
func respond(w http.ResponseWriter, data map[string]interface{}) {
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
@ -2,14 +2,9 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
|
||||
// JwtAuthentication is a Jwt Auth controller with postgres database
|
||||
@ -17,8 +12,8 @@ var JwtAuthentication = func(next http.Handler) http.Handler {
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
notAuth := []string{"/api/user/new", "/api/user/login"} //List of endpoints that doesn't require auth
|
||||
requestPath := r.URL.Path //current request path
|
||||
notAuth := []string{"/api/user/new", "/api/user/login", "/api/user/refresh"} //List of endpoints that doesn't require auth
|
||||
requestPath := r.URL.Path //current request path
|
||||
|
||||
//check if request does not need authentication, serve the request if it doesn't need it
|
||||
for _, value := range notAuth {
|
||||
@ -28,53 +23,14 @@ var JwtAuthentication = func(next http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
response := make(map[string]interface{})
|
||||
tokenHeader := r.Header.Get("Authorization") //Grab the token from the header
|
||||
|
||||
if tokenHeader == "" { //Token is missing, returns with error code 403 Unauthorized
|
||||
response = Message(false, "Missing auth token")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
Respond(w, response)
|
||||
return
|
||||
}
|
||||
|
||||
splitted := strings.Split(tokenHeader, " ") //The token normally comes in format `Bearer {token-body}`, we check if the retrieved token matched this requirement
|
||||
if len(splitted) != 2 {
|
||||
response = Message(false, "Invalid/Malformed auth token")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
Respond(w, response)
|
||||
return
|
||||
}
|
||||
|
||||
tokenPart := splitted[1] //Grab the token part, what we are truly interested in
|
||||
tk := &Token{}
|
||||
log.Println(splitted)
|
||||
token, err := jwt.ParseWithClaims(tokenPart, tk, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(os.Getenv("token_password")), nil
|
||||
})
|
||||
|
||||
if err != nil { //Malformed token, returns with http code 403 as usual
|
||||
response = Message(false, "Malformed authentication token")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
Respond(w, response)
|
||||
return
|
||||
}
|
||||
|
||||
if !token.Valid { //Token is invalid, maybe not signed on this server
|
||||
response = Message(false, "Token is not valid.")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
Respond(w, response)
|
||||
tk, err := ValidateToken(w, r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
//Everything went well, proceed with the request and set the caller to the user retrieved from the parsed token
|
||||
fmt.Sprintf("User %", tk) //Useful for monitoring
|
||||
ctx := context.WithValue(r.Context(), "user", tk.UserId)
|
||||
log.Printf("User %v", tk) //Useful for monitoring
|
||||
ctx := context.WithValue(r.Context(), "user", tk.UserID)
|
||||
r = r.WithContext(ctx)
|
||||
next.ServeHTTP(w, r) //proceed in the middleware chain!
|
||||
})
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/jinzhu/gorm"
|
||||
_ "github.com/jinzhu/gorm/dialects/postgres"
|
||||
"github.com/joho/godotenv"
|
||||
@ -39,3 +40,12 @@ func init() {
|
||||
func GetDB() *gorm.DB {
|
||||
return db
|
||||
}
|
||||
|
||||
// BeforeCreate will set a UUID rather than numeric ID.
|
||||
func (accont *Account) BeforeCreate(scope *gorm.Scope) error {
|
||||
uuid, err := uuid.NewV4()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return scope.SetColumn("ID", uuid)
|
||||
}
|
||||
|
Reference in New Issue
Block a user