diff --git a/reach/reboot_now.go b/reach/reboot_now.go index 3ccdcf3..18eaaf9 100644 --- a/reach/reboot_now.go +++ b/reach/reboot_now.go @@ -1,6 +1,9 @@ package reach import ( + "sync" + + "forge.cadoles.com/Pyxis/golang-socketio" "github.com/pkg/errors" ) @@ -9,9 +12,21 @@ const ( ) // RebootNow asks the ReachRS module to reboot now -func (u *Updater) RebootNow() error { +func (u *Updater) RebootNow(waitDisconnect bool) error { var err error + var wg sync.WaitGroup + + if waitDisconnect { + wg.Add(1) + err = u.conn.On(gosocketio.OnDisconnection, func(h *gosocketio.Channel) { + u.conn.Off(gosocketio.OnDisconnection) + wg.Done() + }) + if err != nil { + return errors.Wrapf(err, "error while binding to '%s' event", gosocketio.OnDisconnection) + } + } u.logf("sending '%s' event", eventReboot) if err = u.conn.Emit(eventReboot, nil); err != nil { @@ -19,6 +34,10 @@ func (u *Updater) RebootNow() error { } u.logf("'%s' event sent", eventReboot) + if waitDisconnect { + wg.Wait() + } + return err } diff --git a/reach/reboot_now_test.go b/reach/reboot_now_test.go index 41fbc8f..21bac03 100644 --- a/reach/reboot_now_test.go +++ b/reach/reboot_now_test.go @@ -18,7 +18,7 @@ func TestClientReboutNow(t *testing.T) { t.Fatal(err) } - if err := client.RebootNow(); err != nil { + if err := client.RebootNow(true); err != nil { t.Error(err) }