diff --git a/Makefile b/Makefile index 6261d94..8c30008 100644 --- a/Makefile +++ b/Makefile @@ -16,8 +16,8 @@ release: (cd bin && tar czf ${BINARY}-${APPVERSION}-${GOOS}-${GOARCH}.tar.gz ${BINARY}) test: - go test base/* - go test terminator/* + go test -cover base/* + go test -cover terminator/* clean: rm -rf bin diff --git a/base/config.go b/base/config.go index 7019078..34759ce 100644 --- a/base/config.go +++ b/base/config.go @@ -141,6 +141,18 @@ func (c *Config) CompileRegexes() (err error) { return err } } + if c.IncludeDatabasesRegex != "" { + c.IncludeDatabasesRegexCompiled, err = regexp.Compile(c.IncludeDatabasesRegex) + if err != nil { + return err + } + } + if c.ExcludeDatabasesRegex != "" { + c.ExcludeDatabasesRegexCompiled, err = regexp.Compile(c.ExcludeDatabasesRegex) + if err != nil { + return err + } + } return nil } diff --git a/base/session.go b/base/session.go index 29e3611..45cc98e 100644 --- a/base/session.go +++ b/base/session.go @@ -60,3 +60,21 @@ func (s *Session) IsIdle() bool { } return false } + +// Equal returns true when two sessions share the same process id +func (s *Session) Equal(session *Session) bool { + if s.Pid == 0 { + return s.User == session.User && s.Db == session.Db && s.Client == session.Client + } + return s.Pid == session.Pid +} + +// InSlice returns true when this sessions in present in the slice +func (s *Session) InSlice(sessions []*Session) bool { + for _, session := range sessions { + if s.Equal(session) { + return true + } + } + return false +} diff --git a/base/session_test.go b/base/session_test.go new file mode 100644 index 0000000..10b56c5 --- /dev/null +++ b/base/session_test.go @@ -0,0 +1,139 @@ +package base + +import ( + "testing" +) + +func TestSessionEqual(t *testing.T) { + tests := []struct { + name string + first *Session + second *Session + want bool + }{ + { + "Empty sessions", + &Session{}, + &Session{}, + true, + }, + { + "Identical process id", + &Session{Pid: 1}, + &Session{Pid: 1}, + true, + }, + { + "Different process id", + &Session{Pid: 1}, + &Session{Pid: 2}, + false, + }, + { + "Identical users", + &Session{User: "test"}, + &Session{User: "test"}, + true, + }, + { + "Different users", + &Session{User: "test"}, + &Session{User: "random"}, + false, + }, + { + "Identical databases", + &Session{Db: "test"}, + &Session{Db: "test"}, + true, + }, + { + "Different databases", + &Session{Db: "test"}, + &Session{Db: "random"}, + false, + }, + { + "Identical users and databases", + &Session{User: "test", Db: "test"}, + &Session{User: "test", Db: "test"}, + true, + }, + { + "Different users and same databases", + &Session{User: "test_1", Db: "test"}, + &Session{User: "test_2", Db: "test"}, + false, + }, + { + "Different databases and same user", + &Session{User: "test", Db: "test_1"}, + &Session{User: "test", Db: "test_2"}, + false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := tc.first.Equal(tc.second) + if got != tc.want { + t.Errorf("got %t; want %t", got, tc.want) + } else { + t.Logf("got %t; want %t", got, tc.want) + } + }) + } +} + +func TestSessionInSlice(t *testing.T) { + sessions := []*Session{ + {User: "test"}, + {User: "test_1"}, + {User: "test_2"}, + {User: "postgres"}, + {Db: "test"}, + } + + tests := []struct { + name string + input *Session + want bool + }{ + { + "Empty session", + &Session{}, + false, + }, + { + "Session with user in slice", + &Session{User: "test"}, + true, + }, + { + "Session with user not in slice", + &Session{User: "random"}, + false, + }, + { + "Session with db in slice", + &Session{Db: "test"}, + true, + }, + { + "Session with db not in slice", + &Session{Db: "random"}, + false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := tc.input.InSlice(sessions) + if got != tc.want { + t.Errorf("got %t; want %t", got, tc.want) + } else { + t.Logf("got %t; want %t", got, tc.want) + } + }) + } +} diff --git a/terminator/terminator.go b/terminator/terminator.go index f6b6f57..03010e2 100644 --- a/terminator/terminator.go +++ b/terminator/terminator.go @@ -84,65 +84,69 @@ func (t *Terminator) filterListeners(sessions []*base.Session) (filtered []*base // filterUsers include and exclude users based on filters func (t *Terminator) filterUsers(sessions []*base.Session) []*base.Session { - var included []*base.Session - if t.config.IncludeUsersFilters == nil { - included = sessions - } else { - for _, filter := range t.config.IncludeUsersFilters { - for _, session := range sessions { - if filter.Include(session.User) { - included = append(included, session) - } + for _, filter := range t.config.IncludeUsersFilters { + for _, session := range sessions { + if filter.Include(session.User) && !session.InSlice(included) { + included = append(included, session) } } } - var filtered []*base.Session - if t.config.ExcludeUsersFilters == nil { - filtered = included - } else { - for _, filter := range t.config.ExcludeUsersFilters { - for _, session := range included { - if filter.Include(session.User) { - filtered = append(filtered, session) - } + var excluded []*base.Session + for _, filter := range t.config.ExcludeUsersFilters { + for _, session := range sessions { + if !filter.Include(session.User) && !session.InSlice(excluded) { + excluded = append(excluded, session) } } } + if included == nil { + included = sessions + } + + var filtered []*base.Session + for _, session := range included { + if !session.InSlice(excluded) && !session.InSlice(filtered) { + filtered = append(filtered, session) + } + } + return filtered } -// filterUsers include and exclude databases based on filters +// filterDatabases include and exclude databases based on filters func (t *Terminator) filterDatabases(sessions []*base.Session) []*base.Session { - var included []*base.Session - if t.config.IncludeDatabasesFilters == nil { - included = sessions - } else { - for _, filter := range t.config.IncludeDatabasesFilters { - for _, session := range sessions { - if filter.Include(session.Db) { - included = append(included, session) - } + for _, filter := range t.config.IncludeDatabasesFilters { + for _, session := range sessions { + if filter.Include(session.Db) && !session.InSlice(included) { + included = append(included, session) } } } - var filtered []*base.Session - if t.config.ExcludeDatabasesFilters == nil { - filtered = included - } else { - for _, filter := range t.config.ExcludeDatabasesFilters { - for _, session := range included { - if filter.Include(session.Db) { - filtered = append(filtered, session) - } + var excluded []*base.Session + for _, filter := range t.config.ExcludeDatabasesFilters { + for _, session := range sessions { + if !filter.Include(session.Db) && !session.InSlice(excluded) { + excluded = append(excluded, session) } } } + if included == nil { + included = sessions + } + + var filtered []*base.Session + for _, session := range included { + if !session.InSlice(excluded) && !session.InSlice(filtered) { + filtered = append(filtered, session) + } + } + return filtered } diff --git a/terminator/terminator_test.go b/terminator/terminator_test.go index f68b528..230425e 100644 --- a/terminator/terminator_test.go +++ b/terminator/terminator_test.go @@ -51,17 +51,37 @@ func TestFilterUsers(t *testing.T) { &base.Config{IncludeUsers: []string{"test", "test_1", "test_2"}, ExcludeUsers: []string{"test"}}, []*base.Session{{User: "test_1"}, {User: "test_2"}}, }, + { + "Include users from list and regex", + &base.Config{ + IncludeUsers: []string{"test"}, + IncludeUsersRegex: "^test_[0-9]$", + }, + []*base.Session{{User: "test"}, {User: "test_1"}, {User: "test_2"}}, + }, + { + "Exclude users from list and regex", + &base.Config{ + ExcludeUsers: []string{"test"}, + ExcludeUsersRegex: "^test_[0-9]$", + }, + []*base.Session{{User: "postgres"}}, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { + err := tc.config.CompileRegexes() + if err != nil { + t.Errorf("Failed to compile regex: %v", err) + } tc.config.CompileFilters() terminator := &Terminator{config: tc.config} got := terminator.filterUsers(sessions) if !reflect.DeepEqual(got, tc.want) { t.Errorf("got %+v; want %+v", ListUsers(got), ListUsers(tc.want)) } else { - t.Logf("Success") + t.Logf("got %+v; want %+v", ListUsers(got), ListUsers(tc.want)) } }) } @@ -119,17 +139,37 @@ func TestFilterDatabases(t *testing.T) { &base.Config{IncludeDatabases: []string{"test", "test_1", "test_2"}, ExcludeDatabases: []string{"test"}}, []*base.Session{{Db: "test_1"}, {Db: "test_2"}}, }, + { + "Include databases from list and regex", + &base.Config{ + IncludeDatabases: []string{"test"}, + IncludeDatabasesRegex: "^test_[0-9]$", + }, + []*base.Session{{Db: "test"}, {Db: "test_1"}, {Db: "test_2"}}, + }, + { + "Exclude databases from list and regex", + &base.Config{ + ExcludeDatabases: []string{"test"}, + ExcludeDatabasesRegex: "^test_[0-9]$", + }, + []*base.Session{{Db: "postgres"}}, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { + err := tc.config.CompileRegexes() + if err != nil { + t.Errorf("Failed to compile regex: %v", err) + } tc.config.CompileFilters() terminator := &Terminator{config: tc.config} got := terminator.filterDatabases(sessions) if !reflect.DeepEqual(got, tc.want) { t.Errorf("got %+v; want %+v", ListDatabases(got), ListDatabases(tc.want)) } else { - t.Logf("Success") + t.Logf("got %+v; want %+v", ListDatabases(got), ListDatabases(tc.want)) } }) }