From 6e0010e957f5301aaff06f7ff7bee740a8fdfde4 Mon Sep 17 00:00:00 2001 From: Ben Visness Date: Wed, 1 Jun 2022 20:38:24 -0500 Subject: [PATCH] Fix querying possibly-nil scalars Our db code for querying scalars didn't do anything to account for pointer destinations (which you need if querying a field that may be nil!) --- src/db/db.go | 18 +++++++++ src/db/db_test.go | 98 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+) diff --git a/src/db/db.go b/src/db/db.go index 1035ba60..19791ce7 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -585,10 +585,28 @@ func (it *Iterator[T]) Next() (*T, bool) { } } +// Takes a value from a database query (reflected) and assigns it to the +// destination. If the destination is a pointer, and the value is non-nil, it +// will initialize the destination before assigning. func setValueFromDB(dest reflect.Value, value reflect.Value) { + if dest.Kind() == reflect.Pointer { + valueIsNilPointer := value.Kind() == reflect.Ptr && value.IsNil() + if !value.IsValid() || valueIsNilPointer { + dest.Set(reflect.Zero(dest.Type())) // nil to nil, the end + return + } else { + // initialize dest + dest.Set(reflect.New(dest.Type().Elem())) + dest = dest.Elem() + } + } + switch dest.Kind() { case reflect.Int: dest.SetInt(value.Int()) + case reflect.String: + dest.SetString(value.String()) + // TODO(ben): More kinds? All the kinds? It kind of feels like we should be able to assign to any destination whose underlying type is a primitive. default: dest.Set(value) } diff --git a/src/db/db_test.go b/src/db/db_test.go index edc647ed..06e4aa99 100644 --- a/src/db/db_test.go +++ b/src/db/db_test.go @@ -4,6 +4,7 @@ import ( "reflect" "strings" "testing" + "unsafe" "github.com/stretchr/testify/assert" ) @@ -169,3 +170,100 @@ func TestQueryBuilder(t *testing.T) { }) }) } + +func TestSetValueFromDB(t *testing.T) { + t.Run("ints", func(t *testing.T) { + t.Run("int to int", func(t *testing.T) { + var dest int + var value int = 3 + + setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value)) + assert.Equal(t, 3, dest) + }) + t.Run("int32 to int", func(t *testing.T) { + var dest int + var value int32 = 3 + + setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value)) + assert.Equal(t, 3, dest) + }) + t.Run("int to *int", func(t *testing.T) { + var dest *int + var value int = 3 + + setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value)) + assert.Equal(t, 3, *dest) + }) + t.Run("int32 to *int", func(t *testing.T) { + var dest *int + var value int32 = 3 + + setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value)) + assert.Equal(t, 3, *dest) + }) + t.Run("pointer nil to *int", func(t *testing.T) { + var dest *int + var value *int = nil + + setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value)) + assert.Nil(t, dest) + }) + t.Run("interface nil to *int", func(t *testing.T) { + var dest *int + var value interface{} = nil + + setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value)) + assert.Nil(t, dest) + }) + }) + + t.Run("strings", func(t *testing.T) { + type myString string + t.Run("string to string", func(t *testing.T) { + var dest string + var value string = "handmade" + + setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value)) + assert.Equal(t, "handmade", dest) + }) + t.Run("custom string to string", func(t *testing.T) { + var dest string + var value myString = "handmade" + + setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value)) + assert.Equal(t, "handmade", dest) + }) + t.Run("string to *string", func(t *testing.T) { + var dest *string + var value string = "handmade" + + setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value)) + assert.Equal(t, "handmade", *dest) + }) + t.Run("custom string to *int", func(t *testing.T) { + var dest *string + var value myString = "handmade" + + setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value)) + assert.Equal(t, "handmade", *dest) + }) + t.Run("pointer nil to *string", func(t *testing.T) { + var dest *string + var value *string = nil + + setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value)) + assert.Nil(t, dest) + }) + t.Run("interface nil to *string", func(t *testing.T) { + var dest *string + var value interface{} = nil + + setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value)) + assert.Nil(t, dest) + }) + }) +} + +func reflectPtr[T any](dest *T) reflect.Value { + return reflect.NewAt(reflect.TypeOf(*dest), unsafe.Pointer(dest)).Elem() +}