diff --git a/Makefile b/Makefile index bae228a..8c30008 100644 --- a/Makefile +++ b/Makefile @@ -15,5 +15,9 @@ release: go build -ldflags "${LDFLAGS}" -o bin/${BINARY} cmd/${BINARY}/main.go (cd bin && tar czf ${BINARY}-${APPVERSION}-${GOOS}-${GOARCH}.tar.gz ${BINARY}) +test: + go test -cover base/* + go test -cover terminator/* + clean: rm -rf bin diff --git a/README.md b/README.md index 920670e..f8d018d 100644 --- a/README.md +++ b/README.md @@ -65,47 +65,55 @@ Print usage: pgterminate -help ``` -# Filtering users +# Filters -`pgterminate` is able to include or exclude users from being terminated. +`pgterminate` is able to include or exclude from being terminated: +- users +- databases ## Configuration + ### List -Arguments `-include-user` or `-exclude-user` can be used multiple times for multiple users: + +The following arguments can be used called multiple times: +- `-include-user` +- `-exclude-user` +- `-include-database` +- `-exclude-database` + +Example: ``` pgterminate -include-user user1 -include-user user2 ``` -Or in configuration file: + +Or in configuration file (mind the plural form): ``` include-users: user1 user2 ``` -Same applies for `-exclude-user` (argument) and `exclude-users` (file). ### Regexes + Regexes can be configured: ``` pgterminate -include-users-regex "(user1|user2)" ``` + Or in configuration file: ``` include-users-regex: "(user1|user2)" ``` -Same applies for `-exclude-users-regex` (argument) and `exclude-users-regex` (file). +## Inclusion and exclusion priority -## Include users - -When include users list or regex is set, `pgterminate` will focus on included users only. It could terminate excluded users if any. If you want to exclude users, use exclude options only. - -## Exclude users - -When exclude users list or regex is set and no include option is set, `pgterminate` will terminate all sessions except excluded users. +Include filters are applied before exclude filters. If a user or a database is +both in the include and exclude filters, the user or database will be ignored +by `pgterminate`. # Listeners diff --git a/VERSION b/VERSION index c946ee6..7dea76e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.6 +1.0.1 diff --git a/base/config.go b/base/config.go index 2099652..34759ce 100644 --- a/base/config.go +++ b/base/config.go @@ -17,31 +17,42 @@ var AppName string // Config receives configuration options type Config struct { - mutex sync.Mutex - File string - Host string `yaml:"host"` - Port int `yaml:"port"` - User string `yaml:"user"` - Password string `yaml:"password"` - Database string `yaml:"database"` - Interval float64 `yaml:"interval"` - ConnectTimeout int `yaml:"connect-timeout"` - IdleTimeout float64 `yaml:"idle-timeout"` - ActiveTimeout float64 `yaml:"active-timeout"` - LogDestination string `yaml:"log-destination"` - LogFile string `yaml:"log-file"` - LogFormat string `yaml:"log-format"` - PidFile string `yaml:"pid-file"` - SyslogIdent string `yaml:"syslog-ident"` - SyslogFacility string `yaml:"syslog-facility"` - IncludeUsers StringFlags `yaml:"include-users"` - IncludeUsersRegex string `yaml:"include-users-regex"` - IncludeUsersRegexCompiled *regexp.Regexp - ExcludeUsers StringFlags `yaml:"exclude-users"` - ExcludeUsersRegex string `yaml:"exclude-users-regex"` - ExcludeUsersRegexCompiled *regexp.Regexp - ExcludeListeners bool `yaml:"exclude-listeners"` - Cancel bool `yaml:"cancel"` + mutex sync.Mutex + File string + Host string `yaml:"host"` + Port int `yaml:"port"` + User string `yaml:"user"` + Password string `yaml:"password"` + Database string `yaml:"database"` + SSLMode string `yaml:"sslmode"` + Interval float64 `yaml:"interval"` + ConnectTimeout int `yaml:"connect-timeout"` + IdleTimeout float64 `yaml:"idle-timeout"` + ActiveTimeout float64 `yaml:"active-timeout"` + LogDestination string `yaml:"log-destination"` + LogFile string `yaml:"log-file"` + LogFormat string `yaml:"log-format"` + PidFile string `yaml:"pid-file"` + SyslogIdent string `yaml:"syslog-ident"` + SyslogFacility string `yaml:"syslog-facility"` + IncludeUsers StringFlags `yaml:"include-users"` + IncludeUsersRegex string `yaml:"include-users-regex"` + IncludeUsersRegexCompiled *regexp.Regexp + IncludeUsersFilters []Filter + ExcludeUsers StringFlags `yaml:"exclude-users"` + ExcludeUsersRegex string `yaml:"exclude-users-regex"` + ExcludeUsersRegexCompiled *regexp.Regexp + ExcludeUsersFilters []Filter + IncludeDatabases StringFlags `yaml:"include-databases"` + IncludeDatabasesRegex string `yaml:"include-databases-regex"` + IncludeDatabasesRegexCompiled *regexp.Regexp + IncludeDatabasesFilters []Filter + ExcludeDatabases StringFlags `yaml:"exclude-databases"` + ExcludeDatabasesRegex string `yaml:"exclude-databases-regex"` + ExcludeDatabasesRegexCompiled *regexp.Regexp + ExcludeDatabasesFilters []Filter + ExcludeListeners bool `yaml:"exclude-listeners"` + Cancel bool `yaml:"cancel"` } func init() { @@ -83,6 +94,7 @@ func (c *Config) Reload() { } err := c.CompileRegexes() Panic(err) + c.CompileFilters() } // Dsn formats a connection string based on Config @@ -106,6 +118,9 @@ func (c *Config) Dsn() string { if c.ConnectTimeout != 0 { parameters = append(parameters, fmt.Sprintf("connect_timeout=%d", c.ConnectTimeout)) } + if c.SSLMode != "" { + parameters = append(parameters, fmt.Sprintf("sslmode=%s", c.SSLMode)) + } if AppName != "" { parameters = append(parameters, fmt.Sprintf("application_name=%s", AppName)) } @@ -126,9 +141,58 @@ func (c *Config) CompileRegexes() (err error) { return err } } + if c.IncludeDatabasesRegex != "" { + c.IncludeDatabasesRegexCompiled, err = regexp.Compile(c.IncludeDatabasesRegex) + if err != nil { + return err + } + } + if c.ExcludeDatabasesRegex != "" { + c.ExcludeDatabasesRegexCompiled, err = regexp.Compile(c.ExcludeDatabasesRegex) + if err != nil { + return err + } + } return nil } +// CompileFilters creates Filter objects based on patterns and compiled regexp +func (c *Config) CompileFilters() { + + c.IncludeUsersFilters = nil + if c.IncludeUsers != nil { + c.IncludeUsersFilters = append(c.IncludeUsersFilters, NewIncludeFilter(c.IncludeUsers)) + } + if c.IncludeUsersRegexCompiled != nil { + c.IncludeUsersFilters = append(c.IncludeUsersFilters, NewIncludeFilterRegex(c.IncludeUsersRegexCompiled)) + } + + c.ExcludeUsersFilters = nil + if c.ExcludeUsers != nil { + c.ExcludeUsersFilters = append(c.ExcludeUsersFilters, NewExcludeFilter(c.ExcludeUsers)) + } + if c.ExcludeUsersRegexCompiled != nil { + c.ExcludeUsersFilters = append(c.ExcludeUsersFilters, NewExcludeFilterRegex(c.ExcludeUsersRegexCompiled)) + } + + c.IncludeDatabasesFilters = nil + if c.IncludeDatabases != nil { + c.IncludeDatabasesFilters = append(c.IncludeDatabasesFilters, NewIncludeFilter(c.IncludeDatabases)) + } + if c.IncludeDatabasesRegexCompiled != nil { + c.IncludeDatabasesFilters = append(c.IncludeDatabasesFilters, NewIncludeFilterRegex(c.IncludeDatabasesRegexCompiled)) + } + + c.ExcludeDatabasesFilters = nil + if c.ExcludeDatabases != nil { + c.ExcludeDatabasesFilters = append(c.ExcludeDatabasesFilters, NewExcludeFilter(c.ExcludeDatabases)) + } + if c.ExcludeDatabasesRegexCompiled != nil { + c.ExcludeDatabasesFilters = append(c.ExcludeDatabasesFilters, NewExcludeFilterRegex(c.ExcludeDatabasesRegexCompiled)) + } + +} + // StringFlags append multiple string flags into a string slice type StringFlags []string diff --git a/base/filter.go b/base/filter.go new file mode 100644 index 0000000..7798acd --- /dev/null +++ b/base/filter.go @@ -0,0 +1,122 @@ +package base + +import ( + "fmt" + "reflect" + "regexp" +) + +// Filter interface to tell if a string should be included or not +type Filter interface { + Include(string) bool + String() string +} + +// IncludeFilter to include a string when it's included in a list of strings +type IncludeFilter struct { + patterns []string +} + +// NewIncludeFilter to create an IncludeFilter +func NewIncludeFilter(patterns []string) IncludeFilter { + return IncludeFilter{ + patterns: patterns, + } +} + +// Include returns true when a string is included in a list of patterns +// Implements the Filter interface +func (f IncludeFilter) Include(s string) bool { + // No or empty patterns must include + if f.patterns == nil || reflect.DeepEqual(f.patterns, []string{""}) { + return true + } + return InSlice(s, f.patterns) +} + +// String to pretty print an IncludeFilter +// Implements the Filter interface +func (f IncludeFilter) String() string { + return fmt.Sprintf("", f.patterns) +} + +// IncludeFilterRegex to include a string when it matches a regex +type IncludeFilterRegex struct { + regex *regexp.Regexp +} + +// NewIncludeFilterRegex to create an IncludeFilterRegex +func NewIncludeFilterRegex(regex *regexp.Regexp) IncludeFilterRegex { + return IncludeFilterRegex{ + regex: regex, + } +} + +// Include returns true when the string matches the regex +// Implements the Filter interface +func (f IncludeFilterRegex) Include(s string) bool { + if f.regex == nil || f.regex.MatchString(s) { + return true + } + return false +} + +// String to pretty print an IncludeFilterRegex +// Implements the Filter interface +func (f IncludeFilterRegex) String() string { + return fmt.Sprintf("", f.regex.String()) +} + +// ExcludeFilter to include a string when it's not included in a list of strings +type ExcludeFilter struct { + patterns []string +} + +// NewExcludeFilter to create an ExcludeFilter +func NewExcludeFilter(patterns []string) ExcludeFilter { + return ExcludeFilter{ + patterns: patterns, + } +} + +// Include returns true when the string is not included in the patterns +// Implements the Filter interface +func (f ExcludeFilter) Include(s string) bool { + return !InSlice(s, f.patterns) +} + +// String to pretty print an ExcludeFilter +// Implements the Filter interface +func (f ExcludeFilter) String() string { + return fmt.Sprintf("", f.patterns) +} + +// ExcludeFilterRegex to include a string when it doesnn't match a regex +type ExcludeFilterRegex struct { + regex *regexp.Regexp +} + +// NewExcludeFilterRegex to create an ExcludeFilterRegex +func NewExcludeFilterRegex(regex *regexp.Regexp) ExcludeFilterRegex { + return ExcludeFilterRegex{ + regex: regex, + } +} + +// Include returns true when the string doesn't match the regex +// Implements the Filter interface +func (f ExcludeFilterRegex) Include(s string) bool { + if f.regex == nil || f.regex.MatchString("") { + return true + } + if f.regex.MatchString(s) { + return false + } + return true +} + +// String to pretty print an ExcludeFilterRegex +// Implements the Filter interface +func (f ExcludeFilterRegex) String() string { + return fmt.Sprintf("", f.regex.String()) +} diff --git a/base/filter_test.go b/base/filter_test.go new file mode 100644 index 0000000..bf75eb5 --- /dev/null +++ b/base/filter_test.go @@ -0,0 +1,124 @@ +package base + +import ( + "fmt" + "regexp" + "testing" +) + +func TestIncludeFilter(t *testing.T) { + tests := []struct { + name string + value string + patterns []string + wanted bool + }{ + {"No filter", "test", nil, true}, + {"Empty filter", "test", []string{""}, true}, + {"Single pattern matching", "test", []string{"test"}, true}, + {"Multiple patterns matching", "test", []string{"test", "postgres"}, true}, + {"Single pattern with no match", "nomatch", []string{"test"}, false}, + {"Multiple patterns with no match", "nomatch", []string{"test", "postgres"}, false}, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf(tc.name), func(t *testing.T) { + f := NewIncludeFilter(tc.patterns) + + if got := f.Include(tc.value); got != tc.wanted { + t.Errorf("Included must be %t for patterns '%s'", tc.wanted, tc.patterns) + } else { + t.Logf("Included is %t for patterns '%s'", tc.wanted, tc.patterns) + } + }) + } +} + +func TestIncludeFilterRegex(t *testing.T) { + tests := []struct { + name string + value string + regex string + wanted bool + }{ + {"No filter", "test", "", true}, + {"String pattern matching", "test", "test", true}, + {"Regex patterns matching", "test", "^t(.*)$", true}, + {"String pattern with no match", "nomatch", "test", false}, + {"Regex patterns with no match", "nomatch", "^t(.*)$", false}, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf(tc.name), func(t *testing.T) { + compiledRegex, err := regexp.Compile(tc.regex) + if err != nil { + t.Fatalf("Regex '%s' doesn't compile: %v", tc.regex, err) + } + + f := NewIncludeFilterRegex(compiledRegex) + if got := f.Include(tc.value); got != tc.wanted { + t.Errorf("Included must be %t for regex '%s'", tc.wanted, tc.regex) + } else { + t.Logf("Included is %t for regex '%s'", tc.wanted, tc.regex) + } + }) + } +} + +func TestExcludeFilter(t *testing.T) { + tests := []struct { + name string + value string + patterns []string + wanted bool + }{ + {"No filter", "test", nil, true}, + {"Empty filter", "test", []string{""}, true}, + {"Single pattern matching", "test", []string{"test"}, false}, + {"Multiple patterns matching", "test", []string{"test", "postgres"}, false}, + {"Single pattern with no match", "nomatch", []string{"test"}, true}, + {"Multiple patterns with no match", "nomatch", []string{"test", "postgres"}, true}, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf(tc.name), func(t *testing.T) { + f := NewExcludeFilter(tc.patterns) + if got := f.Include(tc.value); got != tc.wanted { + t.Errorf("Included must be %t for patterns '%s'", tc.wanted, tc.patterns) + } else { + t.Logf("Included is %t for patterns '%s'", tc.wanted, tc.patterns) + } + }) + } +} + +func TestExcludeFilterRegex(t *testing.T) { + tests := []struct { + name string + value string + regex string + wanted bool + }{ + {"No filter", "test", "", true}, + {"String pattern matching", "test", "test", false}, + {"Regex patterns matching", "test", "^t(.*)$", false}, + {"String pattern with no match", "nomatch", "test", true}, + {"Regex patterns with no match", "nomatch", "^t(.*)$", true}, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf(tc.name), func(t *testing.T) { + compiledRegex, err := regexp.Compile(tc.regex) + if err != nil { + t.Fatalf("Regex '%s' doesn't compile: %v", tc.regex, err) + } + + f := NewExcludeFilterRegex(compiledRegex) + if got := f.Include(tc.value); got != tc.wanted { + t.Errorf("Included must be %t for regex '%s'", tc.wanted, tc.regex) + } else { + t.Logf("Included is %t for regex '%s'", tc.wanted, tc.regex) + } + }) + } +} diff --git a/base/session.go b/base/session.go index 29e3611..45cc98e 100644 --- a/base/session.go +++ b/base/session.go @@ -60,3 +60,21 @@ func (s *Session) IsIdle() bool { } return false } + +// Equal returns true when two sessions share the same process id +func (s *Session) Equal(session *Session) bool { + if s.Pid == 0 { + return s.User == session.User && s.Db == session.Db && s.Client == session.Client + } + return s.Pid == session.Pid +} + +// InSlice returns true when this sessions in present in the slice +func (s *Session) InSlice(sessions []*Session) bool { + for _, session := range sessions { + if s.Equal(session) { + return true + } + } + return false +} diff --git a/base/session_test.go b/base/session_test.go new file mode 100644 index 0000000..10b56c5 --- /dev/null +++ b/base/session_test.go @@ -0,0 +1,139 @@ +package base + +import ( + "testing" +) + +func TestSessionEqual(t *testing.T) { + tests := []struct { + name string + first *Session + second *Session + want bool + }{ + { + "Empty sessions", + &Session{}, + &Session{}, + true, + }, + { + "Identical process id", + &Session{Pid: 1}, + &Session{Pid: 1}, + true, + }, + { + "Different process id", + &Session{Pid: 1}, + &Session{Pid: 2}, + false, + }, + { + "Identical users", + &Session{User: "test"}, + &Session{User: "test"}, + true, + }, + { + "Different users", + &Session{User: "test"}, + &Session{User: "random"}, + false, + }, + { + "Identical databases", + &Session{Db: "test"}, + &Session{Db: "test"}, + true, + }, + { + "Different databases", + &Session{Db: "test"}, + &Session{Db: "random"}, + false, + }, + { + "Identical users and databases", + &Session{User: "test", Db: "test"}, + &Session{User: "test", Db: "test"}, + true, + }, + { + "Different users and same databases", + &Session{User: "test_1", Db: "test"}, + &Session{User: "test_2", Db: "test"}, + false, + }, + { + "Different databases and same user", + &Session{User: "test", Db: "test_1"}, + &Session{User: "test", Db: "test_2"}, + false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := tc.first.Equal(tc.second) + if got != tc.want { + t.Errorf("got %t; want %t", got, tc.want) + } else { + t.Logf("got %t; want %t", got, tc.want) + } + }) + } +} + +func TestSessionInSlice(t *testing.T) { + sessions := []*Session{ + {User: "test"}, + {User: "test_1"}, + {User: "test_2"}, + {User: "postgres"}, + {Db: "test"}, + } + + tests := []struct { + name string + input *Session + want bool + }{ + { + "Empty session", + &Session{}, + false, + }, + { + "Session with user in slice", + &Session{User: "test"}, + true, + }, + { + "Session with user not in slice", + &Session{User: "random"}, + false, + }, + { + "Session with db in slice", + &Session{Db: "test"}, + true, + }, + { + "Session with db not in slice", + &Session{Db: "random"}, + false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := tc.input.InSlice(sessions) + if got != tc.want { + t.Errorf("got %t; want %t", got, tc.want) + } else { + t.Logf("got %t; want %t", got, tc.want) + } + }) + } +} diff --git a/cmd/pgterminate/main.go b/cmd/pgterminate/main.go index cecc413..9375b3b 100644 --- a/cmd/pgterminate/main.go +++ b/cmd/pgterminate/main.go @@ -43,6 +43,7 @@ func main() { flag.StringVar(&config.User, "user", "", "Instance username") flag.StringVar(&config.Password, "password", "", "Instance password") flag.StringVar(&config.Database, "database", "", "Instance database") + flag.StringVar(&config.SSLMode, "sslmode", "", "SSL mode (see https://www.postgresql.org/docs/current/libpq-ssl.html)") prompt := flag.Bool("prompt-password", false, "Prompt for password") flag.Float64Var(&config.Interval, "interval", 1, "Time to sleep between iterations in seconds") flag.IntVar(&config.ConnectTimeout, "connect-timeout", 3, "Connection timeout in seconds") @@ -58,6 +59,10 @@ func main() { flag.StringVar(&config.IncludeUsersRegex, "include-users-regex", "", "Terminate users matching this regexp") flag.Var(&config.ExcludeUsers, "exclude-user", "Ignore this user (can be called multiple times)") flag.StringVar(&config.ExcludeUsersRegex, "exclude-users-regex", "", "Ignore users matching this regexp") + flag.Var(&config.IncludeDatabases, "include-database", "Terminate only this database (can be called multiple times)") + flag.StringVar(&config.IncludeDatabasesRegex, "include-databases-regex", "", "Terminate databases matching this regexp") + flag.Var(&config.ExcludeDatabases, "exclude-database", "Ignore this database (can be called multiple times)") + flag.StringVar(&config.ExcludeDatabasesRegex, "exclude-databases-regex", "", "Ignore databases matching this regexp") flag.BoolVar(&config.ExcludeListeners, "exclude-listeners", false, "Ignore sessions listening for events") flag.BoolVar(&config.Cancel, "cancel", false, "Cancel sessions instead of terminate") flag.Parse() @@ -112,6 +117,7 @@ func main() { err = config.CompileRegexes() base.Panic(err) + config.CompileFilters() if config.PidFile != "" { writePid(config.PidFile) diff --git a/config.yaml.example b/config.yaml.example index fb97dc1..7acce28 100644 --- a/config.yaml.example +++ b/config.yaml.example @@ -14,11 +14,19 @@ #syslog-ident: pgterminate #syslog-facility: LOCAL0 #include-users: -# user1 -# user2 +# - user1 +# - user2 #include-users-regex: "(user1|user2)" #exclude-users: -# user1 -# user2 +# - user1 +# - user2 #exclude-users-regex: "(user1|user2)" -#cancel \ No newline at end of file +#include-databases: +# - db1 +# - db2 +#include-databases-regex: "(db1|db2)" +#exclude-databases: +# - db1 +# - db2 +#exclude-databases-regex: "(db1|db2)" +#cancel: true \ No newline at end of file diff --git a/go.mod b/go.mod index f21c6c2..1fa04b1 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,14 @@ module github.com/jouir/pgterminate -go 1.13 +go 1.19 require ( - github.com/lib/pq v1.10.0 - golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b + github.com/lib/pq v1.10.7 + golang.org/x/crypto v0.5.0 gopkg.in/yaml.v2 v2.4.0 ) + +require ( + golang.org/x/sys v0.4.0 // indirect + golang.org/x/term v0.4.0 // indirect +) diff --git a/go.sum b/go.sum index 564e5f3..53b86a7 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,11 @@ -github.com/lib/pq v1.10.0 h1:Zx5DJFEYQXio93kgXnQ09fXNiUKsqv4OUEu2UtGcB1E= -github.com/lib/pq v1.10.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b h1:wSOdpTq0/eI46Ez/LkDwIsAKA71YP2SRKBODiRWM0as= -golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68 h1:nxC68pudNYkKU6jWhgrqdreuFiOQWj1Fs7T3VrH4Pjw= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= +github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= +golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= +golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= +golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.4.0 h1:O7UWfv5+A2qiuulQk30kVinPoMtoIPeVaKLEgLpVkvg= +golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/terminator/terminator.go b/terminator/terminator.go index f5a6c87..03010e2 100644 --- a/terminator/terminator.go +++ b/terminator/terminator.go @@ -1,10 +1,11 @@ package terminator import ( - "github.com/jouir/pgterminate/base" - "github.com/jouir/pgterminate/log" "strings" "time" + + "github.com/jouir/pgterminate/base" + "github.com/jouir/pgterminate/log" ) // Terminator looks for sessions, filters actives and idles, terminate them and notify sessions channel @@ -71,30 +72,6 @@ func (t *Terminator) notify(sessions []*base.Session) { } } -// filterUsers removes sessions according to include and exclude users settings -// when include users slice and regex are not set, append all sessions except excluded users -// otherwise, append included users -func (t *Terminator) filterUsers(sessions []*base.Session) (filtered []*base.Session) { - includeUsers, includeRegex := t.config.IncludeUsers, t.config.IncludeUsersRegexCompiled - excludeUsers, excludeRegex := t.config.ExcludeUsers, t.config.ExcludeUsersRegexCompiled - - for _, session := range sessions { - if t.config.IncludeUsers == nil && includeRegex == nil { - // append all sessions except excluded users - if !base.InSlice(session.User, excludeUsers) && (excludeRegex != nil && !excludeRegex.MatchString(session.User)) { - filtered = append(filtered, session) - } - } else { - // append included users only - if base.InSlice(session.User, includeUsers) || (includeRegex != nil && includeRegex.MatchString(session.User)) { - filtered = append(filtered, session) - } - } - } - - return filtered -} - // filterListeners excludes sessions with last query starting with "LISTEN" func (t *Terminator) filterListeners(sessions []*base.Session) (filtered []*base.Session) { for _, session := range sessions { @@ -105,11 +82,79 @@ func (t *Terminator) filterListeners(sessions []*base.Session) (filtered []*base return filtered } +// filterUsers include and exclude users based on filters +func (t *Terminator) filterUsers(sessions []*base.Session) []*base.Session { + var included []*base.Session + for _, filter := range t.config.IncludeUsersFilters { + for _, session := range sessions { + if filter.Include(session.User) && !session.InSlice(included) { + included = append(included, session) + } + } + } + + var excluded []*base.Session + for _, filter := range t.config.ExcludeUsersFilters { + for _, session := range sessions { + if !filter.Include(session.User) && !session.InSlice(excluded) { + excluded = append(excluded, session) + } + } + } + + if included == nil { + included = sessions + } + + var filtered []*base.Session + for _, session := range included { + if !session.InSlice(excluded) && !session.InSlice(filtered) { + filtered = append(filtered, session) + } + } + + return filtered +} + +// filterDatabases include and exclude databases based on filters +func (t *Terminator) filterDatabases(sessions []*base.Session) []*base.Session { + var included []*base.Session + for _, filter := range t.config.IncludeDatabasesFilters { + for _, session := range sessions { + if filter.Include(session.Db) && !session.InSlice(included) { + included = append(included, session) + } + } + } + + var excluded []*base.Session + for _, filter := range t.config.ExcludeDatabasesFilters { + for _, session := range sessions { + if !filter.Include(session.Db) && !session.InSlice(excluded) { + excluded = append(excluded, session) + } + } + } + + if included == nil { + included = sessions + } + + var filtered []*base.Session + for _, session := range included { + if !session.InSlice(excluded) && !session.InSlice(filtered) { + filtered = append(filtered, session) + } + } + + return filtered +} + // filter executes all filter functions on a list of sessions func (t *Terminator) filter(sessions []*base.Session) (filtered []*base.Session) { - filtered = sessions + filtered = t.filterListeners(sessions) filtered = t.filterUsers(filtered) - filtered = t.filterListeners(filtered) + filtered = t.filterDatabases(filtered) return filtered } diff --git a/terminator/terminator_test.go b/terminator/terminator_test.go new file mode 100644 index 0000000..dc54feb --- /dev/null +++ b/terminator/terminator_test.go @@ -0,0 +1,184 @@ +package terminator + +import ( + "reflect" + "testing" + + "github.com/jouir/pgterminate/base" +) + +func TestFilterUsers(t *testing.T) { + + sessions := []*base.Session{ + {User: "test"}, + {User: "test_1"}, + {User: "test_2"}, + {User: "postgres"}, + } + + tests := []struct { + name string + config *base.Config + want []*base.Session + }{ + { + "No filter", + &base.Config{}, + sessions, + }, + { + "Include a single user", + &base.Config{IncludeUsers: []string{"test"}}, + []*base.Session{{User: "test"}}, + }, + { + "Include multiple users", + &base.Config{IncludeUsers: []string{"test_1", "test_2"}}, + []*base.Session{{User: "test_1"}, {User: "test_2"}}, + }, + { + "Exclude a single user", + &base.Config{ExcludeUsers: []string{"test"}}, + []*base.Session{{User: "test_1"}, {User: "test_2"}, {User: "postgres"}}, + }, + { + "Exclude multiple users", + &base.Config{ExcludeUsers: []string{"test_1", "test_2"}}, + []*base.Session{{User: "test"}, {User: "postgres"}}, + }, + { + "Include multiple users and exclude one", + &base.Config{IncludeUsers: []string{"test", "test_1", "test_2"}, ExcludeUsers: []string{"test"}}, + []*base.Session{{User: "test_1"}, {User: "test_2"}}, + }, + { + "Include users from list and regex", + &base.Config{ + IncludeUsers: []string{"test"}, + IncludeUsersRegex: "^test_[0-9]$", + }, + []*base.Session{{User: "test"}, {User: "test_1"}, {User: "test_2"}}, + }, + { + "Exclude users from list and regex", + &base.Config{ + ExcludeUsers: []string{"test"}, + ExcludeUsersRegex: "^test_[0-9]$", + }, + []*base.Session{{User: "postgres"}}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.config.CompileRegexes() + if err != nil { + t.Errorf("Failed to compile regex: %v", err) + } + tc.config.CompileFilters() + terminator := &Terminator{config: tc.config} + got := terminator.filterUsers(sessions) + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("got %+v; want %+v", ListUsers(got), ListUsers(tc.want)) + } else { + t.Logf("got %+v; want %+v", ListUsers(got), ListUsers(tc.want)) + } + }) + } +} + +// ListUsers extract usernames from a list of sessions +func ListUsers(sessions []*base.Session) (users []string) { + for _, session := range sessions { + users = append(users, session.User) + } + return users +} + +func TestFilterDatabases(t *testing.T) { + + sessions := []*base.Session{ + {Db: "test"}, + {Db: "test_1"}, + {Db: "test_2"}, + {Db: "postgres"}, + } + + tests := []struct { + name string + config *base.Config + want []*base.Session + }{ + { + "No filter", + &base.Config{}, + sessions, + }, + { + "Include a single database", + &base.Config{IncludeDatabases: []string{"test"}}, + []*base.Session{{Db: "test"}}, + }, + { + "Include multiple databases", + &base.Config{IncludeDatabases: []string{"test_1", "test_2"}}, + []*base.Session{{Db: "test_1"}, {Db: "test_2"}}, + }, + { + "Exclude a single database", + &base.Config{ExcludeDatabases: []string{"test"}}, + []*base.Session{{Db: "test_1"}, {Db: "test_2"}, {Db: "postgres"}}, + }, + { + "Exclude multiple databases", + &base.Config{ExcludeDatabases: []string{"test_1", "test_2"}}, + []*base.Session{{Db: "test"}, {Db: "postgres"}}, + }, + { + "Include multiple databases and exclude one", + &base.Config{IncludeDatabases: []string{"test", "test_1", "test_2"}, ExcludeDatabases: []string{"test"}}, + []*base.Session{{Db: "test_1"}, {Db: "test_2"}}, + }, + { + "Include databases from list and regex", + &base.Config{ + IncludeDatabases: []string{"test"}, + IncludeDatabasesRegex: "^test_[0-9]$", + }, + []*base.Session{{Db: "test"}, {Db: "test_1"}, {Db: "test_2"}}, + }, + { + "Exclude databases from list and regex", + &base.Config{ + ExcludeDatabases: []string{"test"}, + ExcludeDatabasesRegex: "^test_[0-9]$", + }, + []*base.Session{{Db: "postgres"}}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.config.CompileRegexes() + if err != nil { + t.Errorf("Failed to compile regex: %v", err) + } + tc.config.CompileFilters() + terminator := &Terminator{config: tc.config} + got := terminator.filterDatabases(sessions) + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("got %+v; want %+v", ListDatabases(got), ListDatabases(tc.want)) + } else { + t.Logf("got %+v; want %+v", ListDatabases(got), ListDatabases(tc.want)) + } + }) + } +} + +// ListDatabases extract databases from a list of sessions +func ListDatabases(sessions []*base.Session) (databases []string) { + for _, session := range sessions { + databases = append(databases, session.Db) + } + return databases +}