fake-smtp/internal/command/relay_email.go

165 lines
2.9 KiB
Go
Raw Permalink Normal View History

2020-11-05 19:48:18 +01:00
package command
import (
"context"
"crypto/tls"
"io"
"sync"
"github.com/emersion/go-sasl"
"github.com/emersion/go-smtp"
"forge.cadoles.com/wpetit/fake-smtp/internal/config"
"github.com/jhillyerd/enmime"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/cqrs"
"gitlab.com/wpetit/goweb/middleware/container"
)
type RelayEmailRequest struct {
Envelope *enmime.Envelope
}
func HandleRelayEmail(ctx context.Context, cmd cqrs.Command) error {
req, ok := cmd.Request().(*RelayEmailRequest)
if !ok {
return cqrs.ErrUnexpectedRequest
}
ctn, err := container.From(ctx)
if err != nil {
return errors.Wrap(err, "could not retrieve service container")
}
conf, err := config.From(ctn)
if err != nil {
return errors.Wrap(err, "could not retrieve config service")
}
relay := conf.Relay
if err := forwardMail(req.Envelope, relay); err != nil {
return errors.Wrap(err, "could not forward mail")
}
return nil
}
func forwardMail(env *enmime.Envelope, conf config.RelayConfig) error {
var tlsConfig *tls.Config
if conf.InsecureSkipVerify {
tlsConfig = &tls.Config{
// nolint: gosec
InsecureSkipVerify: true,
}
}
addr := conf.Address
var (
client *smtp.Client
err error
)
if conf.UseTLS {
client, err = smtp.DialTLS(addr, tlsConfig)
if err != nil {
return errors.WithStack(err)
}
} else {
client, err = smtp.Dial(addr)
if err != nil {
return errors.WithStack(err)
}
}
defer client.Close()
if err = client.Hello("localhost"); err != nil {
return errors.WithStack(err)
}
if ok, _ := client.Extension("STARTTLS"); ok {
if err = client.StartTLS(tlsConfig); err != nil {
return errors.WithStack(err)
}
}
if conf.Username != "" || conf.Password != "" {
if ok, _ := client.Extension("AUTH"); ok {
var auth sasl.Client
if conf.Anonymous {
auth = sasl.NewAnonymousClient("fakesmtp")
} else {
auth = sasl.NewPlainClient(conf.Identity, conf.Username, conf.Password)
}
if err := client.Auth(auth); err != nil {
return errors.WithStack(err)
}
}
}
var from string
if conf.FromOverride != "" {
from = conf.FromOverride
} else {
from = env.GetHeader("From")
}
if err = client.Mail(from, nil); err != nil {
return errors.WithStack(err)
}
to := env.GetHeaderValues("To")
for _, addr := range to {
if err = client.Rcpt(addr); err != nil {
return errors.WithStack(err)
}
}
w, err := client.Data()
if err != nil {
return errors.WithStack(err)
}
pr, pw := io.Pipe()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
defer pw.Close()
if err = env.Root.Encode(pw); err != nil {
err = errors.WithStack(err)
}
}()
_, err = io.Copy(w, pr)
if err != nil {
return err
}
wg.Wait()
if err != nil {
return errors.WithStack(err)
}
if err = w.Close(); err != nil {
return err
}
if err := client.Quit(); err != nil {
return errors.WithStack(err)
}
return nil
}