From 9367e5e79e135246eefd230ec32052ed089902b3 Mon Sep 17 00:00:00 2001 From: William Petit Date: Wed, 26 Sep 2018 12:05:55 +0200 Subject: [PATCH] Use context.Context to provide timeout detection to websocket calls --- emlid/client.go | 22 +++++++++++++--- emlid/updater/example_test.go | 9 ++++++- emlid/updater/reachview_version.go | 6 +++-- emlid/updater/reachview_version_test.go | 7 +++-- emlid/updater/reboot_now.go | 21 ++++++++++++--- emlid/updater/reboot_now_test.go | 6 ++++- emlid/updater/receiver_upgrade.go | 6 +++-- emlid/updater/receiver_upgrade_test.go | 6 ++++- emlid/updater/test_results.go | 6 +++-- emlid/updater/test_results_test.go | 6 ++++- emlid/updater/time_sync.go | 6 +++-- emlid/updater/time_sync_test.go | 6 ++++- emlid/updater/update.go | 6 +++-- emlid/updater/update_test.go | 6 ++++- emlid/updater/{test.go => util_test.go} | 0 emlid/updater/wifi_networks.go | 33 ++++++++++++++++------- emlid/updater/wifi_networks_test.go | 30 ++++++++++++++++----- example/updater/main.go | 35 ++++++++++++++++++++----- 18 files changed, 170 insertions(+), 47 deletions(-) rename emlid/updater/{test.go => util_test.go} (100%) diff --git a/emlid/client.go b/emlid/client.go index 093d1b4..0b6325c 100644 --- a/emlid/client.go +++ b/emlid/client.go @@ -1,6 +1,7 @@ package emlid import ( + "context" "sync" "forge.cadoles.com/Pyxis/golang-socketio" @@ -98,23 +99,36 @@ func (c *Client) Off(event string) { } // ReqResp emits an event with the given data and waits for a response -func (c *Client) ReqResp(requestEvent string, requestData interface{}, responseEvent string, res interface{}) error { +func (c *Client) ReqResp(ctx context.Context, + requestEvent string, requestData interface{}, + responseEvent string, res interface{}) error { var err error var wg sync.WaitGroup + var once sync.Once + + done := func() { + c.conn.Off(responseEvent) + wg.Done() + } wg.Add(1) + go func() { + <-ctx.Done() + err = ctx.Err() + once.Do(done) + }() + err = c.conn.On(responseEvent, func(_ *gosocketio.Channel, data interface{}) { err = mapstructure.Decode(data, res) - c.conn.Off(responseEvent) - wg.Done() + once.Do(done) }) if err != nil { return errors.Wrapf(err, "error while binding to '%s' event", responseEvent) } - if err := c.Emit(requestEvent, requestData); err != nil { + if err = c.Emit(requestEvent, requestData); err != nil { return errors.Wrapf(err, "error while emitting event '%s'", requestEvent) } diff --git a/emlid/updater/example_test.go b/emlid/updater/example_test.go index 4e4f1e9..29e59fc 100644 --- a/emlid/updater/example_test.go +++ b/emlid/updater/example_test.go @@ -1,7 +1,9 @@ package updater import ( + "context" "log" + "time" "forge.cadoles.com/Pyxis/orion/emlid" ) @@ -17,7 +19,12 @@ func Example_usage() { log.Fatal(err) } - networks, err := updater.WifiNetworks() + // We create a context for the API call with a 10 second delay + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Retrieve the Wifi networks + networks, err := updater.WifiNetworks(ctx) if err != nil { log.Fatal(err) } diff --git a/emlid/updater/reachview_version.go b/emlid/updater/reachview_version.go index c0640d8..e7929a9 100644 --- a/emlid/updater/reachview_version.go +++ b/emlid/updater/reachview_version.go @@ -1,5 +1,7 @@ package updater +import "context" + const ( eventGetReachViewVersion = "get reachview version" eventReachViewVersionResults = "current reachview version" @@ -10,9 +12,9 @@ type reachViewVersion struct { } // ReachViewVersion returns the ReachRS module ReachView version -func (c *Client) ReachViewVersion() (string, error) { +func (c *Client) ReachViewVersion(ctx context.Context) (string, error) { res := &reachViewVersion{} - if err := c.ReqResp(eventGetReachViewVersion, nil, eventReachViewVersionResults, res); err != nil { + if err := c.ReqResp(ctx, eventGetReachViewVersion, nil, eventReachViewVersionResults, res); err != nil { return "", err } return res.Version, nil diff --git a/emlid/updater/reachview_version_test.go b/emlid/updater/reachview_version_test.go index 9ae46a4..48008c0 100644 --- a/emlid/updater/reachview_version_test.go +++ b/emlid/updater/reachview_version_test.go @@ -1,7 +1,9 @@ package updater import ( + "context" "testing" + "time" "forge.cadoles.com/Pyxis/orion/emlid" ) @@ -19,8 +21,9 @@ func TestClientReachViewVersion(t *testing.T) { if err := client.Connect(); err != nil { t.Fatal(err) } - - version, err := client.ReachViewVersion() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + version, err := client.ReachViewVersion(ctx) if err != nil { t.Error(err) } diff --git a/emlid/updater/reboot_now.go b/emlid/updater/reboot_now.go index 3003c53..de09855 100644 --- a/emlid/updater/reboot_now.go +++ b/emlid/updater/reboot_now.go @@ -1,6 +1,7 @@ package updater import ( + "context" "sync" "forge.cadoles.com/Pyxis/golang-socketio" @@ -12,16 +13,30 @@ const ( ) // RebootNow asks the ReachRS module to reboot now -func (c *Client) RebootNow(waitDisconnect bool) error { +func (c *Client) RebootNow(ctx context.Context, waitDisconnect bool) error { var err error var wg sync.WaitGroup if waitDisconnect { - wg.Add(1) - err = c.On(gosocketio.OnDisconnection, func(h *gosocketio.Channel) { + + var once sync.Once + + done := func() { c.Off(gosocketio.OnDisconnection) wg.Done() + } + + wg.Add(1) + + go func() { + <-ctx.Done() + err = ctx.Err() + once.Do(done) + }() + + err = c.On(gosocketio.OnDisconnection, func(h *gosocketio.Channel) { + once.Do(done) }) if err != nil { return errors.Wrapf(err, "error while binding to '%s' event", gosocketio.OnDisconnection) diff --git a/emlid/updater/reboot_now_test.go b/emlid/updater/reboot_now_test.go index d3a531a..0527776 100644 --- a/emlid/updater/reboot_now_test.go +++ b/emlid/updater/reboot_now_test.go @@ -1,7 +1,9 @@ package updater import ( + "context" "testing" + "time" "forge.cadoles.com/Pyxis/orion/emlid" ) @@ -24,7 +26,9 @@ func TestClientRebootNow(t *testing.T) { t.Fatal(err) } - if err := client.RebootNow(true); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := client.RebootNow(ctx, true); err != nil { t.Error(err) } diff --git a/emlid/updater/receiver_upgrade.go b/emlid/updater/receiver_upgrade.go index 6ceba52..07c8675 100644 --- a/emlid/updater/receiver_upgrade.go +++ b/emlid/updater/receiver_upgrade.go @@ -1,5 +1,7 @@ package updater +import "context" + const ( eventIsReceiverUpgradeAvailable = "is receiver upgrade available" eventReceiverUpgradeAvailable = "receiver upgrade available" @@ -11,9 +13,9 @@ type receiverUpgreAvailable struct { } // ReceiverUpgradeAvailable checks if an upgrade is avaialable/running for the ReachRS module -func (c *Client) ReceiverUpgradeAvailable() (bool, bool, error) { +func (c *Client) ReceiverUpgradeAvailable(ctx context.Context) (bool, bool, error) { res := &receiverUpgreAvailable{} - if err := c.ReqResp(eventIsReceiverUpgradeAvailable, nil, eventReceiverUpgradeAvailable, res); err != nil { + if err := c.ReqResp(ctx, eventIsReceiverUpgradeAvailable, nil, eventReceiverUpgradeAvailable, res); err != nil { return false, false, err } c.Logf("receiver upgrade result: available: %v, running: %v", res.Available, res.Running) diff --git a/emlid/updater/receiver_upgrade_test.go b/emlid/updater/receiver_upgrade_test.go index 65913bd..fa8b565 100644 --- a/emlid/updater/receiver_upgrade_test.go +++ b/emlid/updater/receiver_upgrade_test.go @@ -1,7 +1,9 @@ package updater import ( + "context" "testing" + "time" "forge.cadoles.com/Pyxis/orion/emlid" ) @@ -20,7 +22,9 @@ func TestClientReceiverUpgradeAvailable(t *testing.T) { t.Fatal(err) } - _, _, err := client.ReceiverUpgradeAvailable() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, _, err := client.ReceiverUpgradeAvailable(ctx) if err != nil { t.Error(err) } diff --git a/emlid/updater/test_results.go b/emlid/updater/test_results.go index e5342ff..1ad75aa 100644 --- a/emlid/updater/test_results.go +++ b/emlid/updater/test_results.go @@ -1,5 +1,7 @@ package updater +import "context" + const ( eventGetTestResults = "get test results" eventTestResults = "test results" @@ -16,9 +18,9 @@ type TestResults struct { } // TestResults returns the ReachRS module tests results -func (c *Client) TestResults() (*TestResults, error) { +func (c *Client) TestResults(ctx context.Context) (*TestResults, error) { res := &TestResults{} - if err := c.ReqResp(eventGetTestResults, nil, eventTestResults, res); err != nil { + if err := c.ReqResp(ctx, eventGetTestResults, nil, eventTestResults, res); err != nil { return nil, err } return res, nil diff --git a/emlid/updater/test_results_test.go b/emlid/updater/test_results_test.go index 57bf5a3..4a6888f 100644 --- a/emlid/updater/test_results_test.go +++ b/emlid/updater/test_results_test.go @@ -1,7 +1,9 @@ package updater import ( + "context" "testing" + "time" "forge.cadoles.com/Pyxis/orion/emlid" ) @@ -20,7 +22,9 @@ func TestClientTestResults(t *testing.T) { t.Fatal(err) } - results, err := client.TestResults() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + results, err := client.TestResults(ctx) if err != nil { t.Error(err) } diff --git a/emlid/updater/time_sync.go b/emlid/updater/time_sync.go index afaefd2..3613605 100644 --- a/emlid/updater/time_sync.go +++ b/emlid/updater/time_sync.go @@ -1,5 +1,7 @@ package updater +import "context" + const ( eventGetTimeSyncStatus = "get time sync status" eventTimeSyncResults = "time sync status" @@ -11,9 +13,9 @@ type timeSyncStatus struct { // TimeSynced returns the ReachRS module time synchronization status. // A true response means that the module has synchronized its clock. -func (c *Client) TimeSynced() (bool, error) { +func (c *Client) TimeSynced(ctx context.Context) (bool, error) { res := &timeSyncStatus{} - if err := c.ReqResp(eventGetTimeSyncStatus, nil, eventTimeSyncResults, res); err != nil { + if err := c.ReqResp(ctx, eventGetTimeSyncStatus, nil, eventTimeSyncResults, res); err != nil { return false, err } c.Logf("time sync result: %v", res.Status) diff --git a/emlid/updater/time_sync_test.go b/emlid/updater/time_sync_test.go index c625425..39ed4cd 100644 --- a/emlid/updater/time_sync_test.go +++ b/emlid/updater/time_sync_test.go @@ -1,7 +1,9 @@ package updater import ( + "context" "testing" + "time" "forge.cadoles.com/Pyxis/orion/emlid" ) @@ -20,7 +22,9 @@ func TestClientTimeSync(t *testing.T) { t.Fatal(err) } - _, err := client.TimeSynced() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := client.TimeSynced(ctx) if err != nil { t.Error(err) } diff --git a/emlid/updater/update.go b/emlid/updater/update.go index 3fc5b44..4faa983 100644 --- a/emlid/updater/update.go +++ b/emlid/updater/update.go @@ -1,5 +1,7 @@ package updater +import "context" + const ( eventUpdate = "update" eventOPKGUpdateResult = "opkg update result" @@ -13,9 +15,9 @@ type UpdateStatus struct { } // Update asks the ReachRS module to start an OPKG update -func (c *Client) Update() (*UpdateStatus, error) { +func (c *Client) Update(ctx context.Context) (*UpdateStatus, error) { res := &UpdateStatus{} - if err := c.ReqResp(eventUpdate, nil, eventOPKGUpdateResult, res); err != nil { + if err := c.ReqResp(ctx, eventUpdate, nil, eventOPKGUpdateResult, res); err != nil { return nil, err } c.Logf( diff --git a/emlid/updater/update_test.go b/emlid/updater/update_test.go index 39aa897..d54b121 100644 --- a/emlid/updater/update_test.go +++ b/emlid/updater/update_test.go @@ -1,7 +1,9 @@ package updater import ( + "context" "testing" + "time" "forge.cadoles.com/Pyxis/orion/emlid" ) @@ -20,7 +22,9 @@ func TestClientOPKGUpdate(t *testing.T) { t.Fatal(err) } - _, err := client.Update() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := client.Update(ctx) if err != nil { t.Error(err) } diff --git a/emlid/updater/test.go b/emlid/updater/util_test.go similarity index 100% rename from emlid/updater/test.go rename to emlid/updater/util_test.go diff --git a/emlid/updater/wifi_networks.go b/emlid/updater/wifi_networks.go index dbc6326..7683e9b 100644 --- a/emlid/updater/wifi_networks.go +++ b/emlid/updater/wifi_networks.go @@ -1,6 +1,7 @@ package updater import ( + "context" "sync" "forge.cadoles.com/Pyxis/golang-socketio" @@ -41,48 +42,62 @@ type WifiNetwork struct { } // WifiNetworks returns the ReachRS module wifi networks -func (c *Client) WifiNetworks() ([]WifiNetwork, error) { +func (c *Client) WifiNetworks(ctx context.Context) ([]WifiNetwork, error) { res := make([]WifiNetwork, 0) - if err := c.ReqResp(eventGetSavedWifiNetworks, nil, eventSavedWifiNetworkResults, &res); err != nil { + if err := c.ReqResp(ctx, eventGetSavedWifiNetworks, nil, eventSavedWifiNetworkResults, &res); err != nil { return nil, err } return res, nil } // AddWifiNetwork asks the ReachRS module to save the given wifi network informations -func (c *Client) AddWifiNetwork(ssid string, security WifiSecurity, password string) (bool, error) { +func (c *Client) AddWifiNetwork(ctx context.Context, ssid string, security WifiSecurity, password string) (bool, error) { res := false network := &WifiNetwork{ SSID: ssid, Security: security, Password: password, } - if err := c.ReqResp(eventAddWifiNetwork, network, eventAddWifiNetworkResults, &res); err != nil { + if err := c.ReqResp(ctx, eventAddWifiNetwork, network, eventAddWifiNetworkResults, &res); err != nil { return false, err } return res, nil } // RemoveWifiNetwork asks the ReachRS module to remove the given WiFi network -func (c *Client) RemoveWifiNetwork(ssid string) (bool, error) { +func (c *Client) RemoveWifiNetwork(ctx context.Context, ssid string) (bool, error) { res := false - if err := c.ReqResp(eventRemoveWifiNetwork, ssid, eventRemoveWifiNetworkResults, &res); err != nil { + if err := c.ReqResp(ctx, eventRemoveWifiNetwork, ssid, eventRemoveWifiNetworkResults, &res); err != nil { return false, err } return res, nil } // JoinWifiNetwork asks the ReachRS module to join the given WiFi network -func (c *Client) JoinWifiNetwork(ssid string, waitDisconnect bool) error { +func (c *Client) JoinWifiNetwork(ctx context.Context, ssid string, waitDisconnect bool) error { var err error var wg sync.WaitGroup if waitDisconnect { - wg.Add(1) - err = c.On(gosocketio.OnDisconnection, func(h *gosocketio.Channel) { + + var once sync.Once + + done := func() { c.Off(gosocketio.OnDisconnection) wg.Done() + } + + wg.Add(1) + + go func() { + <-ctx.Done() + err = ctx.Err() + once.Do(done) + }() + + err = c.On(gosocketio.OnDisconnection, func(h *gosocketio.Channel) { + once.Do(done) }) if err != nil { return errors.Wrapf(err, "error while binding to '%s' event", gosocketio.OnDisconnection) diff --git a/emlid/updater/wifi_networks_test.go b/emlid/updater/wifi_networks_test.go index 134552e..7ca5dda 100644 --- a/emlid/updater/wifi_networks_test.go +++ b/emlid/updater/wifi_networks_test.go @@ -1,6 +1,7 @@ package updater import ( + "context" "fmt" "math/rand" "testing" @@ -23,7 +24,9 @@ func TestClientSavedWiFiNetworks(t *testing.T) { t.Fatal(err) } - _, err := client.WifiNetworks() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err := client.WifiNetworks(ctx) if err != nil { t.Error(err) } @@ -48,7 +51,11 @@ func TestClientCRUDWiFiNetwork(t *testing.T) { ssid := fmt.Sprintf("wifi_test_%d", rand.Uint32()) - done, err := client.AddWifiNetwork(ssid, SecurityOpen, "") + ctx := context.Background() + + addWifiContext, addWifiCancel := context.WithTimeout(ctx, 5*time.Second) + defer addWifiCancel() + done, err := client.AddWifiNetwork(addWifiContext, ssid, SecurityOpen, "") if err != nil { t.Error(err) } @@ -57,7 +64,9 @@ func TestClientCRUDWiFiNetwork(t *testing.T) { t.Errorf("AddWifiNetwork() -> done: got '%v', expected '%v'", g, e) } - networks, err := client.WifiNetworks() + wifiContext, wifiCancel := context.WithTimeout(ctx, 5*time.Second) + defer wifiCancel() + networks, err := client.WifiNetworks(wifiContext) if err != nil { t.Error(err) } @@ -74,7 +83,9 @@ func TestClientCRUDWiFiNetwork(t *testing.T) { t.Errorf("wifi network '%s' should exists", ssid) } - done, err = client.RemoveWifiNetwork(ssid) + removeWifiContext, removeWifiCancel := context.WithTimeout(ctx, 5*time.Second) + defer removeWifiCancel() + done, err = client.RemoveWifiNetwork(removeWifiContext, ssid) if err != nil { t.Error(err) } @@ -107,7 +118,12 @@ func TestClientWifiNetworkJoin(t *testing.T) { ssid := fmt.Sprintf("wifi_test_%d", rand.Uint32()) - done, err := client.AddWifiNetwork(ssid, SecurityOpen, "") + ctx := context.Background() + + addWifiContext, addWifiCancel := context.WithTimeout(ctx, 5*time.Second) + defer addWifiCancel() + + done, err := client.AddWifiNetwork(addWifiContext, ssid, SecurityOpen, "") if err != nil { t.Error(err) } @@ -116,7 +132,9 @@ func TestClientWifiNetworkJoin(t *testing.T) { t.Errorf("AddWifiNetwork() -> done: got '%v', expected '%v'", g, e) } - if err := client.JoinWifiNetwork(ssid, true); err != nil { + joinWifiContext, joinWifiCancel := context.WithTimeout(ctx, 5*time.Second) + defer joinWifiCancel() + if err := client.JoinWifiNetwork(joinWifiContext, ssid, true); err != nil { t.Error(err) } diff --git a/example/updater/main.go b/example/updater/main.go index 94c5600..6b4a14c 100644 --- a/example/updater/main.go +++ b/example/updater/main.go @@ -1,10 +1,12 @@ package main import ( + "context" "flag" "fmt" "log" "strings" + "time" "forge.cadoles.com/Pyxis/orion/emlid" "forge.cadoles.com/Pyxis/orion/emlid/updater" @@ -78,8 +80,13 @@ func configureWifi() { c := connect() defer c.Close() + ctx := context.Background() + log.Println("checking module status") - results, err := c.TestResults() + + ctx, testResultsCancel := context.WithTimeout(ctx, 5*time.Second) + defer testResultsCancel() + results, err := c.TestResults(ctx) if err != nil { log.Fatal(err) } @@ -92,7 +99,9 @@ func configureWifi() { log.Printf("adding wifi network '%s'", ssid) - done, err := c.AddWifiNetwork(ssid, updater.WifiSecurity(security), password) + ctx, addWifiCancel := context.WithTimeout(ctx, 5*time.Second) + defer addWifiCancel() + done, err := c.AddWifiNetwork(ctx, ssid, updater.WifiSecurity(security), password) if err != nil { log.Fatal(err) } @@ -102,7 +111,9 @@ func configureWifi() { } log.Println("connecting module to wifi network") - if err := c.JoinWifiNetwork(ssid, true); err != nil { + ctx, joinWifiCancel := context.WithTimeout(ctx, 5*time.Second) + defer joinWifiCancel() + if err := c.JoinWifiNetwork(ctx, ssid, true); err != nil { log.Fatal(err) } log.Printf("you can now switch to the wifi network and start phase '%s'", phaseUpdateThenReboot) @@ -114,22 +125,30 @@ func updateThenReboot() { c := connect() defer c.Close() + ctx := context.Background() + log.Println("checking time sync") - synced, err := c.TimeSynced() + ctx, timeSyncedCancel := context.WithTimeout(ctx, 5*time.Second) + defer timeSyncedCancel() + synced, err := c.TimeSynced(ctx) if err != nil { log.Fatal(err) } log.Printf("time synced ? %v", synced) log.Println("checking reachview version") - version, err := c.ReachViewVersion() + ctx, reachviewVersionCancel := context.WithTimeout(ctx, 5*time.Second) + defer reachviewVersionCancel() + version, err := c.ReachViewVersion(ctx) if err != nil { log.Fatal(err) } log.Printf("reachview version ? '%s'", version) log.Println("checking for update") - status, err := c.Update() + ctx, updateCancel := context.WithTimeout(ctx, 5*time.Second) + defer updateCancel() + status, err := c.Update(ctx) if err != nil { log.Fatal(err) } @@ -143,7 +162,9 @@ func updateThenReboot() { } log.Println("rebooting device") - if err := c.RebootNow(true); err != nil { + ctx, rebootCancel := context.WithTimeout(ctx, 5*time.Second) + defer rebootCancel() + if err := c.RebootNow(ctx, true); err != nil { log.Fatal(err) }