package module

import (
	"context"
	"fmt"
	"sync"

	"forge.cadoles.com/arcad/edge/pkg/app"
	"forge.cadoles.com/arcad/edge/pkg/bus"
	"forge.cadoles.com/arcad/edge/pkg/module/util"
	"github.com/dop251/goja"
	"github.com/pkg/errors"
	"gitlab.com/wpetit/goweb/logger"
)

type RPCRequest struct {
	Method string
	Params interface{}
	ID     interface{}
}

type RPCError struct {
	Code    int         `json:"code"`
	Message string      `json:"message"`
	Data    interface{} `json:"data"`
}

type RPCResponse struct {
	Result interface{}
	Error  *RPCError
	ID     interface{}
}

type RPCModule struct {
	server    *app.Server
	bus       bus.Bus
	callbacks sync.Map
}

func (m *RPCModule) Name() string {
	return "rpc"
}

func (m *RPCModule) Export(export *goja.Object) {
	if err := export.Set("register", m.register); err != nil {
		panic(errors.Wrap(err, "could not set 'register' function"))
	}

	if err := export.Set("unregister", m.unregister); err != nil {
		panic(errors.Wrap(err, "could not set 'unregister' function"))
	}
}

func (m *RPCModule) register(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
	fnName := util.AssertString(call.Argument(0), rt)

	var (
		callable goja.Callable
		ok       bool
	)

	if len(call.Arguments) > 1 {
		callable, ok = goja.AssertFunction(call.Argument(1))
	} else {
		callable, ok = goja.AssertFunction(rt.Get(fnName))
	}

	if !ok {
		panic(rt.NewTypeError("method should be a valid function"))
	}

	ctx := context.Background()

	logger.Debug(ctx, "registering method", logger.F("method", fnName))

	m.callbacks.Store(fnName, callable)

	return nil
}

func (m *RPCModule) unregister(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
	fnName := util.AssertString(call.Argument(0), rt)

	m.callbacks.Delete(fnName)

	return nil
}

func (m *RPCModule) handleMessages() {
	ctx := context.Background()

	clientMessages, err := m.bus.Subscribe(ctx, MessageNamespaceClient)
	if err != nil {
		panic(errors.WithStack(err))
	}

	defer func() {
		m.bus.Unsubscribe(ctx, MessageNamespaceClient, clientMessages)
	}()

	sendRes := func(ctx context.Context, req *RPCRequest, result goja.Value) {
		res := &RPCResponse{
			ID:     req.ID,
			Result: result.Export(),
		}

		logger.Debug(ctx, "sending rpc response", logger.F("response", res))

		if err := m.sendResponse(ctx, res); err != nil {
			logger.Error(
				ctx, "could not send response",
				logger.E(errors.WithStack(err)),
				logger.F("response", res),
				logger.F("request", req),
			)
		}
	}

	for msg := range clientMessages {
		clientMessage, ok := msg.(*ClientMessage)
		if !ok {
			logger.Warn(ctx, "unexpected bus message", logger.F("message", msg))

			continue
		}

		ok, req := m.isRPCRequest(clientMessage)
		if !ok {
			continue
		}

		logger.Debug(ctx, "received rpc request", logger.F("request", req))

		rawCallable, exists := m.callbacks.Load(req.Method)
		if !exists {
			logger.Debug(ctx, "method not found", logger.F("req", req))

			if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil {
				logger.Error(
					ctx, "could not send method not found response",
					logger.E(errors.WithStack(err)),
					logger.F("request", req),
				)
			}

			continue
		}

		callable, ok := rawCallable.(goja.Callable)
		if !ok {
			logger.Debug(ctx, "invalid method", logger.F("req", req))

			if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil {
				logger.Error(
					ctx, "could not send method not found response",
					logger.E(errors.WithStack(err)),
					logger.F("request", req),
				)
			}

			continue
		}

		result, err := m.server.Exec(clientMessage.Context, callable, clientMessage.Context, req.Params)
		if err != nil {
			logger.Error(
				ctx, "rpc call error",
				logger.E(errors.WithStack(err)),
				logger.F("request", req),
			)

			if err := m.sendErrorResponse(clientMessage.Context, req, err); err != nil {
				logger.Error(
					ctx, "could not send error response",
					logger.E(errors.WithStack(err)),
					logger.F("originalError", err),
					logger.F("request", req),
				)
			}

			continue
		}

		promise, ok := m.server.IsPromise(result)
		if ok {
			go func(ctx context.Context, req *RPCRequest, promise *goja.Promise) {
				result := m.server.WaitForPromise(promise)
				sendRes(ctx, req, result)
			}(clientMessage.Context, req, promise)
		} else {
			sendRes(clientMessage.Context, req, result)
		}
	}
}

func (m *RPCModule) sendErrorResponse(ctx context.Context, req *RPCRequest, err error) error {
	return m.sendResponse(ctx, &RPCResponse{
		ID:     req.ID,
		Result: nil,
		Error: &RPCError{
			Code:    -32603,
			Message: err.Error(),
		},
	})
}

func (m *RPCModule) sendMethodNotFoundResponse(ctx context.Context, req *RPCRequest) error {
	return m.sendResponse(ctx, &RPCResponse{
		ID:     req.ID,
		Result: nil,
		Error: &RPCError{
			Code:    -32601,
			Message: fmt.Sprintf("method not found"),
		},
	})
}

func (m *RPCModule) sendResponse(ctx context.Context, res *RPCResponse) error {
	msg := NewServerMessage(ctx, map[string]interface{}{
		"jsonrpc": "2.0",
		"id":      res.ID,
		"error":   res.Error,
		"result":  res.Result,
	})

	if err := m.bus.Publish(ctx, msg); err != nil {
		return errors.WithStack(err)
	}

	return nil
}

func (m *RPCModule) isRPCRequest(msg *ClientMessage) (bool, *RPCRequest) {
	jsonRPC, exists := msg.Data["jsonrpc"]
	if !exists || jsonRPC != "2.0" {
		return false, nil
	}

	rawMethod, exists := msg.Data["method"]
	if !exists {
		return false, nil
	}

	method, ok := rawMethod.(string)
	if !ok {
		return false, nil
	}

	id := msg.Data["id"]
	params := msg.Data["params"]

	return true, &RPCRequest{
		ID:     id,
		Method: method,
		Params: params,
	}
}

func RPCModuleFactory(bus bus.Bus) app.ServerModuleFactory {
	return func(server *app.Server) app.ServerModule {
		mod := &RPCModule{
			server: server,
			bus:    bus,
		}

		go mod.handleMessages()

		return mod
	}
}