From 808030099fceb7a6fe1714c3dccc8bcb929652f1 Mon Sep 17 00:00:00 2001 From: Ian Fijolek Date: Wed, 23 Feb 2022 14:13:00 -0800 Subject: [PATCH] Add validation and a lot more testing --- .golangci.yml | 1 + go.mod | 1 + go.sum | 2 + job.go | 316 +++++++++++++++++++++++++++++++------- job_test.go | 175 +++++++++++++++++++++ main.go | 2 +- restic.go | 76 +++++---- restic_test.go | 101 ++++++++++++ shell.go | 8 +- test.hcl => test/test.hcl | 0 utils.go | 13 +- utils_test.go | 81 ++++++++++ 12 files changed, 671 insertions(+), 105 deletions(-) create mode 100644 job_test.go rename test.hcl => test/test.hcl (100%) create mode 100644 utils_test.go diff --git a/.golangci.yml b/.golangci.yml index 90c2bb3..7374c0d 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -107,5 +107,6 @@ issues: linters: - errcheck - gosec + - funlen # Enable autofix fix: true diff --git a/go.mod b/go.mod index 7ce04fd..cc9f31c 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect github.com/google/go-cmp v0.3.1 // indirect github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 // indirect + github.com/robfig/cron/v3 v3.0.1 // indirect github.com/zclconf/go-cty v1.8.0 // indirect golang.org/x/text v0.3.5 // indirect ) diff --git a/go.sum b/go.sum index cc4e4e8..f806fe7 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,8 @@ github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LE github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 h1:DpOJ2HYzCv8LZP15IdmG+YdwD2luVPHITV96TkirNBM= github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7/go.mod h1:ZXFpozHsX6DPmq2I0TCekCxypsnAUbP2oI0UX1GXzOo= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/spf13/pflag v1.0.2/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= diff --git a/job.go b/job.go index 912dc82..8cf136d 100644 --- a/job.go +++ b/job.go @@ -1,14 +1,25 @@ package main import ( + "errors" "fmt" "log" "os" "path/filepath" + "strings" + + "github.com/robfig/cron/v3" ) const WorkDirPerms = 0o666 +var ( + ErrNoJobsFound = errors.New("no jobs found and at least one job is required") + ErrMissingField = errors.New("missing config field") + ErrMissingBlock = errors.New("missing config block") + ErrMutuallyExclusive = errors.New("mutually exclusive values not valid") +) + type TaskConfig struct { JobDir string Env map[string]string @@ -24,6 +35,24 @@ type ResticConfig struct { GlobalOpts *ResticGlobalOpts `hcl:"options,block"` } +func (r ResticConfig) Validate() error { + if r.Passphrase == "" && (r.GlobalOpts == nil || r.GlobalOpts.PasswordFile == "") { + return fmt.Errorf( + "either config { Passphrase = string } or config { options { PasswordFile = string } } must be set: %w", + ErrMutuallyExclusive, + ) + } + + if r.Passphrase != "" && r.GlobalOpts != nil && r.GlobalOpts.PasswordFile != "" { + return fmt.Errorf( + "only one of config { Passphrase = string } or config { options { PasswordFile = string } } may be set: %w", + ErrMutuallyExclusive, + ) + } + + return nil +} + // ExecutableTask is a task to be run before or after backup/retore. type ExecutableTask interface { RunBackup(cfg TaskConfig) error @@ -40,13 +69,17 @@ type JobTaskScript struct { name string } -// RunBackup runs script on backup. -func (t JobTaskScript) RunBackup(cfg TaskConfig) error { - env := MergeEnv(cfg.Env, t.env) +func (t JobTaskScript) run(script string, cfg TaskConfig) error { + if script == "" { + return nil + } + + env := MergeEnvMap(cfg.Env, t.env) if env == nil { env = map[string]string{} } + // Inject the job directory to the running task env["RESTIC_JOB_DIR"] = cfg.JobDir cwd := "" @@ -54,32 +87,21 @@ func (t JobTaskScript) RunBackup(cfg TaskConfig) error { cwd = cfg.JobDir } - if err := RunShell(t.OnBackup, cwd, env, cfg.Logger); err != nil { + if err := RunShell(script, cwd, env, cfg.Logger); err != nil { return fmt.Errorf("failed running task script %s: %w", t.Name(), err) } return nil } +// RunBackup runs script on backup. +func (t JobTaskScript) RunBackup(cfg TaskConfig) error { + return t.run(t.OnBackup, cfg) +} + // RunRestore script on restore. func (t JobTaskScript) RunRestore(cfg TaskConfig) error { - env := MergeEnv(cfg.Env, t.env) - if env == nil { - env = map[string]string{} - } - - env["RESTIC_JOB_DIR"] = cfg.JobDir - - cwd := "" - if t.FromJobDir { - cwd = cfg.JobDir - } - - if err := RunShell(t.OnRestore, cwd, env, cfg.Logger); err != nil { - return fmt.Errorf("failed running task script %s: %w", t.Name(), err) - } - - return nil + return t.run(t.OnRestore, cfg) } func (t JobTaskScript) Name() string { @@ -92,43 +114,82 @@ func (t *JobTaskScript) SetName(name string) { // JobTaskMySQL is a sqlite backup task that performs required pre and post tasks. type JobTaskMySQL struct { - Name string `hcl:"name,label"` - Hostname string `hcl:"hostname,optional"` - Database string `hcl:"database,optional"` - Username string `hcl:"username,optional"` - Password string `hcl:"password,optional"` + Name string `hcl:"name,label"` + Hostname string `hcl:"hostname,optional"` + Database string `hcl:"database,optional"` + Username string `hcl:"username,optional"` + Password string `hcl:"password,optional"` + Tables []string `hcl:"tables,optional"` +} + +func (t JobTaskMySQL) Filename() string { + return fmt.Sprintf("%s.sql", t.Name) +} + +func (t JobTaskMySQL) Validate() error { + if invalidChars := "'\";"; strings.ContainsAny(t.Name, invalidChars) { + return fmt.Errorf("mysql task %s has an invalid name. The name may not contain %s", t.Name, invalidChars) + } + + if len(t.Tables) > 0 && t.Database == "" { + return fmt.Errorf("mysql task %s is invalid. Must specify a database to use tables: %w", t.Name, ErrMissingField) + } + + return nil } func (t JobTaskMySQL) GetPreTask() ExecutableTask { + command := []string{"mysqldump", "--result-file", fmt.Sprintf("'./%s'", t.Filename())} + + if t.Hostname != "" { + command = append(command, "--host", t.Hostname) + } + + if t.Username != "" { + command = append(command, "--user", t.Username) + } + + if t.Password != "" { + command = append(command, "--password", t.Password) + } + + if t.Database != "" { + command = append(command, t.Database) + } + + command = append(command, t.Tables...) + return JobTaskScript{ - name: t.Name, - env: nil, - OnBackup: fmt.Sprintf( - "mysqldump -h '%s' -u '%s' -p '%s' '%s' > './%s.sql'", - t.Hostname, - t.Username, - t.Password, - t.Database, - t.Name, - ), + name: t.Name, + env: nil, + OnBackup: strings.Join(command, " "), OnRestore: "", FromJobDir: true, } } func (t JobTaskMySQL) GetPostTask() ExecutableTask { + command := []string{"mysql"} + + if t.Hostname != "" { + command = append(command, "--host", t.Hostname) + } + + if t.Username != "" { + command = append(command, "--user", t.Username) + } + + if t.Password != "" { + command = append(command, "--password", t.Password) + } + + command = append(command, "<", fmt.Sprintf("'./%s'", t.Filename())) + return JobTaskScript{ - name: t.Name, - env: nil, - OnBackup: "", - OnRestore: fmt.Sprintf( - "mysql -h '%s' -u '%s' -p '%s' '%s' << './%s.sql'", - t.Hostname, - t.Username, - t.Password, - t.Database, - t.Name, - ), + name: t.Name, + env: nil, + OnBackup: "", + OnRestore: strings.Join(command, " "), FromJobDir: true, } } @@ -139,13 +200,25 @@ type JobTaskSqlite struct { Path string `hcl:"path"` } +func (t JobTaskSqlite) Filename() string { + return fmt.Sprintf("%s.db.bak", t.Name) +} + +func (t JobTaskSqlite) Validate() error { + if invalidChars := "'\";"; strings.ContainsAny(t.Name, invalidChars) { + return fmt.Errorf("sqlite task %s has an invalid name. The name may not contain %s", t.Name, invalidChars) + } + + return nil +} + func (t JobTaskSqlite) GetPreTask() ExecutableTask { return JobTaskScript{ name: t.Name, env: nil, OnBackup: fmt.Sprintf( - "sqlite3 %s '.backup $RESTIC_JOB_DIR/%s.bak'", - t.Path, t.Name, + "sqlite3 %s '.backup $RESTIC_JOB_DIR/%s'", + t.Path, t.Filename(), ), OnRestore: "", FromJobDir: false, @@ -157,7 +230,7 @@ func (t JobTaskSqlite) GetPostTask() ExecutableTask { name: t.Name, env: nil, OnBackup: "", - OnRestore: fmt.Sprintf("cp '$RESTIC_JOB_DIR/%s.bak' '%s'", t.Name, t.Path), + OnRestore: fmt.Sprintf("cp '$RESTIC_JOB_DIR/%s' '%s'", t.Filename(), t.Path), FromJobDir: false, } } @@ -170,7 +243,11 @@ type BackupFilesTask struct { } func (t BackupFilesTask) RunBackup(cfg TaskConfig) error { - if err := cfg.Restic.Backup(t.Files, t.BackupOpts); err != nil { + if t.BackupOpts == nil { + t.BackupOpts = &BackupOpts{} // nolint:exhaustivestruct + } + + if err := cfg.Restic.Backup(t.Files, *t.BackupOpts); err != nil { err = fmt.Errorf("failed backing up files: %w", err) cfg.Logger.Fatal(err) @@ -181,7 +258,11 @@ func (t BackupFilesTask) RunBackup(cfg TaskConfig) error { } func (t BackupFilesTask) RunRestore(cfg TaskConfig) error { - if err := cfg.Restic.Restore("latest", t.RestoreOpts); err != nil { + if t.RestoreOpts == nil { + t.RestoreOpts = &RestoreOpts{} // nolint:exhaustivestruct + } + + if err := cfg.Restic.Restore("latest", *t.RestoreOpts); err != nil { err = fmt.Errorf("failed restoring files: %w", err) cfg.Logger.Fatal(err) @@ -206,6 +287,26 @@ type JobTask struct { Backup *BackupFilesTask `hcl:"backup,block"` } +func (t JobTask) Validate() error { + if len(t.Scripts) > 0 && t.Backup != nil { + return fmt.Errorf( + "task %s is invalid. script and backup blocks are mutually exclusive: %w", + t.Name, + ErrMutuallyExclusive, + ) + } + + if len(t.Scripts) == 0 && t.Backup == nil { + return fmt.Errorf( + "task %s is invalid. Ether script or backup blocks must be provided: %w", + t.Name, + ErrMutuallyExclusive, + ) + } + + return nil +} + func (t JobTask) GetTasks() []ExecutableTask { allTasks := []ExecutableTask{} @@ -229,7 +330,6 @@ type Job struct { Schedule string `hcl:"schedule"` Config ResticConfig `hcl:"config,block"` Tasks []JobTask `hcl:"task,block"` - Validate bool `hcl:"validate,optional"` Forget *ForgetOpts `hcl:"forget,block"` // Meta Tasks @@ -237,6 +337,62 @@ type Job struct { Sqlite []JobTaskSqlite `hcl:"sqlite,block"` } +func (j Job) validateTasks() error { + if len(j.Tasks) == 0 { + return fmt.Errorf("job %s is missing tasks: %w", j.Name, ErrMissingBlock) + } + + foundBackup := false + + for _, task := range j.Tasks { + if task.Backup != nil { + foundBackup = true + } + + if err := task.Validate(); err != nil { + return fmt.Errorf("job %s has an inavalid task: %w", j.Name, err) + } + } + + if !foundBackup { + return fmt.Errorf("job %s is missing a backup task: %w", j.Name, ErrMissingBlock) + } + + return nil +} + +func (j Job) Validate() error { + if j.Name == "" { + return fmt.Errorf("job is missing name: %w", ErrMissingField) + } + + if _, err := cron.ParseStandard(j.Schedule); err != nil { + return fmt.Errorf("job %s has an invalid schedule: %w", j.Name, err) + } + + if err := j.Config.Validate(); err != nil { + return fmt.Errorf("job %s has invalid config: %w", j.Name, err) + } + + if err := j.validateTasks(); err != nil { + return err + } + + for _, mysql := range j.MySQL { + if err := mysql.Validate(); err != nil { + return fmt.Errorf("job %s has an inavalid task: %w", j.Name, err) + } + } + + for _, sqlite := range j.Sqlite { + if err := sqlite.Validate(); err != nil { + return fmt.Errorf("job %s has an inavalid task: %w", j.Name, err) + } + } + + return nil +} + func (j Job) AllTasks() []ExecutableTask { allTasks := []ExecutableTask{} @@ -273,7 +429,18 @@ func (j Job) JobDir() string { return cwd } -func (j Job) RunTasks() error { +/* + * func NewTaskConfig(jobDir string, jobLogger *log.Logger, restic *ResticCmd, taskName string) TaskConfig { + * return TaskConfig{ + * JobDir: jobDir, + * Logger: GetChildLogger(jobLogger, taskName), + * Restic: restic, + * Env: nil, + * } + * } + */ + +func (j Job) RunBackup() error { logger := GetLogger(j.Name) restic := j.NewRestic() jobDir := j.JobDir() @@ -296,7 +463,7 @@ func (j Job) RunTasks() error { } if j.Forget != nil { - if err := restic.Forget(j.Forget); err != nil { + if err := restic.Forget(*j.Forget); err != nil { return fmt.Errorf("failed forgetting and pruning job %s: %w", j.Name, err) } } @@ -304,6 +471,31 @@ func (j Job) RunTasks() error { return nil } +func (j Job) RunRestore() error { + logger := GetLogger(j.Name) + restic := j.NewRestic() + jobDir := j.JobDir() + + if err := restic.RunRestic("snapshots", NoOpts{}); err != nil { + return fmt.Errorf("no repository or snapshots for job %s: %w", j.Name, err) + } + + for _, exTask := range j.AllTasks() { + taskCfg := TaskConfig{ + JobDir: jobDir, + Logger: GetChildLogger(logger, exTask.Name()), + Restic: restic, + Env: nil, + } + + if err := exTask.RunRestore(taskCfg); err != nil { + return fmt.Errorf("failed running job %s: %w", j.Name, err) + } + } + + return nil +} + func (j Job) NewRestic() *ResticCmd { return &ResticCmd{ Logger: GetLogger(j.Name), @@ -319,6 +511,20 @@ type Config struct { Jobs []Job `hcl:"job,block"` } +func (c Config) Validate() error { + if len(c.Jobs) == 0 { + return ErrNoJobsFound + } + + for _, job := range c.Jobs { + if err := job.Validate(); err != nil { + return err + } + } + + return nil +} + /*** job "My App" { diff --git a/job_test.go b/job_test.go new file mode 100644 index 0000000..8875484 --- /dev/null +++ b/job_test.go @@ -0,0 +1,175 @@ +package main_test + +import ( + "bytes" + "errors" + "log" + "testing" + + main "git.iamthefij.com/iamthefij/restic-scheduler" +) + +func TestResticConfigValidate(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + config main.ResticConfig + expectedErr error + }{ + { + name: "missing passphrase", + expectedErr: main.ErrMutuallyExclusive, + config: main.ResticConfig{}, // nolint:exhaustivestruct + }, + { + name: "passphrase no file", + expectedErr: nil, + // nolint:exhaustivestruct + config: main.ResticConfig{ + Passphrase: "shh", + }, + }, + { + name: "file no passphrase", + expectedErr: nil, + // nolint:exhaustivestruct + config: main.ResticConfig{ + GlobalOpts: &main.ResticGlobalOpts{ + PasswordFile: "file", + }, + }, + }, + { + name: "file and passphrase", + expectedErr: main.ErrMutuallyExclusive, + // nolint:exhaustivestruct + config: main.ResticConfig{ + Passphrase: "shh", + GlobalOpts: &main.ResticGlobalOpts{ + PasswordFile: "file", + }, + }, + }, + } + + for _, c := range cases { + testCase := c + + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + actual := testCase.config.Validate() + + if !errors.Is(actual, testCase.expectedErr) { + t.Errorf("expected error to wrap %v but found %v", testCase.expectedErr, actual) + } + }) + } +} + +func NewBufferedLogger(prefix string) (*bytes.Buffer, *log.Logger) { + outputBuffer := bytes.Buffer{} + logger := log.New(&outputBuffer, prefix, 0) + + return &outputBuffer, logger +} + +func TestJobTaskScript(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + script main.JobTaskScript + config main.TaskConfig + expectedErr error + expectedOutput string + }{ + { + name: "simple", + config: main.TaskConfig{ + JobDir: "./test", + Env: nil, + Logger: nil, + Restic: nil, + }, + script: main.JobTaskScript{ + OnBackup: "echo yass", + OnRestore: "echo yass", + FromJobDir: false, + }, + expectedErr: nil, + expectedOutput: "t yass\nt \n", + }, + { + name: "check job dir", + config: main.TaskConfig{ + JobDir: "./test", + Env: nil, + Logger: nil, + Restic: nil, + }, + script: main.JobTaskScript{ + OnBackup: "echo $RESTIC_JOB_DIR", + OnRestore: "echo $RESTIC_JOB_DIR", + FromJobDir: false, + }, + expectedErr: nil, + expectedOutput: "t ./test\nt \n", + }, + { + name: "check from job dir", + config: main.TaskConfig{ + JobDir: "./test", + Env: nil, + Logger: nil, + Restic: nil, + }, + script: main.JobTaskScript{ + OnBackup: "basename `pwd`", + OnRestore: "basename `pwd`", + FromJobDir: true, + }, + expectedErr: nil, + expectedOutput: "t test\nt \n", + }, + { + name: "check env", + config: main.TaskConfig{ + JobDir: "./test", + Env: map[string]string{"TEST": "OK"}, + Logger: nil, + Restic: nil, + }, + script: main.JobTaskScript{ + OnBackup: "echo $TEST", + OnRestore: "echo $TEST", + FromJobDir: false, + }, + expectedErr: nil, + expectedOutput: "t OK\nt \n", + }, + } + + for _, c := range cases { + testCase := c + + buf, logger := NewBufferedLogger("t") + testCase.config.Logger = logger + + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + actual := testCase.script.RunBackup(testCase.config) + + if !errors.Is(actual, testCase.expectedErr) { + t.Errorf("expected error to wrap %v but found %v", testCase.expectedErr, actual) + } + + output := buf.String() + + if testCase.expectedOutput != output { + t.Errorf("Unexpected output. expected: %s actual: %s", testCase.expectedOutput, output) + } + }) + } +} diff --git a/main.go b/main.go index 71ff109..9391909 100644 --- a/main.go +++ b/main.go @@ -42,7 +42,7 @@ func main() { } for _, job := range config.Jobs { - if err := job.RunTasks(); err != nil { + if err := job.RunBackup(); err != nil { log.Fatalf("%v", err) } } diff --git a/restic.go b/restic.go index 717bd8b..82ad4c8 100644 --- a/restic.go +++ b/restic.go @@ -122,7 +122,7 @@ func (rcmd ResticCmd) RunRestic(command string, options CommandOptions, commandA cmd.Dir = rcmd.Cwd if err := cmd.Run(); err != nil { - return fmt.Errorf("error running restic: %w", err) + return fmt.Errorf("error running restic %s: %w", command, err) } return nil @@ -136,6 +136,14 @@ type BackupOpts struct { } func (bo BackupOpts) ToArgs() (args []string) { + for _, exclude := range bo.Exclude { + args = append(args, "--exclude", exclude) + } + + for _, include := range bo.Include { + args = append(args, "--include", include) + } + for _, tag := range bo.Tags { args = append(args, "--tag", tag) } @@ -147,22 +155,16 @@ func (bo BackupOpts) ToArgs() (args []string) { return } -func (rcmd ResticCmd) Backup(files []string, options *BackupOpts) error { - if options == nil { - options = &BackupOpts{} // nolint:exhaustivestruct - } - - err := rcmd.RunRestic("backup", options, files...) - - return err +func (rcmd ResticCmd) Backup(files []string, opts BackupOpts) error { + return rcmd.RunRestic("backup", opts, files...) } type RestoreOpts struct { Exclude []string `hcl:"Exclude,optional"` - Host []string `hcl:"Host,optional"` Include []string `hcl:"Include,optional"` - Path string `hcl:"Path,optional"` + Host []string `hcl:"Host,optional"` Tags []string `hcl:"Tags,optional"` + Path string `hcl:"Path,optional"` Target string `hcl:"Target,optional"` Verify bool `hcl:"Verify,optional"` } @@ -180,14 +182,14 @@ func (ro RestoreOpts) ToArgs() (args []string) { args = append(args, "--host", host) } - if ro.Path != "" { - args = append(args, "--path", ro.Path) - } - for _, tag := range ro.Tags { args = append(args, "--tag", tag) } + if ro.Path != "" { + args = append(args, "--path", ro.Path) + } + if ro.Target != "" { args = append(args, "--target", ro.Target) } @@ -199,14 +201,14 @@ func (ro RestoreOpts) ToArgs() (args []string) { return } -func (rcmd ResticCmd) Restore(snapshot string, opts *RestoreOpts) error { - if opts == nil { - opts = &RestoreOpts{} // nolint:exhaustivestruct - } +func (rcmd ResticCmd) Restore(snapshot string, opts RestoreOpts) error { + return rcmd.RunRestic("restore", opts, snapshot) +} - err := rcmd.RunRestic("restore", opts, snapshot) +type TagList []string - return err +func (t TagList) String() string { + return strings.Join(t, ",") } type ForgetOpts struct { @@ -224,8 +226,8 @@ type ForgetOpts struct { KeepWithinMonthly time.Duration `hcl:"KeepWithinMonthly,optional"` KeepWithinYearly time.Duration `hcl:"KeepWithinYearly,optional"` - Tags []string `hcl:"Tags,optional"` - KeepTags []string `hcl:"KeepTags,optional"` + Tags []TagList `hcl:"Tags,optional"` + KeepTags []TagList `hcl:"KeepTags,optional"` Prune bool `hcl:"Prune,optional"` } @@ -257,11 +259,12 @@ func (fo ForgetOpts) ToArgs() (args []string) { args = append(args, "--keep-yearly", fmt.Sprint(fo.KeepYearly)) } + // Add keep-within-* + if fo.KeepWithin > 0 { - args = append(args, "--keep-within", fmt.Sprint(fo.KeepWithin)) + args = append(args, "--keep-within", fo.KeepWithin.String()) } - // Add keep-within-* if fo.KeepWithinHourly > 0 { args = append(args, "--keep-within-hourly", fo.KeepWithinHourly.String()) } @@ -283,16 +286,15 @@ func (fo ForgetOpts) ToArgs() (args []string) { } // Add tags - if len(fo.Tags) > 0 { - args = append(args, "--tag", strings.Join(fo.Tags, ",")) + for _, tagList := range fo.Tags { + args = append(args, "--tag", tagList.String()) } - if len(fo.KeepTags) > 0 { - args = append(args, "--keep-tag", strings.Join(fo.Tags, ",")) + for _, tagList := range fo.KeepTags { + args = append(args, "--keep-tag", tagList.String()) } // Add prune options - if fo.Prune { args = append(args, "--prune") } @@ -300,20 +302,12 @@ func (fo ForgetOpts) ToArgs() (args []string) { return args } -func (rcmd ResticCmd) Forget(forgetOpts *ForgetOpts) error { - if forgetOpts == nil { - forgetOpts = &ForgetOpts{} // nolint:exhaustivestruct - } - - err := rcmd.RunRestic("forget", forgetOpts) - - return err +func (rcmd ResticCmd) Forget(forgetOpts ForgetOpts) error { + return rcmd.RunRestic("forget", forgetOpts) } func (rcmd ResticCmd) Check() error { - err := rcmd.RunRestic("check", NoOpts{}) - - return err + return rcmd.RunRestic("check", NoOpts{}) } func (rcmd ResticCmd) EnsureInit() error { diff --git a/restic_test.go b/restic_test.go index 13f251a..0a3a0d4 100644 --- a/restic_test.go +++ b/restic_test.go @@ -3,6 +3,7 @@ package main_test import ( "os" "testing" + "time" main "git.iamthefij.com/iamthefij/restic-scheduler" "github.com/go-test/deep" @@ -42,6 +43,106 @@ func TestGlobalOptions(t *testing.T) { } } +func TestBackupOpts(t *testing.T) { + t.Parallel() + + args := main.BackupOpts{ + Exclude: []string{"file1", "file2"}, + Include: []string{"directory"}, + Tags: []string{"thing"}, + Host: "steve", + }.ToArgs() + + expected := []string{ + "--exclude", "file1", + "--exclude", "file2", + "--include", "directory", + "--tag", "thing", + "--host", "steve", + } + + if diff := deep.Equal(args, expected); diff != nil { + t.Errorf("args didn't match %v", diff) + } +} + +func TestRestoreOpts(t *testing.T) { + t.Parallel() + + args := main.RestoreOpts{ + Exclude: []string{"file1", "file2"}, + Include: []string{"directory"}, + Host: []string{"steve"}, + Tags: []string{"thing"}, + Path: "directory", + Target: "directory", + Verify: true, + }.ToArgs() + + expected := []string{ + "--exclude", "file1", + "--exclude", "file2", + "--include", "directory", + "--host", "steve", + "--tag", "thing", + "--path", "directory", + "--target", "directory", + "--verify", + } + + if diff := deep.Equal(args, expected); diff != nil { + t.Errorf("args didn't match %v", diff) + } +} + +func TestForgetOpts(t *testing.T) { + t.Parallel() + + args := main.ForgetOpts{ + KeepLast: 1, + KeepHourly: 1, + KeepDaily: 1, + KeepWeekly: 1, + KeepMonthly: 1, + KeepYearly: 1, + KeepWithin: 1 * time.Second, + KeepWithinHourly: 1 * time.Second, + KeepWithinDaily: 1 * time.Second, + KeepWithinWeekly: 1 * time.Second, + KeepWithinMonthly: 1 * time.Second, + KeepWithinYearly: 1 * time.Second, + Tags: []main.TagList{ + {"thing1", "thing2"}, + {"otherthing"}, + }, + KeepTags: []main.TagList{{"thing"}}, + Prune: true, + }.ToArgs() + + expected := []string{ + "--keep-last", "1", + "--keep-hourly", "1", + "--keep-daily", "1", + "--keep-weekly", "1", + "--keep-monthly", "1", + "--keep-yearly", "1", + "--keep-within", "1s", + "--keep-within-hourly", "1s", + "--keep-within-daily", "1s", + "--keep-within-weekly", "1s", + "--keep-within-monthly", "1s", + "--keep-within-yearly", "1s", + "--tag", "thing1,thing2", + "--tag", "otherthing", + "--keep-tag", "thing", + "--prune", + } + + if diff := deep.Equal(args, expected); diff != nil { + t.Errorf("args didn't match %v", diff) + } +} + func TestBuildEnv(t *testing.T) { t.Parallel() diff --git a/shell.go b/shell.go index 3b6c2b5..416a15d 100644 --- a/shell.go +++ b/shell.go @@ -51,8 +51,6 @@ func RunShell(script string, cwd string, env map[string]string, logger *log.Logg cmd := exec.Command("sh", "-c", strings.TrimSpace(script)) // nolint:gosec // Make both stderr and stdout go to logger - // fmt.Println("LOGGER PREFIX", logger.Prefix()) - // logger.Println("From logger") cmd.Stdout = NewLogWriter(logger) cmd.Stderr = cmd.Stdout @@ -62,11 +60,7 @@ func RunShell(script string, cwd string, env map[string]string, logger *log.Logg // Convert env to list if values provided if len(env) > 0 { envList := os.Environ() - - for name, value := range env { - envList = append(envList, fmt.Sprintf("%s=%s", name, value)) - } - + envList = append(envList, EnvMapToList(env)...) cmd.Env = envList } diff --git a/test.hcl b/test/test.hcl similarity index 100% rename from test.hcl rename to test/test.hcl diff --git a/utils.go b/utils.go index af82978..f4d9359 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,8 @@ package main -func MergeEnv(parent, child map[string]string) map[string]string { +import "fmt" + +func MergeEnvMap(parent, child map[string]string) map[string]string { result := map[string]string{} for key, value := range parent { @@ -13,3 +15,12 @@ func MergeEnv(parent, child map[string]string) map[string]string { return result } + +func EnvMapToList(envMap map[string]string) []string { + envList := []string{} + for name, value := range envMap { + envList = append(envList, fmt.Sprintf("%s=%s", name, value)) + } + + return envList +} diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 0000000..c84f7fe --- /dev/null +++ b/utils_test.go @@ -0,0 +1,81 @@ +package main_test + +import ( + "testing" + + main "git.iamthefij.com/iamthefij/restic-scheduler" + "github.com/go-test/deep" +) + +func TestMergeEnvMap(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + parent map[string]string + child map[string]string + expected map[string]string + }{ + { + name: "No child", + parent: map[string]string{ + "key": "value", + }, + child: nil, + expected: map[string]string{ + "key": "value", + }, + }, + { + name: "No parent", + parent: nil, + child: map[string]string{ + "key": "value", + }, + expected: map[string]string{ + "key": "value", + }, + }, + { + name: "Overwrite value", + parent: map[string]string{ + "key": "old", + "other": "other", + }, + child: map[string]string{ + "key": "new", + }, + expected: map[string]string{ + "key": "new", + "other": "other", + }, + }, + } + + for _, c := range cases { + c := c + t.Run(c.name, func(t *testing.T) { + t.Parallel() + actual := main.MergeEnvMap(c.parent, c.child) + if diff := deep.Equal(c.expected, actual); diff != nil { + t.Error(diff) + } + }) + } +} + +func TestEnvMapToList(t *testing.T) { + t.Parallel() + + env := map[string]string{ + "key": "value", + } + expected := []string{ + "key=value", + } + actual := main.EnvMapToList(env) + + if diff := deep.Equal(expected, actual); diff != nil { + t.Error(diff) + } +}