2020-07-16 13:56:57 +02:00
package middleware
2020-07-16 10:51:50 +02:00
import (
"encoding/json"
2020-07-21 11:08:01 +02:00
"errors"
"fmt"
2020-07-16 10:51:50 +02:00
"net/http"
2020-07-21 11:08:01 +02:00
"strconv"
"time"
2020-07-16 10:51:50 +02:00
"github.com/dgrijalva/jwt-go"
2020-07-21 11:08:01 +02:00
"github.com/gofrs/uuid"
"github.com/joho/godotenv"
2020-07-16 10:51:50 +02:00
"strings"
"github.com/jinzhu/gorm"
"os"
"golang.org/x/crypto/bcrypt"
)
2020-07-21 11:08:01 +02:00
var (
accessExpires int
refreshExpiration int
audience string
subject string
)
func init ( ) {
e := godotenv . Load ( ) //Load .env file
if e != nil {
fmt . Print ( e )
}
var err error
accessExpires , err = strconv . Atoi ( os . Getenv ( "access_expiration" ) )
if err != nil {
panic ( err )
}
refreshExpiration , err = strconv . Atoi ( os . Getenv ( "access_expiration" ) )
if err != nil {
panic ( err )
}
audience = os . Getenv ( "audience" )
subject = os . Getenv ( "subject" )
}
// TokenDetails struct
type TokenDetails struct {
UserID uuid . UUID ` json:"user_id" `
2020-07-16 10:51:50 +02:00
jwt . StandardClaims
}
2020-07-21 11:08:01 +02:00
//Account struct to rep user account
2020-07-16 10:51:50 +02:00
type Account struct {
2020-07-21 11:08:01 +02:00
ID uuid . UUID ` gorm:"type:uuid;primary_key;" `
Email string ` json:"email" `
Password string ` json:"password,omitempty" `
Token string ` json:"access_token,omitempty" sql:"-" `
RefreshToken string ` json:"refresh_token,omitempty" sql:"-" `
TokenExpiresAt string ` json:"-" `
2020-07-16 10:51:50 +02:00
}
//Validate incoming user details...
func ( account * Account ) Validate ( ) ( map [ string ] interface { } , bool ) {
if ! strings . Contains ( account . Email , "@" ) {
2020-07-21 11:08:01 +02:00
return message ( false , "Email address is required" ) , false
2020-07-16 10:51:50 +02:00
}
if len ( account . Password ) < 1 {
2020-07-21 11:08:01 +02:00
return message ( false , "Password is required" ) , false
2020-07-16 10:51:50 +02:00
}
//Email must be unique
temp := & Account { }
//check for errors and duplicate emails
err := GetDB ( ) . Table ( "accounts" ) . Where ( "email = ?" , account . Email ) . First ( temp ) . Error
if err != nil && err != gorm . ErrRecordNotFound {
2020-07-21 11:08:01 +02:00
return message ( false , "Connection error. Please retry" ) , false
2020-07-16 10:51:50 +02:00
}
if temp . Email != "" {
2020-07-21 11:08:01 +02:00
return message ( false , "Email address already in use by another user." ) , false
2020-07-16 10:51:50 +02:00
}
2020-07-21 11:08:01 +02:00
return message ( false , "Requirement passed" ) , true
2020-07-16 10:51:50 +02:00
}
2020-07-21 11:08:01 +02:00
// Create user account
2020-07-16 10:51:50 +02:00
func ( account * Account ) Create ( ) map [ string ] interface { } {
if resp , ok := account . Validate ( ) ; ! ok {
return resp
}
hashedPassword , _ := bcrypt . GenerateFromPassword ( [ ] byte ( account . Password ) , bcrypt . DefaultCost )
account . Password = string ( hashedPassword )
GetDB ( ) . Create ( account )
2020-07-21 11:08:01 +02:00
if account . ID == uuid . Nil {
return message ( false , "Failed to create account, connection error." )
2020-07-16 10:51:50 +02:00
}
account . Password = "" //delete password
2020-07-21 11:08:01 +02:00
response := message ( true , "Account has been created" )
2020-07-16 10:51:50 +02:00
response [ "account" ] = account
return response
}
2020-07-21 11:08:01 +02:00
// Login and provides user token
2020-07-16 10:51:50 +02:00
func Login ( email , password string ) map [ string ] interface { } {
account := & Account { }
err := GetDB ( ) . Table ( "accounts" ) . Where ( "email = ?" , email ) . First ( account ) . Error
if err != nil {
if err == gorm . ErrRecordNotFound {
2020-07-21 11:08:01 +02:00
return message ( false , "Email address not found" )
2020-07-16 10:51:50 +02:00
}
2020-07-21 11:08:01 +02:00
return message ( false , "Connection error. Please retry" )
2020-07-16 10:51:50 +02:00
}
err = bcrypt . CompareHashAndPassword ( [ ] byte ( account . Password ) , [ ] byte ( password ) )
if err != nil && err == bcrypt . ErrMismatchedHashAndPassword { //Password does not match!
2020-07-21 11:08:01 +02:00
return message ( false , "Invalid login credentials. Please try again" )
2020-07-16 10:51:50 +02:00
}
//Worked! Logged In
account . Password = ""
2020-07-21 11:08:01 +02:00
account . GenerateTokenPair ( )
//account.CreateAuth(td)
2020-07-16 10:51:50 +02:00
2020-07-21 11:08:01 +02:00
resp := make ( map [ string ] interface { } )
resp [ "access_token" ] = account . Token
resp [ "refresh_token" ] = account . RefreshToken
return resp
}
// GenerateTokenPair return access_token and refresh_token pair
func ( account * Account ) GenerateTokenPair ( ) ( * TokenDetails , error ) {
2020-07-16 10:51:50 +02:00
//Create JWT token
2020-07-21 11:08:01 +02:00
tk := & TokenDetails {
account . ID ,
jwt . StandardClaims {
IssuedAt : time . Now ( ) . Unix ( ) ,
ExpiresAt : time . Now ( ) . Add ( time . Second * time . Duration ( accessExpires ) ) . Unix ( ) ,
Audience : audience ,
Subject : subject ,
} ,
}
2020-07-16 10:51:50 +02:00
token := jwt . NewWithClaims ( jwt . GetSigningMethod ( "HS256" ) , tk )
2020-07-21 11:08:01 +02:00
tokenString , err := token . SignedString ( [ ] byte ( os . Getenv ( "access_token_password" ) ) )
if err != nil {
return nil , err
}
2020-07-16 10:51:50 +02:00
account . Token = tokenString //Store the token in the response
2020-07-21 11:08:01 +02:00
reftk := & TokenDetails {
account . ID ,
jwt . StandardClaims {
ExpiresAt : time . Now ( ) . Add ( time . Minute * time . Duration ( refreshExpiration ) ) . Unix ( ) ,
} ,
}
refreshtoken := jwt . NewWithClaims ( jwt . GetSigningMethod ( "HS256" ) , reftk )
refreshtokenString , err := refreshtoken . SignedString ( [ ] byte ( os . Getenv ( "refresh_token_password" ) ) )
if err != nil {
return nil , err
}
account . RefreshToken = refreshtokenString //Store the token in the response
return reftk , nil
2020-07-16 10:51:50 +02:00
}
2020-07-21 11:08:01 +02:00
// CreateAuth stores refresh_token expires
// func (account *Account) CreateAuth(td *TokenDetails) {
// rt := time.Unix(td.ExpiresAt, 0)
// db.Model(&account).Update(account.TokenExpiresAt, rt.String())
// }
// GetUser in the database
func GetUser ( u uuid . UUID ) * Account {
2020-07-16 10:51:50 +02:00
acc := & Account { }
GetDB ( ) . Table ( "accounts" ) . Where ( "id = ?" , u ) . First ( acc )
if acc . Email == "" { //User not found!
return nil
}
acc . Password = ""
return acc
}
2020-07-21 11:08:01 +02:00
// CreateAccount middleware to create user account
2020-07-16 10:51:50 +02:00
var CreateAccount = func ( w http . ResponseWriter , r * http . Request ) {
account := & Account { }
err := json . NewDecoder ( r . Body ) . Decode ( account ) //decode the request body into struct and failed if any error occur
if err != nil {
2020-07-21 11:08:01 +02:00
respond ( w , message ( false , "Invalid request" ) )
2020-07-16 10:51:50 +02:00
return
}
resp := account . Create ( ) //Create account
2020-07-21 11:08:01 +02:00
respond ( w , resp )
2020-07-16 10:51:50 +02:00
}
2020-07-21 11:08:01 +02:00
// Authenticate middleware to authenticate the user
2020-07-16 10:51:50 +02:00
var Authenticate = func ( w http . ResponseWriter , r * http . Request ) {
account := & Account { }
err := json . NewDecoder ( r . Body ) . Decode ( account ) //decode the request body into struct and failed if any error occur
if err != nil {
2020-07-21 11:08:01 +02:00
respond ( w , message ( false , "Invalid request" ) )
2020-07-16 10:51:50 +02:00
return
}
resp := Login ( account . Email , account . Password )
2020-07-21 11:08:01 +02:00
respond ( w , resp )
2020-07-16 10:51:50 +02:00
}
2020-07-21 11:08:01 +02:00
// Refresh middleware to authenticate the user
var Refresh = func ( w http . ResponseWriter , r * http . Request ) {
tk , err := ValidateToken ( w , r )
if err != nil {
w . WriteHeader ( http . StatusForbidden )
w . Header ( ) . Add ( "Content-Type" , "application/json" )
response := message ( true , "Refreshed expired" )
respond ( w , response )
return
}
// check if user exists in database
acc := GetUser ( tk . UserID )
if acc == nil {
respond ( w , message ( false , "User not found" ) )
return
}
acc . Password = ""
acc . GenerateTokenPair ( )
//account.CreateAuth(td)
resp := make ( map [ string ] interface { } )
resp [ "access_token" ] = acc . Token
resp [ "refresh_token" ] = acc . RefreshToken
//update refresh expirity
// rt := time.Now().Add(time.Minute * time.Duration(refreshExpiration)).Unix()
// log.Println(rt)
// db.Model(&acc).Update(acc.TokenExpiresAt, rt)
respond ( w , resp )
}
// ValidateToken to check if token is valid
func ValidateToken ( w http . ResponseWriter , r * http . Request ) ( * TokenDetails , error ) {
var pwd string
urltomatch := r . URL . String ( )
2020-07-21 11:20:06 +02:00
if strings . Contains ( urltomatch , "/api/user/refresh" ) {
2020-07-21 11:08:01 +02:00
pwd = "refresh_token_password"
} else {
pwd = "access_token_password"
}
response := make ( map [ string ] interface { } )
tk := & TokenDetails { }
bearToken := r . Header . Get ( "Authorization" ) //Grab the token from the header
if bearToken == "" { //Token is missing, returns with error code 403 Unauthorized
response = message ( false , "Missing auth token" )
w . WriteHeader ( http . StatusForbidden )
w . Header ( ) . Add ( "Content-Type" , "application/json" )
respond ( w , response )
2020-07-21 11:21:39 +02:00
return tk , errors . New ( "Missing auth token" )
2020-07-21 11:08:01 +02:00
}
splitted := strings . Split ( bearToken , " " ) //The token normally comes in format `Bearer {token-body}`, we check if the retrieved token matched this requirement
if len ( splitted ) != 2 {
response = message ( false , "Invalid/Malformed auth token" )
w . WriteHeader ( http . StatusForbidden )
w . Header ( ) . Add ( "Content-Type" , "application/json" )
respond ( w , response )
return tk , errors . New ( "Invalid/Malformed auth token" )
}
tokenPart := splitted [ 1 ] //Grab the token part, what we are truly interested in
token , err := jwt . ParseWithClaims ( tokenPart , tk , func ( token * jwt . Token ) ( interface { } , error ) {
return [ ] byte ( os . Getenv ( pwd ) ) , nil
} )
if err != nil { //Malformed token, returns with http code 403 as usual
response = message ( false , "Malformed authentication token" )
w . WriteHeader ( http . StatusForbidden )
w . Header ( ) . Add ( "Content-Type" , "application/json" )
respond ( w , response )
return tk , errors . New ( "Malformed authentication token" )
}
if ! token . Valid { //Token is invalid, maybe not signed on this server
response = message ( false , "Token is not valid." )
w . WriteHeader ( http . StatusForbidden )
w . Header ( ) . Add ( "Content-Type" , "application/json" )
respond ( w , response )
return tk , errors . New ( "Token is not valid" )
}
return tk , nil
}
func message ( status bool , message string ) map [ string ] interface { } {
2020-07-16 10:51:50 +02:00
return map [ string ] interface { } { "status" : status , "message" : message }
}
2020-07-21 11:08:01 +02:00
func respond ( w http . ResponseWriter , data map [ string ] interface { } ) {
2020-07-16 10:51:50 +02:00
w . Header ( ) . Add ( "Content-Type" , "application/json" )
json . NewEncoder ( w ) . Encode ( data )
}