diff --git a/main.go b/main.go index e912156..212b64c 100644 --- a/main.go +++ b/main.go @@ -32,6 +32,7 @@ type ShellRunner struct { mu sync.Mutex isStopped bool activeCmds map[*exec.Cmd]struct{} // Track active commands for cancellation + timeout time.Duration } // NewShellRunner creates a new ShellRunner instance with the default shell (`sh`). @@ -49,9 +50,14 @@ func NewShellRunnerWithShell(shell string) *ShellRunner { shell: shell, activeCmds: make(map[*exec.Cmd]struct{}), isStopped: true, + timeout: 0, } } +func (sr *ShellRunner) SetTimeout(timeout time.Duration) { + sr.timeout = timeout +} + // Start begins processing shell commands asynchronously. func (sr *ShellRunner) Start() { go func() { @@ -85,6 +91,11 @@ func (sr *ShellRunner) Start() { // The callback is executed asynchronously after the command has completed. // The order of command execution and callback invocation can be expected to be preserved. func (sr *ShellRunner) AddCommand(command string, callback func(*CommandResult)) error { + cmd, cancel := sr.newShellCommand(command) + return sr.AddCmd(cmd, callback, cancel) +} + +func (sr *ShellRunner) AddCmd(cmd *exec.Cmd, callback func(*CommandResult), cancel context.CancelFunc) error { sr.mu.Lock() defer sr.mu.Unlock() @@ -93,10 +104,12 @@ func (sr *ShellRunner) AddCommand(command string, callback func(*CommandResult)) } sr.cmdQueue <- func() *CommandResult { - result := sr.executeCommand(command) + result := sr.executeCommand(cmd, cancel) if callback != nil { + sr.running.Add(1) sr.callbacks <- func() { callback(result) + sr.running.Done() } } return result @@ -135,15 +148,6 @@ func (sr *ShellRunner) Stop() { // Kill stops the ShellRunner immediately, terminating all running commands. func (sr *ShellRunner) Kill() { sr.mu.Lock() - if sr.isStopped { - sr.mu.Unlock() - return - } - - sr.isStopped = true - - close(sr.cmdQueue) // Prevent further commands - close(sr.stopChan) // Terminate all active commands for cmd := range sr.activeCmds { @@ -151,7 +155,7 @@ func (sr *ShellRunner) Kill() { } sr.mu.Unlock() - sr.running.Wait() + sr.Stop() } // KillWithTimeout attempts to stop the ShellRunner, killing commands if the duration is exceeded. @@ -171,12 +175,28 @@ func (sr *ShellRunner) KillWithTimeout(timeout time.Duration) error { } } +func (sr *ShellRunner) newShellCommand(command string) (*exec.Cmd, context.CancelFunc) { + var ctx context.Context + + var cancel context.CancelFunc + + if sr.timeout > 0 { + ctx, cancel = context.WithTimeout(context.Background(), sr.timeout) + } else { + ctx = context.Background() + } + + return exec.CommandContext(ctx, sr.shell, "-c", command), cancel +} + // executeCommand runs a shell command asynchronously, capturing stdout, stderr, and return code. -func (sr *ShellRunner) executeCommand(command string) *CommandResult { +func (sr *ShellRunner) executeCommand(cmd *exec.Cmd, cancel context.CancelFunc) *CommandResult { + if cancel != nil { + defer cancel() + } + var outBuf, errBuf bytes.Buffer - ctx := context.Background() - cmd := exec.CommandContext(ctx, sr.shell, "-c", command) cmd.Stdout = &outBuf cmd.Stderr = &errBuf @@ -188,7 +208,7 @@ func (sr *ShellRunner) executeCommand(command string) *CommandResult { err := cmd.Start() if err != nil { return &CommandResult{ - Command: command, + Command: cmd.String(), ReturnCode: -1, ErrOutput: err.Error(), } @@ -202,7 +222,7 @@ func (sr *ShellRunner) executeCommand(command string) *CommandResult { sr.mu.Unlock() result := &CommandResult{ - Command: command, + Command: cmd.String(), Output: outBuf.String(), ErrOutput: errBuf.String(), } diff --git a/main_test.go b/main_test.go index dabeb15..694710b 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,8 @@ package tortoise_test import ( + "context" + "os/exec" "sync" "testing" "time" @@ -8,6 +10,10 @@ import ( "git.iamthefij.com/iamthefij/tortoise" ) +const ( + TaskStartWait = 10 * time.Millisecond +) + func TestShellRunnerNoCallback(t *testing.T) { t.Parallel() @@ -33,10 +39,17 @@ func TestShellRunnerNoCallback(t *testing.T) { t.Fatalf("unexpected error adding command: %v", err) } + // Wait a sec for the worker to pick up the task + time.Sleep(TaskStartWait) + runner.Stop() result := runner.GetResults() - if result == nil || result.Output != c.output || result.ReturnCode != c.ReturnCode { + if result == nil { + t.Fatal("expected result, got nil") + } + + if result.Output != c.output || result.ReturnCode != c.ReturnCode { t.Fatalf("expected output '%s' and return code %d, got '%s' and %d", c.output, c.ReturnCode, result.Output, result.ReturnCode) } }) @@ -65,6 +78,9 @@ func TestShellRunnerCallback(t *testing.T) { t.Fatalf("unexpected error adding command: %v", err) } + // Wait a sec for the worker to pick up the task + time.Sleep(TaskStartWait) + callbackWait.Add(1) if err := runner.AddCommand("echo callback b", func(result *tortoise.CommandResult) { @@ -91,7 +107,7 @@ func TestShellRunnerCallback(t *testing.T) { } if outputString != "ab" { - t.Fatal("callbacks was not reached in order:", outputString) + t.Fatal("callbacks were not reached in order:", outputString) } runner.Stop() @@ -119,8 +135,8 @@ func TestShellRunnerKillWithTimeout(t *testing.T) { t.Fatalf("unexpected error adding command: %v", err) } - // Wait one second to make sure the command starts running - time.Sleep(1 * time.Second) + // Wait a sec for the worker to pick up the task + time.Sleep(TaskStartWait) if err := runner.KillWithTimeout(1 * time.Second); err == nil { t.Fatal("expected error when killing commands, but got none") @@ -151,3 +167,60 @@ func TestAddingPriorToStart(t *testing.T) { t.Fatal("Should have failed to add prior to starting runner") } } + +func TestAddCmdWithTimeout(t *testing.T) { + t.Parallel() + + runner := tortoise.NewShellRunner() + runner.Start() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + + cmd := exec.CommandContext(ctx, "sleep", "10") + + err := runner.AddCmd(cmd, nil, cancel) + if err != nil { + t.Fatalf("unexpected error adding command: %v", err) + } + + // Wait a sec for the worker to pick up the task + time.Sleep(TaskStartWait) + + runner.Stop() + + result := runner.GetResults() + if result == nil { + t.Fatal("expected result, got nil") + } + + if result.ReturnCode != -1 { + t.Fatalf("expected return code -1, got %d", result.ReturnCode) + } +} + +func TestShellWithTimeout(t *testing.T) { + t.Parallel() + + runner := tortoise.NewShellRunner() + runner.SetTimeout(1 * time.Second) + runner.Start() + + err := runner.AddCommand("sleep 10", nil) + if err != nil { + t.Fatalf("unexpected error adding command: %v", err) + } + + // Wait a sec for the worker to pick up the task + time.Sleep(TaskStartWait) + + runner.Stop() + + result := runner.GetResults() + if result == nil { + t.Fatal("expected result, got nil") + } + + if result.ReturnCode != -1 { + t.Fatalf("expected return code -1, got %d", result.ReturnCode) + } +}