diff --git a/util/set/set.go b/util/set/set.go index df4b1fa3a..c3d2350a7 100644 --- a/util/set/set.go +++ b/util/set/set.go @@ -7,6 +7,8 @@ import ( "encoding/json" "maps" + "reflect" + "sort" ) // Set is a set of T. @@ -53,16 +55,53 @@ func (s *Set[T]) Make() { } } -// Slice returns the elements of the set as a slice. The elements will not be -// in any particular order. +// Slice returns the elements of the set as a slice. If the element type is +// ordered (integers, floats, or strings), the elements are returned in sorted +// order. Otherwise, the order is not defined. func (s Set[T]) Slice() []T { es := make([]T, 0, s.Len()) for k := range s { es = append(es, k) } + if f := genOrderedSwapper(reflect.TypeFor[T]()); f != nil { + sort.Slice(es, f(reflect.ValueOf(es))) + } return es } +// genOrderedSwapper returns a generator for a swap function that can be used to +// sort a slice of the given type. If rt is not an ordered type, +// genOrderedSwapper returns nil. +func genOrderedSwapper(rt reflect.Type) func(reflect.Value) func(i, j int) bool { + switch rt.Kind() { + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return func(rv reflect.Value) func(i, j int) bool { + return func(i, j int) bool { + return rv.Index(i).Uint() < rv.Index(j).Uint() + } + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return func(rv reflect.Value) func(i, j int) bool { + return func(i, j int) bool { + return rv.Index(i).Int() < rv.Index(j).Int() + } + } + case reflect.Float32, reflect.Float64: + return func(rv reflect.Value) func(i, j int) bool { + return func(i, j int) bool { + return rv.Index(i).Float() < rv.Index(j).Float() + } + } + case reflect.String: + return func(rv reflect.Value) func(i, j int) bool { + return func(i, j int) bool { + return rv.Index(i).String() < rv.Index(j).String() + } + } + } + return nil +} + // Delete removes e from the set. func (s Set[T]) Delete(e T) { delete(s, e) } diff --git a/util/set/set_test.go b/util/set/set_test.go index 4afaeea57..2188cbb4d 100644 --- a/util/set/set_test.go +++ b/util/set/set_test.go @@ -159,6 +159,39 @@ func TestSetJSONRoundTrip(t *testing.T) { } } +func checkSliceSorted[T comparable](t *testing.T, s Set[T], want []T) { + t.Helper() + got := s.Slice() + if !slices.Equal(got, want) { + t.Errorf("got %v; want %v", got, want) + } +} + +func TestSliceSorted(t *testing.T) { + t.Run("int", func(t *testing.T) { + checkSliceSorted(t, Of(3, 1, 4, 1, 5), []int{1, 3, 4, 5}) + }) + t.Run("int8", func(t *testing.T) { + checkSliceSorted(t, Of[int8](-1, 3, -100, 50), []int8{-100, -1, 3, 50}) + }) + t.Run("uint16", func(t *testing.T) { + checkSliceSorted(t, Of[uint16](300, 1, 65535, 0), []uint16{0, 1, 300, 65535}) + }) + t.Run("float64", func(t *testing.T) { + checkSliceSorted(t, Of(2.7, 1.0, 3.14), []float64{1.0, 2.7, 3.14}) + }) + t.Run("float32", func(t *testing.T) { + checkSliceSorted(t, Of[float32](2.5, 1.0, 3.0), []float32{1.0, 2.5, 3.0}) + }) + t.Run("string", func(t *testing.T) { + checkSliceSorted(t, Of("banana", "apple", "cherry"), []string{"apple", "banana", "cherry"}) + }) + t.Run("named-uint", func(t *testing.T) { + type Port uint16 + checkSliceSorted(t, Of[Port](443, 80, 8080), []Port{80, 443, 8080}) + }) +} + func TestMake(t *testing.T) { var s Set[int] s.Make()