Add timeouts and the ability to pass in a Cmd directly

This commit is contained in:
IamTheFij 2024-11-18 22:33:14 -08:00
parent 9f32c6c43f
commit 0adf15dc3f
2 changed files with 113 additions and 20 deletions

52
main.go
View File

@ -32,6 +32,7 @@ type ShellRunner struct {
mu sync.Mutex
isStopped bool
activeCmds map[*exec.Cmd]struct{} // Track active commands for cancellation
timeout time.Duration
}
// NewShellRunner creates a new ShellRunner instance with the default shell (`sh`).
@ -49,9 +50,14 @@ func NewShellRunnerWithShell(shell string) *ShellRunner {
shell: shell,
activeCmds: make(map[*exec.Cmd]struct{}),
isStopped: true,
timeout: 0,
}
}
func (sr *ShellRunner) SetTimeout(timeout time.Duration) {
sr.timeout = timeout
}
// Start begins processing shell commands asynchronously.
func (sr *ShellRunner) Start() {
go func() {
@ -85,6 +91,11 @@ func (sr *ShellRunner) Start() {
// The callback is executed asynchronously after the command has completed.
// The order of command execution and callback invocation can be expected to be preserved.
func (sr *ShellRunner) AddCommand(command string, callback func(*CommandResult)) error {
cmd, cancel := sr.newShellCommand(command)
return sr.AddCmd(cmd, callback, cancel)
}
func (sr *ShellRunner) AddCmd(cmd *exec.Cmd, callback func(*CommandResult), cancel context.CancelFunc) error {
sr.mu.Lock()
defer sr.mu.Unlock()
@ -93,10 +104,12 @@ func (sr *ShellRunner) AddCommand(command string, callback func(*CommandResult))
}
sr.cmdQueue <- func() *CommandResult {
result := sr.executeCommand(command)
result := sr.executeCommand(cmd, cancel)
if callback != nil {
sr.running.Add(1)
sr.callbacks <- func() {
callback(result)
sr.running.Done()
}
}
return result
@ -135,15 +148,6 @@ func (sr *ShellRunner) Stop() {
// Kill stops the ShellRunner immediately, terminating all running commands.
func (sr *ShellRunner) Kill() {
sr.mu.Lock()
if sr.isStopped {
sr.mu.Unlock()
return
}
sr.isStopped = true
close(sr.cmdQueue) // Prevent further commands
close(sr.stopChan)
// Terminate all active commands
for cmd := range sr.activeCmds {
@ -151,7 +155,7 @@ func (sr *ShellRunner) Kill() {
}
sr.mu.Unlock()
sr.running.Wait()
sr.Stop()
}
// KillWithTimeout attempts to stop the ShellRunner, killing commands if the duration is exceeded.
@ -171,12 +175,28 @@ func (sr *ShellRunner) KillWithTimeout(timeout time.Duration) error {
}
}
func (sr *ShellRunner) newShellCommand(command string) (*exec.Cmd, context.CancelFunc) {
var ctx context.Context
var cancel context.CancelFunc
if sr.timeout > 0 {
ctx, cancel = context.WithTimeout(context.Background(), sr.timeout)
} else {
ctx = context.Background()
}
return exec.CommandContext(ctx, sr.shell, "-c", command), cancel
}
// executeCommand runs a shell command asynchronously, capturing stdout, stderr, and return code.
func (sr *ShellRunner) executeCommand(command string) *CommandResult {
func (sr *ShellRunner) executeCommand(cmd *exec.Cmd, cancel context.CancelFunc) *CommandResult {
if cancel != nil {
defer cancel()
}
var outBuf, errBuf bytes.Buffer
ctx := context.Background()
cmd := exec.CommandContext(ctx, sr.shell, "-c", command)
cmd.Stdout = &outBuf
cmd.Stderr = &errBuf
@ -188,7 +208,7 @@ func (sr *ShellRunner) executeCommand(command string) *CommandResult {
err := cmd.Start()
if err != nil {
return &CommandResult{
Command: command,
Command: cmd.String(),
ReturnCode: -1,
ErrOutput: err.Error(),
}
@ -202,7 +222,7 @@ func (sr *ShellRunner) executeCommand(command string) *CommandResult {
sr.mu.Unlock()
result := &CommandResult{
Command: command,
Command: cmd.String(),
Output: outBuf.String(),
ErrOutput: errBuf.String(),
}

View File

@ -1,6 +1,8 @@
package tortoise_test
import (
"context"
"os/exec"
"sync"
"testing"
"time"
@ -8,6 +10,10 @@ import (
"git.iamthefij.com/iamthefij/tortoise"
)
const (
TaskStartWait = 10 * time.Millisecond
)
func TestShellRunnerNoCallback(t *testing.T) {
t.Parallel()
@ -33,10 +39,17 @@ func TestShellRunnerNoCallback(t *testing.T) {
t.Fatalf("unexpected error adding command: %v", err)
}
// Wait a sec for the worker to pick up the task
time.Sleep(TaskStartWait)
runner.Stop()
result := runner.GetResults()
if result == nil || result.Output != c.output || result.ReturnCode != c.ReturnCode {
if result == nil {
t.Fatal("expected result, got nil")
}
if result.Output != c.output || result.ReturnCode != c.ReturnCode {
t.Fatalf("expected output '%s' and return code %d, got '%s' and %d", c.output, c.ReturnCode, result.Output, result.ReturnCode)
}
})
@ -65,6 +78,9 @@ func TestShellRunnerCallback(t *testing.T) {
t.Fatalf("unexpected error adding command: %v", err)
}
// Wait a sec for the worker to pick up the task
time.Sleep(TaskStartWait)
callbackWait.Add(1)
if err := runner.AddCommand("echo callback b", func(result *tortoise.CommandResult) {
@ -91,7 +107,7 @@ func TestShellRunnerCallback(t *testing.T) {
}
if outputString != "ab" {
t.Fatal("callbacks was not reached in order:", outputString)
t.Fatal("callbacks were not reached in order:", outputString)
}
runner.Stop()
@ -119,8 +135,8 @@ func TestShellRunnerKillWithTimeout(t *testing.T) {
t.Fatalf("unexpected error adding command: %v", err)
}
// Wait one second to make sure the command starts running
time.Sleep(1 * time.Second)
// Wait a sec for the worker to pick up the task
time.Sleep(TaskStartWait)
if err := runner.KillWithTimeout(1 * time.Second); err == nil {
t.Fatal("expected error when killing commands, but got none")
@ -151,3 +167,60 @@ func TestAddingPriorToStart(t *testing.T) {
t.Fatal("Should have failed to add prior to starting runner")
}
}
func TestAddCmdWithTimeout(t *testing.T) {
t.Parallel()
runner := tortoise.NewShellRunner()
runner.Start()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
cmd := exec.CommandContext(ctx, "sleep", "10")
err := runner.AddCmd(cmd, nil, cancel)
if err != nil {
t.Fatalf("unexpected error adding command: %v", err)
}
// Wait a sec for the worker to pick up the task
time.Sleep(TaskStartWait)
runner.Stop()
result := runner.GetResults()
if result == nil {
t.Fatal("expected result, got nil")
}
if result.ReturnCode != -1 {
t.Fatalf("expected return code -1, got %d", result.ReturnCode)
}
}
func TestShellWithTimeout(t *testing.T) {
t.Parallel()
runner := tortoise.NewShellRunner()
runner.SetTimeout(1 * time.Second)
runner.Start()
err := runner.AddCommand("sleep 10", nil)
if err != nil {
t.Fatalf("unexpected error adding command: %v", err)
}
// Wait a sec for the worker to pick up the task
time.Sleep(TaskStartWait)
runner.Stop()
result := runner.GetResults()
if result == nil {
t.Fatal("expected result, got nil")
}
if result.ReturnCode != -1 {
t.Fatalf("expected return code -1, got %d", result.ReturnCode)
}
}