fix: Filters from list and regex both match (#2)
Signed-off-by: Julien Riou <julien@riou.xyz>
This commit is contained in:
parent
763359c3d6
commit
9cff851021
6 changed files with 254 additions and 41 deletions
|
@ -84,65 +84,69 @@ func (t *Terminator) filterListeners(sessions []*base.Session) (filtered []*base
|
|||
|
||||
// filterUsers include and exclude users based on filters
|
||||
func (t *Terminator) filterUsers(sessions []*base.Session) []*base.Session {
|
||||
|
||||
var included []*base.Session
|
||||
if t.config.IncludeUsersFilters == nil {
|
||||
included = sessions
|
||||
} else {
|
||||
for _, filter := range t.config.IncludeUsersFilters {
|
||||
for _, session := range sessions {
|
||||
if filter.Include(session.User) {
|
||||
included = append(included, session)
|
||||
}
|
||||
for _, filter := range t.config.IncludeUsersFilters {
|
||||
for _, session := range sessions {
|
||||
if filter.Include(session.User) && !session.InSlice(included) {
|
||||
included = append(included, session)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var filtered []*base.Session
|
||||
if t.config.ExcludeUsersFilters == nil {
|
||||
filtered = included
|
||||
} else {
|
||||
for _, filter := range t.config.ExcludeUsersFilters {
|
||||
for _, session := range included {
|
||||
if filter.Include(session.User) {
|
||||
filtered = append(filtered, session)
|
||||
}
|
||||
var excluded []*base.Session
|
||||
for _, filter := range t.config.ExcludeUsersFilters {
|
||||
for _, session := range sessions {
|
||||
if !filter.Include(session.User) && !session.InSlice(excluded) {
|
||||
excluded = append(excluded, session)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if included == nil {
|
||||
included = sessions
|
||||
}
|
||||
|
||||
var filtered []*base.Session
|
||||
for _, session := range included {
|
||||
if !session.InSlice(excluded) && !session.InSlice(filtered) {
|
||||
filtered = append(filtered, session)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
// filterUsers include and exclude databases based on filters
|
||||
// filterDatabases include and exclude databases based on filters
|
||||
func (t *Terminator) filterDatabases(sessions []*base.Session) []*base.Session {
|
||||
|
||||
var included []*base.Session
|
||||
if t.config.IncludeDatabasesFilters == nil {
|
||||
included = sessions
|
||||
} else {
|
||||
for _, filter := range t.config.IncludeDatabasesFilters {
|
||||
for _, session := range sessions {
|
||||
if filter.Include(session.Db) {
|
||||
included = append(included, session)
|
||||
}
|
||||
for _, filter := range t.config.IncludeDatabasesFilters {
|
||||
for _, session := range sessions {
|
||||
if filter.Include(session.Db) && !session.InSlice(included) {
|
||||
included = append(included, session)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var filtered []*base.Session
|
||||
if t.config.ExcludeDatabasesFilters == nil {
|
||||
filtered = included
|
||||
} else {
|
||||
for _, filter := range t.config.ExcludeDatabasesFilters {
|
||||
for _, session := range included {
|
||||
if filter.Include(session.Db) {
|
||||
filtered = append(filtered, session)
|
||||
}
|
||||
var excluded []*base.Session
|
||||
for _, filter := range t.config.ExcludeDatabasesFilters {
|
||||
for _, session := range sessions {
|
||||
if !filter.Include(session.Db) && !session.InSlice(excluded) {
|
||||
excluded = append(excluded, session)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if included == nil {
|
||||
included = sessions
|
||||
}
|
||||
|
||||
var filtered []*base.Session
|
||||
for _, session := range included {
|
||||
if !session.InSlice(excluded) && !session.InSlice(filtered) {
|
||||
filtered = append(filtered, session)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
|
|
|
@ -51,17 +51,37 @@ func TestFilterUsers(t *testing.T) {
|
|||
&base.Config{IncludeUsers: []string{"test", "test_1", "test_2"}, ExcludeUsers: []string{"test"}},
|
||||
[]*base.Session{{User: "test_1"}, {User: "test_2"}},
|
||||
},
|
||||
{
|
||||
"Include users from list and regex",
|
||||
&base.Config{
|
||||
IncludeUsers: []string{"test"},
|
||||
IncludeUsersRegex: "^test_[0-9]$",
|
||||
},
|
||||
[]*base.Session{{User: "test"}, {User: "test_1"}, {User: "test_2"}},
|
||||
},
|
||||
{
|
||||
"Exclude users from list and regex",
|
||||
&base.Config{
|
||||
ExcludeUsers: []string{"test"},
|
||||
ExcludeUsersRegex: "^test_[0-9]$",
|
||||
},
|
||||
[]*base.Session{{User: "postgres"}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.config.CompileRegexes()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to compile regex: %v", err)
|
||||
}
|
||||
tc.config.CompileFilters()
|
||||
terminator := &Terminator{config: tc.config}
|
||||
got := terminator.filterUsers(sessions)
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Errorf("got %+v; want %+v", ListUsers(got), ListUsers(tc.want))
|
||||
} else {
|
||||
t.Logf("Success")
|
||||
t.Logf("got %+v; want %+v", ListUsers(got), ListUsers(tc.want))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -119,17 +139,37 @@ func TestFilterDatabases(t *testing.T) {
|
|||
&base.Config{IncludeDatabases: []string{"test", "test_1", "test_2"}, ExcludeDatabases: []string{"test"}},
|
||||
[]*base.Session{{Db: "test_1"}, {Db: "test_2"}},
|
||||
},
|
||||
{
|
||||
"Include databases from list and regex",
|
||||
&base.Config{
|
||||
IncludeDatabases: []string{"test"},
|
||||
IncludeDatabasesRegex: "^test_[0-9]$",
|
||||
},
|
||||
[]*base.Session{{Db: "test"}, {Db: "test_1"}, {Db: "test_2"}},
|
||||
},
|
||||
{
|
||||
"Exclude databases from list and regex",
|
||||
&base.Config{
|
||||
ExcludeDatabases: []string{"test"},
|
||||
ExcludeDatabasesRegex: "^test_[0-9]$",
|
||||
},
|
||||
[]*base.Session{{Db: "postgres"}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.config.CompileRegexes()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to compile regex: %v", err)
|
||||
}
|
||||
tc.config.CompileFilters()
|
||||
terminator := &Terminator{config: tc.config}
|
||||
got := terminator.filterDatabases(sessions)
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Errorf("got %+v; want %+v", ListDatabases(got), ListDatabases(tc.want))
|
||||
} else {
|
||||
t.Logf("Success")
|
||||
t.Logf("got %+v; want %+v", ListDatabases(got), ListDatabases(tc.want))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue