From 9cff85102139d745a02a8bfca1f2679d99138dd1 Mon Sep 17 00:00:00 2001
From: Julien Riou <julien@riou.xyz>
Date: Fri, 20 Jan 2023 11:17:44 +0100
Subject: [PATCH] fix: Filters from list and regex both match (#2)

Signed-off-by: Julien Riou <julien@riou.xyz>
---
 Makefile                      |   4 +-
 base/config.go                |  12 +++
 base/session.go               |  18 +++++
 base/session_test.go          | 139 ++++++++++++++++++++++++++++++++++
 terminator/terminator.go      |  78 ++++++++++---------
 terminator/terminator_test.go |  44 ++++++++++-
 6 files changed, 254 insertions(+), 41 deletions(-)
 create mode 100644 base/session_test.go

diff --git a/Makefile b/Makefile
index 6261d94..8c30008 100644
--- a/Makefile
+++ b/Makefile
@@ -16,8 +16,8 @@ release:
 	(cd bin && tar czf ${BINARY}-${APPVERSION}-${GOOS}-${GOARCH}.tar.gz ${BINARY})
 
 test:
-	go test base/*
-	go test terminator/*
+	go test -cover base/*
+	go test -cover terminator/*
 
 clean:
 	rm -rf bin
diff --git a/base/config.go b/base/config.go
index 7019078..34759ce 100644
--- a/base/config.go
+++ b/base/config.go
@@ -141,6 +141,18 @@ func (c *Config) CompileRegexes() (err error) {
 			return err
 		}
 	}
+	if c.IncludeDatabasesRegex != "" {
+		c.IncludeDatabasesRegexCompiled, err = regexp.Compile(c.IncludeDatabasesRegex)
+		if err != nil {
+			return err
+		}
+	}
+	if c.ExcludeDatabasesRegex != "" {
+		c.ExcludeDatabasesRegexCompiled, err = regexp.Compile(c.ExcludeDatabasesRegex)
+		if err != nil {
+			return err
+		}
+	}
 	return nil
 }
 
diff --git a/base/session.go b/base/session.go
index 29e3611..45cc98e 100644
--- a/base/session.go
+++ b/base/session.go
@@ -60,3 +60,21 @@ func (s *Session) IsIdle() bool {
 	}
 	return false
 }
+
+// Equal returns true when two sessions share the same process id
+func (s *Session) Equal(session *Session) bool {
+	if s.Pid == 0 {
+		return s.User == session.User && s.Db == session.Db && s.Client == session.Client
+	}
+	return s.Pid == session.Pid
+}
+
+// InSlice returns true when this sessions in present in the slice
+func (s *Session) InSlice(sessions []*Session) bool {
+	for _, session := range sessions {
+		if s.Equal(session) {
+			return true
+		}
+	}
+	return false
+}
diff --git a/base/session_test.go b/base/session_test.go
new file mode 100644
index 0000000..10b56c5
--- /dev/null
+++ b/base/session_test.go
@@ -0,0 +1,139 @@
+package base
+
+import (
+	"testing"
+)
+
+func TestSessionEqual(t *testing.T) {
+	tests := []struct {
+		name   string
+		first  *Session
+		second *Session
+		want   bool
+	}{
+		{
+			"Empty sessions",
+			&Session{},
+			&Session{},
+			true,
+		},
+		{
+			"Identical process id",
+			&Session{Pid: 1},
+			&Session{Pid: 1},
+			true,
+		},
+		{
+			"Different process id",
+			&Session{Pid: 1},
+			&Session{Pid: 2},
+			false,
+		},
+		{
+			"Identical users",
+			&Session{User: "test"},
+			&Session{User: "test"},
+			true,
+		},
+		{
+			"Different users",
+			&Session{User: "test"},
+			&Session{User: "random"},
+			false,
+		},
+		{
+			"Identical databases",
+			&Session{Db: "test"},
+			&Session{Db: "test"},
+			true,
+		},
+		{
+			"Different databases",
+			&Session{Db: "test"},
+			&Session{Db: "random"},
+			false,
+		},
+		{
+			"Identical users and databases",
+			&Session{User: "test", Db: "test"},
+			&Session{User: "test", Db: "test"},
+			true,
+		},
+		{
+			"Different users and same databases",
+			&Session{User: "test_1", Db: "test"},
+			&Session{User: "test_2", Db: "test"},
+			false,
+		},
+		{
+			"Different databases and same user",
+			&Session{User: "test", Db: "test_1"},
+			&Session{User: "test", Db: "test_2"},
+			false,
+		},
+	}
+
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			got := tc.first.Equal(tc.second)
+			if got != tc.want {
+				t.Errorf("got %t; want %t", got, tc.want)
+			} else {
+				t.Logf("got %t; want %t", got, tc.want)
+			}
+		})
+	}
+}
+
+func TestSessionInSlice(t *testing.T) {
+	sessions := []*Session{
+		{User: "test"},
+		{User: "test_1"},
+		{User: "test_2"},
+		{User: "postgres"},
+		{Db: "test"},
+	}
+
+	tests := []struct {
+		name  string
+		input *Session
+		want  bool
+	}{
+		{
+			"Empty session",
+			&Session{},
+			false,
+		},
+		{
+			"Session with user in slice",
+			&Session{User: "test"},
+			true,
+		},
+		{
+			"Session with user not in slice",
+			&Session{User: "random"},
+			false,
+		},
+		{
+			"Session with db in slice",
+			&Session{Db: "test"},
+			true,
+		},
+		{
+			"Session with db not in slice",
+			&Session{Db: "random"},
+			false,
+		},
+	}
+
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			got := tc.input.InSlice(sessions)
+			if got != tc.want {
+				t.Errorf("got %t; want %t", got, tc.want)
+			} else {
+				t.Logf("got %t; want %t", got, tc.want)
+			}
+		})
+	}
+}
diff --git a/terminator/terminator.go b/terminator/terminator.go
index f6b6f57..03010e2 100644
--- a/terminator/terminator.go
+++ b/terminator/terminator.go
@@ -84,65 +84,69 @@ func (t *Terminator) filterListeners(sessions []*base.Session) (filtered []*base
 
 // filterUsers include and exclude users based on filters
 func (t *Terminator) filterUsers(sessions []*base.Session) []*base.Session {
-
 	var included []*base.Session
-	if t.config.IncludeUsersFilters == nil {
-		included = sessions
-	} else {
-		for _, filter := range t.config.IncludeUsersFilters {
-			for _, session := range sessions {
-				if filter.Include(session.User) {
-					included = append(included, session)
-				}
+	for _, filter := range t.config.IncludeUsersFilters {
+		for _, session := range sessions {
+			if filter.Include(session.User) && !session.InSlice(included) {
+				included = append(included, session)
 			}
 		}
 	}
 
-	var filtered []*base.Session
-	if t.config.ExcludeUsersFilters == nil {
-		filtered = included
-	} else {
-		for _, filter := range t.config.ExcludeUsersFilters {
-			for _, session := range included {
-				if filter.Include(session.User) {
-					filtered = append(filtered, session)
-				}
+	var excluded []*base.Session
+	for _, filter := range t.config.ExcludeUsersFilters {
+		for _, session := range sessions {
+			if !filter.Include(session.User) && !session.InSlice(excluded) {
+				excluded = append(excluded, session)
 			}
 		}
 	}
 
+	if included == nil {
+		included = sessions
+	}
+
+	var filtered []*base.Session
+	for _, session := range included {
+		if !session.InSlice(excluded) && !session.InSlice(filtered) {
+			filtered = append(filtered, session)
+		}
+	}
+
 	return filtered
 }
 
-// filterUsers include and exclude databases based on filters
+// filterDatabases include and exclude databases based on filters
 func (t *Terminator) filterDatabases(sessions []*base.Session) []*base.Session {
-
 	var included []*base.Session
-	if t.config.IncludeDatabasesFilters == nil {
-		included = sessions
-	} else {
-		for _, filter := range t.config.IncludeDatabasesFilters {
-			for _, session := range sessions {
-				if filter.Include(session.Db) {
-					included = append(included, session)
-				}
+	for _, filter := range t.config.IncludeDatabasesFilters {
+		for _, session := range sessions {
+			if filter.Include(session.Db) && !session.InSlice(included) {
+				included = append(included, session)
 			}
 		}
 	}
 
-	var filtered []*base.Session
-	if t.config.ExcludeDatabasesFilters == nil {
-		filtered = included
-	} else {
-		for _, filter := range t.config.ExcludeDatabasesFilters {
-			for _, session := range included {
-				if filter.Include(session.Db) {
-					filtered = append(filtered, session)
-				}
+	var excluded []*base.Session
+	for _, filter := range t.config.ExcludeDatabasesFilters {
+		for _, session := range sessions {
+			if !filter.Include(session.Db) && !session.InSlice(excluded) {
+				excluded = append(excluded, session)
 			}
 		}
 	}
 
+	if included == nil {
+		included = sessions
+	}
+
+	var filtered []*base.Session
+	for _, session := range included {
+		if !session.InSlice(excluded) && !session.InSlice(filtered) {
+			filtered = append(filtered, session)
+		}
+	}
+
 	return filtered
 }
 
diff --git a/terminator/terminator_test.go b/terminator/terminator_test.go
index f68b528..230425e 100644
--- a/terminator/terminator_test.go
+++ b/terminator/terminator_test.go
@@ -51,17 +51,37 @@ func TestFilterUsers(t *testing.T) {
 			&base.Config{IncludeUsers: []string{"test", "test_1", "test_2"}, ExcludeUsers: []string{"test"}},
 			[]*base.Session{{User: "test_1"}, {User: "test_2"}},
 		},
+		{
+			"Include users from list and regex",
+			&base.Config{
+				IncludeUsers:      []string{"test"},
+				IncludeUsersRegex: "^test_[0-9]$",
+			},
+			[]*base.Session{{User: "test"}, {User: "test_1"}, {User: "test_2"}},
+		},
+		{
+			"Exclude users from list and regex",
+			&base.Config{
+				ExcludeUsers:      []string{"test"},
+				ExcludeUsersRegex: "^test_[0-9]$",
+			},
+			[]*base.Session{{User: "postgres"}},
+		},
 	}
 
 	for _, tc := range tests {
 		t.Run(tc.name, func(t *testing.T) {
+			err := tc.config.CompileRegexes()
+			if err != nil {
+				t.Errorf("Failed to compile regex: %v", err)
+			}
 			tc.config.CompileFilters()
 			terminator := &Terminator{config: tc.config}
 			got := terminator.filterUsers(sessions)
 			if !reflect.DeepEqual(got, tc.want) {
 				t.Errorf("got %+v; want %+v", ListUsers(got), ListUsers(tc.want))
 			} else {
-				t.Logf("Success")
+				t.Logf("got %+v; want %+v", ListUsers(got), ListUsers(tc.want))
 			}
 		})
 	}
@@ -119,17 +139,37 @@ func TestFilterDatabases(t *testing.T) {
 			&base.Config{IncludeDatabases: []string{"test", "test_1", "test_2"}, ExcludeDatabases: []string{"test"}},
 			[]*base.Session{{Db: "test_1"}, {Db: "test_2"}},
 		},
+		{
+			"Include databases from list and regex",
+			&base.Config{
+				IncludeDatabases:      []string{"test"},
+				IncludeDatabasesRegex: "^test_[0-9]$",
+			},
+			[]*base.Session{{Db: "test"}, {Db: "test_1"}, {Db: "test_2"}},
+		},
+		{
+			"Exclude databases from list and regex",
+			&base.Config{
+				ExcludeDatabases:      []string{"test"},
+				ExcludeDatabasesRegex: "^test_[0-9]$",
+			},
+			[]*base.Session{{Db: "postgres"}},
+		},
 	}
 
 	for _, tc := range tests {
 		t.Run(tc.name, func(t *testing.T) {
+			err := tc.config.CompileRegexes()
+			if err != nil {
+				t.Errorf("Failed to compile regex: %v", err)
+			}
 			tc.config.CompileFilters()
 			terminator := &Terminator{config: tc.config}
 			got := terminator.filterDatabases(sessions)
 			if !reflect.DeepEqual(got, tc.want) {
 				t.Errorf("got %+v; want %+v", ListDatabases(got), ListDatabases(tc.want))
 			} else {
-				t.Logf("Success")
+				t.Logf("got %+v; want %+v", ListDatabases(got), ListDatabases(tc.want))
 			}
 		})
 	}