Fix querying possibly-nil scalars

Our db code for querying scalars didn't do anything to account for
pointer destinations (which you need if querying a field that may be
nil!)
This commit is contained in:
Ben Visness 2022-06-01 20:38:24 -05:00
parent ac2d00aca7
commit 6e0010e957
2 changed files with 116 additions and 0 deletions

View File

@ -585,10 +585,28 @@ func (it *Iterator[T]) Next() (*T, bool) {
} }
} }
// Takes a value from a database query (reflected) and assigns it to the
// destination. If the destination is a pointer, and the value is non-nil, it
// will initialize the destination before assigning.
func setValueFromDB(dest reflect.Value, value reflect.Value) { func setValueFromDB(dest reflect.Value, value reflect.Value) {
if dest.Kind() == reflect.Pointer {
valueIsNilPointer := value.Kind() == reflect.Ptr && value.IsNil()
if !value.IsValid() || valueIsNilPointer {
dest.Set(reflect.Zero(dest.Type())) // nil to nil, the end
return
} else {
// initialize dest
dest.Set(reflect.New(dest.Type().Elem()))
dest = dest.Elem()
}
}
switch dest.Kind() { switch dest.Kind() {
case reflect.Int: case reflect.Int:
dest.SetInt(value.Int()) dest.SetInt(value.Int())
case reflect.String:
dest.SetString(value.String())
// TODO(ben): More kinds? All the kinds? It kind of feels like we should be able to assign to any destination whose underlying type is a primitive.
default: default:
dest.Set(value) dest.Set(value)
} }

View File

@ -4,6 +4,7 @@ import (
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"unsafe"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -169,3 +170,100 @@ func TestQueryBuilder(t *testing.T) {
}) })
}) })
} }
func TestSetValueFromDB(t *testing.T) {
t.Run("ints", func(t *testing.T) {
t.Run("int to int", func(t *testing.T) {
var dest int
var value int = 3
setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value))
assert.Equal(t, 3, dest)
})
t.Run("int32 to int", func(t *testing.T) {
var dest int
var value int32 = 3
setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value))
assert.Equal(t, 3, dest)
})
t.Run("int to *int", func(t *testing.T) {
var dest *int
var value int = 3
setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value))
assert.Equal(t, 3, *dest)
})
t.Run("int32 to *int", func(t *testing.T) {
var dest *int
var value int32 = 3
setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value))
assert.Equal(t, 3, *dest)
})
t.Run("pointer nil to *int", func(t *testing.T) {
var dest *int
var value *int = nil
setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value))
assert.Nil(t, dest)
})
t.Run("interface nil to *int", func(t *testing.T) {
var dest *int
var value interface{} = nil
setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value))
assert.Nil(t, dest)
})
})
t.Run("strings", func(t *testing.T) {
type myString string
t.Run("string to string", func(t *testing.T) {
var dest string
var value string = "handmade"
setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value))
assert.Equal(t, "handmade", dest)
})
t.Run("custom string to string", func(t *testing.T) {
var dest string
var value myString = "handmade"
setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value))
assert.Equal(t, "handmade", dest)
})
t.Run("string to *string", func(t *testing.T) {
var dest *string
var value string = "handmade"
setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value))
assert.Equal(t, "handmade", *dest)
})
t.Run("custom string to *int", func(t *testing.T) {
var dest *string
var value myString = "handmade"
setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value))
assert.Equal(t, "handmade", *dest)
})
t.Run("pointer nil to *string", func(t *testing.T) {
var dest *string
var value *string = nil
setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value))
assert.Nil(t, dest)
})
t.Run("interface nil to *string", func(t *testing.T) {
var dest *string
var value interface{} = nil
setValueFromDB(reflectPtr(&dest), reflect.ValueOf(value))
assert.Nil(t, dest)
})
})
}
func reflectPtr[T any](dest *T) reflect.Value {
return reflect.NewAt(reflect.TypeOf(*dest), unsafe.Pointer(dest)).Elem()
}