Add timeouts and the ability to pass in a Cmd directly
This commit is contained in:
parent
9f32c6c43f
commit
0adf15dc3f
52
main.go
52
main.go
@ -32,6 +32,7 @@ type ShellRunner struct {
|
||||
mu sync.Mutex
|
||||
isStopped bool
|
||||
activeCmds map[*exec.Cmd]struct{} // Track active commands for cancellation
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewShellRunner creates a new ShellRunner instance with the default shell (`sh`).
|
||||
@ -49,9 +50,14 @@ func NewShellRunnerWithShell(shell string) *ShellRunner {
|
||||
shell: shell,
|
||||
activeCmds: make(map[*exec.Cmd]struct{}),
|
||||
isStopped: true,
|
||||
timeout: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (sr *ShellRunner) SetTimeout(timeout time.Duration) {
|
||||
sr.timeout = timeout
|
||||
}
|
||||
|
||||
// Start begins processing shell commands asynchronously.
|
||||
func (sr *ShellRunner) Start() {
|
||||
go func() {
|
||||
@ -85,6 +91,11 @@ func (sr *ShellRunner) Start() {
|
||||
// The callback is executed asynchronously after the command has completed.
|
||||
// The order of command execution and callback invocation can be expected to be preserved.
|
||||
func (sr *ShellRunner) AddCommand(command string, callback func(*CommandResult)) error {
|
||||
cmd, cancel := sr.newShellCommand(command)
|
||||
return sr.AddCmd(cmd, callback, cancel)
|
||||
}
|
||||
|
||||
func (sr *ShellRunner) AddCmd(cmd *exec.Cmd, callback func(*CommandResult), cancel context.CancelFunc) error {
|
||||
sr.mu.Lock()
|
||||
defer sr.mu.Unlock()
|
||||
|
||||
@ -93,10 +104,12 @@ func (sr *ShellRunner) AddCommand(command string, callback func(*CommandResult))
|
||||
}
|
||||
|
||||
sr.cmdQueue <- func() *CommandResult {
|
||||
result := sr.executeCommand(command)
|
||||
result := sr.executeCommand(cmd, cancel)
|
||||
if callback != nil {
|
||||
sr.running.Add(1)
|
||||
sr.callbacks <- func() {
|
||||
callback(result)
|
||||
sr.running.Done()
|
||||
}
|
||||
}
|
||||
return result
|
||||
@ -135,15 +148,6 @@ func (sr *ShellRunner) Stop() {
|
||||
// Kill stops the ShellRunner immediately, terminating all running commands.
|
||||
func (sr *ShellRunner) Kill() {
|
||||
sr.mu.Lock()
|
||||
if sr.isStopped {
|
||||
sr.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
sr.isStopped = true
|
||||
|
||||
close(sr.cmdQueue) // Prevent further commands
|
||||
close(sr.stopChan)
|
||||
|
||||
// Terminate all active commands
|
||||
for cmd := range sr.activeCmds {
|
||||
@ -151,7 +155,7 @@ func (sr *ShellRunner) Kill() {
|
||||
}
|
||||
sr.mu.Unlock()
|
||||
|
||||
sr.running.Wait()
|
||||
sr.Stop()
|
||||
}
|
||||
|
||||
// KillWithTimeout attempts to stop the ShellRunner, killing commands if the duration is exceeded.
|
||||
@ -171,12 +175,28 @@ func (sr *ShellRunner) KillWithTimeout(timeout time.Duration) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (sr *ShellRunner) newShellCommand(command string) (*exec.Cmd, context.CancelFunc) {
|
||||
var ctx context.Context
|
||||
|
||||
var cancel context.CancelFunc
|
||||
|
||||
if sr.timeout > 0 {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), sr.timeout)
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
return exec.CommandContext(ctx, sr.shell, "-c", command), cancel
|
||||
}
|
||||
|
||||
// executeCommand runs a shell command asynchronously, capturing stdout, stderr, and return code.
|
||||
func (sr *ShellRunner) executeCommand(command string) *CommandResult {
|
||||
func (sr *ShellRunner) executeCommand(cmd *exec.Cmd, cancel context.CancelFunc) *CommandResult {
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
var outBuf, errBuf bytes.Buffer
|
||||
|
||||
ctx := context.Background()
|
||||
cmd := exec.CommandContext(ctx, sr.shell, "-c", command)
|
||||
cmd.Stdout = &outBuf
|
||||
cmd.Stderr = &errBuf
|
||||
|
||||
@ -188,7 +208,7 @@ func (sr *ShellRunner) executeCommand(command string) *CommandResult {
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
return &CommandResult{
|
||||
Command: command,
|
||||
Command: cmd.String(),
|
||||
ReturnCode: -1,
|
||||
ErrOutput: err.Error(),
|
||||
}
|
||||
@ -202,7 +222,7 @@ func (sr *ShellRunner) executeCommand(command string) *CommandResult {
|
||||
sr.mu.Unlock()
|
||||
|
||||
result := &CommandResult{
|
||||
Command: command,
|
||||
Command: cmd.String(),
|
||||
Output: outBuf.String(),
|
||||
ErrOutput: errBuf.String(),
|
||||
}
|
||||
|
81
main_test.go
81
main_test.go
@ -1,6 +1,8 @@
|
||||
package tortoise_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@ -8,6 +10,10 @@ import (
|
||||
"git.iamthefij.com/iamthefij/tortoise"
|
||||
)
|
||||
|
||||
const (
|
||||
TaskStartWait = 10 * time.Millisecond
|
||||
)
|
||||
|
||||
func TestShellRunnerNoCallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -33,10 +39,17 @@ func TestShellRunnerNoCallback(t *testing.T) {
|
||||
t.Fatalf("unexpected error adding command: %v", err)
|
||||
}
|
||||
|
||||
// Wait a sec for the worker to pick up the task
|
||||
time.Sleep(TaskStartWait)
|
||||
|
||||
runner.Stop()
|
||||
|
||||
result := runner.GetResults()
|
||||
if result == nil || result.Output != c.output || result.ReturnCode != c.ReturnCode {
|
||||
if result == nil {
|
||||
t.Fatal("expected result, got nil")
|
||||
}
|
||||
|
||||
if result.Output != c.output || result.ReturnCode != c.ReturnCode {
|
||||
t.Fatalf("expected output '%s' and return code %d, got '%s' and %d", c.output, c.ReturnCode, result.Output, result.ReturnCode)
|
||||
}
|
||||
})
|
||||
@ -65,6 +78,9 @@ func TestShellRunnerCallback(t *testing.T) {
|
||||
t.Fatalf("unexpected error adding command: %v", err)
|
||||
}
|
||||
|
||||
// Wait a sec for the worker to pick up the task
|
||||
time.Sleep(TaskStartWait)
|
||||
|
||||
callbackWait.Add(1)
|
||||
|
||||
if err := runner.AddCommand("echo callback b", func(result *tortoise.CommandResult) {
|
||||
@ -91,7 +107,7 @@ func TestShellRunnerCallback(t *testing.T) {
|
||||
}
|
||||
|
||||
if outputString != "ab" {
|
||||
t.Fatal("callbacks was not reached in order:", outputString)
|
||||
t.Fatal("callbacks were not reached in order:", outputString)
|
||||
}
|
||||
|
||||
runner.Stop()
|
||||
@ -119,8 +135,8 @@ func TestShellRunnerKillWithTimeout(t *testing.T) {
|
||||
t.Fatalf("unexpected error adding command: %v", err)
|
||||
}
|
||||
|
||||
// Wait one second to make sure the command starts running
|
||||
time.Sleep(1 * time.Second)
|
||||
// Wait a sec for the worker to pick up the task
|
||||
time.Sleep(TaskStartWait)
|
||||
|
||||
if err := runner.KillWithTimeout(1 * time.Second); err == nil {
|
||||
t.Fatal("expected error when killing commands, but got none")
|
||||
@ -151,3 +167,60 @@ func TestAddingPriorToStart(t *testing.T) {
|
||||
t.Fatal("Should have failed to add prior to starting runner")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddCmdWithTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runner := tortoise.NewShellRunner()
|
||||
runner.Start()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
|
||||
cmd := exec.CommandContext(ctx, "sleep", "10")
|
||||
|
||||
err := runner.AddCmd(cmd, nil, cancel)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error adding command: %v", err)
|
||||
}
|
||||
|
||||
// Wait a sec for the worker to pick up the task
|
||||
time.Sleep(TaskStartWait)
|
||||
|
||||
runner.Stop()
|
||||
|
||||
result := runner.GetResults()
|
||||
if result == nil {
|
||||
t.Fatal("expected result, got nil")
|
||||
}
|
||||
|
||||
if result.ReturnCode != -1 {
|
||||
t.Fatalf("expected return code -1, got %d", result.ReturnCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellWithTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runner := tortoise.NewShellRunner()
|
||||
runner.SetTimeout(1 * time.Second)
|
||||
runner.Start()
|
||||
|
||||
err := runner.AddCommand("sleep 10", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error adding command: %v", err)
|
||||
}
|
||||
|
||||
// Wait a sec for the worker to pick up the task
|
||||
time.Sleep(TaskStartWait)
|
||||
|
||||
runner.Stop()
|
||||
|
||||
result := runner.GetResults()
|
||||
if result == nil {
|
||||
t.Fatal("expected result, got nil")
|
||||
}
|
||||
|
||||
if result.ReturnCode != -1 {
|
||||
t.Fatalf("expected return code -1, got %d", result.ReturnCode)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user