Add timeouts and the ability to pass in a Cmd directly
This commit is contained in:
parent
9f32c6c43f
commit
0adf15dc3f
52
main.go
52
main.go
@ -32,6 +32,7 @@ type ShellRunner struct {
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
isStopped bool
|
isStopped bool
|
||||||
activeCmds map[*exec.Cmd]struct{} // Track active commands for cancellation
|
activeCmds map[*exec.Cmd]struct{} // Track active commands for cancellation
|
||||||
|
timeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewShellRunner creates a new ShellRunner instance with the default shell (`sh`).
|
// NewShellRunner creates a new ShellRunner instance with the default shell (`sh`).
|
||||||
@ -49,9 +50,14 @@ func NewShellRunnerWithShell(shell string) *ShellRunner {
|
|||||||
shell: shell,
|
shell: shell,
|
||||||
activeCmds: make(map[*exec.Cmd]struct{}),
|
activeCmds: make(map[*exec.Cmd]struct{}),
|
||||||
isStopped: true,
|
isStopped: true,
|
||||||
|
timeout: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sr *ShellRunner) SetTimeout(timeout time.Duration) {
|
||||||
|
sr.timeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
// Start begins processing shell commands asynchronously.
|
// Start begins processing shell commands asynchronously.
|
||||||
func (sr *ShellRunner) Start() {
|
func (sr *ShellRunner) Start() {
|
||||||
go func() {
|
go func() {
|
||||||
@ -85,6 +91,11 @@ func (sr *ShellRunner) Start() {
|
|||||||
// The callback is executed asynchronously after the command has completed.
|
// The callback is executed asynchronously after the command has completed.
|
||||||
// The order of command execution and callback invocation can be expected to be preserved.
|
// The order of command execution and callback invocation can be expected to be preserved.
|
||||||
func (sr *ShellRunner) AddCommand(command string, callback func(*CommandResult)) error {
|
func (sr *ShellRunner) AddCommand(command string, callback func(*CommandResult)) error {
|
||||||
|
cmd, cancel := sr.newShellCommand(command)
|
||||||
|
return sr.AddCmd(cmd, callback, cancel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sr *ShellRunner) AddCmd(cmd *exec.Cmd, callback func(*CommandResult), cancel context.CancelFunc) error {
|
||||||
sr.mu.Lock()
|
sr.mu.Lock()
|
||||||
defer sr.mu.Unlock()
|
defer sr.mu.Unlock()
|
||||||
|
|
||||||
@ -93,10 +104,12 @@ func (sr *ShellRunner) AddCommand(command string, callback func(*CommandResult))
|
|||||||
}
|
}
|
||||||
|
|
||||||
sr.cmdQueue <- func() *CommandResult {
|
sr.cmdQueue <- func() *CommandResult {
|
||||||
result := sr.executeCommand(command)
|
result := sr.executeCommand(cmd, cancel)
|
||||||
if callback != nil {
|
if callback != nil {
|
||||||
|
sr.running.Add(1)
|
||||||
sr.callbacks <- func() {
|
sr.callbacks <- func() {
|
||||||
callback(result)
|
callback(result)
|
||||||
|
sr.running.Done()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
@ -135,15 +148,6 @@ func (sr *ShellRunner) Stop() {
|
|||||||
// Kill stops the ShellRunner immediately, terminating all running commands.
|
// Kill stops the ShellRunner immediately, terminating all running commands.
|
||||||
func (sr *ShellRunner) Kill() {
|
func (sr *ShellRunner) Kill() {
|
||||||
sr.mu.Lock()
|
sr.mu.Lock()
|
||||||
if sr.isStopped {
|
|
||||||
sr.mu.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sr.isStopped = true
|
|
||||||
|
|
||||||
close(sr.cmdQueue) // Prevent further commands
|
|
||||||
close(sr.stopChan)
|
|
||||||
|
|
||||||
// Terminate all active commands
|
// Terminate all active commands
|
||||||
for cmd := range sr.activeCmds {
|
for cmd := range sr.activeCmds {
|
||||||
@ -151,7 +155,7 @@ func (sr *ShellRunner) Kill() {
|
|||||||
}
|
}
|
||||||
sr.mu.Unlock()
|
sr.mu.Unlock()
|
||||||
|
|
||||||
sr.running.Wait()
|
sr.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
// KillWithTimeout attempts to stop the ShellRunner, killing commands if the duration is exceeded.
|
// KillWithTimeout attempts to stop the ShellRunner, killing commands if the duration is exceeded.
|
||||||
@ -171,12 +175,28 @@ func (sr *ShellRunner) KillWithTimeout(timeout time.Duration) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sr *ShellRunner) newShellCommand(command string) (*exec.Cmd, context.CancelFunc) {
|
||||||
|
var ctx context.Context
|
||||||
|
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
|
||||||
|
if sr.timeout > 0 {
|
||||||
|
ctx, cancel = context.WithTimeout(context.Background(), sr.timeout)
|
||||||
|
} else {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
return exec.CommandContext(ctx, sr.shell, "-c", command), cancel
|
||||||
|
}
|
||||||
|
|
||||||
// executeCommand runs a shell command asynchronously, capturing stdout, stderr, and return code.
|
// executeCommand runs a shell command asynchronously, capturing stdout, stderr, and return code.
|
||||||
func (sr *ShellRunner) executeCommand(command string) *CommandResult {
|
func (sr *ShellRunner) executeCommand(cmd *exec.Cmd, cancel context.CancelFunc) *CommandResult {
|
||||||
|
if cancel != nil {
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
var outBuf, errBuf bytes.Buffer
|
var outBuf, errBuf bytes.Buffer
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
cmd := exec.CommandContext(ctx, sr.shell, "-c", command)
|
|
||||||
cmd.Stdout = &outBuf
|
cmd.Stdout = &outBuf
|
||||||
cmd.Stderr = &errBuf
|
cmd.Stderr = &errBuf
|
||||||
|
|
||||||
@ -188,7 +208,7 @@ func (sr *ShellRunner) executeCommand(command string) *CommandResult {
|
|||||||
err := cmd.Start()
|
err := cmd.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &CommandResult{
|
return &CommandResult{
|
||||||
Command: command,
|
Command: cmd.String(),
|
||||||
ReturnCode: -1,
|
ReturnCode: -1,
|
||||||
ErrOutput: err.Error(),
|
ErrOutput: err.Error(),
|
||||||
}
|
}
|
||||||
@ -202,7 +222,7 @@ func (sr *ShellRunner) executeCommand(command string) *CommandResult {
|
|||||||
sr.mu.Unlock()
|
sr.mu.Unlock()
|
||||||
|
|
||||||
result := &CommandResult{
|
result := &CommandResult{
|
||||||
Command: command,
|
Command: cmd.String(),
|
||||||
Output: outBuf.String(),
|
Output: outBuf.String(),
|
||||||
ErrOutput: errBuf.String(),
|
ErrOutput: errBuf.String(),
|
||||||
}
|
}
|
||||||
|
81
main_test.go
81
main_test.go
@ -1,6 +1,8 @@
|
|||||||
package tortoise_test
|
package tortoise_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"os/exec"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -8,6 +10,10 @@ import (
|
|||||||
"git.iamthefij.com/iamthefij/tortoise"
|
"git.iamthefij.com/iamthefij/tortoise"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TaskStartWait = 10 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
func TestShellRunnerNoCallback(t *testing.T) {
|
func TestShellRunnerNoCallback(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@ -33,10 +39,17 @@ func TestShellRunnerNoCallback(t *testing.T) {
|
|||||||
t.Fatalf("unexpected error adding command: %v", err)
|
t.Fatalf("unexpected error adding command: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wait a sec for the worker to pick up the task
|
||||||
|
time.Sleep(TaskStartWait)
|
||||||
|
|
||||||
runner.Stop()
|
runner.Stop()
|
||||||
|
|
||||||
result := runner.GetResults()
|
result := runner.GetResults()
|
||||||
if result == nil || result.Output != c.output || result.ReturnCode != c.ReturnCode {
|
if result == nil {
|
||||||
|
t.Fatal("expected result, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Output != c.output || result.ReturnCode != c.ReturnCode {
|
||||||
t.Fatalf("expected output '%s' and return code %d, got '%s' and %d", c.output, c.ReturnCode, result.Output, result.ReturnCode)
|
t.Fatalf("expected output '%s' and return code %d, got '%s' and %d", c.output, c.ReturnCode, result.Output, result.ReturnCode)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -65,6 +78,9 @@ func TestShellRunnerCallback(t *testing.T) {
|
|||||||
t.Fatalf("unexpected error adding command: %v", err)
|
t.Fatalf("unexpected error adding command: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wait a sec for the worker to pick up the task
|
||||||
|
time.Sleep(TaskStartWait)
|
||||||
|
|
||||||
callbackWait.Add(1)
|
callbackWait.Add(1)
|
||||||
|
|
||||||
if err := runner.AddCommand("echo callback b", func(result *tortoise.CommandResult) {
|
if err := runner.AddCommand("echo callback b", func(result *tortoise.CommandResult) {
|
||||||
@ -91,7 +107,7 @@ func TestShellRunnerCallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if outputString != "ab" {
|
if outputString != "ab" {
|
||||||
t.Fatal("callbacks was not reached in order:", outputString)
|
t.Fatal("callbacks were not reached in order:", outputString)
|
||||||
}
|
}
|
||||||
|
|
||||||
runner.Stop()
|
runner.Stop()
|
||||||
@ -119,8 +135,8 @@ func TestShellRunnerKillWithTimeout(t *testing.T) {
|
|||||||
t.Fatalf("unexpected error adding command: %v", err)
|
t.Fatalf("unexpected error adding command: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait one second to make sure the command starts running
|
// Wait a sec for the worker to pick up the task
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(TaskStartWait)
|
||||||
|
|
||||||
if err := runner.KillWithTimeout(1 * time.Second); err == nil {
|
if err := runner.KillWithTimeout(1 * time.Second); err == nil {
|
||||||
t.Fatal("expected error when killing commands, but got none")
|
t.Fatal("expected error when killing commands, but got none")
|
||||||
@ -151,3 +167,60 @@ func TestAddingPriorToStart(t *testing.T) {
|
|||||||
t.Fatal("Should have failed to add prior to starting runner")
|
t.Fatal("Should have failed to add prior to starting runner")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAddCmdWithTimeout(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
runner := tortoise.NewShellRunner()
|
||||||
|
runner.Start()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "sleep", "10")
|
||||||
|
|
||||||
|
err := runner.AddCmd(cmd, nil, cancel)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error adding command: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait a sec for the worker to pick up the task
|
||||||
|
time.Sleep(TaskStartWait)
|
||||||
|
|
||||||
|
runner.Stop()
|
||||||
|
|
||||||
|
result := runner.GetResults()
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("expected result, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.ReturnCode != -1 {
|
||||||
|
t.Fatalf("expected return code -1, got %d", result.ReturnCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShellWithTimeout(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
runner := tortoise.NewShellRunner()
|
||||||
|
runner.SetTimeout(1 * time.Second)
|
||||||
|
runner.Start()
|
||||||
|
|
||||||
|
err := runner.AddCommand("sleep 10", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error adding command: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait a sec for the worker to pick up the task
|
||||||
|
time.Sleep(TaskStartWait)
|
||||||
|
|
||||||
|
runner.Stop()
|
||||||
|
|
||||||
|
result := runner.GetResults()
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("expected result, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.ReturnCode != -1 {
|
||||||
|
t.Fatalf("expected return code -1, got %d", result.ReturnCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user