From 9f32c6c43f98195d55b18b3a6422eb926bba7035 Mon Sep 17 00:00:00 2001 From: Ian Fijolek Date: Wed, 23 Oct 2024 13:34:44 -0700 Subject: [PATCH] Unbuffer callbacks so they are executed in order --- main.go | 4 +++- main_test.go | 42 +++++++++++++++++++++++++++++++----------- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/main.go b/main.go index 22297ff..e912156 100644 --- a/main.go +++ b/main.go @@ -45,7 +45,7 @@ func NewShellRunnerWithShell(shell string) *ShellRunner { cmdQueue: make(chan func() *CommandResult), results: make(chan *CommandResult, MAX_RESULTS), stopChan: make(chan struct{}), - callbacks: make(chan func(), MAX_RESULTS), + callbacks: make(chan func()), shell: shell, activeCmds: make(map[*exec.Cmd]struct{}), isStopped: true, @@ -82,6 +82,8 @@ func (sr *ShellRunner) Start() { // AddCommand adds a shell command to be executed with an optional callback. // No commands can be added if the runner has been stopped or not yet started. +// The callback is executed asynchronously after the command has completed. +// The order of command execution and callback invocation can be expected to be preserved. func (sr *ShellRunner) AddCommand(command string, callback func(*CommandResult)) error { sr.mu.Lock() defer sr.mu.Unlock() diff --git a/main_test.go b/main_test.go index 8d1052d..dabeb15 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,7 @@ package tortoise_test import ( + "sync" "testing" "time" @@ -49,29 +50,48 @@ func TestShellRunnerCallback(t *testing.T) { runner.Start() // Test command with callback - done := make(chan struct{}) + outputString := "" + callbackWait := sync.WaitGroup{} - callbackReached := false + callbackWait.Add(1) - if err := runner.AddCommand("echo callback test", func(result *tortoise.CommandResult) { - callbackReached = true - if result.Output != "callback test\n" { - t.Fatalf("expected 'callback test', got '%s'", result.Output) + if err := runner.AddCommand("echo callback a", func(result *tortoise.CommandResult) { + if result.Output != "callback a\n" { + t.Fatalf("expected 'callback a', got '%s'", result.Output) } - close(done) + outputString = outputString + "a" + callbackWait.Done() }); err != nil { t.Fatalf("unexpected error adding command: %v", err) } - // Timeout waiting for callback + callbackWait.Add(1) + + if err := runner.AddCommand("echo callback b", func(result *tortoise.CommandResult) { + if result.Output != "callback b\n" { + t.Fatalf("expected 'callback b', got '%s'", result.Output) + } + outputString = outputString + "b" + callbackWait.Done() + }); err != nil { + t.Fatalf("unexpected error adding command: %v", err) + } + + // Timeout waiting for callbacks + done := make(chan struct{}) + go func() { + callbackWait.Wait() + close(done) + }() + select { case <-done: case <-time.After(2 * time.Second): - t.Fatal("callback timed out") + t.Fatal("callbacks timed out") } - if !callbackReached { - t.Fatal("callback was not reached") + if outputString != "ab" { + t.Fatal("callbacks was not reached in order:", outputString) } runner.Stop()