diff --git a/parameter_test.go b/parameter_test.go index f2bf40a..13e531b 100644 --- a/parameter_test.go +++ b/parameter_test.go @@ -37,6 +37,142 @@ func TestParameter_Inference(t *testing.T) { }) } +func TestParameter_BigInt(t *testing.T) { + t.Run("Should infer int64 as BIGINT", func(t *testing.T) { + maxInt64 := int64(9223372036854775807) + values := []driver.NamedValue{ + {Value: maxInt64}, + } + parameters, err := convertNamedValuesToSparkParams(values) + require.NoError(t, err) + require.Equal(t, "BIGINT", *parameters[0].Type) + require.Equal(t, "9223372036854775807", *parameters[0].Value.StringValue) + }) + + t.Run("Should infer uint64 as BIGINT", func(t *testing.T) { + largeUint64 := uint64(0x123456789ABCDEF0) + values := []driver.NamedValue{ + {Value: largeUint64}, + } + parameters, err := convertNamedValuesToSparkParams(values) + require.NoError(t, err) + require.Equal(t, "BIGINT", *parameters[0].Type) + require.Equal(t, "1311768467463790320", *parameters[0].Value.StringValue) + }) + + t.Run("Should infer negative int64 as BIGINT", func(t *testing.T) { + minInt64 := int64(-9223372036854775808) + values := []driver.NamedValue{ + {Value: minInt64}, + } + parameters, err := convertNamedValuesToSparkParams(values) + require.NoError(t, err) + require.Equal(t, "BIGINT", *parameters[0].Type) + require.Equal(t, "-9223372036854775808", *parameters[0].Value.StringValue) + }) + + t.Run("Should handle explicit BigInt Parameter with non-string value", func(t *testing.T) { + values := []driver.NamedValue{ + {Value: Parameter{Type: SqlBigInt, Value: int64(12345)}}, + } + parameters, err := convertNamedValuesToSparkParams(values) + require.NoError(t, err) + require.Equal(t, "BIGINT", *parameters[0].Type) + require.Equal(t, "12345", *parameters[0].Value.StringValue) + }) + + t.Run("Should preserve int32 as INTEGER", func(t *testing.T) { + values := []driver.NamedValue{ + {Value: int32(2147483647)}, + } + parameters, err := convertNamedValuesToSparkParams(values) + require.NoError(t, err) + require.Equal(t, "INTEGER", *parameters[0].Type) + require.Equal(t, "2147483647", *parameters[0].Value.StringValue) + }) +} + +func TestParameter_Float(t *testing.T) { + t.Run("Should infer float64 as DOUBLE", func(t *testing.T) { + value := float64(3.141592653589793) + values := []driver.NamedValue{ + {Value: value}, + } + parameters, err := convertNamedValuesToSparkParams(values) + require.NoError(t, err) + require.Equal(t, "DOUBLE", *parameters[0].Type) + require.Equal(t, "3.141592653589793", *parameters[0].Value.StringValue) + }) + + t.Run("Should infer float32 as FLOAT", func(t *testing.T) { + value := float32(3.14) + values := []driver.NamedValue{ + {Value: value}, + } + parameters, err := convertNamedValuesToSparkParams(values) + require.NoError(t, err) + require.Equal(t, "FLOAT", *parameters[0].Type) + require.Equal(t, "3.14", *parameters[0].Value.StringValue) + }) + + t.Run("Should handle large float64 values", func(t *testing.T) { + // Value beyond float32 range + value := float64(1e200) + values := []driver.NamedValue{ + {Value: value}, + } + parameters, err := convertNamedValuesToSparkParams(values) + require.NoError(t, err) + require.Equal(t, "DOUBLE", *parameters[0].Type) + }) + + t.Run("Should handle small float64 values", func(t *testing.T) { + // Value below float32 precision + value := float64(1e-300) + values := []driver.NamedValue{ + {Value: value}, + } + parameters, err := convertNamedValuesToSparkParams(values) + require.NoError(t, err) + require.Equal(t, "DOUBLE", *parameters[0].Type) + }) + + t.Run("Should handle explicit Double Parameter with non-string value", func(t *testing.T) { + values := []driver.NamedValue{ + {Value: Parameter{Type: SqlDouble, Value: float64(3.14159)}}, + } + parameters, err := convertNamedValuesToSparkParams(values) + require.NoError(t, err) + require.Equal(t, "DOUBLE", *parameters[0].Type) + require.Equal(t, "3.14159", *parameters[0].Value.StringValue) + }) + + t.Run("Should format large float64 consistently when using explicit type", func(t *testing.T) { + // This tests that explicit Parameter with large float64 uses decimal notation + // (strconv.FormatFloat) instead of scientific notation (fmt.Sprintf) + value := float64(1e20) + values := []driver.NamedValue{ + {Value: Parameter{Type: SqlDouble, Value: value}}, + } + parameters, err := convertNamedValuesToSparkParams(values) + require.NoError(t, err) + require.Equal(t, "DOUBLE", *parameters[0].Type) + // Should be decimal notation, not "1e+20" + require.Equal(t, "100000000000000000000", *parameters[0].Value.StringValue) + }) + + t.Run("Should format float32 consistently when using explicit type", func(t *testing.T) { + value := float32(3.14159) + values := []driver.NamedValue{ + {Value: Parameter{Type: SqlFloat, Value: value}}, + } + parameters, err := convertNamedValuesToSparkParams(values) + require.NoError(t, err) + require.Equal(t, "FLOAT", *parameters[0].Type) + require.Equal(t, "3.14159", *parameters[0].Value.StringValue) + }) +} + func TestParameters_ConvertToSpark(t *testing.T) { t.Run("Should convert names parameters", func(t *testing.T) { values := [2]driver.NamedValue{ diff --git a/parameters.go b/parameters.go index d9c8980..c24027b 100644 --- a/parameters.go +++ b/parameters.go @@ -140,17 +140,17 @@ func inferType(param *Parameter) { param.Value = strconv.FormatUint(uint64(value), 10) param.Type = SqlInteger case int64: - param.Value = strconv.Itoa(int(value)) - param.Type = SqlInteger + param.Value = strconv.FormatInt(value, 10) + param.Type = SqlBigInt case uint64: - param.Value = strconv.FormatUint(uint64(value), 10) - param.Type = SqlInteger + param.Value = strconv.FormatUint(value, 10) + param.Type = SqlBigInt case float32: param.Value = strconv.FormatFloat(float64(value), 'f', -1, 32) param.Type = SqlFloat case float64: - param.Value = strconv.FormatFloat(float64(value), 'f', -1, 64) - param.Type = SqlFloat + param.Value = strconv.FormatFloat(value, 'f', -1, 64) + param.Type = SqlDouble case time.Time: param.Value = value.Format(time.RFC3339Nano) param.Type = SqlTimestamp @@ -179,7 +179,21 @@ func convertNamedValuesToSparkParams(values []driver.NamedValue) ([]*cli_service if sqlParam.Type == SqlVoid { sparkValue = nil } else { - stringValue := sqlParam.Value.(string) + var stringValue string + switch v := sqlParam.Value.(type) { + case string: + stringValue = v + case float32: + stringValue = strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + stringValue = strconv.FormatFloat(v, 'f', -1, 64) + case int64: + stringValue = strconv.FormatInt(v, 10) + case uint64: + stringValue = strconv.FormatUint(v, 10) + default: + stringValue = fmt.Sprintf("%v", sqlParam.Value) + } sparkValue = &cli_service.TSparkParameterValue{StringValue: &stringValue} }