From c33a753a413e305983c833c0044c8ba623fbd25e Mon Sep 17 00:00:00 2001 From: "s.hardt" Date: Wed, 30 Oct 2019 12:17:19 +0100 Subject: [PATCH 1/3] Update mssqldb and prometheus dependency --- Makefile | 2 +- .../denisenkom/go-mssqldb/README.md | 283 ++++- .../denisenkom/go-mssqldb/appveyor.yml | 50 + .../github.com/denisenkom/go-mssqldb/buf.go | 200 ++-- .../denisenkom/go-mssqldb/bulkcopy.go | 583 +++++++++++ .../denisenkom/go-mssqldb/bulkcopy_sql.go | 93 ++ .../denisenkom/go-mssqldb/collation.go | 39 - .../denisenkom/go-mssqldb/conn_str.go | 456 +++++++++ .../denisenkom/go-mssqldb/convert.go | 306 ++++++ .../denisenkom/go-mssqldb/decimal.go | 115 --- .../github.com/denisenkom/go-mssqldb/doc.go | 14 + .../github.com/denisenkom/go-mssqldb/error.go | 34 + .../github.com/denisenkom/go-mssqldb/go.mod | 8 + .../github.com/denisenkom/go-mssqldb/go.sum | 5 + .../go-mssqldb/{ => internal/cp}/charset.go | 8 +- .../go-mssqldb/internal/cp/collation.go | 20 + .../go-mssqldb/{ => internal/cp}/cp1250.go | 2 +- .../go-mssqldb/{ => internal/cp}/cp1251.go | 2 +- .../go-mssqldb/{ => internal/cp}/cp1252.go | 2 +- .../go-mssqldb/{ => internal/cp}/cp1253.go | 2 +- .../go-mssqldb/{ => internal/cp}/cp1254.go | 2 +- .../go-mssqldb/{ => internal/cp}/cp1255.go | 2 +- .../go-mssqldb/{ => internal/cp}/cp1256.go | 2 +- .../go-mssqldb/{ => internal/cp}/cp1257.go | 2 +- .../go-mssqldb/{ => internal/cp}/cp1258.go | 2 +- .../go-mssqldb/{ => internal/cp}/cp437.go | 2 +- .../go-mssqldb/{ => internal/cp}/cp850.go | 2 +- .../go-mssqldb/{ => internal/cp}/cp874.go | 2 +- .../go-mssqldb/{ => internal/cp}/cp932.go | 2 +- .../go-mssqldb/{ => internal/cp}/cp936.go | 2 +- .../go-mssqldb/{ => internal/cp}/cp949.go | 2 +- .../go-mssqldb/{ => internal/cp}/cp950.go | 2 +- .../go-mssqldb/internal/decimal/decimal.go | 252 +++++ .../{ => internal/querytext}/parser.go | 52 +- .../github.com/denisenkom/go-mssqldb/log.go | 21 +- .../github.com/denisenkom/go-mssqldb/mssql.go | 966 +++++++++++++----- .../denisenkom/go-mssqldb/mssql_go1.3.go | 11 - .../denisenkom/go-mssqldb/mssql_go1.3pre.go | 11 - .../denisenkom/go-mssqldb/mssql_go110.go | 52 + .../denisenkom/go-mssqldb/mssql_go110pre.go | 31 + .../denisenkom/go-mssqldb/mssql_go19.go | 196 ++++ .../denisenkom/go-mssqldb/mssql_go19pre.go | 16 + .../github.com/denisenkom/go-mssqldb/net.go | 155 ++- .../github.com/denisenkom/go-mssqldb/ntlm.go | 84 +- .../github.com/denisenkom/go-mssqldb/rpc.go | 47 +- .../denisenkom/go-mssqldb/sspi_windows.go | 2 +- .../github.com/denisenkom/go-mssqldb/tds.go | 407 +++----- .../github.com/denisenkom/go-mssqldb/token.go | 466 ++++++++- .../denisenkom/go-mssqldb/token_string.go | 53 + .../github.com/denisenkom/go-mssqldb/tran.go | 27 +- .../denisenkom/go-mssqldb/tvp_go19.go | 231 +++++ .../github.com/denisenkom/go-mssqldb/types.go | 869 ++++++++++++++-- .../denisenkom/go-mssqldb/uniqueidentifier.go | 80 ++ .../golang-sql/civil/CONTRIBUTING.md | 73 ++ vendor/github.com/golang-sql/civil/LICENSE | 202 ++++ vendor/github.com/golang-sql/civil/README.md | 15 + vendor/github.com/golang-sql/civil/civil.go | 277 +++++ vendor/vendor.json | 41 +- 58 files changed, 5825 insertions(+), 1060 deletions(-) create mode 100644 vendor/github.com/denisenkom/go-mssqldb/appveyor.yml create mode 100644 vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go create mode 100644 vendor/github.com/denisenkom/go-mssqldb/bulkcopy_sql.go delete mode 100644 vendor/github.com/denisenkom/go-mssqldb/collation.go create mode 100644 vendor/github.com/denisenkom/go-mssqldb/conn_str.go create mode 100644 vendor/github.com/denisenkom/go-mssqldb/convert.go delete mode 100644 vendor/github.com/denisenkom/go-mssqldb/decimal.go create mode 100644 vendor/github.com/denisenkom/go-mssqldb/doc.go create mode 100644 vendor/github.com/denisenkom/go-mssqldb/go.mod create mode 100644 vendor/github.com/denisenkom/go-mssqldb/go.sum rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/charset.go (94%) create mode 100644 vendor/github.com/denisenkom/go-mssqldb/internal/cp/collation.go rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/cp1250.go (99%) rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/cp1251.go (99%) rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/cp1252.go (99%) rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/cp1253.go (99%) rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/cp1254.go (99%) rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/cp1255.go (99%) rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/cp1256.go (99%) rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/cp1257.go (99%) rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/cp1258.go (99%) rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/cp437.go (99%) rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/cp850.go (99%) rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/cp874.go (99%) rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/cp932.go (99%) rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/cp936.go (99%) rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/cp949.go (99%) rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/cp}/cp950.go (99%) create mode 100644 vendor/github.com/denisenkom/go-mssqldb/internal/decimal/decimal.go rename vendor/github.com/denisenkom/go-mssqldb/{ => internal/querytext}/parser.go (69%) delete mode 100644 vendor/github.com/denisenkom/go-mssqldb/mssql_go1.3.go delete mode 100644 vendor/github.com/denisenkom/go-mssqldb/mssql_go1.3pre.go create mode 100644 vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go create mode 100644 vendor/github.com/denisenkom/go-mssqldb/mssql_go110pre.go create mode 100644 vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go create mode 100644 vendor/github.com/denisenkom/go-mssqldb/mssql_go19pre.go create mode 100644 vendor/github.com/denisenkom/go-mssqldb/token_string.go create mode 100644 vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go create mode 100644 vendor/github.com/denisenkom/go-mssqldb/uniqueidentifier.go create mode 100644 vendor/github.com/golang-sql/civil/CONTRIBUTING.md create mode 100644 vendor/github.com/golang-sql/civil/LICENSE create mode 100644 vendor/github.com/golang-sql/civil/README.md create mode 100644 vendor/github.com/golang-sql/civil/civil.go diff --git a/Makefile b/Makefile index 96ed227..a11b82c 100644 --- a/Makefile +++ b/Makefile @@ -26,4 +26,4 @@ format: .PHONY: docker docker: - docker run --rm -v "$$PWD":/go/src/github.com/iamseth/azure_sql_exporter -w /go/src/github.com/iamseth/azure_sql_exporter golang:1.6 bash -c make + docker run --rm -v "$$PWD":/go/src/github.com/iamseth/azure_sql_exporter -w /go/src/github.com/iamseth/azure_sql_exporter golang:1.8 bash -c make diff --git a/vendor/github.com/denisenkom/go-mssqldb/README.md b/vendor/github.com/denisenkom/go-mssqldb/README.md index 86e2d6b..b655176 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/README.md +++ b/vendor/github.com/denisenkom/go-mssqldb/README.md @@ -1,78 +1,244 @@ # A pure Go MSSQL driver for Go's database/sql package +[![GoDoc](https://godoc.org/github.com/denisenkom/go-mssqldb?status.svg)](http://godoc.org/github.com/denisenkom/go-mssqldb) +[![Build status](https://ci.appveyor.com/api/projects/status/jrln8cs62wj9i0a2?svg=true)](https://ci.appveyor.com/project/denisenkom/go-mssqldb) +[![codecov](https://codecov.io/gh/denisenkom/go-mssqldb/branch/master/graph/badge.svg)](https://codecov.io/gh/denisenkom/go-mssqldb) + ## Install - go get github.com/denisenkom/go-mssqldb +Requires Go 1.8 or above. -## Tests +Install with `go get github.com/denisenkom/go-mssqldb` . -`go test` is used for testing. A running instance of MSSQL server is required. -Environment variables are used to pass login information. +## Connection Parameters and DSN -Example: +The recommended connection string uses a URL format: +`sqlserver://username:password@host/instance?param1=value¶m2=value` +Other supported formats are listed below. + +### Common parameters: + +* `user id` - enter the SQL Server Authentication user id or the Windows Authentication user id in the DOMAIN\User format. On Windows, if user id is empty or missing Single-Sign-On is used. +* `password` +* `database` +* `connection timeout` - in seconds (default is 0 for no timeout), set to 0 for no timeout. Recommended to set to 0 and use context to manage query and connection timeouts. +* `dial timeout` - in seconds (default is 15), set to 0 for no timeout +* `encrypt` + * `disable` - Data send between client and server is not encrypted. + * `false` - Data sent between client and server is not encrypted beyond the login packet. (Default) + * `true` - Data sent between client and server is encrypted. +* `app name` - The application name (default is go-mssqldb) + +### Connection parameters for ODBC and ADO style connection strings: + +* `server` - host or host\instance (default localhost) +* `port` - used only when there is no instance in server (default 1433) - env HOST=localhost SQLUSER=sa SQLPASSWORD=sa DATABASE=test go test - -## Connection Parameters - -* "server" - host or host\instance (default localhost) -* "port" - used only when there is no instance in server (default 1433) -* "failoverpartner" - host or host\instance (default is no partner). Used only until a successful connection has been made; thereafter, the partner provided in the first successful connection is used. -* "failoverport" - used only when there is no instance in failoverpartner (default 1433) -* "user id" - enter the SQL Server Authentication user id or the Windows Authentication user id in the DOMAIN\User format. On Windows, if user id is empty or missing Single-Sign-On is used. -* "password" -* "database" -* "connection timeout" - in seconds (default is 30) -* "dial timeout" - in seconds (default is 5) -* "keepAlive" - in seconds; 0 to disable (default is 0) -* "log" - logging flags (default 0/no logging, 63 for full logging) +### Less common parameters: + +* `keepAlive` - in seconds; 0 to disable (default is 30) +* `failoverpartner` - host or host\instance (default is no partner). +* `failoverport` - used only when there is no instance in failoverpartner (default 1433) +* `packet size` - in bytes; 512 to 32767 (default is 4096) + * Encrypted connections have a maximum packet size of 16383 bytes + * Further information on usage: https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option +* `log` - logging flags (default 0/no logging, 63 for full logging) * 1 log errors * 2 log messages * 4 log rows affected * 8 trace sql statements * 16 log statement parameters * 32 log transaction begin/end -* "encrypt" - * disable - Data send between client and server is not encrypted. - * false - Data sent between client and server is not encrypted beyond the login packet. (Default) - * true - Data sent between client and server is encrypted. -* "TrustServerCertificate" +* `TrustServerCertificate` * false - Server certificate is checked. Default is false if encypt is specified. * true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing. -* "certificate" - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. -* "hostNameInCertificate" - Specifies the Common Name (CN) in the server certificate. Default value is the server host. -* "ServerSPN" - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. -* "Workstation ID" - The workstation name (default is the host name) -* "app name" - The application name (default is go-mssqldb) -* "ApplicationIntent" - Can be given the value "ReadOnly" to initiate a read-only connection to an Availability Group listener. +* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. +* `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host. +* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. +* `Workstation ID` - The workstation name (default is the host name) +* `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`. -Example: +### The connection string can be specified in one of three formats: + + +1. URL: with `sqlserver` scheme. username and password appears before the host. Any instance appears as + the first segment in the path. All other options are query parameters. Examples: + + * `sqlserver://username:password@host/instance?param1=value¶m2=value` + * `sqlserver://username:password@host:port?param1=value¶m2=value` + * `sqlserver://sa@localhost/SQLExpress?database=master&connection+timeout=30` // `SQLExpress instance. + * `sqlserver://sa:mypass@localhost?database=master&connection+timeout=30` // username=sa, password=mypass. + * `sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30` // port 1234 on localhost. + * `sqlserver://sa:my%7Bpass@somehost?connection+timeout=30` // password is "my{pass" + + A string of this format can be constructed using the `URL` type in the `net/url` package. ```go - db, err := sql.Open("mssql", "server=localhost;user id=sa") + query := url.Values{} + query.Add("app name", "MyAppName") + + u := &url.URL{ + Scheme: "sqlserver", + User: url.UserPassword(username, password), + Host: fmt.Sprintf("%s:%d", hostname, port), + // Path: instance, // if connecting to an instance instead of a port + RawQuery: query.Encode(), + } + db, err := sql.Open("sqlserver", u.String()) ``` -## Statement Parameters +2. ADO: `key=value` pairs separated by `;`. Values may not contain `;`, leading and trailing whitespace is ignored. + Examples: + + * `server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName` + * `server=localhost;user id=sa;database=master;app name=MyAppName` -In the SQL statement text, literals may be replaced by a parameter that matches one of the following: +3. ODBC: Prefix with `odbc`, `key=value` pairs separated by `;`. Allow `;` by wrapping + values in `{}`. Examples: + + * `odbc:server=localhost\\SQLExpress;user id=sa;database=master;app name=MyAppName` + * `odbc:server=localhost;user id=sa;database=master;app name=MyAppName` + * `odbc:server=localhost;user id=sa;password={foo;bar}` // Value marked with `{}`, password is "foo;bar" + * `odbc:server=localhost;user id=sa;password={foo{bar}` // Value marked with `{}`, password is "foo{bar" + * `odbc:server=localhost;user id=sa;password={foobar }` // Value marked with `{}`, password is "foobar " + * `odbc:server=localhost;user id=sa;password=foo{bar` // Literal `{`, password is "foo{bar" + * `odbc:server=localhost;user id=sa;password=foo}bar` // Literal `}`, password is "foo}bar" + * `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar" + * `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with `}}`, password is "foo}bar" -* ? -* ?nnn -* :nnn -* $nnn +## Executing Stored Procedures + +To run a stored procedure, set the query text to the procedure name: +```go +var account = "abc" +_, err := db.ExecContext(ctx, "sp_RunMe", + sql.Named("ID", 123), + sql.Named("Account", sql.Out{Dest: &account}), +) +``` + +## Reading Output Parameters from a Stored Procedure with Resultset + +To read output parameters from a stored procedure with resultset, make sure you read all the rows before reading the output parameters: +```go +sqltextcreate := ` +CREATE PROCEDURE spwithoutputandrows + @bitparam BIT OUTPUT +AS BEGIN + SET @bitparam = 1 + SELECT 'Row 1' +END +` +var bitout int64 +rows, err := db.QueryContext(ctx, "spwithoutputandrows", sql.Named("bitparam", sql.Out{Dest: &bitout})) +var strrow string +for rows.Next() { + err = rows.Scan(&strrow) +} +fmt.Printf("bitparam is %d", bitout) +``` -where nnn represents an integer that specifies a 1-indexed positional parameter. Ex: +## Caveat for local temporary tables + +Due to protocol limitations, temporary tables will only be allocated on the connection +as a result of executing a query with zero parameters. The following query +will, due to the use of a parameter, execute in its own session, +and `#mytemp` will be de-allocated right away: ```go -db.Query("SELECT * FROM t WHERE a = ?3, b = ?2, c = ?1", "x", "y", "z") +conn, err := pool.Conn(ctx) +defer conn.Close() +_, err := conn.ExecContext(ctx, "select @p1 as x into #mytemp", 1) +// at this point #mytemp is already dropped again as the session of the ExecContext is over ``` -will expand to roughly +To work around this, always explicitly create the local temporary +table in a query without any parameters. As a special case, the driver +will then be able to execute the query directly on the +connection-scoped session. The following example works: + +```go +conn, err := pool.Conn(ctx) + +// Set us up so that temp table is always cleaned up, since conn.Close() +// merely returns conn to pool, rather than actually closing the connection. +defer func() { + _, _ = conn.ExecContext(ctx, "drop table #mytemp") // always clean up + conn.Close() // merely returns conn to pool +}() + + +// Since we not pass any parameters below, the query will execute on the scope of +// the connection and succeed in creating the table. +_, err := conn.ExecContext(ctx, "create table #mytemp ( x int )") + +// #mytemp is now available even if you pass parameters +_, err := conn.ExecContext(ctx, "insert into #mytemp (x) values (@p1)", 1) -```sql -SELECT * FROM t WHERE a = 'z', b = 'y', c = 'x' ``` +## Return Status + +To get the procedure return status, pass into the parameters a +`*mssql.ReturnStatus`. For example: +``` +var rs mssql.ReturnStatus +_, err := db.ExecContext(ctx, "theproc", &rs) +log.Printf("status=%d", rs) +``` + +or + +``` +var rs mssql.ReturnStatus +_, err := db.QueryContext(ctx, "theproc", &rs) +for rows.Next() { + err = rows.Scan(&val) +} +log.Printf("status=%d", rs) +``` + +Limitation: ReturnStatus cannot be retrieved using `QueryRow`. + +## Parameters + +The `sqlserver` driver uses normal MS SQL Server syntax and expects parameters in +the sql query to be in the form of either `@Name` or `@p1` to `@pN` (ordinal position). + +```go +db.QueryContext(ctx, `select * from t where ID = @ID and Name = @p2;`, sql.Named("ID", 6), "Bob") +``` + +### Parameter Types + +To pass specific types to the query parameters, say `varchar` or `date` types, +you must convert the types to the type before passing in. The following types +are supported: + + * string -> nvarchar + * mssql.VarChar -> varchar + * time.Time -> datetimeoffset or datetime (TDS version dependent) + * mssql.DateTime1 -> datetime + * mssql.DateTimeOffset -> datetimeoffset + * "github.com/golang-sql/civil".Date -> date + * "github.com/golang-sql/civil".DateTime -> datetime2 + * "github.com/golang-sql/civil".Time -> time + * mssql.TVP -> Table Value Parameter (TDS version dependent) + +## Important Notes + + * [LastInsertId](https://golang.org/pkg/database/sql/#Result.LastInsertId) should + not be used with this driver (or SQL Server) due to how the TDS protocol + works. Please use the [OUTPUT Clause](https://docs.microsoft.com/en-us/sql/t-sql/queries/output-clause-transact-sql) + or add a `select ID = convert(bigint, SCOPE_IDENTITY());` to the end of your + query (ref [SCOPE_IDENTITY](https://docs.microsoft.com/en-us/sql/t-sql/functions/scope-identity-transact-sql)). + This will ensure you are getting the correct ID and will prevent a network round trip. + * [NewConnector](https://godoc.org/github.com/denisenkom/go-mssqldb#NewConnector) + may be used with [OpenDB](https://golang.org/pkg/database/sql/#OpenDB). + * [Connector.SessionInitSQL](https://godoc.org/github.com/denisenkom/go-mssqldb#Connector.SessionInitSQL) + may be set to set any driver specific session settings after the session + has been reset. If empty the session will still be reset but use the database + defaults in Go1.10+. ## Features @@ -85,6 +251,35 @@ SELECT * FROM t WHERE a = 'z', b = 'y', c = 'x' * Supports SQL Server and Windows Authentication * Supports Single-Sign-On on Windows * Supports connections to AlwaysOn Availability Group listeners, including re-direction to read-only replicas. +* Supports query notifications + +## Tests + +`go test` is used for testing. A running instance of MSSQL server is required. +Environment variables are used to pass login information. + +Example: + + env SQLSERVER_DSN=sqlserver://user:pass@hostname/instance?database=test1 go test + +## Deprecated + +These features still exist in the driver, but they are are deprecated. + +### Query Parameter Token Replace (driver "mssql") + +If you use the driver name "mssql" (rather then "sqlserver") the SQL text +will be loosly parsed and an attempt to extract identifiers using one of + +* ? +* ?nnn +* :nnn +* $nnn + +will be used. This is not recommended with SQL Server. +There is at least one existing `won't fix` issue with the query parsing. + +Use the native "@Name" parameters instead with the "sqlserver" driver name. ## Known Issues diff --git a/vendor/github.com/denisenkom/go-mssqldb/appveyor.yml b/vendor/github.com/denisenkom/go-mssqldb/appveyor.yml new file mode 100644 index 0000000..c4d2bb0 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/appveyor.yml @@ -0,0 +1,50 @@ +version: 1.0.{build} + +os: Windows Server 2012 R2 + +clone_folder: c:\gopath\src\github.com\denisenkom\go-mssqldb + +environment: + GOPATH: c:\gopath + HOST: localhost + SQLUSER: sa + SQLPASSWORD: Password12! + DATABASE: test + GOVERSION: 111 + matrix: + - GOVERSION: 18 + SQLINSTANCE: SQL2016 + - GOVERSION: 19 + SQLINSTANCE: SQL2016 + - GOVERSION: 110 + SQLINSTANCE: SQL2016 + - GOVERSION: 111 + SQLINSTANCE: SQL2016 + - SQLINSTANCE: SQL2014 + - SQLINSTANCE: SQL2012SP1 + - SQLINSTANCE: SQL2008R2SP2 + +install: + - set GOROOT=c:\go%GOVERSION% + - set PATH=%GOPATH%\bin;%GOROOT%\bin;%PATH% + - go version + - go env + - go get -u github.com/golang-sql/civil + +build_script: + - go build + +before_test: + # setup SQL Server + - ps: | + $instanceName = $env:SQLINSTANCE + Start-Service "MSSQL`$$instanceName" + Start-Service "SQLBrowser" + - sqlcmd -S "(local)\%SQLINSTANCE%" -Q "Use [master]; CREATE DATABASE test;" + - sqlcmd -S "(local)\%SQLINSTANCE%" -h -1 -Q "set nocount on; Select @@version" + - pip install codecov + + +test_script: + - go test -race -cpu 4 -coverprofile=coverage.txt -covermode=atomic + - codecov -f coverage.txt diff --git a/vendor/github.com/denisenkom/go-mssqldb/buf.go b/vendor/github.com/denisenkom/go-mssqldb/buf.go index 4fb2c47..ba39b40 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/buf.go +++ b/vendor/github.com/denisenkom/go-mssqldb/buf.go @@ -2,11 +2,14 @@ package mssql import ( "encoding/binary" + "errors" "io" ) +type packetType uint8 + type header struct { - PacketType uint8 + PacketType packetType Status uint8 Size uint16 Spid uint16 @@ -14,118 +17,159 @@ type header struct { Pad uint8 } +// tdsBuffer reads and writes TDS packets of data to the transport. +// The write and read buffers are separate to make sending attn signals +// possible without locks. Currently attn signals are only sent during +// reads, not writes. type tdsBuffer struct { - buf []byte - pos uint16 - transport io.ReadWriteCloser - size uint16 + transport io.ReadWriteCloser + + packetSize int + + // Write fields. + wbuf []byte + wpos int + wPacketSeq byte + wPacketType packetType + + // Read fields. + rbuf []byte + rpos int + rsize int final bool - packet_type uint8 - afterFirst func() + rPacketType packetType + + // afterFirst is assigned to right after tdsBuffer is created and + // before the first use. It is executed after the first packet is + // written and then removed. + afterFirst func() +} + +func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer { + return &tdsBuffer{ + packetSize: int(bufsize), + wbuf: make([]byte, 1<<16), + rbuf: make([]byte, 1<<16), + rpos: 8, + transport: transport, + } } -func newTdsBuffer(bufsize int, transport io.ReadWriteCloser) *tdsBuffer { - buf := make([]byte, bufsize) - w := new(tdsBuffer) - w.buf = buf - w.pos = 8 - w.transport = transport - w.size = 0 - return w +func (rw *tdsBuffer) ResizeBuffer(packetSize int) { + rw.packetSize = packetSize +} + +func (w *tdsBuffer) PackageSize() int { + return w.packetSize } func (w *tdsBuffer) flush() (err error) { - binary.BigEndian.PutUint16(w.buf[2:], w.pos) - if _, err = w.transport.Write(w.buf[:w.pos]); err != nil { + // Write packet size. + w.wbuf[0] = byte(w.wPacketType) + binary.BigEndian.PutUint16(w.wbuf[2:], uint16(w.wpos)) + w.wbuf[6] = w.wPacketSeq + + // Write packet into underlying transport. + if _, err = w.transport.Write(w.wbuf[:w.wpos]); err != nil { return err } + // It is possible to create a whole new buffer after a flush. + // Useful for debugging. Normally reuse the buffer. + // w.wbuf = make([]byte, 1<<16) + + // Execute afterFirst hook if it is set. if w.afterFirst != nil { w.afterFirst() w.afterFirst = nil } - w.pos = 8 - w.buf[6] += 1 + + w.wpos = 8 + w.wPacketSeq++ return nil } -func (w *tdsBuffer) Write(p []byte) (nn int, err error) { - total := 0 +func (w *tdsBuffer) Write(p []byte) (total int, err error) { for { - copied := copy(w.buf[w.pos:], p) - w.pos += uint16(copied) + copied := copy(w.wbuf[w.wpos:w.packetSize], p) + w.wpos += copied total += copied if copied == len(p) { - break + return } if err = w.flush(); err != nil { - return total, err + return } p = p[copied:] } - return total, nil } func (w *tdsBuffer) WriteByte(b byte) error { - if int(w.pos) == len(w.buf) { + if int(w.wpos) == len(w.wbuf) || w.wpos == w.packetSize { if err := w.flush(); err != nil { return err } } - w.buf[w.pos] = b - w.pos += 1 + w.wbuf[w.wpos] = b + w.wpos += 1 return nil } -func (w *tdsBuffer) BeginPacket(packet_type byte) { - w.buf[0] = packet_type - w.buf[1] = 0 // packet is incomplete - w.buf[4] = 0 // spid - w.buf[5] = 0 - w.buf[6] = 1 // packet id - w.buf[7] = 0 // window - w.pos = 8 +func (w *tdsBuffer) BeginPacket(packetType packetType, resetSession bool) { + status := byte(0) + if resetSession { + switch packetType { + // Reset session can only be set on the following packet types. + case packSQLBatch, packRPCRequest, packTransMgrReq: + status = 0x8 + } + } + w.wbuf[1] = status // Packet is incomplete. This byte is set again in FinishPacket. + w.wpos = 8 + w.wPacketSeq = 1 + w.wPacketType = packetType } -func (w *tdsBuffer) FinishPacket() (err error) { - w.buf[1] = 1 // packet is complete - binary.BigEndian.PutUint16(w.buf[2:], w.pos) - _, err = w.transport.Write(w.buf[:w.pos]) - if w.afterFirst != nil { - w.afterFirst() - w.afterFirst = nil - } - return err +func (w *tdsBuffer) FinishPacket() error { + w.wbuf[1] |= 1 // Mark this as the last packet in the message. + return w.flush() } +var headerSize = binary.Size(header{}) + func (r *tdsBuffer) readNextPacket() error { - header := header{} + h := header{} var err error - err = binary.Read(r.transport, binary.BigEndian, &header) + err = binary.Read(r.transport, binary.BigEndian, &h) if err != nil { return err } - offset := uint16(binary.Size(header)) - _, err = io.ReadFull(r.transport, r.buf[offset:header.Size]) + if int(h.Size) > r.packetSize { + return errors.New("Invalid packet size, it is longer than buffer size") + } + if headerSize > int(h.Size) { + return errors.New("Invalid packet size, it is shorter than header size") + } + _, err = io.ReadFull(r.transport, r.rbuf[headerSize:h.Size]) if err != nil { return err } - r.pos = offset - r.size = header.Size - r.final = header.Status != 0 - r.packet_type = header.PacketType + r.rpos = headerSize + r.rsize = int(h.Size) + r.final = h.Status != 0 + r.rPacketType = h.PacketType return nil } -func (r *tdsBuffer) BeginRead() (uint8, error) { +func (r *tdsBuffer) BeginRead() (packetType, error) { err := r.readNextPacket() if err != nil { return 0, err } - return r.packet_type, nil + return r.rPacketType, nil } func (r *tdsBuffer) ReadByte() (res byte, err error) { - if r.pos == r.size { + if r.rpos == r.rsize { if r.final { return 0, io.EOF } @@ -134,8 +178,8 @@ func (r *tdsBuffer) ReadByte() (res byte, err error) { return 0, err } } - res = r.buf[r.pos] - r.pos++ + res = r.rbuf[r.rpos] + r.rpos++ return res, nil } @@ -177,36 +221,42 @@ func (r *tdsBuffer) uint16() uint16 { } func (r *tdsBuffer) BVarChar() string { - l := int(r.byte()) - return r.readUcs2(l) + return readBVarCharOrPanic(r) } -func (r *tdsBuffer) UsVarChar() string { - l := int(r.uint16()) - return r.readUcs2(l) +func readBVarCharOrPanic(r io.Reader) string { + s, err := readBVarChar(r) + if err != nil { + badStreamPanic(err) + } + return s } -func (r *tdsBuffer) readUcs2(numchars int) string { - b := make([]byte, numchars*2) - r.ReadFull(b) - res, err := ucs22str(b) +func readUsVarCharOrPanic(r io.Reader) string { + s, err := readUsVarChar(r) if err != nil { badStreamPanic(err) } - return res + return s } -func (r *tdsBuffer) Read(buf []byte) (n int, err error) { - if r.pos == r.size { +func (r *tdsBuffer) UsVarChar() string { + return readUsVarCharOrPanic(r) +} + +func (r *tdsBuffer) Read(buf []byte) (copied int, err error) { + copied = 0 + err = nil + if r.rpos == r.rsize { if r.final { return 0, io.EOF } err = r.readNextPacket() if err != nil { - return 0, err + return } } - copied := copy(buf, r.buf[r.pos:r.size]) - r.pos += uint16(copied) - return copied, nil + copied = copy(buf, r.rbuf[r.rpos:r.rsize]) + r.rpos += copied + return } diff --git a/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go b/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go new file mode 100644 index 0000000..1d5eacb --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go @@ -0,0 +1,583 @@ +package mssql + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "math" + "reflect" + "strings" + "time" + + "github.com/denisenkom/go-mssqldb/internal/decimal" +) + +type Bulk struct { + // ctx is used only for AddRow and Done methods. + // This could be removed if AddRow and Done accepted + // a ctx field as well, which is available with the + // database/sql call. + ctx context.Context + + cn *Conn + metadata []columnStruct + bulkColumns []columnStruct + columnsName []string + tablename string + numRows int + + headerSent bool + Options BulkOptions + Debug bool +} +type BulkOptions struct { + CheckConstraints bool + FireTriggers bool + KeepNulls bool + KilobytesPerBatch int + RowsPerBatch int + Order []string + Tablock bool +} + +type DataValue interface{} + +const ( + sqlDateFormat = "2006-01-02" + sqlTimeFormat = "2006-01-02 15:04:05.999999999Z07:00" +) + +func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) { + b := Bulk{ctx: context.Background(), cn: cn, tablename: table, headerSent: false, columnsName: columns} + b.Debug = false + return &b +} + +func (cn *Conn) CreateBulkContext(ctx context.Context, table string, columns []string) (_ *Bulk) { + b := Bulk{ctx: ctx, cn: cn, tablename: table, headerSent: false, columnsName: columns} + b.Debug = false + return &b +} + +func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) { + //get table columns info + err = b.getMetadata(ctx) + if err != nil { + return err + } + + //match the columns + for _, colname := range b.columnsName { + var bulkCol *columnStruct + + for _, m := range b.metadata { + if m.ColName == colname { + bulkCol = &m + break + } + } + if bulkCol != nil { + + if bulkCol.ti.TypeId == typeUdt { + //send udt as binary + bulkCol.ti.TypeId = typeBigVarBin + } + b.bulkColumns = append(b.bulkColumns, *bulkCol) + b.dlogf("Adding column %s %s %#x", colname, bulkCol.ColName, bulkCol.ti.TypeId) + } else { + return fmt.Errorf("Column %s does not exist in destination table %s", colname, b.tablename) + } + } + + //create the bulk command + + //columns definitions + var col_defs bytes.Buffer + for i, col := range b.bulkColumns { + if i != 0 { + col_defs.WriteString(", ") + } + col_defs.WriteString("[" + col.ColName + "] " + makeDecl(col.ti)) + } + + //options + var with_opts []string + + if b.Options.CheckConstraints { + with_opts = append(with_opts, "CHECK_CONSTRAINTS") + } + if b.Options.FireTriggers { + with_opts = append(with_opts, "FIRE_TRIGGERS") + } + if b.Options.KeepNulls { + with_opts = append(with_opts, "KEEP_NULLS") + } + if b.Options.KilobytesPerBatch > 0 { + with_opts = append(with_opts, fmt.Sprintf("KILOBYTES_PER_BATCH = %d", b.Options.KilobytesPerBatch)) + } + if b.Options.RowsPerBatch > 0 { + with_opts = append(with_opts, fmt.Sprintf("ROWS_PER_BATCH = %d", b.Options.RowsPerBatch)) + } + if len(b.Options.Order) > 0 { + with_opts = append(with_opts, fmt.Sprintf("ORDER(%s)", strings.Join(b.Options.Order, ","))) + } + if b.Options.Tablock { + with_opts = append(with_opts, "TABLOCK") + } + var with_part string + if len(with_opts) > 0 { + with_part = fmt.Sprintf("WITH (%s)", strings.Join(with_opts, ",")) + } + + query := fmt.Sprintf("INSERT BULK %s (%s) %s", b.tablename, col_defs.String(), with_part) + + stmt, err := b.cn.PrepareContext(ctx, query) + if err != nil { + return fmt.Errorf("Prepare failed: %s", err.Error()) + } + b.dlogf(query) + + _, err = stmt.(*Stmt).ExecContext(ctx, nil) + if err != nil { + return err + } + + b.headerSent = true + + var buf = b.cn.sess.buf + buf.BeginPacket(packBulkLoadBCP, false) + + // Send the columns metadata. + columnMetadata := b.createColMetadata() + _, err = buf.Write(columnMetadata) + + return +} + +// AddRow immediately writes the row to the destination table. +// The arguments are the row values in the order they were specified. +func (b *Bulk) AddRow(row []interface{}) (err error) { + if !b.headerSent { + err = b.sendBulkCommand(b.ctx) + if err != nil { + return + } + } + + if len(row) != len(b.bulkColumns) { + return fmt.Errorf("Row does not have the same number of columns than the destination table %d %d", + len(row), len(b.bulkColumns)) + } + + bytes, err := b.makeRowData(row) + if err != nil { + return + } + + _, err = b.cn.sess.buf.Write(bytes) + if err != nil { + return + } + + b.numRows = b.numRows + 1 + return +} + +func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) { + buf := new(bytes.Buffer) + buf.WriteByte(byte(tokenRow)) + + var logcol bytes.Buffer + for i, col := range b.bulkColumns { + + if b.Debug { + logcol.WriteString(fmt.Sprintf(" col[%d]='%v' ", i, row[i])) + } + param, err := b.makeParam(row[i], col) + if err != nil { + return nil, fmt.Errorf("bulkcopy: %s", err.Error()) + } + + if col.ti.Writer == nil { + return nil, fmt.Errorf("no writer for column: %s, TypeId: %#x", + col.ColName, col.ti.TypeId) + } + err = col.ti.Writer(buf, param.ti, param.buffer) + if err != nil { + return nil, fmt.Errorf("bulkcopy: %s", err.Error()) + } + } + + b.dlogf("row[%d] %s\n", b.numRows, logcol.String()) + + return buf.Bytes(), nil +} + +func (b *Bulk) Done() (rowcount int64, err error) { + if b.headerSent == false { + //no rows had been sent + return 0, nil + } + var buf = b.cn.sess.buf + buf.WriteByte(byte(tokenDone)) + + binary.Write(buf, binary.LittleEndian, uint16(doneFinal)) + binary.Write(buf, binary.LittleEndian, uint16(0)) // curcmd + + if b.cn.sess.loginAck.TDSVersion >= verTDS72 { + binary.Write(buf, binary.LittleEndian, uint64(0)) //rowcount 0 + } else { + binary.Write(buf, binary.LittleEndian, uint32(0)) //rowcount 0 + } + + buf.FinishPacket() + + tokchan := make(chan tokenStruct, 5) + go processResponse(b.ctx, b.cn.sess, tokchan, nil) + + var rowCount int64 + for token := range tokchan { + switch token := token.(type) { + case doneStruct: + if token.Status&doneCount != 0 { + rowCount = int64(token.RowCount) + } + if token.isError() { + return 0, token.getError() + } + case error: + return 0, b.cn.checkBadConn(token) + } + } + return rowCount, nil +} + +func (b *Bulk) createColMetadata() []byte { + buf := new(bytes.Buffer) + buf.WriteByte(byte(tokenColMetadata)) // token + binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count + + for i, col := range b.bulkColumns { + + if b.cn.sess.loginAck.TDSVersion >= verTDS72 { + binary.Write(buf, binary.LittleEndian, uint32(col.UserType)) // usertype, always 0? + } else { + binary.Write(buf, binary.LittleEndian, uint16(col.UserType)) + } + binary.Write(buf, binary.LittleEndian, uint16(col.Flags)) + + writeTypeInfo(buf, &b.bulkColumns[i].ti) + + if col.ti.TypeId == typeNText || + col.ti.TypeId == typeText || + col.ti.TypeId == typeImage { + + tablename_ucs2 := str2ucs2(b.tablename) + binary.Write(buf, binary.LittleEndian, uint16(len(tablename_ucs2)/2)) + buf.Write(tablename_ucs2) + } + colname_ucs2 := str2ucs2(col.ColName) + buf.WriteByte(uint8(len(colname_ucs2) / 2)) + buf.Write(colname_ucs2) + } + + return buf.Bytes() +} + +func (b *Bulk) getMetadata(ctx context.Context) (err error) { + stmt, err := b.cn.prepareContext(ctx, "SET FMTONLY ON") + if err != nil { + return + } + + _, err = stmt.ExecContext(ctx, nil) + if err != nil { + return + } + + // Get columns info. + stmt, err = b.cn.prepareContext(ctx, fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename)) + if err != nil { + return + } + rows, err := stmt.QueryContext(ctx, nil) + if err != nil { + return fmt.Errorf("get columns info failed: %v", err) + } + b.metadata = rows.(*Rows).cols + + if b.Debug { + for _, col := range b.metadata { + b.dlogf("col: %s typeId: %#x size: %d scale: %d prec: %d flags: %d lcid: %#x\n", + col.ColName, col.ti.TypeId, col.ti.Size, col.ti.Scale, col.ti.Prec, + col.Flags, col.ti.Collation.LcidAndFlags) + } + } + + return rows.Close() +} + +func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) { + res.ti.Size = col.ti.Size + res.ti.TypeId = col.ti.TypeId + + if val == nil { + res.ti.Size = 0 + return + } + + switch col.ti.TypeId { + + case typeInt1, typeInt2, typeInt4, typeInt8, typeIntN: + var intvalue int64 + + switch val := val.(type) { + case int: + intvalue = int64(val) + case int32: + intvalue = int64(val) + case int64: + intvalue = val + default: + err = fmt.Errorf("mssql: invalid type for int column: %T", val) + return + } + + res.buffer = make([]byte, res.ti.Size) + if col.ti.Size == 1 { + res.buffer[0] = byte(intvalue) + } else if col.ti.Size == 2 { + binary.LittleEndian.PutUint16(res.buffer, uint16(intvalue)) + } else if col.ti.Size == 4 { + binary.LittleEndian.PutUint32(res.buffer, uint32(intvalue)) + } else if col.ti.Size == 8 { + binary.LittleEndian.PutUint64(res.buffer, uint64(intvalue)) + } + case typeFlt4, typeFlt8, typeFltN: + var floatvalue float64 + + switch val := val.(type) { + case float32: + floatvalue = float64(val) + case float64: + floatvalue = val + case int: + floatvalue = float64(val) + case int64: + floatvalue = float64(val) + default: + err = fmt.Errorf("mssql: invalid type for float column: %T %s", val, val) + return + } + + if col.ti.Size == 4 { + res.buffer = make([]byte, 4) + binary.LittleEndian.PutUint32(res.buffer, math.Float32bits(float32(floatvalue))) + } else if col.ti.Size == 8 { + res.buffer = make([]byte, 8) + binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(floatvalue)) + } + case typeNVarChar, typeNText, typeNChar: + + switch val := val.(type) { + case string: + res.buffer = str2ucs2(val) + case []byte: + res.buffer = val + default: + err = fmt.Errorf("mssql: invalid type for nvarchar column: %T %s", val, val) + return + } + res.ti.Size = len(res.buffer) + + case typeVarChar, typeBigVarChar, typeText, typeChar, typeBigChar: + switch val := val.(type) { + case string: + res.buffer = []byte(val) + case []byte: + res.buffer = val + default: + err = fmt.Errorf("mssql: invalid type for varchar column: %T %s", val, val) + return + } + res.ti.Size = len(res.buffer) + + case typeBit, typeBitN: + if reflect.TypeOf(val).Kind() != reflect.Bool { + err = fmt.Errorf("mssql: invalid type for bit column: %T %s", val, val) + return + } + res.ti.TypeId = typeBitN + res.ti.Size = 1 + res.buffer = make([]byte, 1) + if val.(bool) { + res.buffer[0] = 1 + } + case typeDateTime2N: + switch val := val.(type) { + case time.Time: + res.buffer = encodeDateTime2(val, int(col.ti.Scale)) + res.ti.Size = len(res.buffer) + case string: + var t time.Time + if t, err = time.Parse(sqlTimeFormat, val); err != nil { + return res, fmt.Errorf("bulk: unable to convert string to date: %v", err) + } + res.buffer = encodeDateTime2(t, int(col.ti.Scale)) + res.ti.Size = len(res.buffer) + default: + err = fmt.Errorf("mssql: invalid type for datetime2 column: %T %s", val, val) + return + } + case typeDateTimeOffsetN: + switch val := val.(type) { + case time.Time: + res.buffer = encodeDateTimeOffset(val, int(col.ti.Scale)) + res.ti.Size = len(res.buffer) + case string: + var t time.Time + if t, err = time.Parse(sqlTimeFormat, val); err != nil { + return res, fmt.Errorf("bulk: unable to convert string to date: %v", err) + } + res.buffer = encodeDateTimeOffset(t, int(col.ti.Scale)) + res.ti.Size = len(res.buffer) + default: + err = fmt.Errorf("mssql: invalid type for datetimeoffset column: %T %s", val, val) + return + } + case typeDateN: + switch val := val.(type) { + case time.Time: + res.buffer = encodeDate(val) + res.ti.Size = len(res.buffer) + case string: + var t time.Time + if t, err = time.ParseInLocation(sqlDateFormat, val, time.UTC); err != nil { + return res, fmt.Errorf("bulk: unable to convert string to date: %v", err) + } + res.buffer = encodeDate(t) + res.ti.Size = len(res.buffer) + default: + err = fmt.Errorf("mssql: invalid type for date column: %T %s", val, val) + return + } + case typeDateTime, typeDateTimeN, typeDateTim4: + var t time.Time + switch val := val.(type) { + case time.Time: + t = val + case string: + if t, err = time.Parse(sqlTimeFormat, val); err != nil { + return res, fmt.Errorf("bulk: unable to convert string to date: %v", err) + } + default: + err = fmt.Errorf("mssql: invalid type for datetime column: %T %s", val, val) + return + } + + if col.ti.Size == 4 { + res.buffer = encodeDateTim4(t) + res.ti.Size = len(res.buffer) + } else if col.ti.Size == 8 { + res.buffer = encodeDateTime(t) + res.ti.Size = len(res.buffer) + } else { + err = fmt.Errorf("mssql: invalid size of column %d", col.ti.Size) + } + + // case typeMoney, typeMoney4, typeMoneyN: + case typeDecimal, typeDecimalN, typeNumeric, typeNumericN: + prec := col.ti.Prec + scale := col.ti.Scale + var dec decimal.Decimal + switch v := val.(type) { + case int: + dec = decimal.Int64ToDecimalScale(int64(v), 0) + case int8: + dec = decimal.Int64ToDecimalScale(int64(v), 0) + case int16: + dec = decimal.Int64ToDecimalScale(int64(v), 0) + case int32: + dec = decimal.Int64ToDecimalScale(int64(v), 0) + case int64: + dec = decimal.Int64ToDecimalScale(int64(v), 0) + case float32: + dec, err = decimal.Float64ToDecimalScale(float64(v), scale) + case float64: + dec, err = decimal.Float64ToDecimalScale(float64(v), scale) + case string: + dec, err = decimal.StringToDecimalScale(v, scale) + default: + return res, fmt.Errorf("unknown value for decimal: %T %#v", v, v) + } + + if err != nil { + return res, err + } + dec.SetPrec(prec) + + var length byte + switch { + case prec <= 9: + length = 4 + case prec <= 19: + length = 8 + case prec <= 28: + length = 12 + default: + length = 16 + } + + buf := make([]byte, length+1) + // first byte length written by typeInfo.writer + res.ti.Size = int(length) + 1 + // second byte sign + if !dec.IsPositive() { + buf[0] = 0 + } else { + buf[0] = 1 + } + + ub := dec.UnscaledBytes() + l := len(ub) + if l > int(length) { + err = fmt.Errorf("decimal out of range: %s", dec) + return res, err + } + // reverse the bytes + for i, j := 1, l-1; j >= 0; i, j = i+1, j-1 { + buf[i] = ub[j] + } + res.buffer = buf + case typeBigVarBin, typeBigBinary: + switch val := val.(type) { + case []byte: + res.ti.Size = len(val) + res.buffer = val + default: + err = fmt.Errorf("mssql: invalid type for Binary column: %T %s", val, val) + return + } + case typeGuid: + switch val := val.(type) { + case []byte: + res.ti.Size = len(val) + res.buffer = val + default: + err = fmt.Errorf("mssql: invalid type for Guid column: %T %s", val, val) + return + } + + default: + err = fmt.Errorf("mssql: type %x not implemented", col.ti.TypeId) + } + return + +} + +func (b *Bulk) dlogf(format string, v ...interface{}) { + if b.Debug { + b.cn.sess.log.Printf(format, v...) + } +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/bulkcopy_sql.go b/vendor/github.com/denisenkom/go-mssqldb/bulkcopy_sql.go new file mode 100644 index 0000000..709505b --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/bulkcopy_sql.go @@ -0,0 +1,93 @@ +package mssql + +import ( + "context" + "database/sql/driver" + "encoding/json" + "errors" +) + +type copyin struct { + cn *Conn + bulkcopy *Bulk + closed bool +} + +type serializableBulkConfig struct { + TableName string + ColumnsName []string + Options BulkOptions +} + +func (d *Driver) OpenConnection(dsn string) (*Conn, error) { + return d.open(context.Background(), dsn) +} + +func (c *Conn) prepareCopyIn(ctx context.Context, query string) (_ driver.Stmt, err error) { + config_json := query[11:] + + bulkconfig := serializableBulkConfig{} + err = json.Unmarshal([]byte(config_json), &bulkconfig) + if err != nil { + return + } + + bulkcopy := c.CreateBulkContext(ctx, bulkconfig.TableName, bulkconfig.ColumnsName) + bulkcopy.Options = bulkconfig.Options + + ci := ©in{ + cn: c, + bulkcopy: bulkcopy, + } + + return ci, nil +} + +func CopyIn(table string, options BulkOptions, columns ...string) string { + bulkconfig := &serializableBulkConfig{TableName: table, Options: options, ColumnsName: columns} + + config_json, err := json.Marshal(bulkconfig) + if err != nil { + panic(err) + } + + stmt := "INSERTBULK " + string(config_json) + + return stmt +} + +func (ci *copyin) NumInput() int { + return -1 +} + +func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) { + panic("should never be called") +} + +func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { + if ci.closed { + return nil, errors.New("copyin query is closed") + } + + if len(v) == 0 { + rowCount, err := ci.bulkcopy.Done() + ci.closed = true + return driver.RowsAffected(rowCount), err + } + + t := make([]interface{}, len(v)) + for i, val := range v { + t[i] = val + } + + err = ci.bulkcopy.AddRow(t) + if err != nil { + return + } + + return driver.RowsAffected(0), nil +} + +func (ci *copyin) Close() (err error) { + return nil +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/collation.go b/vendor/github.com/denisenkom/go-mssqldb/collation.go deleted file mode 100644 index ac9cf20..0000000 --- a/vendor/github.com/denisenkom/go-mssqldb/collation.go +++ /dev/null @@ -1,39 +0,0 @@ -package mssql - -import ( - "encoding/binary" - "io" -) - -// http://msdn.microsoft.com/en-us/library/dd340437.aspx - -type collation struct { - lcidAndFlags uint32 - sortId uint8 -} - -func (c collation) getLcid() uint32 { - return c.lcidAndFlags & 0x000fffff -} - -func (c collation) getFlags() uint32 { - return (c.lcidAndFlags & 0x0ff00000) >> 20 -} - -func (c collation) getVersion() uint32 { - return (c.lcidAndFlags & 0xf0000000) >> 28 -} - -func readCollation(r *tdsBuffer) (res collation) { - res.lcidAndFlags = r.uint32() - res.sortId = r.byte() - return -} - -func writeCollation(w io.Writer, col collation) (err error) { - if err = binary.Write(w, binary.LittleEndian, col.lcidAndFlags); err != nil { - return - } - err = binary.Write(w, binary.LittleEndian, col.sortId) - return -} diff --git a/vendor/github.com/denisenkom/go-mssqldb/conn_str.go b/vendor/github.com/denisenkom/go-mssqldb/conn_str.go new file mode 100644 index 0000000..0742755 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/conn_str.go @@ -0,0 +1,456 @@ +package mssql + +import ( + "fmt" + "net" + "net/url" + "os" + "strconv" + "strings" + "time" + "unicode" +) + +type connectParams struct { + logFlags uint64 + port uint64 + host string + instance string + database string + user string + password string + dial_timeout time.Duration + conn_timeout time.Duration + keepAlive time.Duration + encrypt bool + disableEncryption bool + trustServerCertificate bool + certificate string + hostInCertificate string + hostInCertificateProvided bool + serverSPN string + workstation string + appname string + typeFlags uint8 + failOverPartner string + failOverPort uint64 + packetSize uint16 +} + +func parseConnectParams(dsn string) (connectParams, error) { + var p connectParams + + var params map[string]string + if strings.HasPrefix(dsn, "odbc:") { + parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):]) + if err != nil { + return p, err + } + params = parameters + } else if strings.HasPrefix(dsn, "sqlserver://") { + parameters, err := splitConnectionStringURL(dsn) + if err != nil { + return p, err + } + params = parameters + } else { + params = splitConnectionString(dsn) + } + + strlog, ok := params["log"] + if ok { + var err error + p.logFlags, err = strconv.ParseUint(strlog, 10, 64) + if err != nil { + return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error()) + } + } + server := params["server"] + parts := strings.SplitN(server, `\`, 2) + p.host = parts[0] + if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" { + p.host = "localhost" + } + if len(parts) > 1 { + p.instance = parts[1] + } + p.database = params["database"] + p.user = params["user id"] + p.password = params["password"] + + p.port = 1433 + strport, ok := params["port"] + if ok { + var err error + p.port, err = strconv.ParseUint(strport, 10, 16) + if err != nil { + f := "Invalid tcp port '%v': %v" + return p, fmt.Errorf(f, strport, err.Error()) + } + } + + // https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option + // Default packet size remains at 4096 bytes + p.packetSize = 4096 + strpsize, ok := params["packet size"] + if ok { + var err error + psize, err := strconv.ParseUint(strpsize, 0, 16) + if err != nil { + f := "Invalid packet size '%v': %v" + return p, fmt.Errorf(f, strpsize, err.Error()) + } + + // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes + // NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request + // a higher packet size, the server will respond with an ENVCHANGE request to + // alter the packet size to 16383 bytes. + p.packetSize = uint16(psize) + if p.packetSize < 512 { + p.packetSize = 512 + } else if p.packetSize > 32767 { + p.packetSize = 32767 + } + } + + // https://msdn.microsoft.com/en-us/library/dd341108.aspx + // + // Do not set a connection timeout. Use Context to manage such things. + // Default to zero, but still allow it to be set. + if strconntimeout, ok := params["connection timeout"]; ok { + timeout, err := strconv.ParseUint(strconntimeout, 10, 64) + if err != nil { + f := "Invalid connection timeout '%v': %v" + return p, fmt.Errorf(f, strconntimeout, err.Error()) + } + p.conn_timeout = time.Duration(timeout) * time.Second + } + p.dial_timeout = 15 * time.Second + if strdialtimeout, ok := params["dial timeout"]; ok { + timeout, err := strconv.ParseUint(strdialtimeout, 10, 64) + if err != nil { + f := "Invalid dial timeout '%v': %v" + return p, fmt.Errorf(f, strdialtimeout, err.Error()) + } + p.dial_timeout = time.Duration(timeout) * time.Second + } + + // default keep alive should be 30 seconds according to spec: + // https://msdn.microsoft.com/en-us/library/dd341108.aspx + p.keepAlive = 30 * time.Second + if keepAlive, ok := params["keepalive"]; ok { + timeout, err := strconv.ParseUint(keepAlive, 10, 64) + if err != nil { + f := "Invalid keepAlive value '%s': %s" + return p, fmt.Errorf(f, keepAlive, err.Error()) + } + p.keepAlive = time.Duration(timeout) * time.Second + } + encrypt, ok := params["encrypt"] + if ok { + if strings.EqualFold(encrypt, "DISABLE") { + p.disableEncryption = true + } else { + var err error + p.encrypt, err = strconv.ParseBool(encrypt) + if err != nil { + f := "Invalid encrypt '%s': %s" + return p, fmt.Errorf(f, encrypt, err.Error()) + } + } + } else { + p.trustServerCertificate = true + } + trust, ok := params["trustservercertificate"] + if ok { + var err error + p.trustServerCertificate, err = strconv.ParseBool(trust) + if err != nil { + f := "Invalid trust server certificate '%s': %s" + return p, fmt.Errorf(f, trust, err.Error()) + } + } + p.certificate = params["certificate"] + p.hostInCertificate, ok = params["hostnameincertificate"] + if ok { + p.hostInCertificateProvided = true + } else { + p.hostInCertificate = p.host + p.hostInCertificateProvided = false + } + + serverSPN, ok := params["serverspn"] + if ok { + p.serverSPN = serverSPN + } else { + p.serverSPN = fmt.Sprintf("MSSQLSvc/%s:%d", p.host, p.port) + } + + workstation, ok := params["workstation id"] + if ok { + p.workstation = workstation + } else { + workstation, err := os.Hostname() + if err == nil { + p.workstation = workstation + } + } + + appname, ok := params["app name"] + if !ok { + appname = "go-mssqldb" + } + p.appname = appname + + appintent, ok := params["applicationintent"] + if ok { + if appintent == "ReadOnly" { + if p.database == "" { + return p, fmt.Errorf("Database must be specified when ApplicationIntent is ReadOnly") + } + p.typeFlags |= fReadOnlyIntent + } + } + + failOverPartner, ok := params["failoverpartner"] + if ok { + p.failOverPartner = failOverPartner + } + + failOverPort, ok := params["failoverport"] + if ok { + var err error + p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16) + if err != nil { + f := "Invalid tcp port '%v': %v" + return p, fmt.Errorf(f, failOverPort, err.Error()) + } + } + + return p, nil +} + +func splitConnectionString(dsn string) (res map[string]string) { + res = map[string]string{} + parts := strings.Split(dsn, ";") + for _, part := range parts { + if len(part) == 0 { + continue + } + lst := strings.SplitN(part, "=", 2) + name := strings.TrimSpace(strings.ToLower(lst[0])) + if len(name) == 0 { + continue + } + var value string = "" + if len(lst) > 1 { + value = strings.TrimSpace(lst[1]) + } + res[name] = value + } + return res +} + +// Splits a URL of the form sqlserver://username:password@host/instance?param1=value¶m2=value +func splitConnectionStringURL(dsn string) (map[string]string, error) { + res := map[string]string{} + + u, err := url.Parse(dsn) + if err != nil { + return res, err + } + + if u.Scheme != "sqlserver" { + return res, fmt.Errorf("scheme %s is not recognized", u.Scheme) + } + + if u.User != nil { + res["user id"] = u.User.Username() + p, exists := u.User.Password() + if exists { + res["password"] = p + } + } + + host, port, err := net.SplitHostPort(u.Host) + if err != nil { + host = u.Host + } + + if len(u.Path) > 0 { + res["server"] = host + "\\" + u.Path[1:] + } else { + res["server"] = host + } + + if len(port) > 0 { + res["port"] = port + } + + query := u.Query() + for k, v := range query { + if len(v) > 1 { + return res, fmt.Errorf("key %s provided more than once", k) + } + res[strings.ToLower(k)] = v[0] + } + + return res, nil +} + +// Splits a URL in the ODBC format +func splitConnectionStringOdbc(dsn string) (map[string]string, error) { + res := map[string]string{} + + type parserState int + const ( + // Before the start of a key + parserStateBeforeKey parserState = iota + + // Inside a key + parserStateKey + + // Beginning of a value. May be bare or braced + parserStateBeginValue + + // Inside a bare value + parserStateBareValue + + // Inside a braced value + parserStateBracedValue + + // A closing brace inside a braced value. + // May be the end of the value or an escaped closing brace, depending on the next character + parserStateBracedValueClosingBrace + + // After a value. Next character should be a semicolon or whitespace. + parserStateEndValue + ) + + var state = parserStateBeforeKey + + var key string + var value string + + for i, c := range dsn { + switch state { + case parserStateBeforeKey: + switch { + case c == '=': + return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i) + case !unicode.IsSpace(c) && c != ';': + state = parserStateKey + key += string(c) + } + + case parserStateKey: + switch c { + case '=': + key = normalizeOdbcKey(key) + state = parserStateBeginValue + + case ';': + // Key without value + key = normalizeOdbcKey(key) + res[key] = value + key = "" + value = "" + state = parserStateBeforeKey + + default: + key += string(c) + } + + case parserStateBeginValue: + switch { + case c == '{': + state = parserStateBracedValue + case c == ';': + // Empty value + res[key] = value + key = "" + state = parserStateBeforeKey + case unicode.IsSpace(c): + // Ignore whitespace + default: + state = parserStateBareValue + value += string(c) + } + + case parserStateBareValue: + if c == ';' { + res[key] = strings.TrimRightFunc(value, unicode.IsSpace) + key = "" + value = "" + state = parserStateBeforeKey + } else { + value += string(c) + } + + case parserStateBracedValue: + if c == '}' { + state = parserStateBracedValueClosingBrace + } else { + value += string(c) + } + + case parserStateBracedValueClosingBrace: + if c == '}' { + // Escaped closing brace + value += string(c) + state = parserStateBracedValue + continue + } + + // End of braced value + res[key] = value + key = "" + value = "" + + // This character is the first character past the end, + // so it needs to be parsed like the parserStateEndValue state. + state = parserStateEndValue + switch { + case c == ';': + state = parserStateBeforeKey + case unicode.IsSpace(c): + // Ignore whitespace + default: + return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i) + } + + case parserStateEndValue: + switch { + case c == ';': + state = parserStateBeforeKey + case unicode.IsSpace(c): + // Ignore whitespace + default: + return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i) + } + } + } + + switch state { + case parserStateBeforeKey: // Okay + case parserStateKey: // Unfinished key. Treat as key without value. + key = normalizeOdbcKey(key) + res[key] = value + case parserStateBeginValue: // Empty value + res[key] = value + case parserStateBareValue: + res[key] = strings.TrimRightFunc(value, unicode.IsSpace) + case parserStateBracedValue: + return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn)) + case parserStateBracedValueClosingBrace: // End of braced value + res[key] = value + case parserStateEndValue: // Okay + } + + return res, nil +} + +// Normalizes the given string as an ODBC-format key +func normalizeOdbcKey(s string) string { + return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace)) +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/convert.go b/vendor/github.com/denisenkom/go-mssqldb/convert.go new file mode 100644 index 0000000..51bd4ee --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/convert.go @@ -0,0 +1,306 @@ +package mssql + +import "errors" + +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Type conversions for Scan. + +// This file was imported from database.sql.convert for go 1.10.3 with minor modifications to get +// convertAssign function +// This function is used internally by sql to convert values during call to Scan, we need same +// logic to return values for OUTPUT parameters. +// TODO: sql library should instead expose function defaultCheckNamedValue to be callable by drivers + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "strconv" + "time" +) + +var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error + +// convertAssign copies to dest the value in src, converting it if possible. +// An error is returned if the copy would result in loss of information. +// dest should be a pointer type. +func convertAssign(dest, src interface{}) error { + // Common cases, without reflect. + switch s := src.(type) { + case string: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = s + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s) + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = append((*d)[:0], s...) + return nil + } + case []byte: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = string(s) + return nil + case *interface{}: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = s + return nil + } + case time.Time: + switch d := dest.(type) { + case *time.Time: + *d = s + return nil + case *string: + *d = s.Format(time.RFC3339Nano) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s.Format(time.RFC3339Nano)) + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = s.AppendFormat((*d)[:0], time.RFC3339Nano) + return nil + } + case nil: + switch d := dest.(type) { + case *interface{}: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = nil + return nil + } + } + + var sv reflect.Value + + switch d := dest.(type) { + case *string: + sv = reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + *d = asString(src) + return nil + } + case *[]byte: + sv = reflect.ValueOf(src) + if b, ok := asBytes(nil, sv); ok { + *d = b + return nil + } + case *sql.RawBytes: + sv = reflect.ValueOf(src) + if b, ok := asBytes([]byte(*d)[:0], sv); ok { + *d = sql.RawBytes(b) + return nil + } + case *bool: + bv, err := driver.Bool.ConvertValue(src) + if err == nil { + *d = bv.(bool) + } + return err + case *interface{}: + *d = src + return nil + } + + if scanner, ok := dest.(sql.Scanner); ok { + return scanner.Scan(src) + } + + dpv := reflect.ValueOf(dest) + if dpv.Kind() != reflect.Ptr { + return errors.New("destination not a pointer") + } + if dpv.IsNil() { + return errNilPtr + } + + if !sv.IsValid() { + sv = reflect.ValueOf(src) + } + + dv := reflect.Indirect(dpv) + if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { + switch b := src.(type) { + case []byte: + dv.Set(reflect.ValueOf(cloneBytes(b))) + default: + dv.Set(sv) + } + return nil + } + + if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { + dv.Set(sv.Convert(dv.Type())) + return nil + } + + // The following conversions use a string value as an intermediate representation + // to convert between various numeric types. + // + // This also allows scanning into user defined types such as "type Int int64". + // For symmetry, also check for string destination types. + switch dv.Kind() { + case reflect.Ptr: + if src == nil { + dv.Set(reflect.Zero(dv.Type())) + return nil + } else { + dv.Set(reflect.New(dv.Type().Elem())) + return convertAssign(dv.Interface(), src) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + s := asString(src) + i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + s := asString(src) + u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + s := asString(src) + f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + case reflect.String: + switch v := src.(type) { + case string: + dv.SetString(v) + return nil + case []byte: + dv.SetString(string(v)) + return nil + } + } + + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) +} + +func strconvErr(err error) error { + if ne, ok := err.(*strconv.NumError); ok { + return ne.Err + } + return err +} + +func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } else { + c := make([]byte, len(b)) + copy(c, b) + return c + } +} + +func asString(src interface{}) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} + +func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.AppendInt(buf, rv.Int(), 10), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.AppendUint(buf, rv.Uint(), 10), true + case reflect.Float32: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true + case reflect.Float64: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true + case reflect.Bool: + return strconv.AppendBool(buf, rv.Bool()), true + case reflect.String: + s := rv.String() + return append(buf, s...), true + } + return +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/decimal.go b/vendor/github.com/denisenkom/go-mssqldb/decimal.go deleted file mode 100644 index 76f3a6b..0000000 --- a/vendor/github.com/denisenkom/go-mssqldb/decimal.go +++ /dev/null @@ -1,115 +0,0 @@ -package mssql - -import ( - "encoding/binary" - "errors" - "math" - "math/big" -) - -// http://msdn.microsoft.com/en-us/library/ee780893.aspx -type Decimal struct { - integer [4]uint32 - positive bool - prec uint8 - scale uint8 -} - -var scaletblflt64 [39]float64 - -func (d Decimal) ToFloat64() float64 { - val := float64(0) - for i := 3; i >= 0; i-- { - val *= 0x100000000 - val += float64(d.integer[i]) - } - if !d.positive { - val = -val - } - if d.scale != 0 { - val /= scaletblflt64[d.scale] - } - return val -} - -func Float64ToDecimal(f float64) (Decimal, error) { - var dec Decimal - if math.IsNaN(f) { - return dec, errors.New("NaN") - } - if math.IsInf(f, 0) { - return dec, errors.New("Infinity can't be converted to decimal") - } - dec.positive = f >= 0 - if !dec.positive { - f = math.Abs(f) - } - if f > 3.402823669209385e+38 { - return dec, errors.New("Float value is out of range") - } - dec.prec = 20 - var integer float64 - for dec.scale = 0; dec.scale <= 20; dec.scale++ { - integer = f * scaletblflt64[dec.scale] - _, frac := math.Modf(integer) - if frac == 0 { - break - } - } - for i := 0; i < 4; i++ { - mod := math.Mod(integer, 0x100000000) - integer -= mod - integer /= 0x100000000 - dec.integer[i] = uint32(mod) - } - return dec, nil -} - -func init() { - var acc float64 = 1 - for i := 0; i <= 38; i++ { - scaletblflt64[i] = acc - acc *= 10 - } -} - -func (d Decimal) Bytes() []byte { - bytes := make([]byte, 16) - binary.BigEndian.PutUint32(bytes[0:4], d.integer[3]) - binary.BigEndian.PutUint32(bytes[4:8], d.integer[2]) - binary.BigEndian.PutUint32(bytes[8:12], d.integer[1]) - binary.BigEndian.PutUint32(bytes[12:16], d.integer[0]) - var x big.Int - x.SetBytes(bytes) - if !d.positive { - x.Neg(&x) - } - return scaleBytes(x.String(), d.scale) -} - -func scaleBytes(s string, scale uint8) []byte { - z := make([]byte, 0, len(s)+1) - if s[0] == '-' || s[0] == '+' { - z = append(z, byte(s[0])) - s = s[1:] - } - pos := len(s) - int(scale) - if pos <= 0 { - z = append(z, byte('0')) - } else if pos > 0 { - z = append(z, []byte(s[:pos])...) - } - if scale > 0 { - z = append(z, byte('.')) - for pos < 0 { - z = append(z, byte('0')) - pos++ - } - z = append(z, []byte(s[pos:])...) - } - return z -} - -func (d Decimal) String() string { - return string(d.Bytes()) -} diff --git a/vendor/github.com/denisenkom/go-mssqldb/doc.go b/vendor/github.com/denisenkom/go-mssqldb/doc.go new file mode 100644 index 0000000..2e54929 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/doc.go @@ -0,0 +1,14 @@ +// package mssql implements the TDS protocol used to connect to MS SQL Server (sqlserver) +// database servers. +// +// This package registers the driver: +// sqlserver: uses native "@" parameter placeholder names and does no pre-processing. +// +// If the ordinal position is used for query parameters, identifiers will be named +// "@p1", "@p2", ... "@pN". +// +// Please refer to the README for the format of the DSN. There are multiple DSN +// formats accepted: ADO style, ODBC style, and URL style. The following is an +// example of a URL style DSN: +// sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30 +package mssql diff --git a/vendor/github.com/denisenkom/go-mssqldb/error.go b/vendor/github.com/denisenkom/go-mssqldb/error.go index 20a0bb9..2e5bace 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/error.go +++ b/vendor/github.com/denisenkom/go-mssqldb/error.go @@ -4,6 +4,11 @@ import ( "fmt" ) +// Error represents an SQL Server error. This +// type includes methods for reading the contents +// of the struct, which allows calling programs +// to check for specific error conditions without +// having to import this package directly. type Error struct { Number int32 State uint8 @@ -18,6 +23,35 @@ func (e Error) Error() string { return "mssql: " + e.Message } +// SQLErrorNumber returns the SQL Server error number. +func (e Error) SQLErrorNumber() int32 { + return e.Number +} + +func (e Error) SQLErrorState() uint8 { + return e.State +} + +func (e Error) SQLErrorClass() uint8 { + return e.Class +} + +func (e Error) SQLErrorMessage() string { + return e.Message +} + +func (e Error) SQLErrorServerName() string { + return e.ServerName +} + +func (e Error) SQLErrorProcName() string { + return e.ProcName +} + +func (e Error) SQLErrorLineNo() int32 { + return e.LineNo +} + type StreamError struct { Message string } diff --git a/vendor/github.com/denisenkom/go-mssqldb/go.mod b/vendor/github.com/denisenkom/go-mssqldb/go.mod new file mode 100644 index 0000000..ebc02ab --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/go.mod @@ -0,0 +1,8 @@ +module github.com/denisenkom/go-mssqldb + +go 1.11 + +require ( + github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe + golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c +) diff --git a/vendor/github.com/denisenkom/go-mssqldb/go.sum b/vendor/github.com/denisenkom/go-mssqldb/go.sum new file mode 100644 index 0000000..1887801 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/go.sum @@ -0,0 +1,5 @@ +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= +golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/vendor/github.com/denisenkom/go-mssqldb/charset.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/charset.go similarity index 94% rename from vendor/github.com/denisenkom/go-mssqldb/charset.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/charset.go index f1cc247..8dc2279 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/charset.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/charset.go @@ -1,14 +1,14 @@ -package mssql +package cp type charsetMap struct { sb [256]rune // single byte runes, -1 for a double byte character lead byte db map[int]rune // double byte runes } -func collation2charset(col collation) *charsetMap { +func collation2charset(col Collation) *charsetMap { // http://msdn.microsoft.com/en-us/library/ms144250.aspx // http://msdn.microsoft.com/en-us/library/ms144250(v=sql.105).aspx - switch col.sortId { + switch col.SortId { case 30, 31, 32, 33, 34: return cp437 case 40, 41, 42, 44, 49, 55, 56, 57, 58, 59, 60, 61: @@ -86,7 +86,7 @@ func collation2charset(col collation) *charsetMap { return cp1252 } -func charset2utf8(col collation, s []byte) string { +func CharsetToUTF8(col Collation, s []byte) string { cm := collation2charset(col) if cm == nil { return string(s) diff --git a/vendor/github.com/denisenkom/go-mssqldb/internal/cp/collation.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/collation.go new file mode 100644 index 0000000..ae7b03b --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/collation.go @@ -0,0 +1,20 @@ +package cp + +// http://msdn.microsoft.com/en-us/library/dd340437.aspx + +type Collation struct { + LcidAndFlags uint32 + SortId uint8 +} + +func (c Collation) getLcid() uint32 { + return c.LcidAndFlags & 0x000fffff +} + +func (c Collation) getFlags() uint32 { + return (c.LcidAndFlags & 0x0ff00000) >> 20 +} + +func (c Collation) getVersion() uint32 { + return (c.LcidAndFlags & 0xf0000000) >> 28 +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/cp1250.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1250.go similarity index 99% rename from vendor/github.com/denisenkom/go-mssqldb/cp1250.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1250.go index 8207366..5c8094e 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/cp1250.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1250.go @@ -1,4 +1,4 @@ -package mssql +package cp var cp1250 *charsetMap = &charsetMap{ sb: [256]rune{ diff --git a/vendor/github.com/denisenkom/go-mssqldb/cp1251.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1251.go similarity index 99% rename from vendor/github.com/denisenkom/go-mssqldb/cp1251.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1251.go index f5b81c3..dc58967 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/cp1251.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1251.go @@ -1,4 +1,4 @@ -package mssql +package cp var cp1251 *charsetMap = &charsetMap{ sb: [256]rune{ diff --git a/vendor/github.com/denisenkom/go-mssqldb/cp1252.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1252.go similarity index 99% rename from vendor/github.com/denisenkom/go-mssqldb/cp1252.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1252.go index ed705d3..5ae8703 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/cp1252.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1252.go @@ -1,4 +1,4 @@ -package mssql +package cp var cp1252 *charsetMap = &charsetMap{ sb: [256]rune{ diff --git a/vendor/github.com/denisenkom/go-mssqldb/cp1253.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1253.go similarity index 99% rename from vendor/github.com/denisenkom/go-mssqldb/cp1253.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1253.go index cb1e1a7..52c8e07 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/cp1253.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1253.go @@ -1,4 +1,4 @@ -package mssql +package cp var cp1253 *charsetMap = &charsetMap{ sb: [256]rune{ diff --git a/vendor/github.com/denisenkom/go-mssqldb/cp1254.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1254.go similarity index 99% rename from vendor/github.com/denisenkom/go-mssqldb/cp1254.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1254.go index a4b09bb..5d8864a 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/cp1254.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1254.go @@ -1,4 +1,4 @@ -package mssql +package cp var cp1254 *charsetMap = &charsetMap{ sb: [256]rune{ diff --git a/vendor/github.com/denisenkom/go-mssqldb/cp1255.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1255.go similarity index 99% rename from vendor/github.com/denisenkom/go-mssqldb/cp1255.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1255.go index 97f9ee9..6061989 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/cp1255.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1255.go @@ -1,4 +1,4 @@ -package mssql +package cp var cp1255 *charsetMap = &charsetMap{ sb: [256]rune{ diff --git a/vendor/github.com/denisenkom/go-mssqldb/cp1256.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1256.go similarity index 99% rename from vendor/github.com/denisenkom/go-mssqldb/cp1256.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1256.go index e91241b..ffd04b3 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/cp1256.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1256.go @@ -1,4 +1,4 @@ -package mssql +package cp var cp1256 *charsetMap = &charsetMap{ sb: [256]rune{ diff --git a/vendor/github.com/denisenkom/go-mssqldb/cp1257.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1257.go similarity index 99% rename from vendor/github.com/denisenkom/go-mssqldb/cp1257.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1257.go index bd93e6f..492da72 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/cp1257.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1257.go @@ -1,4 +1,4 @@ -package mssql +package cp var cp1257 *charsetMap = &charsetMap{ sb: [256]rune{ diff --git a/vendor/github.com/denisenkom/go-mssqldb/cp1258.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1258.go similarity index 99% rename from vendor/github.com/denisenkom/go-mssqldb/cp1258.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1258.go index 4e1f8ac..80be52c 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/cp1258.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp1258.go @@ -1,4 +1,4 @@ -package mssql +package cp var cp1258 *charsetMap = &charsetMap{ sb: [256]rune{ diff --git a/vendor/github.com/denisenkom/go-mssqldb/cp437.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp437.go similarity index 99% rename from vendor/github.com/denisenkom/go-mssqldb/cp437.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp437.go index f47f8ec..76dedfb 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/cp437.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp437.go @@ -1,4 +1,4 @@ -package mssql +package cp var cp437 *charsetMap = &charsetMap{ sb: [256]rune{ diff --git a/vendor/github.com/denisenkom/go-mssqldb/cp850.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp850.go similarity index 99% rename from vendor/github.com/denisenkom/go-mssqldb/cp850.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp850.go index e6b3d16..927ab24 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/cp850.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp850.go @@ -1,4 +1,4 @@ -package mssql +package cp var cp850 *charsetMap = &charsetMap{ sb: [256]rune{ diff --git a/vendor/github.com/denisenkom/go-mssqldb/cp874.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp874.go similarity index 99% rename from vendor/github.com/denisenkom/go-mssqldb/cp874.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp874.go index 9d691a1..723bf6c 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/cp874.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp874.go @@ -1,4 +1,4 @@ -package mssql +package cp var cp874 *charsetMap = &charsetMap{ sb: [256]rune{ diff --git a/vendor/github.com/denisenkom/go-mssqldb/cp932.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp932.go similarity index 99% rename from vendor/github.com/denisenkom/go-mssqldb/cp932.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp932.go index 980c55d..5fc1377 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/cp932.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp932.go @@ -1,4 +1,4 @@ -package mssql +package cp var cp932 *charsetMap = &charsetMap{ sb: [256]rune{ diff --git a/vendor/github.com/denisenkom/go-mssqldb/cp936.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp936.go similarity index 99% rename from vendor/github.com/denisenkom/go-mssqldb/cp936.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp936.go index fca5da7..d1fac12 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/cp936.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp936.go @@ -1,4 +1,4 @@ -package mssql +package cp var cp936 *charsetMap = &charsetMap{ sb: [256]rune{ diff --git a/vendor/github.com/denisenkom/go-mssqldb/cp949.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp949.go similarity index 99% rename from vendor/github.com/denisenkom/go-mssqldb/cp949.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp949.go index cddfcbc..52c708d 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/cp949.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp949.go @@ -1,4 +1,4 @@ -package mssql +package cp var cp949 *charsetMap = &charsetMap{ sb: [256]rune{ diff --git a/vendor/github.com/denisenkom/go-mssqldb/cp950.go b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp950.go similarity index 99% rename from vendor/github.com/denisenkom/go-mssqldb/cp950.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp950.go index cbf25cb..1301cd0 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/cp950.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/cp/cp950.go @@ -1,4 +1,4 @@ -package mssql +package cp var cp950 *charsetMap = &charsetMap{ sb: [256]rune{ diff --git a/vendor/github.com/denisenkom/go-mssqldb/internal/decimal/decimal.go b/vendor/github.com/denisenkom/go-mssqldb/internal/decimal/decimal.go new file mode 100644 index 0000000..7da0375 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/decimal/decimal.go @@ -0,0 +1,252 @@ +package decimal + +import ( + "encoding/binary" + "errors" + "fmt" + "math" + "math/big" + "strings" +) + +// Decimal represents decimal type in the Microsoft Open Specifications: http://msdn.microsoft.com/en-us/library/ee780893.aspx +type Decimal struct { + integer [4]uint32 // Little-endian + positive bool + prec uint8 + scale uint8 +} + +var ( + scaletblflt64 [39]float64 + int10 big.Int + int1e5 big.Int +) + +func init() { + var acc float64 = 1 + for i := 0; i <= 38; i++ { + scaletblflt64[i] = acc + acc *= 10 + } + + int10.SetInt64(10) + int1e5.SetInt64(1e5) +} + +const autoScale = 100 + +// SetInteger sets the ind'th element in the integer array +func (d *Decimal) SetInteger(integer uint32, ind uint8) { + d.integer[ind] = integer +} + +// SetPositive sets the positive member +func (d *Decimal) SetPositive(positive bool) { + d.positive = positive +} + +// SetPrec sets the prec member +func (d *Decimal) SetPrec(prec uint8) { + d.prec = prec +} + +// SetScale sets the scale member +func (d *Decimal) SetScale(scale uint8) { + d.scale = scale +} + +// IsPositive returns true if the Decimal is positive +func (d *Decimal) IsPositive() bool { + return d.positive +} + +// ToFloat64 converts decimal to a float64 +func (d Decimal) ToFloat64() float64 { + val := float64(0) + for i := 3; i >= 0; i-- { + val *= 0x100000000 + val += float64(d.integer[i]) + } + if !d.positive { + val = -val + } + if d.scale != 0 { + val /= scaletblflt64[d.scale] + } + return val +} + +// BigInt converts decimal to a bigint +func (d Decimal) BigInt() big.Int { + bytes := make([]byte, 16) + binary.BigEndian.PutUint32(bytes[0:4], d.integer[3]) + binary.BigEndian.PutUint32(bytes[4:8], d.integer[2]) + binary.BigEndian.PutUint32(bytes[8:12], d.integer[1]) + binary.BigEndian.PutUint32(bytes[12:16], d.integer[0]) + var x big.Int + x.SetBytes(bytes) + if !d.positive { + x.Neg(&x) + } + return x +} + +// Bytes converts decimal to a scaled byte slice +func (d Decimal) Bytes() []byte { + x := d.BigInt() + return ScaleBytes(x.String(), d.scale) +} + +// UnscaledBytes converts decimal to a unscaled byte slice +func (d Decimal) UnscaledBytes() []byte { + x := d.BigInt() + return x.Bytes() +} + +// String converts decimal to a string +func (d Decimal) String() string { + return string(d.Bytes()) +} + +// Float64ToDecimal converts float64 to decimal +func Float64ToDecimal(f float64) (Decimal, error) { + return Float64ToDecimalScale(f, autoScale) +} + +// Float64ToDecimalScale converts float64 to decimal; user can specify the scale +func Float64ToDecimalScale(f float64, scale uint8) (Decimal, error) { + var dec Decimal + if math.IsNaN(f) { + return dec, errors.New("NaN") + } + if math.IsInf(f, 0) { + return dec, errors.New("Infinity can't be converted to decimal") + } + dec.positive = f >= 0 + if !dec.positive { + f = math.Abs(f) + } + if f > 3.402823669209385e+38 { + return dec, errors.New("Float value is out of range") + } + dec.prec = 20 + var integer float64 + for dec.scale = 0; dec.scale <= scale; dec.scale++ { + integer = f * scaletblflt64[dec.scale] + _, frac := math.Modf(integer) + if frac == 0 && scale == autoScale { + break + } + } + for i := 0; i < 4; i++ { + mod := math.Mod(integer, 0x100000000) + integer -= mod + integer /= 0x100000000 + dec.integer[i] = uint32(mod) + if mod-math.Trunc(mod) >= 0.5 { + dec.integer[i] = uint32(mod) + 1 + } + } + return dec, nil +} + +// Int64ToDecimalScale converts float64 to decimal; user can specify the scale +func Int64ToDecimalScale(v int64, scale uint8) Decimal { + positive := v >= 0 + if !positive { + if v == math.MinInt64 { + // Special case - can't negate + return Decimal{ + integer: [4]uint32{0, 0x80000000, 0, 0}, + positive: false, + prec: 20, + scale: 0, + } + } + v = -v + } + return Decimal{ + integer: [4]uint32{uint32(v), uint32(v >> 32), 0, 0}, + positive: positive, + prec: 20, + scale: scale, + } +} + +// StringToDecimalScale converts string to decimal +func StringToDecimalScale(v string, outScale uint8) (Decimal, error) { + var r big.Int + var unscaled string + var inScale int + + point := strings.LastIndexByte(v, '.') + if point == -1 { + inScale = 0 + unscaled = v + } else { + inScale = len(v) - point - 1 + unscaled = v[:point] + v[point+1:] + } + if inScale > math.MaxUint8 { + return Decimal{}, fmt.Errorf("can't parse %q as a decimal number: scale too large", v) + } + + _, ok := r.SetString(unscaled, 10) + if !ok { + return Decimal{}, fmt.Errorf("can't parse %q as a decimal number", v) + } + + if inScale > int(outScale) { + return Decimal{}, fmt.Errorf("can't parse %q as a decimal number: scale %d is larger than the scale %d of the target column", v, inScale, outScale) + } + for inScale < int(outScale) { + if int(outScale)-inScale >= 5 { + r.Mul(&r, &int1e5) + inScale += 5 + } else { + r.Mul(&r, &int10) + inScale++ + } + } + + bytes := r.Bytes() + if len(bytes) > 16 { + return Decimal{}, fmt.Errorf("can't parse %q as a decimal number: precision too large", v) + } + var out [4]uint32 + for i, b := range bytes { + pos := len(bytes) - i - 1 + out[pos/4] += uint32(b) << uint(pos%4*8) + } + return Decimal{ + integer: out, + positive: r.Sign() >= 0, + prec: 20, + scale: uint8(inScale), + }, nil +} + +// ScaleBytes converts a stringified decimal to a scaled byte slice +func ScaleBytes(s string, scale uint8) []byte { + z := make([]byte, 0, len(s)+1) + if s[0] == '-' || s[0] == '+' { + z = append(z, byte(s[0])) + s = s[1:] + } + pos := len(s) - int(scale) + if pos <= 0 { + z = append(z, byte('0')) + } else if pos > 0 { + z = append(z, []byte(s[:pos])...) + } + if scale > 0 { + z = append(z, byte('.')) + for pos < 0 { + z = append(z, byte('0')) + pos++ + } + z = append(z, []byte(s[pos:])...) + } + return z +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/parser.go b/vendor/github.com/denisenkom/go-mssqldb/internal/querytext/parser.go similarity index 69% rename from vendor/github.com/denisenkom/go-mssqldb/parser.go rename to vendor/github.com/denisenkom/go-mssqldb/internal/querytext/parser.go index 9e37c16..14650e3 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/parser.go +++ b/vendor/github.com/denisenkom/go-mssqldb/internal/querytext/parser.go @@ -1,4 +1,8 @@ -package mssql +// Package querytext is the old query parser and parameter substitute process. +// Do not use on new code. +// +// This package is not subject to any API compatibility guarantee. +package querytext import ( "bytes" @@ -11,6 +15,9 @@ type parser struct { w bytes.Buffer paramCount int paramMax int + + // using map as a set + namedParams map[string]bool } func (p *parser) next() (rune, bool) { @@ -37,15 +44,20 @@ func (p *parser) write(ch rune) { type stateFunc func(*parser) stateFunc -func parseParams(query string) (string, int) { +// ParseParams rewrites the query from using "?" placeholders +// to using "@pN" parameter names that SQL Server will accept. +// +// This function and package is not subject to any API compatibility guarantee. +func ParseParams(query string) (string, int) { p := &parser{ - r: bytes.NewReader([]byte(query)), + r: bytes.NewReader([]byte(query)), + namedParams: map[string]bool{}, } state := parseNormal for state != nil { state = state(p) } - return p.w.String(), p.paramMax + return p.w.String(), p.paramMax + len(p.namedParams) } func parseNormal(p *parser) stateFunc { @@ -55,7 +67,7 @@ func parseNormal(p *parser) stateFunc { return nil } if ch == '?' { - return parseParameter + return parseOrdinalParameter } else if ch == '$' || ch == ':' { ch2, ok := p.next() if !ok { @@ -64,7 +76,9 @@ func parseNormal(p *parser) stateFunc { } p.unread() if ch2 >= '0' && ch2 <= '9' { - return parseParameter + return parseOrdinalParameter + } else if 'a' <= ch2 && ch2 <= 'z' || 'A' <= ch2 && ch2 <= 'Z' { + return parseNamedParameter } } p.write(ch) @@ -83,7 +97,7 @@ func parseNormal(p *parser) stateFunc { } } -func parseParameter(p *parser) stateFunc { +func parseOrdinalParameter(p *parser) stateFunc { var paramN int var ok bool for { @@ -113,6 +127,30 @@ func parseParameter(p *parser) stateFunc { return parseNormal } +func parseNamedParameter(p *parser) stateFunc { + var paramName string + var ok bool + for { + var ch rune + ch, ok = p.next() + if ok && (ch >= '0' && ch <= '9' || 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z') { + paramName = paramName + string(ch) + } else { + break + } + } + if ok { + p.unread() + } + p.namedParams[paramName] = true + p.w.WriteString("@") + p.w.WriteString(paramName) + if !ok { + return nil + } + return parseNormal +} + func parseQuote(p *parser) stateFunc { for { ch, ok := p.next() diff --git a/vendor/github.com/denisenkom/go-mssqldb/log.go b/vendor/github.com/denisenkom/go-mssqldb/log.go index f350aed..9b8c551 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/log.go +++ b/vendor/github.com/denisenkom/go-mssqldb/log.go @@ -4,19 +4,26 @@ import ( "log" ) -type Logger log.Logger +type Logger interface { + Printf(format string, v ...interface{}) + Println(v ...interface{}) +} + +type optionalLogger struct { + logger Logger +} -func (logger *Logger) Printf(format string, v ...interface{}) { - if logger != nil { - (*log.Logger)(logger).Printf(format, v...) +func (o optionalLogger) Printf(format string, v ...interface{}) { + if o.logger != nil { + o.logger.Printf(format, v...) } else { log.Printf(format, v...) } } -func (logger *Logger) Println(v ...interface{}) { - if logger != nil { - (*log.Logger)(logger).Println(v...) +func (o optionalLogger) Println(v ...interface{}) { + if o.logger != nil { + o.logger.Println(v...) } else { log.Println(v...) } diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql.go b/vendor/github.com/denisenkom/go-mssqldb/mssql.go index c70475a..e37109c 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/mssql.go +++ b/vendor/github.com/denisenkom/go-mssqldb/mssql.go @@ -1,359 +1,806 @@ package mssql import ( + "context" "database/sql" "database/sql/driver" "encoding/binary" "errors" "fmt" "io" - "log" "math" "net" + "reflect" "strings" - "sync" "time" + "unicode" + + "github.com/denisenkom/go-mssqldb/internal/querytext" ) -var partnersCache partners = partners{mu: sync.RWMutex{}, v: make(map[string]string)} +// ReturnStatus may be used to return the return value from a proc. +// +// var rs mssql.ReturnStatus +// _, err := db.Exec("theproc", &rs) +// log.Printf("return status = %d", rs) +type ReturnStatus int32 + +var driverInstance = &Driver{processQueryText: true} +var driverInstanceNoProcess = &Driver{processQueryText: false} func init() { - sql.Register("mssql", &MssqlDriver{}) + sql.Register("mssql", driverInstance) + sql.Register("sqlserver", driverInstanceNoProcess) + createDialer = func(p *connectParams) Dialer { + return netDialer{&net.Dialer{KeepAlive: p.keepAlive}} + } +} + +var createDialer func(p *connectParams) Dialer + +type netDialer struct { + nd *net.Dialer } -type MssqlDriver struct { - log *log.Logger +func (d netDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { + return d.nd.DialContext(ctx, network, addr) } -func (d *MssqlDriver) SetLogger(logger *log.Logger) { - d.log = logger +type Driver struct { + log optionalLogger + + processQueryText bool } -func CheckBadConn(err error) error { - if err == io.EOF { - return driver.ErrBadConn +// OpenConnector opens a new connector. Useful to dial with a context. +func (d *Driver) OpenConnector(dsn string) (*Connector, error) { + params, err := parseConnectParams(dsn) + if err != nil { + return nil, err } - neterr, ok := err.(net.Error) - if !ok || (!neterr.Timeout() && neterr.Temporary()) { - return err + return &Connector{ + params: params, + driver: d, + }, nil +} + +func (d *Driver) Open(dsn string) (driver.Conn, error) { + return d.open(context.Background(), dsn) +} + +func SetLogger(logger Logger) { + driverInstance.SetLogger(logger) + driverInstanceNoProcess.SetLogger(logger) +} + +func (d *Driver) SetLogger(logger Logger) { + d.log = optionalLogger{logger} +} + +// NewConnector creates a new connector from a DSN. +// The returned connector may be used with sql.OpenDB. +func NewConnector(dsn string) (*Connector, error) { + params, err := parseConnectParams(dsn) + if err != nil { + return nil, err + } + c := &Connector{ + params: params, + driver: driverInstanceNoProcess, } - return driver.ErrBadConn + return c, nil } -type MssqlConn struct { - sess *tdsSession +// Connector holds the parsed DSN and is ready to make a new connection +// at any time. +// +// In the future, settings that cannot be passed through a string DSN +// may be set directly on the connector. +type Connector struct { + params connectParams + driver *Driver + + // SessionInitSQL is executed after marking a given session to be reset. + // When not present, the next query will still reset the session to the + // database defaults. + // + // When present the connection will immediately mark the session to + // be reset, then execute the SessionInitSQL text to setup the session + // that may be different from the base database defaults. + // + // For Example, the application relies on the following defaults + // but is not allowed to set them at the database system level. + // + // SET XACT_ABORT ON; + // SET TEXTSIZE -1; + // SET ANSI_NULLS ON; + // SET LOCK_TIMEOUT 10000; + // + // SessionInitSQL should not attempt to manually call sp_reset_connection. + // This will happen at the TDS layer. + // + // SessionInitSQL is optional. The session will be reset even if + // SessionInitSQL is empty. + SessionInitSQL string + + // Dialer sets a custom dialer for all network operations. + // If Dialer is not set, normal net dialers are used. + Dialer Dialer } -func (c *MssqlConn) Commit() error { - headers := []headerStruct{ - {hdrtype: dataStmHdrTransDescr, - data: transDescrHdr{c.sess.tranid, 1}.pack()}, +type Dialer interface { + DialContext(ctx context.Context, network string, addr string) (net.Conn, error) +} + +func (c *Connector) getDialer(p *connectParams) Dialer { + if c != nil && c.Dialer != nil { + return c.Dialer + } + return createDialer(p) +} + +type Conn struct { + connector *Connector + sess *tdsSession + transactionCtx context.Context + resetSession bool + + processQueryText bool + connectionGood bool + + outs map[string]interface{} + returnStatus *ReturnStatus +} + +func (c *Conn) setReturnStatus(s ReturnStatus) { + if c.returnStatus == nil { + return } - if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil { + *c.returnStatus = s +} + +func (c *Conn) checkBadConn(err error) error { + // this is a hack to address Issue #275 + // we set connectionGood flag to false if + // error indicates that connection is not usable + // but we return actual error instead of ErrBadConn + // this will cause connection to stay in a pool + // but next request to this connection will return ErrBadConn + + // it might be possible to revise this hack after + // https://github.com/golang/go/issues/20807 + // is implemented + switch err { + case nil: + return nil + case io.EOF: + c.connectionGood = false + return driver.ErrBadConn + case driver.ErrBadConn: + // It is an internal programming error if driver.ErrBadConn + // is ever passed to this function. driver.ErrBadConn should + // only ever be returned in response to a *mssql.Conn.connectionGood == false + // check in the external facing API. + panic("driver.ErrBadConn in checkBadConn. This should not happen.") + } + + switch err.(type) { + case net.Error: + c.connectionGood = false + return err + case StreamError: + c.connectionGood = false + return err + default: return err } +} + +func (c *Conn) clearOuts() { + c.outs = nil +} +func (c *Conn) simpleProcessResp(ctx context.Context) error { tokchan := make(chan tokenStruct, 5) - go processResponse(c.sess, tokchan) + go processResponse(ctx, c.sess, tokchan, c.outs) + c.clearOuts() for tok := range tokchan { switch token := tok.(type) { + case doneStruct: + if token.isError() { + return c.checkBadConn(token.getError()) + } case error: - return token + return c.checkBadConn(token) } } return nil } -func (c *MssqlConn) Rollback() error { +func (c *Conn) Commit() error { + if !c.connectionGood { + return driver.ErrBadConn + } + if err := c.sendCommitRequest(); err != nil { + return c.checkBadConn(err) + } + return c.simpleProcessResp(c.transactionCtx) +} + +func (c *Conn) sendCommitRequest() error { headers := []headerStruct{ {hdrtype: dataStmHdrTransDescr, data: transDescrHdr{c.sess.tranid, 1}.pack()}, } - if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil { - return err + reset := c.resetSession + c.resetSession = false + if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil { + if c.sess.logFlags&logErrors != 0 { + c.sess.log.Printf("Failed to send CommitXact with %v", err) + } + c.connectionGood = false + return fmt.Errorf("Faild to send CommitXact: %v", err) + } + return nil +} + +func (c *Conn) Rollback() error { + if !c.connectionGood { + return driver.ErrBadConn + } + if err := c.sendRollbackRequest(); err != nil { + return c.checkBadConn(err) } + return c.simpleProcessResp(c.transactionCtx) +} - tokchan := make(chan tokenStruct, 5) - go processResponse(c.sess, tokchan) - for tok := range tokchan { - switch token := tok.(type) { - case error: - return token +func (c *Conn) sendRollbackRequest() error { + headers := []headerStruct{ + {hdrtype: dataStmHdrTransDescr, + data: transDescrHdr{c.sess.tranid, 1}.pack()}, + } + reset := c.resetSession + c.resetSession = false + if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil { + if c.sess.logFlags&logErrors != 0 { + c.sess.log.Printf("Failed to send RollbackXact with %v", err) } + c.connectionGood = false + return fmt.Errorf("Failed to send RollbackXact: %v", err) } return nil } -func (c *MssqlConn) Begin() (driver.Tx, error) { +func (c *Conn) Begin() (driver.Tx, error) { + return c.begin(context.Background(), isolationUseCurrent) +} + +func (c *Conn) begin(ctx context.Context, tdsIsolation isoLevel) (tx driver.Tx, err error) { + if !c.connectionGood { + return nil, driver.ErrBadConn + } + err = c.sendBeginRequest(ctx, tdsIsolation) + if err != nil { + return nil, c.checkBadConn(err) + } + tx, err = c.processBeginResponse(ctx) + if err != nil { + return nil, c.checkBadConn(err) + } + return +} + +func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error { + c.transactionCtx = ctx headers := []headerStruct{ {hdrtype: dataStmHdrTransDescr, data: transDescrHdr{0, 1}.pack()}, } - if err := sendBeginXact(c.sess.buf, headers, 0, ""); err != nil { - return nil, CheckBadConn(err) - } - tokchan := make(chan tokenStruct, 5) - go processResponse(c.sess, tokchan) - for tok := range tokchan { - switch token := tok.(type) { - case error: - if c.sess.tranid != 0 { - return nil, token - } - return nil, CheckBadConn(token) + reset := c.resetSession + c.resetSession = false + if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, "", reset); err != nil { + if c.sess.logFlags&logErrors != 0 { + c.sess.log.Printf("Failed to send BeginXact with %v", err) } + c.connectionGood = false + return fmt.Errorf("Failed to send BeginXact: %v", err) } - // successful BEGINXACT request will return sess.tranid - // for started transaction - return c, nil + return nil } -func parseConnectionString(dsn string) (res map[string]string) { - res = map[string]string{} - parts := strings.Split(dsn, ";") - for _, part := range parts { - if len(part) == 0 { - continue - } - lst := strings.SplitN(part, "=", 2) - name := strings.TrimSpace(strings.ToLower(lst[0])) - if len(name) == 0 { - continue - } - var value string = "" - if len(lst) > 1 { - value = strings.TrimSpace(lst[1]) - } - res[name] = value +func (c *Conn) processBeginResponse(ctx context.Context) (driver.Tx, error) { + if err := c.simpleProcessResp(ctx); err != nil { + return nil, err } - return res + // successful BEGINXACT request will return sess.tranid + // for started transaction + return c, nil } -func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) { - params := parseConnectionString(dsn) - - conn, err := openConnection(dsn, params) +func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) { + params, err := parseConnectParams(dsn) if err != nil { return nil, err } - - conn.sess.log = (*Logger)(d.log) - return conn, nil + return d.connect(ctx, nil, params) } -func openConnection(dsn string, params map[string]string) (*MssqlConn, error) { - sess, err := connect(params) +// connect to the server, using the provided context for dialing only. +func (d *Driver) connect(ctx context.Context, c *Connector, params connectParams) (*Conn, error) { + sess, err := connect(ctx, c, d.log, params) if err != nil { - partner := partnersCache.Get(dsn) - if partner == "" { - partner = params["failoverpartner"] - // remove the failoverpartner entry to prevent infinite recursion - delete(params, "failoverpartner") - if port, ok := params["failoverport"]; ok { - params["port"] = port - } + // main server failed, try fail-over partner + if params.failOverPartner == "" { + return nil, err } - if partner != "" { - params["server"] = partner - return openConnection(dsn, params) + params.host = params.failOverPartner + if params.failOverPort != 0 { + params.port = params.failOverPort } - return nil, err + sess, err = connect(ctx, c, d.log, params) + if err != nil { + // fail-over partner also failed, now fail + return nil, err + } } - if partner := sess.partner; partner != "" { - // append an instance so the port will be ignored when this value is used; - // tds does not provide the port number. - if !strings.Contains(partner, `\`) { - partner += `\.` - } - partnersCache.Set(dsn, partner) + conn := &Conn{ + connector: c, + sess: sess, + transactionCtx: context.Background(), + processQueryText: d.processQueryText, + connectionGood: true, } - return &MssqlConn{sess}, nil + return conn, nil } -func (c *MssqlConn) Close() error { +func (c *Conn) Close() error { return c.sess.buf.transport.Close() } -type MssqlStmt struct { - c *MssqlConn +type Stmt struct { + c *Conn query string paramCount int + notifSub *queryNotifSub +} + +type queryNotifSub struct { + msgText string + options string + timeout uint32 } -func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) { - q, paramCount := parseParams(query) - return &MssqlStmt{c, q, paramCount}, nil +func (c *Conn) Prepare(query string) (driver.Stmt, error) { + if !c.connectionGood { + return nil, driver.ErrBadConn + } + if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") { + return c.prepareCopyIn(context.Background(), query) + } + return c.prepareContext(context.Background(), query) } -func (s *MssqlStmt) Close() error { +func (c *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) { + paramCount := -1 + if c.processQueryText { + query, paramCount = querytext.ParseParams(query) + } + return &Stmt{c, query, paramCount, nil}, nil +} + +func (s *Stmt) Close() error { return nil } -func (s *MssqlStmt) NumInput() int { +func (s *Stmt) SetQueryNotification(id, options string, timeout time.Duration) { + to := uint32(timeout / time.Second) + if to < 1 { + to = 1 + } + s.notifSub = &queryNotifSub{id, options, to} +} + +func (s *Stmt) NumInput() int { return s.paramCount } -func (s *MssqlStmt) sendQuery(args []driver.Value) (err error) { +func (s *Stmt) sendQuery(args []namedValue) (err error) { headers := []headerStruct{ {hdrtype: dataStmHdrTransDescr, data: transDescrHdr{s.c.sess.tranid, 1}.pack()}, } - if len(args) != s.paramCount { - return errors.New(fmt.Sprintf("sql: expected %d parameters, got %d", s.paramCount, len(args))) + + if s.notifSub != nil { + headers = append(headers, + headerStruct{ + hdrtype: dataStmHdrQueryNotif, + data: queryNotifHdr{ + s.notifSub.msgText, + s.notifSub.options, + s.notifSub.timeout, + }.pack(), + }) } - if s.c.sess.logFlags&logSQL != 0 { - s.c.sess.log.Println(s.query) + + conn := s.c + + // no need to check number of parameters here, it is checked by database/sql + if conn.sess.logFlags&logSQL != 0 { + conn.sess.log.Println(s.query) } - if s.c.sess.logFlags&logParams != 0 && len(args) > 0 { + if conn.sess.logFlags&logParams != 0 && len(args) > 0 { for i := 0; i < len(args); i++ { - s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i]) + if len(args[i].Name) > 0 { + s.c.sess.log.Printf("\t@%s\t%v\n", args[i].Name, args[i].Value) + } else { + s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i].Value) + } } - } + + reset := conn.resetSession + conn.resetSession = false if len(args) == 0 { - if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil { - if s.c.sess.tranid != 0 { - return err + if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil { + if conn.sess.logFlags&logErrors != 0 { + conn.sess.log.Printf("Failed to send SqlBatch with %v", err) } - return CheckBadConn(err) + conn.connectionGood = false + return fmt.Errorf("failed to send SQL Batch: %v", err) } } else { - params := make([]Param, len(args)+2) - decls := make([]string, len(args)) - params[0], err = s.makeParam(s.query) - if err != nil { - return - } - for i, val := range args { - params[i+2], err = s.makeParam(val) + proc := sp_ExecuteSql + var params []param + if isProc(s.query) { + proc.name = s.query + params, _, err = s.makeRPCParams(args, true) if err != nil { return } - name := fmt.Sprintf("@p%d", i+1) - params[i+2].Name = name - decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+2].ti)) - } - params[1], err = s.makeParam(strings.Join(decls, ",")) - if err != nil { - return + } else { + var decls []string + params, decls, err = s.makeRPCParams(args, false) + if err != nil { + return + } + params[0] = makeStrParam(s.query) + params[1] = makeStrParam(strings.Join(decls, ",")) } - if err = sendRpc(s.c.sess.buf, headers, Sp_ExecuteSql, 0, params); err != nil { - if s.c.sess.tranid != 0 { - return err + if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil { + if conn.sess.logFlags&logErrors != 0 { + conn.sess.log.Printf("Failed to send Rpc with %v", err) } - return CheckBadConn(err) + conn.connectionGood = false + return fmt.Errorf("Failed to send RPC: %v", err) } } return } -func (s *MssqlStmt) Query(args []driver.Value) (res driver.Rows, err error) { +// isProc takes the query text in s and determines if it is a stored proc name +// or SQL text. +func isProc(s string) bool { + if len(s) == 0 { + return false + } + const ( + outside = iota + text + escaped + ) + st := outside + var rn1, rPrev rune + for _, r := range s { + rPrev = rn1 + rn1 = r + switch r { + // No newlines or string sequences. + case '\n', '\r', '\'', ';': + return false + } + switch st { + case outside: + switch { + case unicode.IsSpace(r): + return false + case r == '[': + st = escaped + continue + case r == ']' && rPrev == ']': + st = escaped + continue + case unicode.IsLetter(r): + st = text + } + case text: + switch { + case r == '.': + st = outside + continue + case unicode.IsSpace(r): + return false + } + case escaped: + switch { + case r == ']': + st = outside + continue + } + } + } + return true +} + +func (s *Stmt) makeRPCParams(args []namedValue, isProc bool) ([]param, []string, error) { + var err error + var offset int + if !isProc { + offset = 2 + } + params := make([]param, len(args)+offset) + decls := make([]string, len(args)) + for i, val := range args { + params[i+offset], err = s.makeParam(val.Value) + if err != nil { + return nil, nil, err + } + var name string + if len(val.Name) > 0 { + name = "@" + val.Name + } else if !isProc { + name = fmt.Sprintf("@p%d", val.Ordinal) + } + params[i+offset].Name = name + decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+offset].ti)) + } + return params, decls, nil +} + +type namedValue struct { + Name string + Ordinal int + Value driver.Value +} + +func convertOldArgs(args []driver.Value) []namedValue { + list := make([]namedValue, len(args)) + for i, v := range args { + list[i] = namedValue{ + Ordinal: i + 1, + Value: v, + } + } + return list +} + +func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { + return s.queryContext(context.Background(), convertOldArgs(args)) +} + +func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver.Rows, err error) { + if !s.c.connectionGood { + return nil, driver.ErrBadConn + } if err = s.sendQuery(args); err != nil { - return + return nil, s.c.checkBadConn(err) } + return s.processQueryResponse(ctx) +} + +func (s *Stmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) { tokchan := make(chan tokenStruct, 5) - go processResponse(s.c.sess, tokchan) + ctx, cancel := context.WithCancel(ctx) + go processResponse(ctx, s.c.sess, tokchan, s.c.outs) + s.c.clearOuts() // process metadata - var cols []string + var cols []columnStruct loop: for tok := range tokchan { switch token := tok.(type) { - // by ignoring DONE token we effectively - // skip empty result-sets - // this improves results in queryes like that: + // By ignoring DONE token we effectively + // skip empty result-sets. + // This improves results in queries like that: // set nocount on; select 1 // see TestIgnoreEmptyResults test //case doneStruct: - //break loop + //break loop case []columnStruct: - cols = make([]string, len(token)) - for i, col := range token { - cols[i] = col.ColName - } + cols = token break loop - case error: - if s.c.sess.tranid != 0 { - return nil, token + case doneStruct: + if token.isError() { + return nil, s.c.checkBadConn(token.getError()) } - return nil, CheckBadConn(token) + case ReturnStatus: + s.c.setReturnStatus(token) + case error: + return nil, s.c.checkBadConn(token) } } - return &MssqlRows{sess: s.c.sess, tokchan: tokchan, cols: cols}, nil + res = &Rows{stmt: s, tokchan: tokchan, cols: cols, cancel: cancel} + return } -func (s *MssqlStmt) Exec(args []driver.Value) (res driver.Result, err error) { +func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) { + return s.exec(context.Background(), convertOldArgs(args)) +} + +func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result, err error) { + if !s.c.connectionGood { + return nil, driver.ErrBadConn + } if err = s.sendQuery(args); err != nil { - return + return nil, s.c.checkBadConn(err) } + if res, err = s.processExec(ctx); err != nil { + return nil, s.c.checkBadConn(err) + } + return +} + +func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) { tokchan := make(chan tokenStruct, 5) - go processResponse(s.c.sess, tokchan) + go processResponse(ctx, s.c.sess, tokchan, s.c.outs) + s.c.clearOuts() var rowCount int64 for token := range tokchan { switch token := token.(type) { case doneInProcStruct: if token.Status&doneCount != 0 { - rowCount = int64(token.RowCount) + rowCount += int64(token.RowCount) } case doneStruct: if token.Status&doneCount != 0 { - rowCount = int64(token.RowCount) + rowCount += int64(token.RowCount) } - case error: - if s.c.sess.logFlags&logErrors != 0 { - s.c.sess.log.Println("got error:", token) - } - if s.c.sess.tranid != 0 { - return nil, token + if token.isError() { + return nil, token.getError() } - return nil, CheckBadConn(token) + case ReturnStatus: + s.c.setReturnStatus(token) + case error: + return nil, token } } - return &MssqlResult{s.c, rowCount}, nil + return &Result{s.c, rowCount}, nil } -type MssqlRows struct { - sess *tdsSession - cols []string +type Rows struct { + stmt *Stmt + cols []columnStruct tokchan chan tokenStruct + + nextCols []columnStruct + + cancel func() } -func (rc *MssqlRows) Close() error { +func (rc *Rows) Close() error { + rc.cancel() for _ = range rc.tokchan { } rc.tokchan = nil return nil } -func (rc *MssqlRows) Columns() (res []string) { - return rc.cols +func (rc *Rows) Columns() (res []string) { + res = make([]string, len(rc.cols)) + for i, col := range rc.cols { + res[i] = col.ColName + } + return } -func (rc *MssqlRows) Next(dest []driver.Value) (err error) { +func (rc *Rows) Next(dest []driver.Value) error { + if !rc.stmt.c.connectionGood { + return driver.ErrBadConn + } + if rc.nextCols != nil { + return io.EOF + } for tok := range rc.tokchan { switch tokdata := tok.(type) { case []columnStruct: - return streamErrorf("Unexpected token COLMETADATA") + rc.nextCols = tokdata + return io.EOF case []interface{}: for i := range dest { dest[i] = tokdata[i] } return nil + case doneStruct: + if tokdata.isError() { + return rc.stmt.c.checkBadConn(tokdata.getError()) + } + case ReturnStatus: + rc.stmt.c.setReturnStatus(tokdata) case error: - return tokdata + return rc.stmt.c.checkBadConn(tokdata) } } return io.EOF } -func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) { +func (rc *Rows) HasNextResultSet() bool { + return rc.nextCols != nil +} + +func (rc *Rows) NextResultSet() error { + rc.cols = rc.nextCols + rc.nextCols = nil + if rc.cols == nil { + return io.EOF + } + return nil +} + +// It should return +// the value type that can be used to scan types into. For example, the database +// column type "bigint" this should return "reflect.TypeOf(int64(0))". +func (r *Rows) ColumnTypeScanType(index int) reflect.Type { + return makeGoLangScanType(r.cols[index].ti) +} + +// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the +// database system type name without the length. Type names should be uppercase. +// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT", +// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML", +// "TIMESTAMP". +func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { + return makeGoLangTypeName(r.cols[index].ti) +} + +// RowsColumnTypeLength may be implemented by Rows. It should return the length +// of the column type if the column is a variable length type. If the column is +// not a variable length type ok should return false. +// If length is not limited other than system limits, it should return math.MaxInt64. +// The following are examples of returned values for various types: +// TEXT (math.MaxInt64, true) +// varchar(10) (10, true) +// nvarchar(10) (10, true) +// decimal (0, false) +// int (0, false) +// bytea(30) (30, true) +func (r *Rows) ColumnTypeLength(index int) (int64, bool) { + return makeGoLangTypeLength(r.cols[index].ti) +} + +// It should return +// the precision and scale for decimal types. If not applicable, ok should be false. +// The following are examples of returned values for various types: +// decimal(38, 4) (38, 4, true) +// int (0, 0, false) +// decimal (math.MaxInt64, math.MaxInt64, true) +func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) { + return makeGoLangTypePrecisionScale(r.cols[index].ti) +} + +// The nullable value should +// be true if it is known the column may be null, or false if the column is known +// to be not nullable. +// If the column nullability is unknown, ok should be false. +func (r *Rows) ColumnTypeNullable(index int) (nullable, ok bool) { + nullable = r.cols[index].Flags&colFlagNullable != 0 + ok = true + return +} + +func makeStrParam(val string) (res param) { + res.ti.TypeId = typeNVarChar + res.buffer = str2ucs2(val) + res.ti.Size = len(res.buffer) + return +} + +func (s *Stmt) makeParam(val driver.Value) (res param, err error) { if val == nil { - res.ti.TypeId = typeNVarChar + res.ti.TypeId = typeNull res.buffer = nil - res.ti.Size = 2 + res.ti.Size = 0 return } switch val := val.(type) { @@ -362,19 +809,34 @@ func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) { res.buffer = make([]byte, 8) res.ti.Size = 8 binary.LittleEndian.PutUint64(res.buffer, uint64(val)) + case sql.NullInt64: + // only null values should be getting here + res.ti.TypeId = typeIntN + res.ti.Size = 8 + res.buffer = []byte{} + case float64: res.ti.TypeId = typeFltN res.ti.Size = 8 res.buffer = make([]byte, 8) binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val)) + case sql.NullFloat64: + // only null values should be getting here + res.ti.TypeId = typeFltN + res.ti.Size = 8 + res.buffer = []byte{} + case []byte: res.ti.TypeId = typeBigVarBin res.ti.Size = len(val) res.buffer = val case string: + res = makeStrParam(val) + case sql.NullString: + // only null values should be getting here res.ti.TypeId = typeNVarChar - res.buffer = str2ucs2(val) - res.ti.Size = len(res.buffer) + res.buffer = nil + res.ti.Size = 8000 case bool: res.ti.TypeId = typeBitN res.ti.Size = 1 @@ -382,96 +844,114 @@ func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) { if val { res.buffer[0] = 1 } + case sql.NullBool: + // only null values should be getting here + res.ti.TypeId = typeBitN + res.ti.Size = 1 + res.buffer = []byte{} + case time.Time: if s.c.sess.loginAck.TDSVersion >= verTDS73 { res.ti.TypeId = typeDateTimeOffsetN res.ti.Scale = 7 - res.ti.Size = 10 - buf := make([]byte, 10) - res.buffer = buf - days, ns := dateTime2(val) - ns /= 100 - buf[0] = byte(ns) - buf[1] = byte(ns >> 8) - buf[2] = byte(ns >> 16) - buf[3] = byte(ns >> 24) - buf[4] = byte(ns >> 32) - buf[5] = byte(days) - buf[6] = byte(days >> 8) - buf[7] = byte(days >> 16) - _, offset := val.Zone() - offset /= 60 - buf[8] = byte(offset) - buf[9] = byte(offset >> 8) + res.buffer = encodeDateTimeOffset(val, int(res.ti.Scale)) + res.ti.Size = len(res.buffer) } else { res.ti.TypeId = typeDateTimeN - res.ti.Size = 8 - res.buffer = make([]byte, 8) - ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC) - dur := val.Sub(ref) - days := dur / (24 * time.Hour) - tm := (300 * (dur % (24 * time.Hour))) / time.Second - binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(days)) - binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm)) + res.buffer = encodeDateTime(val) + res.ti.Size = len(res.buffer) } default: - err = fmt.Errorf("mssql: unknown type for %T", val) - return + return s.makeParamExtra(val) } return } -type MssqlResult struct { - c *MssqlConn +type Result struct { + c *Conn rowsAffected int64 } -func (r *MssqlResult) RowsAffected() (int64, error) { +func (r *Result) RowsAffected() (int64, error) { return r.rowsAffected, nil } -func (r *MssqlResult) LastInsertId() (int64, error) { - s, err := r.c.Prepare("select cast(@@identity as bigint)") - if err != nil { - return 0, err +var _ driver.Pinger = &Conn{} + +// Ping is used to check if the remote server is available and satisfies the Pinger interface. +func (c *Conn) Ping(ctx context.Context) error { + if !c.connectionGood { + return driver.ErrBadConn } - defer s.Close() - rows, err := s.Query(nil) - if err != nil { - return 0, err + stmt := &Stmt{c, `select 1;`, 0, nil} + _, err := stmt.ExecContext(ctx, nil) + return err +} + +var _ driver.ConnBeginTx = &Conn{} + +// BeginTx satisfies ConnBeginTx. +func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if !c.connectionGood { + return nil, driver.ErrBadConn + } + if opts.ReadOnly { + return nil, errors.New("Read-only transactions are not supported") + } + + var tdsIsolation isoLevel + switch sql.IsolationLevel(opts.Isolation) { + case sql.LevelDefault: + tdsIsolation = isolationUseCurrent + case sql.LevelReadUncommitted: + tdsIsolation = isolationReadUncommited + case sql.LevelReadCommitted: + tdsIsolation = isolationReadCommited + case sql.LevelWriteCommitted: + return nil, errors.New("LevelWriteCommitted isolation level is not supported") + case sql.LevelRepeatableRead: + tdsIsolation = isolationRepeatableRead + case sql.LevelSnapshot: + tdsIsolation = isolationSnapshot + case sql.LevelSerializable: + tdsIsolation = isolationSerializable + case sql.LevelLinearizable: + return nil, errors.New("LevelLinearizable isolation level is not supported") + default: + return nil, errors.New("Isolation level is not supported or unknown") } - defer rows.Close() - dest := make([]driver.Value, 1) - err = rows.Next(dest) - if err != nil { - return 0, err + return c.begin(ctx, tdsIsolation) +} + +func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if !c.connectionGood { + return nil, driver.ErrBadConn } - if dest[0] == nil { - return -1, errors.New("There is no generated identity value") + if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") { + return c.prepareCopyIn(ctx, query) } - lastInsertId := dest[0].(int64) - return lastInsertId, nil -} -type partners struct { - mu sync.RWMutex - v map[string]string + return c.prepareContext(ctx, query) } -func (p *partners) Set(key, value string) error { - p.mu.Lock() - defer p.mu.Unlock() - if _, ok := p.v[key]; ok { - return errors.New("key already exists") +func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + if !s.c.connectionGood { + return nil, driver.ErrBadConn } - - p.v[key] = value - return nil + list := make([]namedValue, len(args)) + for i, nv := range args { + list[i] = namedValue(nv) + } + return s.queryContext(ctx, list) } -func (p *partners) Get(key string) (value string) { - p.mu.RLock() - value = p.v[key] - p.mu.RUnlock() - return +func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + if !s.c.connectionGood { + return nil, driver.ErrBadConn + } + list := make([]namedValue, len(args)) + for i, nv := range args { + list[i] = namedValue(nv) + } + return s.exec(ctx, list) } diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql_go1.3.go b/vendor/github.com/denisenkom/go-mssqldb/mssql_go1.3.go deleted file mode 100644 index 22c6891..0000000 --- a/vendor/github.com/denisenkom/go-mssqldb/mssql_go1.3.go +++ /dev/null @@ -1,11 +0,0 @@ -// +build go1.3 - -package mssql - -import ( - "net" -) - -func createDialer(p *connectParams) *net.Dialer { - return &net.Dialer{Timeout: p.dial_timeout, KeepAlive: p.keepAlive} -} diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql_go1.3pre.go b/vendor/github.com/denisenkom/go-mssqldb/mssql_go1.3pre.go deleted file mode 100644 index 3c7e727..0000000 --- a/vendor/github.com/denisenkom/go-mssqldb/mssql_go1.3pre.go +++ /dev/null @@ -1,11 +0,0 @@ -// +build !go1.3 - -package mssql - -import ( - "net" -) - -func createDialer(p *connectParams) *net.Dialer { - return &net.Dialer{Timeout: p.dial_timeout} -} diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go b/vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go new file mode 100644 index 0000000..6d76fba --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/mssql_go110.go @@ -0,0 +1,52 @@ +// +build go1.10 + +package mssql + +import ( + "context" + "database/sql/driver" + "errors" +) + +var _ driver.Connector = &Connector{} +var _ driver.SessionResetter = &Conn{} + +func (c *Conn) ResetSession(ctx context.Context) error { + if !c.connectionGood { + return driver.ErrBadConn + } + c.resetSession = true + + if c.connector == nil || len(c.connector.SessionInitSQL) == 0 { + return nil + } + + s, err := c.prepareContext(ctx, c.connector.SessionInitSQL) + if err != nil { + return driver.ErrBadConn + } + _, err = s.exec(ctx, nil) + if err != nil { + return driver.ErrBadConn + } + + return nil +} + +// Connect to the server and return a TDS connection. +func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { + conn, err := c.driver.connect(ctx, c, c.params) + if err == nil { + err = conn.ResetSession(ctx) + } + return conn, err +} + +// Driver underlying the Connector. +func (c *Connector) Driver() driver.Driver { + return c.driver +} + +func (r *Result) LastInsertId() (int64, error) { + return -1, errors.New("LastInsertId is not supported. Please use the OUTPUT clause or add `select ID = convert(bigint, SCOPE_IDENTITY())` to the end of your query.") +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql_go110pre.go b/vendor/github.com/denisenkom/go-mssqldb/mssql_go110pre.go new file mode 100644 index 0000000..3baae0c --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/mssql_go110pre.go @@ -0,0 +1,31 @@ +// +build !go1.10 + +package mssql + +import ( + "database/sql/driver" + "errors" +) + +func (r *Result) LastInsertId() (int64, error) { + s, err := r.c.Prepare("select cast(@@identity as bigint)") + if err != nil { + return 0, err + } + defer s.Close() + rows, err := s.Query(nil) + if err != nil { + return 0, err + } + defer rows.Close() + dest := make([]driver.Value, 1) + err = rows.Next(dest) + if err != nil { + return 0, err + } + if dest[0] == nil { + return -1, errors.New("There is no generated identity value") + } + lastInsertId := dest[0].(int64) + return lastInsertId, nil +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go b/vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go new file mode 100644 index 0000000..a2bd116 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/mssql_go19.go @@ -0,0 +1,196 @@ +// +build go1.9 + +package mssql + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "reflect" + "time" + + // "github.com/cockroachdb/apd" + "github.com/golang-sql/civil" +) + +// Type alias provided for compatibility. + +type MssqlDriver = Driver // Deprecated: users should transition to the new name when possible. +type MssqlBulk = Bulk // Deprecated: users should transition to the new name when possible. +type MssqlBulkOptions = BulkOptions // Deprecated: users should transition to the new name when possible. +type MssqlConn = Conn // Deprecated: users should transition to the new name when possible. +type MssqlResult = Result // Deprecated: users should transition to the new name when possible. +type MssqlRows = Rows // Deprecated: users should transition to the new name when possible. +type MssqlStmt = Stmt // Deprecated: users should transition to the new name when possible. + +var _ driver.NamedValueChecker = &Conn{} + +// VarChar parameter types. +type VarChar string + +type NVarCharMax string +type VarCharMax string + +// DateTime1 encodes parameters to original DateTime SQL types. +type DateTime1 time.Time + +// DateTimeOffset encodes parameters to DateTimeOffset, preserving the UTC offset. +type DateTimeOffset time.Time + +func convertInputParameter(val interface{}) (interface{}, error) { + switch v := val.(type) { + case VarChar: + return val, nil + case NVarCharMax: + return val, nil + case VarCharMax: + return val, nil + case DateTime1: + return val, nil + case DateTimeOffset: + return val, nil + case civil.Date: + return val, nil + case civil.DateTime: + return val, nil + case civil.Time: + return val, nil + // case *apd.Decimal: + // return nil + default: + return driver.DefaultParameterConverter.ConvertValue(v) + } +} + +func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error { + switch v := nv.Value.(type) { + case sql.Out: + if c.outs == nil { + c.outs = make(map[string]interface{}) + } + c.outs[nv.Name] = v.Dest + + if v.Dest == nil { + return errors.New("destination is a nil pointer") + } + + dest_info := reflect.ValueOf(v.Dest) + if dest_info.Kind() != reflect.Ptr { + return errors.New("destination not a pointer") + } + + if dest_info.IsNil() { + return errors.New("destination is a nil pointer") + } + + pointed_value := reflect.Indirect(dest_info) + + // don't allow pointer to a pointer, only pointer to a value can be handled + // correctly + if pointed_value.Kind() == reflect.Ptr { + return errors.New("destination is a pointer to a pointer") + } + + // Unwrap the Out value and check the inner value. + val := pointed_value.Interface() + if val == nil { + return errors.New("MSSQL does not allow NULL value without type for OUTPUT parameters") + } + conv, err := convertInputParameter(val) + if err != nil { + return err + } + if conv == nil { + // if we replace with nil we would lose type information + nv.Value = sql.Out{Dest: val} + } else { + nv.Value = sql.Out{Dest: conv} + } + return nil + case *ReturnStatus: + *v = 0 // By default the return value should be zero. + c.returnStatus = v + return driver.ErrRemoveArgument + case TVP: + return nil + default: + var err error + nv.Value, err = convertInputParameter(nv.Value) + return err + } +} + +func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) { + switch val := val.(type) { + case VarChar: + res.ti.TypeId = typeBigVarChar + res.buffer = []byte(val) + res.ti.Size = len(res.buffer) + case VarCharMax: + res.ti.TypeId = typeBigVarChar + res.buffer = []byte(val) + res.ti.Size = 0 // currently zero forces varchar(max) + case NVarCharMax: + res.ti.TypeId = typeNVarChar + res.buffer = str2ucs2(string(val)) + res.ti.Size = 0 // currently zero forces nvarchar(max) + case DateTime1: + t := time.Time(val) + res.ti.TypeId = typeDateTimeN + res.buffer = encodeDateTime(t) + res.ti.Size = len(res.buffer) + case DateTimeOffset: + res.ti.TypeId = typeDateTimeOffsetN + res.ti.Scale = 7 + res.buffer = encodeDateTimeOffset(time.Time(val), int(res.ti.Scale)) + res.ti.Size = len(res.buffer) + case civil.Date: + res.ti.TypeId = typeDateN + res.buffer = encodeDate(val.In(time.UTC)) + res.ti.Size = len(res.buffer) + case civil.DateTime: + res.ti.TypeId = typeDateTime2N + res.ti.Scale = 7 + res.buffer = encodeDateTime2(val.In(time.UTC), int(res.ti.Scale)) + res.ti.Size = len(res.buffer) + case civil.Time: + res.ti.TypeId = typeTimeN + res.ti.Scale = 7 + res.buffer = encodeTime(val.Hour, val.Minute, val.Second, val.Nanosecond, int(res.ti.Scale)) + res.ti.Size = len(res.buffer) + case sql.Out: + res, err = s.makeParam(val.Dest) + res.Flags = fByRevValue + case TVP: + err = val.check() + if err != nil { + return + } + schema, name, errGetName := getSchemeAndName(val.TypeName) + if errGetName != nil { + return + } + res.ti.UdtInfo.TypeName = name + res.ti.UdtInfo.SchemaName = schema + res.ti.TypeId = typeTvp + columnStr, tvpFieldIndexes, errCalTypes := val.columnTypes() + if errCalTypes != nil { + err = errCalTypes + return + } + res.buffer, err = val.encode(schema, name, columnStr, tvpFieldIndexes) + if err != nil { + return + } + res.ti.Size = len(res.buffer) + + default: + err = fmt.Errorf("mssql: unknown type for %T", val) + } + return +} + +func scanIntoOut(name string, fromServer, scanInto interface{}) error { + return convertAssign(scanInto, fromServer) +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql_go19pre.go b/vendor/github.com/denisenkom/go-mssqldb/mssql_go19pre.go new file mode 100644 index 0000000..9680f51 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/mssql_go19pre.go @@ -0,0 +1,16 @@ +// +build !go1.9 + +package mssql + +import ( + "database/sql/driver" + "fmt" +) + +func (s *Stmt) makeParamExtra(val driver.Value) (param, error) { + return param{}, fmt.Errorf("mssql: unknown type for %T", val) +} + +func scanIntoOut(name string, fromServer, scanInto interface{}) error { + return fmt.Errorf("mssql: unsupported OUTPUT type, use a newer Go version") +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/net.go b/vendor/github.com/denisenkom/go-mssqldb/net.go index 72a8734..94858cc 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/net.go +++ b/vendor/github.com/denisenkom/go-mssqldb/net.go @@ -9,12 +9,9 @@ import ( type timeoutConn struct { c net.Conn timeout time.Duration - buf *tdsBuffer - packetPending bool - continueRead bool } -func NewTimeoutConn(conn net.Conn, timeout time.Duration) *timeoutConn { +func newTimeoutConn(conn net.Conn, timeout time.Duration) *timeoutConn { return &timeoutConn{ c: conn, timeout: timeout, @@ -22,54 +19,21 @@ func NewTimeoutConn(conn net.Conn, timeout time.Duration) *timeoutConn { } func (c *timeoutConn) Read(b []byte) (n int, err error) { - if c.buf != nil { - if c.packetPending { - c.packetPending = false - err = c.buf.FinishPacket() - if err != nil { - err = fmt.Errorf("Cannot send handshake packet: %s", err.Error()) - return - } - c.continueRead = false - } - if !c.continueRead { - var packet uint8 - packet, err = c.buf.BeginRead() - if err != nil { - err = fmt.Errorf("Cannot read handshake packet: %s", err.Error()) - return - } - if packet != packPrelogin { - err = fmt.Errorf("unexpected packet %d, expecting prelogin", packet) - return - } - c.continueRead = true + if c.timeout > 0 { + err = c.c.SetDeadline(time.Now().Add(c.timeout)) + if err != nil { + return } - n, err = c.buf.Read(b) - return - } - err = c.c.SetDeadline(time.Now().Add(c.timeout)) - if err != nil { - return } return c.c.Read(b) } func (c *timeoutConn) Write(b []byte) (n int, err error) { - if c.buf != nil { - if !c.packetPending { - c.buf.BeginPacket(packPrelogin) - c.packetPending = true - } - n, err = c.buf.Write(b) + if c.timeout > 0 { + err = c.c.SetDeadline(time.Now().Add(c.timeout)) if err != nil { return } - return - } - err = c.c.SetDeadline(time.Now().Add(c.timeout)) - if err != nil { - return } return c.c.Write(b) } @@ -97,3 +61,108 @@ func (c timeoutConn) SetReadDeadline(t time.Time) error { func (c timeoutConn) SetWriteDeadline(t time.Time) error { panic("Not implemented") } + +// this connection is used during TLS Handshake +// TDS protocol requires TLS handshake messages to be sent inside TDS packets +type tlsHandshakeConn struct { + buf *tdsBuffer + packetPending bool + continueRead bool +} + +func (c *tlsHandshakeConn) Read(b []byte) (n int, err error) { + if c.packetPending { + c.packetPending = false + err = c.buf.FinishPacket() + if err != nil { + err = fmt.Errorf("Cannot send handshake packet: %s", err.Error()) + return + } + c.continueRead = false + } + if !c.continueRead { + var packet packetType + packet, err = c.buf.BeginRead() + if err != nil { + err = fmt.Errorf("Cannot read handshake packet: %s", err.Error()) + return + } + if packet != packPrelogin { + err = fmt.Errorf("unexpected packet %d, expecting prelogin", packet) + return + } + c.continueRead = true + } + return c.buf.Read(b) +} + +func (c *tlsHandshakeConn) Write(b []byte) (n int, err error) { + if !c.packetPending { + c.buf.BeginPacket(packPrelogin, false) + c.packetPending = true + } + return c.buf.Write(b) +} + +func (c *tlsHandshakeConn) Close() error { + panic("Not implemented") +} + +func (c *tlsHandshakeConn) LocalAddr() net.Addr { + panic("Not implemented") +} + +func (c *tlsHandshakeConn) RemoteAddr() net.Addr { + panic("Not implemented") +} + +func (c *tlsHandshakeConn) SetDeadline(t time.Time) error { + panic("Not implemented") +} + +func (c *tlsHandshakeConn) SetReadDeadline(t time.Time) error { + panic("Not implemented") +} + +func (c *tlsHandshakeConn) SetWriteDeadline(t time.Time) error { + panic("Not implemented") +} + +// this connection just delegates all methods to it's wrapped connection +// it also allows switching underlying connection on the fly +// it is needed because tls.Conn does not allow switching underlying connection +type passthroughConn struct { + c net.Conn +} + +func (c passthroughConn) Read(b []byte) (n int, err error) { + return c.c.Read(b) +} + +func (c passthroughConn) Write(b []byte) (n int, err error) { + return c.c.Write(b) +} + +func (c passthroughConn) Close() error { + return c.c.Close() +} + +func (c passthroughConn) LocalAddr() net.Addr { + panic("Not implemented") +} + +func (c passthroughConn) RemoteAddr() net.Addr { + panic("Not implemented") +} + +func (c passthroughConn) SetDeadline(t time.Time) error { + panic("Not implemented") +} + +func (c passthroughConn) SetReadDeadline(t time.Time) error { + panic("Not implemented") +} + +func (c passthroughConn) SetWriteDeadline(t time.Time) error { + panic("Not implemented") +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/ntlm.go b/vendor/github.com/denisenkom/go-mssqldb/ntlm.go index f853435..7c0cc4f 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/ntlm.go +++ b/vendor/github.com/denisenkom/go-mssqldb/ntlm.go @@ -15,56 +15,56 @@ import ( ) const ( - NEGOTIATE_MESSAGE = 1 - CHALLENGE_MESSAGE = 2 - AUTHENTICATE_MESSAGE = 3 + _NEGOTIATE_MESSAGE = 1 + _CHALLENGE_MESSAGE = 2 + _AUTHENTICATE_MESSAGE = 3 ) const ( - NEGOTIATE_UNICODE = 0x00000001 - NEGOTIATE_OEM = 0x00000002 - NEGOTIATE_TARGET = 0x00000004 - NEGOTIATE_SIGN = 0x00000010 - NEGOTIATE_SEAL = 0x00000020 - NEGOTIATE_DATAGRAM = 0x00000040 - NEGOTIATE_LMKEY = 0x00000080 - NEGOTIATE_NTLM = 0x00000200 - NEGOTIATE_ANONYMOUS = 0x00000800 - NEGOTIATE_OEM_DOMAIN_SUPPLIED = 0x00001000 - NEGOTIATE_OEM_WORKSTATION_SUPPLIED = 0x00002000 - NEGOTIATE_ALWAYS_SIGN = 0x00008000 - NEGOTIATE_TARGET_TYPE_DOMAIN = 0x00010000 - NEGOTIATE_TARGET_TYPE_SERVER = 0x00020000 - NEGOTIATE_EXTENDED_SESSIONSECURITY = 0x00080000 - NEGOTIATE_IDENTIFY = 0x00100000 - REQUEST_NON_NT_SESSION_KEY = 0x00400000 - NEGOTIATE_TARGET_INFO = 0x00800000 - NEGOTIATE_VERSION = 0x02000000 - NEGOTIATE_128 = 0x20000000 - NEGOTIATE_KEY_EXCH = 0x40000000 - NEGOTIATE_56 = 0x80000000 + _NEGOTIATE_UNICODE = 0x00000001 + _NEGOTIATE_OEM = 0x00000002 + _NEGOTIATE_TARGET = 0x00000004 + _NEGOTIATE_SIGN = 0x00000010 + _NEGOTIATE_SEAL = 0x00000020 + _NEGOTIATE_DATAGRAM = 0x00000040 + _NEGOTIATE_LMKEY = 0x00000080 + _NEGOTIATE_NTLM = 0x00000200 + _NEGOTIATE_ANONYMOUS = 0x00000800 + _NEGOTIATE_OEM_DOMAIN_SUPPLIED = 0x00001000 + _NEGOTIATE_OEM_WORKSTATION_SUPPLIED = 0x00002000 + _NEGOTIATE_ALWAYS_SIGN = 0x00008000 + _NEGOTIATE_TARGET_TYPE_DOMAIN = 0x00010000 + _NEGOTIATE_TARGET_TYPE_SERVER = 0x00020000 + _NEGOTIATE_EXTENDED_SESSIONSECURITY = 0x00080000 + _NEGOTIATE_IDENTIFY = 0x00100000 + _REQUEST_NON_NT_SESSION_KEY = 0x00400000 + _NEGOTIATE_TARGET_INFO = 0x00800000 + _NEGOTIATE_VERSION = 0x02000000 + _NEGOTIATE_128 = 0x20000000 + _NEGOTIATE_KEY_EXCH = 0x40000000 + _NEGOTIATE_56 = 0x80000000 ) -const NEGOTIATE_FLAGS = NEGOTIATE_UNICODE | - NEGOTIATE_NTLM | - NEGOTIATE_OEM_DOMAIN_SUPPLIED | - NEGOTIATE_OEM_WORKSTATION_SUPPLIED | - NEGOTIATE_ALWAYS_SIGN | - NEGOTIATE_EXTENDED_SESSIONSECURITY +const _NEGOTIATE_FLAGS = _NEGOTIATE_UNICODE | + _NEGOTIATE_NTLM | + _NEGOTIATE_OEM_DOMAIN_SUPPLIED | + _NEGOTIATE_OEM_WORKSTATION_SUPPLIED | + _NEGOTIATE_ALWAYS_SIGN | + _NEGOTIATE_EXTENDED_SESSIONSECURITY -type NTLMAuth struct { +type ntlmAuth struct { Domain string UserName string Password string Workstation string } -func getAuth(user, password, service, workstation string) (Auth, bool) { +func getAuth(user, password, service, workstation string) (auth, bool) { if !strings.ContainsRune(user, '\\') { return nil, false } domain_user := strings.SplitN(user, "\\", 2) - return &NTLMAuth{ + return &ntlmAuth{ Domain: domain_user[0], UserName: domain_user[1], Password: password, @@ -86,13 +86,13 @@ func utf16le(val string) []byte { return v } -func (auth *NTLMAuth) InitialBytes() ([]byte, error) { +func (auth *ntlmAuth) InitialBytes() ([]byte, error) { domain_len := len(auth.Domain) workstation_len := len(auth.Workstation) msg := make([]byte, 40+domain_len+workstation_len) copy(msg, []byte("NTLMSSP\x00")) - binary.LittleEndian.PutUint32(msg[8:], NEGOTIATE_MESSAGE) - binary.LittleEndian.PutUint32(msg[12:], NEGOTIATE_FLAGS) + binary.LittleEndian.PutUint32(msg[8:], _NEGOTIATE_MESSAGE) + binary.LittleEndian.PutUint32(msg[12:], _NEGOTIATE_FLAGS) // Domain Name Fields binary.LittleEndian.PutUint16(msg[16:], uint16(domain_len)) binary.LittleEndian.PutUint16(msg[18:], uint16(domain_len)) @@ -198,11 +198,11 @@ func ntlmSessionResponse(clientNonce [8]byte, serverChallenge [8]byte, password return response(hash, passwordHash) } -func (auth *NTLMAuth) NextBytes(bytes []byte) ([]byte, error) { +func (auth *ntlmAuth) NextBytes(bytes []byte) ([]byte, error) { if string(bytes[0:8]) != "NTLMSSP\x00" { return nil, errorNTLM } - if binary.LittleEndian.Uint32(bytes[8:12]) != CHALLENGE_MESSAGE { + if binary.LittleEndian.Uint32(bytes[8:12]) != _CHALLENGE_MESSAGE { return nil, errorNTLM } flags := binary.LittleEndian.Uint32(bytes[20:24]) @@ -210,7 +210,7 @@ func (auth *NTLMAuth) NextBytes(bytes []byte) ([]byte, error) { copy(challenge[:], bytes[24:32]) var lm, nt []byte - if (flags & NEGOTIATE_EXTENDED_SESSIONSECURITY) != 0 { + if (flags & _NEGOTIATE_EXTENDED_SESSIONSECURITY) != 0 { nonce := clientChallenge() var lm_bytes [24]byte copy(lm_bytes[:8], nonce[:]) @@ -235,7 +235,7 @@ func (auth *NTLMAuth) NextBytes(bytes []byte) ([]byte, error) { msg := make([]byte, 88+lm_len+nt_len+domain_len+user_len+workstation_len) copy(msg, []byte("NTLMSSP\x00")) - binary.LittleEndian.PutUint32(msg[8:], AUTHENTICATE_MESSAGE) + binary.LittleEndian.PutUint32(msg[8:], _AUTHENTICATE_MESSAGE) // Lm Challenge Response Fields binary.LittleEndian.PutUint16(msg[12:], uint16(lm_len)) binary.LittleEndian.PutUint16(msg[14:], uint16(lm_len)) @@ -279,5 +279,5 @@ func (auth *NTLMAuth) NextBytes(bytes []byte) ([]byte, error) { return msg, nil } -func (auth *NTLMAuth) Free() { +func (auth *ntlmAuth) Free() { } diff --git a/vendor/github.com/denisenkom/go-mssqldb/rpc.go b/vendor/github.com/denisenkom/go-mssqldb/rpc.go index 00b9b1e..4ca2257 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/rpc.go +++ b/vendor/github.com/denisenkom/go-mssqldb/rpc.go @@ -4,7 +4,7 @@ import ( "encoding/binary" ) -type ProcId struct { +type procId struct { id uint16 name string } @@ -15,24 +15,13 @@ const ( fDefaultValue = 2 ) -type Param struct { +type param struct { Name string Flags uint8 ti typeInfo buffer []byte } -func MakeProcId(name string) (res ProcId) { - res.name = name - if len(name) == 0 { - panic("Proc name shouln't be empty") - } - if len(name) >= 0xffff { - panic("Invalid length of procedure name, should be less than 0xffff") - } - return res -} - const ( fWithRecomp = 1 fNoMetaData = 2 @@ -40,25 +29,25 @@ const ( ) var ( - Sp_Cursor = ProcId{1, ""} - Sp_CursorOpen = ProcId{2, ""} - Sp_CursorPrepare = ProcId{3, ""} - Sp_CursorExecute = ProcId{4, ""} - Sp_CursorPrepExec = ProcId{5, ""} - Sp_CursorUnprepare = ProcId{6, ""} - Sp_CursorFetch = ProcId{7, ""} - Sp_CursorOption = ProcId{8, ""} - Sp_CursorClose = ProcId{9, ""} - Sp_ExecuteSql = ProcId{10, ""} - Sp_Prepare = ProcId{11, ""} - Sp_PrepExec = ProcId{13, ""} - Sp_PrepExecRpc = ProcId{14, ""} - Sp_Unprepare = ProcId{15, ""} + sp_Cursor = procId{1, ""} + sp_CursorOpen = procId{2, ""} + sp_CursorPrepare = procId{3, ""} + sp_CursorExecute = procId{4, ""} + sp_CursorPrepExec = procId{5, ""} + sp_CursorUnprepare = procId{6, ""} + sp_CursorFetch = procId{7, ""} + sp_CursorOption = procId{8, ""} + sp_CursorClose = procId{9, ""} + sp_ExecuteSql = procId{10, ""} + sp_Prepare = procId{11, ""} + sp_PrepExec = procId{13, ""} + sp_PrepExecRpc = procId{14, ""} + sp_Unprepare = procId{15, ""} ) // http://msdn.microsoft.com/en-us/library/dd357576.aspx -func sendRpc(buf *tdsBuffer, headers []headerStruct, proc ProcId, flags uint16, params []Param) (err error) { - buf.BeginPacket(packRPCRequest) +func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool) (err error) { + buf.BeginPacket(packRPCRequest, resetSession) writeAllHeaders(buf, headers) if len(proc.name) == 0 { var idswitch uint16 = 0xffff diff --git a/vendor/github.com/denisenkom/go-mssqldb/sspi_windows.go b/vendor/github.com/denisenkom/go-mssqldb/sspi_windows.go index a6e9505..9b5bc68 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/sspi_windows.go +++ b/vendor/github.com/denisenkom/go-mssqldb/sspi_windows.go @@ -113,7 +113,7 @@ type SSPIAuth struct { ctxt SecHandle } -func getAuth(user, password, service, workstation string) (Auth, bool) { +func getAuth(user, password, service, workstation string) (auth, bool) { if user == "" { return &SSPIAuth{Service: service}, true } diff --git a/vendor/github.com/denisenkom/go-mssqldb/tds.go b/vendor/github.com/denisenkom/go-mssqldb/tds.go index bab15e5..5a9f53b 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/tds.go +++ b/vendor/github.com/denisenkom/go-mssqldb/tds.go @@ -1,6 +1,7 @@ package mssql import ( + "context" "crypto/tls" "crypto/x509" "encoding/binary" @@ -9,11 +10,9 @@ import ( "io" "io/ioutil" "net" - "os" "sort" "strconv" "strings" - "time" "unicode/utf16" "unicode/utf8" ) @@ -47,13 +46,14 @@ func parseInstances(msg []byte) map[string]map[string]string { return results } -func getInstances(address string) (map[string]map[string]string, error) { - conn, err := net.DialTimeout("udp", address+":1434", 5*time.Second) +func getInstances(ctx context.Context, d Dialer, address string) (map[string]map[string]string, error) { + conn, err := d.DialContext(ctx, "udp", address+":1434") if err != nil { return nil, err } defer conn.Close() - conn.SetDeadline(time.Now().Add(5*time.Second)) + deadline, _ := ctx.Deadline() + conn.SetDeadline(deadline) _, err = conn.Write([]byte{3}) if err != nil { return nil, err @@ -79,11 +79,16 @@ const ( ) // packet types +// https://msdn.microsoft.com/en-us/library/dd304214.aspx const ( - packSQLBatch = 1 - packRPCRequest = 3 - packReply = 4 - packCancel = 6 + packSQLBatch packetType = 1 + packRPCRequest = 3 + packReply = 4 + + // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx + // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx + packAttention = 6 + packBulkLoadBCP = 7 packTransMgrReq = 14 packNormal = 15 @@ -119,7 +124,7 @@ type tdsSession struct { columns []columnStruct tranid uint64 logFlags uint64 - log *Logger + log optionalLogger routedServer string routedPort uint16 } @@ -131,6 +136,7 @@ const ( logSQL = 8 logParams = 16 logTransaction = 32 + logDebug = 64 ) type columnStruct struct { @@ -140,19 +146,19 @@ type columnStruct struct { ti typeInfo } -type KeySlice []uint8 +type keySlice []uint8 -func (p KeySlice) Len() int { return len(p) } -func (p KeySlice) Less(i, j int) bool { return p[i] < p[j] } -func (p KeySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } +func (p keySlice) Len() int { return len(p) } +func (p keySlice) Less(i, j int) bool { return p[i] < p[j] } +func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // http://msdn.microsoft.com/en-us/library/dd357559.aspx func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error { var err error - w.BeginPacket(packPrelogin) + w.BeginPacket(packPrelogin, false) offset := uint16(5*len(fields) + 1) - keys := make(KeySlice, 0, len(fields)) + keys := make(keySlice, 0, len(fields)) for k, _ := range fields { keys = append(keys, k) } @@ -340,7 +346,7 @@ func manglePassword(password string) []byte { // http://msdn.microsoft.com/en-us/library/dd304019.aspx func sendLogin(w *tdsBuffer, login login) error { - w.BeginPacket(packLogin7) + w.BeginPacket(packLogin7, false) hostname := str2ucs2(login.HostName) username := str2ucs2(login.UserName) password := manglePassword(login.Password) @@ -462,10 +468,9 @@ func readUcs2(r io.Reader, numchars int) (res string, err error) { } func readUsVarChar(r io.Reader) (res string, err error) { - var numchars uint16 - err = binary.Read(r, binary.LittleEndian, &numchars) + numchars, err := readUshort(r) if err != nil { - return "", err + return } return readUcs2(r, int(numchars)) } @@ -485,11 +490,15 @@ func writeUsVarChar(w io.Writer, s string) (err error) { } func readBVarChar(r io.Reader) (res string, err error) { - var numchars uint8 - err = binary.Read(r, binary.LittleEndian, &numchars) + numchars, err := readByte(r) if err != nil { return "", err } + + // A zero length could be returned, return an empty string + if numchars == 0 { + return "", nil + } return readUcs2(r, int(numchars)) } @@ -508,8 +517,7 @@ func writeBVarChar(w io.Writer, s string) (err error) { } func readBVarByte(r io.Reader) (res []byte, err error) { - var length uint8 - err = binary.Read(r, binary.LittleEndian, &length) + length, err := readByte(r) if err != nil { return } @@ -543,6 +551,36 @@ const ( dataStmHdrTraceActivity = 3 ) +// Query Notifications Header +// http://msdn.microsoft.com/en-us/library/dd304949.aspx +type queryNotifHdr struct { + notifyId string + ssbDeployment string + notifyTimeout uint32 +} + +func (hdr queryNotifHdr) pack() (res []byte) { + notifyId := str2ucs2(hdr.notifyId) + ssbDeployment := str2ucs2(hdr.ssbDeployment) + + res = make([]byte, 2+len(notifyId)+2+len(ssbDeployment)+4) + b := res + + binary.LittleEndian.PutUint16(b, uint16(len(notifyId))) + b = b[2:] + copy(b, notifyId) + b = b[len(notifyId):] + + binary.LittleEndian.PutUint16(b, uint16(len(ssbDeployment))) + b = b[2:] + copy(b, ssbDeployment) + b = b[len(ssbDeployment):] + + binary.LittleEndian.PutUint32(b, hdr.notifyTimeout) + + return res +} + // MARS Transaction Descriptor Header // http://msdn.microsoft.com/en-us/library/dd340515.aspx type transDescrHdr struct { @@ -558,7 +596,7 @@ func (hdr transDescrHdr) pack() (res []byte) { } func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) { - // calculatint total length + // Calculating total length. var totallen uint32 = 4 for _, hdr := range headers { totallen += 4 + 2 + uint32(len(hdr.data)) @@ -586,186 +624,28 @@ func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) { return nil } -func sendSqlBatch72(buf *tdsBuffer, - sqltext string, - headers []headerStruct) (err error) { - buf.BeginPacket(packSQLBatch) +func sendSqlBatch72(buf *tdsBuffer, sqltext string, headers []headerStruct, resetSession bool) (err error) { + buf.BeginPacket(packSQLBatch, resetSession) - writeAllHeaders(buf, headers) + if err = writeAllHeaders(buf, headers); err != nil { + return + } _, err = buf.Write(str2ucs2(sqltext)) if err != nil { - return err + return } return buf.FinishPacket() } -type connectParams struct { - logFlags uint64 - port uint64 - host string - instance string - database string - user string - password string - dial_timeout time.Duration - conn_timeout time.Duration - keepAlive time.Duration - encrypt bool - disableEncryption bool - trustServerCertificate bool - certificate string - hostInCertificate string - serverSPN string - workstation string - appname string - typeFlags uint8 -} - -func parseConnectParams(params map[string]string) (*connectParams, error) { - var p connectParams - strlog, ok := params["log"] - if ok { - var err error - p.logFlags, err = strconv.ParseUint(strlog, 10, 0) - if err != nil { - return nil, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error()) - } - } - server := params["server"] - parts := strings.SplitN(server, "\\", 2) - p.host = parts[0] - if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" { - p.host = "localhost" - } - if len(parts) > 1 { - p.instance = parts[1] - } - p.database = params["database"] - p.user = params["user id"] - p.password = params["password"] - p.port = 1433 - if p.instance != "" { - p.instance = strings.ToUpper(p.instance) - instances, err := getInstances(p.host) - if err != nil { - f := "Unable to get instances from Sql Server Browser on host %v: %v" - return nil, fmt.Errorf(f, p.host, err.Error()) - } - strport, ok := instances[p.instance]["tcp"] - if !ok { - f := "No instance matching '%v' returned from host '%v'" - return nil, fmt.Errorf(f, p.instance, p.host) - } - p.port, err = strconv.ParseUint(strport, 0, 16) - if err != nil { - f := "Invalid tcp port returned from Sql Server Browser '%v': %v" - return nil, fmt.Errorf(f, strport, err.Error()) - } - } else { - strport, ok := params["port"] - if ok { - var err error - p.port, err = strconv.ParseUint(strport, 0, 16) - if err != nil { - f := "Invalid tcp port '%v': %v" - return nil, fmt.Errorf(f, strport, err.Error()) - } - } - } - p.dial_timeout = 5 * time.Second - p.conn_timeout = 30 * time.Second - strconntimeout, ok := params["connection timeout"] - if ok { - timeout, err := strconv.ParseUint(strconntimeout, 0, 16) - if err != nil { - f := "Invalid connection timeout '%v': %v" - return nil, fmt.Errorf(f, strconntimeout, err.Error()) - } - p.conn_timeout = time.Duration(timeout) * time.Second - } - strdialtimeout, ok := params["dial timeout"] - if ok { - timeout, err := strconv.ParseUint(strdialtimeout, 0, 16) - if err != nil { - f := "Invalid dial timeout '%v': %v" - return nil, fmt.Errorf(f, strdialtimeout, err.Error()) - } - p.dial_timeout = time.Duration(timeout) * time.Second - } - keepAlive, ok := params["keepalive"] - if ok { - timeout, err := strconv.ParseUint(keepAlive, 0, 16) - if err != nil { - f := "Invalid keepAlive value '%s': %s" - return nil, fmt.Errorf(f, keepAlive, err.Error()) - } - p.keepAlive = time.Duration(timeout) * time.Second - } - encrypt, ok := params["encrypt"] - if ok { - if strings.ToUpper(encrypt) == "DISABLE" { - p.disableEncryption = true - } else { - var err error - p.encrypt, err = strconv.ParseBool(encrypt) - if err != nil { - f := "Invalid encrypt '%s': %s" - return nil, fmt.Errorf(f, encrypt, err.Error()) - } - } - } else { - p.trustServerCertificate = true - } - trust, ok := params["trustservercertificate"] - if ok { - var err error - p.trustServerCertificate, err = strconv.ParseBool(trust) - if err != nil { - f := "Invalid trust server certificate '%s': %s" - return nil, fmt.Errorf(f, trust, err.Error()) - } - } - p.certificate = params["certificate"] - p.hostInCertificate, ok = params["hostnameincertificate"] - if !ok { - p.hostInCertificate = p.host - } - - serverSPN, ok := params["ServerSPN"] - if ok { - p.serverSPN = serverSPN - } else { - p.serverSPN = fmt.Sprintf("MSSQLSvc/%s:%d", p.host, p.port) - } - - workstation, ok := params["Workstation ID"] - if ok { - p.workstation = workstation - } else { - workstation, err := os.Hostname() - if err == nil { - p.workstation = workstation - } - } - - appname, ok := params["app name"] - if !ok { - appname = "go-mssqldb" - } - p.appname = appname - - appintent, ok := params["applicationintent"] - if ok { - if appintent == "ReadOnly" { - p.typeFlags |= fReadOnlyIntent - } - } - - return &p, nil +// 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx +// 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx +func sendAttention(buf *tdsBuffer) error { + buf.BeginPacket(packAttention, false) + return buf.FinishPacket() } -type Auth interface { +type auth interface { InitialBytes() ([]byte, error) NextBytes([]byte) ([]byte, error) Free() @@ -774,7 +654,7 @@ type Auth interface { // SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a // list of IP addresses. So if there is more than one, try them all and // use the first one that allows a connection. -func dialConnection(p *connectParams) (conn net.Conn, err error) { +func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn net.Conn, err error) { var ips []net.IP ips, err = net.LookupIP(p.host) if err != nil { @@ -785,19 +665,20 @@ func dialConnection(p *connectParams) (conn net.Conn, err error) { ips = []net.IP{ip} } if len(ips) == 1 { - d := createDialer(p) + d := c.getDialer(&p) addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port))) - conn, err = d.Dial("tcp", addr) + conn, err = d.DialContext(ctx, "tcp", addr) } else { //Try Dials in parallel to avoid waiting for timeouts. connChan := make(chan net.Conn, len(ips)) errChan := make(chan error, len(ips)) + portStr := strconv.Itoa(int(p.port)) for _, ip := range ips { go func(ip net.IP) { - d := createDialer(p) - addr := net.JoinHostPort(ip.String(), strconv.Itoa(int(p.port))) - conn, err := d.Dial("tcp", addr) + d := c.getDialer(&p) + addr := net.JoinHostPort(ip.String(), portStr) + conn, err := d.DialContext(ctx, "tcp", addr) if err == nil { connChan <- conn } else { @@ -832,27 +713,49 @@ func dialConnection(p *connectParams) (conn net.Conn, err error) { f := "Unable to open tcp connection with host '%v:%v': %v" return nil, fmt.Errorf(f, p.host, p.port, err.Error()) } - return conn, err } -func connect(params map[string]string) (res *tdsSession, err error) { - p, err := parseConnectParams(params) - if err != nil { - return nil, err +func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) { + dialCtx := ctx + if p.dial_timeout > 0 { + var cancel func() + dialCtx, cancel = context.WithTimeout(ctx, p.dial_timeout) + defer cancel() + } + // if instance is specified use instance resolution service + if p.instance != "" { + p.instance = strings.ToUpper(p.instance) + d := c.getDialer(&p) + instances, err := getInstances(dialCtx, d, p.host) + if err != nil { + f := "Unable to get instances from Sql Server Browser on host %v: %v" + return nil, fmt.Errorf(f, p.host, err.Error()) + } + strport, ok := instances[p.instance]["tcp"] + if !ok { + f := "No instance matching '%v' returned from host '%v'" + return nil, fmt.Errorf(f, p.instance, p.host) + } + p.port, err = strconv.ParseUint(strport, 0, 16) + if err != nil { + f := "Invalid tcp port returned from Sql Server Browser '%v': %v" + return nil, fmt.Errorf(f, strport, err.Error()) + } } initiate_connection: - conn, err := dialConnection(p) + conn, err := dialConnection(dialCtx, c, p) if err != nil { return nil, err } - toconn := NewTimeoutConn(conn, p.conn_timeout) + toconn := newTimeoutConn(conn, p.conn_timeout) - outbuf := newTdsBuffer(4096, toconn) + outbuf := newTdsBuffer(p.packetSize, toconn) sess := tdsSession{ buf: outbuf, + log: log, logFlags: p.logFlags, } @@ -898,8 +801,7 @@ initiate_connection: if p.certificate != "" { pem, err := ioutil.ReadFile(p.certificate) if err != nil { - f := "Cannot read certificate '%s': %s" - return nil, fmt.Errorf(f, p.certificate, err.Error()) + return nil, fmt.Errorf("Cannot read certificate %q: %v", p.certificate, err) } certs := x509.NewCertPool() certs.AppendCertsFromPEM(pem) @@ -909,15 +811,20 @@ initiate_connection: config.InsecureSkipVerify = true } config.ServerName = p.hostInCertificate - outbuf.transport = conn - toconn.buf = outbuf - tlsConn := tls.Client(toconn, &config) + // fix for https://github.com/denisenkom/go-mssqldb/issues/166 + // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, + // while SQL Server seems to expect one TCP segment per encrypted TDS package. + // Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package + config.DynamicRecordSizingDisabled = true + // setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream + handshakeConn := tlsHandshakeConn{buf: outbuf} + passthrough := passthroughConn{c: &handshakeConn} + tlsConn := tls.Client(&passthrough, &config) err = tlsConn.Handshake() - toconn.buf = nil + passthrough.c = toconn outbuf.transport = tlsConn if err != nil { - f := "TLS Handshake failed: %s" - return nil, fmt.Errorf(f, err.Error()) + return nil, fmt.Errorf("TLS Handshake failed: %v", err) } if encrypt == encryptOff { outbuf.afterFirst = func() { @@ -928,7 +835,7 @@ initiate_connection: login := login{ TDSVersion: verTDS74, - PacketSize: uint32(len(outbuf.buf)), + PacketSize: uint32(outbuf.PackageSize()), Database: p.database, OptionFlags2: fODBC, // to get unlimited TEXTSIZE HostName: p.workstation, @@ -954,38 +861,43 @@ initiate_connection: } // processing login response - var sspi_msg []byte -continue_login: - tokchan := make(chan tokenStruct, 5) - go processResponse(&sess, tokchan) success := false - for tok := range tokchan { - switch token := tok.(type) { - case sspiMsg: - sspi_msg, err = auth.NextBytes(token) - if err != nil { - return nil, err + for { + tokchan := make(chan tokenStruct, 5) + go processResponse(context.Background(), &sess, tokchan, nil) + for tok := range tokchan { + switch token := tok.(type) { + case sspiMsg: + sspi_msg, err := auth.NextBytes(token) + if err != nil { + return nil, err + } + if sspi_msg != nil && len(sspi_msg) > 0 { + outbuf.BeginPacket(packSSPIMessage, false) + _, err = outbuf.Write(sspi_msg) + if err != nil { + return nil, err + } + err = outbuf.FinishPacket() + if err != nil { + return nil, err + } + sspi_msg = nil + } + case loginAckStruct: + success = true + sess.loginAck = token + case error: + return nil, fmt.Errorf("Login error: %s", token.Error()) + case doneStruct: + if token.isError() { + return nil, fmt.Errorf("Login error: %s", token.getError()) + } + goto loginEnd } - case loginAckStruct: - success = true - sess.loginAck = token - case error: - return nil, fmt.Errorf("Login error: %s", token.Error()) - } - } - if sspi_msg != nil { - outbuf.BeginPacket(packSSPIMessage) - _, err = outbuf.Write(sspi_msg) - if err != nil { - return nil, err } - err = outbuf.FinishPacket() - if err != nil { - return nil, err - } - sspi_msg = nil - goto continue_login } +loginEnd: if !success { return nil, fmt.Errorf("Login failed") } @@ -993,6 +905,9 @@ continue_login: toconn.Close() p.host = sess.routedServer p.port = uint64(sess.routedPort) + if !p.hostInCertificateProvided { + p.hostInCertificate = sess.routedServer + } goto initiate_connection } return &sess, nil diff --git a/vendor/github.com/denisenkom/go-mssqldb/token.go b/vendor/github.com/denisenkom/go-mssqldb/token.go index 3292bf7..1acac8a 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/token.go +++ b/vendor/github.com/denisenkom/go-mssqldb/token.go @@ -1,30 +1,40 @@ package mssql import ( + "context" "encoding/binary" + "errors" + "fmt" "io" + "net" "strconv" "strings" ) +//go:generate stringer -type token + +type token byte + // token ids const ( - tokenReturnStatus = 121 // 0x79 - tokenColMetadata = 129 // 0x81 - tokenOrder = 169 // 0xA9 - tokenError = 170 // 0xAA - tokenInfo = 171 // 0xAB - tokenLoginAck = 173 // 0xad - tokenRow = 209 // 0xd1 - tokenNbcRow = 210 // 0xd2 - tokenEnvChange = 227 // 0xE3 - tokenSSPI = 237 // 0xED - tokenDone = 253 // 0xFD - tokenDoneProc = 254 - tokenDoneInProc = 255 + tokenReturnStatus token = 121 // 0x79 + tokenColMetadata token = 129 // 0x81 + tokenOrder token = 169 // 0xA9 + tokenError token = 170 // 0xAA + tokenInfo token = 171 // 0xAB + tokenReturnValue token = 0xAC + tokenLoginAck token = 173 // 0xad + tokenRow token = 209 // 0xd1 + tokenNbcRow token = 210 // 0xd2 + tokenEnvChange token = 227 // 0xE3 + tokenSSPI token = 237 // 0xED + tokenDone token = 253 // 0xFD + tokenDoneProc token = 254 + tokenDoneInProc token = 255 ) // done flags +// https://msdn.microsoft.com/en-us/library/dd340421.aspx const ( doneFinal = 0 doneMore = 1 @@ -42,13 +52,30 @@ const ( envTypLanguage = 2 envTypCharset = 3 envTypPacketSize = 4 + envSortId = 5 + envSortFlags = 6 + envSqlCollation = 7 envTypBeginTran = 8 envTypCommitTran = 9 envTypRollbackTran = 10 + envEnlistDTC = 11 + envDefectTran = 12 envDatabaseMirrorPartner = 13 + envPromoteTran = 15 + envTranMgrAddr = 16 + envTranEnded = 17 + envResetConnAck = 18 + envStartedInstanceName = 19 envRouting = 20 ) +// COLMETADATA flags +// https://msdn.microsoft.com/en-us/library/dd357363.aspx +const ( + colFlagNullable = 1 + // TODO implement more flags +) + // interface for all tokens type tokenStruct interface{} @@ -60,6 +87,19 @@ type doneStruct struct { Status uint16 CurCmd uint16 RowCount uint64 + errors []Error +} + +func (d doneStruct) isError() bool { + return d.Status&doneError != 0 || len(d.errors) > 0 +} + +func (d doneStruct) getError() Error { + if len(d.errors) > 0 { + return d.errors[len(d.errors)-1] + } else { + return Error{Message: "Request failed but didn't provide reason"} + } } type doneInProcStruct doneStruct @@ -109,6 +149,26 @@ func processEnvChg(sess *tdsSession) { if err != nil { badStreamPanic(err) } + case envTypLanguage: + // currently ignored + // new value + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + // old value + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + case envTypCharset: + // currently ignored + // new value + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + // old value + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } case envTypPacketSize: packetsize, err := readBVarChar(r) if err != nil { @@ -122,10 +182,57 @@ func processEnvChg(sess *tdsSession) { if err != nil { badStreamPanicf("Invalid Packet size value returned from server (%s): %s", packetsize, err.Error()) } - if len(sess.buf.buf) != packetsizei { - newbuf := make([]byte, packetsizei) - copy(newbuf, sess.buf.buf) - sess.buf.buf = newbuf + sess.buf.ResizeBuffer(packetsizei) + case envSortId: + // currently ignored + // new value + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + // old value, should be 0 + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + case envSortFlags: + // currently ignored + // new value + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + // old value, should be 0 + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + case envSqlCollation: + // currently ignored + var collationSize uint8 + err = binary.Read(r, binary.LittleEndian, &collationSize) + if err != nil { + badStreamPanic(err) + } + + // SQL Collation data should contain 5 bytes in length + if collationSize != 5 { + badStreamPanicf("Invalid SQL Collation size value returned from server: %d", collationSize) + } + + // 4 bytes, contains: LCID ColFlags Version + var info uint32 + err = binary.Read(r, binary.LittleEndian, &info) + if err != nil { + badStreamPanic(err) + } + + // 1 byte, contains: sortID + var sortID uint8 + err = binary.Read(r, binary.LittleEndian, &sortID) + if err != nil { + badStreamPanic(err) + } + + // old value, should be 0 + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) } case envTypBeginTran: tranid, err := readBVarByte(r) @@ -160,6 +267,26 @@ func processEnvChg(sess *tdsSession) { } } sess.tranid = 0 + case envEnlistDTC: + // currently ignored + // new value, should be 0 + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + // old value + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + case envDefectTran: + // currently ignored + // new value + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + // old value, should be 0 + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } case envDatabaseMirrorPartner: sess.partner, err = readBVarChar(r) if err != nil { @@ -169,6 +296,57 @@ func processEnvChg(sess *tdsSession) { if err != nil { badStreamPanic(err) } + case envPromoteTran: + // currently ignored + // old value, should be 0 + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + // dtc token + // spec says it should be L_VARBYTE, so this code might be wrong + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + case envTranMgrAddr: + // currently ignored + // old value, should be 0 + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + // XACT_MANAGER_ADDRESS = B_VARBYTE + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + case envTranEnded: + // currently ignored + // old value, B_VARBYTE + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + // should be 0 + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + case envResetConnAck: + // currently ignored + // old value, should be 0 + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + // should be 0 + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + case envStartedInstanceName: + // currently ignored + // old value, should be 0 + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + // instance name + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } case envRouting: // RoutingData message is: // ValueLength USHORT @@ -199,25 +377,17 @@ func processEnvChg(sess *tdsSession) { sess.routedServer = newServer sess.routedPort = newPort default: - // ignore unknown env change types - _, err = readBVarByte(r) - if err != nil { - badStreamPanic(err) - } - _, err = readBVarByte(r) - if err != nil { - badStreamPanic(err) - } + // ignore rest of records because we don't know how to skip those + sess.log.Printf("WARN: Unknown ENVCHANGE record detected with type id = %d\n", envtype) + break } } } -type returnStatus int32 - // http://msdn.microsoft.com/en-us/library/dd358180.aspx -func parseReturnStatus(r *tdsBuffer) returnStatus { - return returnStatus(r.int32()) +func parseReturnStatus(r *tdsBuffer) ReturnStatus { + return ReturnStatus(r.int32()) } func parseOrder(r *tdsBuffer) (res orderStruct) { @@ -229,6 +399,7 @@ func parseOrder(r *tdsBuffer) (res orderStruct) { return res } +// https://msdn.microsoft.com/en-us/library/dd340421.aspx func parseDone(r *tdsBuffer) (res doneStruct) { res.Status = r.uint16() res.CurCmd = r.uint16() @@ -236,6 +407,7 @@ func parseDone(r *tdsBuffer) (res doneStruct) { return res } +// https://msdn.microsoft.com/en-us/library/dd340553.aspx func parseDoneInProc(r *tdsBuffer) (res doneInProcStruct) { res.Status = r.uint16() res.CurCmd = r.uint16() @@ -344,26 +516,57 @@ func parseInfo(r *tdsBuffer) (res Error) { return } -func processResponse(sess *tdsSession, ch chan tokenStruct) { +// https://msdn.microsoft.com/en-us/library/dd303881.aspx +func parseReturnValue(r *tdsBuffer) (nv namedValue) { + /* + ParamOrdinal + ParamName + Status + UserType + Flags + TypeInfo + CryptoMetadata + Value + */ + r.uint16() + nv.Name = r.BVarChar() + r.byte() + r.uint32() // UserType (uint16 prior to 7.2) + r.uint16() + ti := readTypeInfo(r) + nv.Value = ti.Reader(&ti, r) + return +} + +func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) { defer func() { if err := recover(); err != nil { + if sess.logFlags&logErrors != 0 { + sess.log.Printf("ERROR: Intercepted panic %v", err) + } ch <- err } close(ch) }() + packet_type, err := sess.buf.BeginRead() if err != nil { + if sess.logFlags&logErrors != 0 { + sess.log.Printf("ERROR: BeginRead failed %v", err) + } ch <- err return } if packet_type != packReply { - badStreamPanicf("invalid response packet type, expected REPLY, actual: %d", packet_type) + badStreamPanic(fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply)) } var columns []columnStruct - var lastError Error - var failed bool + errs := make([]Error, 0, 5) for { - token := sess.buf.byte() + token := token(sess.buf.byte()) + if sess.logFlags&logDebug != 0 { + sess.log.Printf("got token %v", token) + } switch token { case tokenSSPI: ch <- parseSSPIMsg(sess.buf) @@ -385,18 +588,17 @@ func processResponse(sess *tdsSession, ch chan tokenStruct) { ch <- done case tokenDone, tokenDoneProc: done := parseDone(sess.buf) - if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 { - sess.log.Printf("(%d row(s) affected)\n", done.RowCount) - } - if done.Status&doneError != 0 || failed { - ch <- lastError - return + done.errors = errs + if sess.logFlags&logDebug != 0 { + sess.log.Printf("got DONE or DONEPROC status=%d", done.Status) } if done.Status&doneSrvError != 0 { - lastError.Message = "Server Error" - ch <- lastError + ch <- errors.New("SQL Server had internal error") return } + if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 { + sess.log.Printf("(%d row(s) affected)\n", done.RowCount) + } ch <- done if done.Status&doneMore == 0 { return @@ -415,18 +617,188 @@ func processResponse(sess *tdsSession, ch chan tokenStruct) { case tokenEnvChange: processEnvChg(sess) case tokenError: - lastError = parseError72(sess.buf) - failed = true + err := parseError72(sess.buf) + if sess.logFlags&logDebug != 0 { + sess.log.Printf("got ERROR %d %s", err.Number, err.Message) + } + errs = append(errs, err) if sess.logFlags&logErrors != 0 { - sess.log.Println(lastError.Message) + sess.log.Println(err.Message) } case tokenInfo: info := parseInfo(sess.buf) + if sess.logFlags&logDebug != 0 { + sess.log.Printf("got INFO %d %s", info.Number, info.Message) + } if sess.logFlags&logMessages != 0 { sess.log.Println(info.Message) } + case tokenReturnValue: + nv := parseReturnValue(sess.buf) + if len(nv.Name) > 0 { + name := nv.Name[1:] // Remove the leading "@". + if ov, has := outs[name]; has { + err = scanIntoOut(name, nv.Value, ov) + if err != nil { + fmt.Println("scan error", err) + ch <- err + } + } + } default: - badStreamPanicf("Unknown token type: %d", token) + badStreamPanic(fmt.Errorf("unknown token type returned: %v", token)) + } + } +} + +type parseRespIter byte + +const ( + parseRespIterContinue parseRespIter = iota // Continue parsing current token. + parseRespIterNext // Fetch the next token. + parseRespIterDone // Done with parsing the response. +) + +type parseRespState byte + +const ( + parseRespStateNormal parseRespState = iota // Normal response state. + parseRespStateCancel // Query is canceled, wait for server to confirm. + parseRespStateClosing // Waiting for tokens to come through. +) + +type parseResp struct { + sess *tdsSession + ctxDone <-chan struct{} + state parseRespState + cancelError error +} + +func (ts *parseResp) sendAttention(ch chan tokenStruct) parseRespIter { + if err := sendAttention(ts.sess.buf); err != nil { + ts.dlogf("failed to send attention signal %v", err) + ch <- err + return parseRespIterDone + } + ts.state = parseRespStateCancel + return parseRespIterContinue +} + +func (ts *parseResp) dlog(msg string) { + if ts.sess.logFlags&logDebug != 0 { + ts.sess.log.Println(msg) + } +} +func (ts *parseResp) dlogf(f string, v ...interface{}) { + if ts.sess.logFlags&logDebug != 0 { + ts.sess.log.Printf(f, v...) + } +} + +func (ts *parseResp) iter(ctx context.Context, ch chan tokenStruct, tokChan chan tokenStruct) parseRespIter { + switch ts.state { + default: + panic("unknown state") + case parseRespStateNormal: + select { + case tok, ok := <-tokChan: + if !ok { + ts.dlog("response finished") + return parseRespIterDone + } + if err, ok := tok.(net.Error); ok && err.Timeout() { + ts.cancelError = err + ts.dlog("got timeout error, sending attention signal to server") + return ts.sendAttention(ch) + } + // Pass the token along. + ch <- tok + return parseRespIterContinue + + case <-ts.ctxDone: + ts.ctxDone = nil + ts.dlog("got cancel message, sending attention signal to server") + return ts.sendAttention(ch) + } + case parseRespStateCancel: // Read all responses until a DONE or error is received.Auth + select { + case tok, ok := <-tokChan: + if !ok { + ts.dlog("response finished but waiting for attention ack") + return parseRespIterNext + } + switch tok := tok.(type) { + default: + // Ignore all other tokens while waiting. + // The TDS spec says other tokens may arrive after an attention + // signal is sent. Ignore these tokens and continue looking for + // a DONE with attention confirm mark. + case doneStruct: + if tok.Status&doneAttn != 0 { + ts.dlog("got cancellation confirmation from server") + if ts.cancelError != nil { + ch <- ts.cancelError + ts.cancelError = nil + } else { + ch <- ctx.Err() + } + return parseRespIterDone + } + + // If an error happens during cancel, pass it along and just stop. + // We are uncertain to receive more tokens. + case error: + ch <- tok + ts.state = parseRespStateClosing + } + return parseRespIterContinue + case <-ts.ctxDone: + ts.ctxDone = nil + ts.state = parseRespStateClosing + return parseRespIterContinue + } + case parseRespStateClosing: // Wait for current token chan to close. + if _, ok := <-tokChan; !ok { + ts.dlog("response finished") + return parseRespIterDone + } + return parseRespIterContinue + } +} + +func processResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) { + ts := &parseResp{ + sess: sess, + ctxDone: ctx.Done(), + } + defer func() { + // Ensure any remaining error is piped through + // or the query may look like it executed when it actually failed. + if ts.cancelError != nil { + ch <- ts.cancelError + ts.cancelError = nil + } + close(ch) + }() + + // Loop over multiple responses. + for { + ts.dlog("initiating response reading") + + tokChan := make(chan tokenStruct) + go processSingleResponse(sess, tokChan, outs) + + // Loop over multiple tokens in response. + tokensLoop: + for { + switch ts.iter(ctx, ch, tokChan) { + case parseRespIterContinue: + // Nothing, continue to next token. + case parseRespIterNext: + break tokensLoop + case parseRespIterDone: + return + } } } } diff --git a/vendor/github.com/denisenkom/go-mssqldb/token_string.go b/vendor/github.com/denisenkom/go-mssqldb/token_string.go new file mode 100644 index 0000000..c075b23 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/token_string.go @@ -0,0 +1,53 @@ +// Code generated by "stringer -type token"; DO NOT EDIT + +package mssql + +import "fmt" + +const ( + _token_name_0 = "tokenReturnStatus" + _token_name_1 = "tokenColMetadata" + _token_name_2 = "tokenOrdertokenErrortokenInfo" + _token_name_3 = "tokenLoginAck" + _token_name_4 = "tokenRowtokenNbcRow" + _token_name_5 = "tokenEnvChange" + _token_name_6 = "tokenSSPI" + _token_name_7 = "tokenDonetokenDoneProctokenDoneInProc" +) + +var ( + _token_index_0 = [...]uint8{0, 17} + _token_index_1 = [...]uint8{0, 16} + _token_index_2 = [...]uint8{0, 10, 20, 29} + _token_index_3 = [...]uint8{0, 13} + _token_index_4 = [...]uint8{0, 8, 19} + _token_index_5 = [...]uint8{0, 14} + _token_index_6 = [...]uint8{0, 9} + _token_index_7 = [...]uint8{0, 9, 22, 37} +) + +func (i token) String() string { + switch { + case i == 121: + return _token_name_0 + case i == 129: + return _token_name_1 + case 169 <= i && i <= 171: + i -= 169 + return _token_name_2[_token_index_2[i]:_token_index_2[i+1]] + case i == 173: + return _token_name_3 + case 209 <= i && i <= 210: + i -= 209 + return _token_name_4[_token_index_4[i]:_token_index_4[i+1]] + case i == 227: + return _token_name_5 + case i == 237: + return _token_name_6 + case 253 <= i && i <= 255: + i -= 253 + return _token_name_7[_token_index_7[i]:_token_index_7[i+1]] + default: + return fmt.Sprintf("token(%d)", i) + } +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/tran.go b/vendor/github.com/denisenkom/go-mssqldb/tran.go index ae38107..cb64368 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/tran.go +++ b/vendor/github.com/denisenkom/go-mssqldb/tran.go @@ -1,6 +1,7 @@ +package mssql + // Transaction Manager requests // http://msdn.microsoft.com/en-us/library/dd339887.aspx -package mssql import ( "encoding/binary" @@ -16,9 +17,19 @@ const ( tmSaveXact = 9 ) -func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation uint8, - name string) (err error) { - buf.BeginPacket(packTransMgrReq) +type isoLevel uint8 + +const ( + isolationUseCurrent isoLevel = 0 + isolationReadUncommited = 1 + isolationReadCommited = 2 + isolationRepeatableRead = 3 + isolationSerializable = 4 + isolationSnapshot = 5 +) + +func sendBeginXact(buf *tdsBuffer, headers []headerStruct, isolation isoLevel, name string, resetSession bool) (err error) { + buf.BeginPacket(packTransMgrReq, resetSession) writeAllHeaders(buf, headers) var rqtype uint16 = tmBeginXact err = binary.Write(buf, binary.LittleEndian, &rqtype) @@ -40,8 +51,8 @@ const ( fBeginXact = 1 ) -func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string) error { - buf.BeginPacket(packTransMgrReq) +func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string, resetSession bool) error { + buf.BeginPacket(packTransMgrReq, resetSession) writeAllHeaders(buf, headers) var rqtype uint16 = tmCommitXact err := binary.Write(buf, binary.LittleEndian, &rqtype) @@ -69,8 +80,8 @@ func sendCommitXact(buf *tdsBuffer, headers []headerStruct, name string, flags u return buf.FinishPacket() } -func sendRollbackXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string) error { - buf.BeginPacket(packTransMgrReq) +func sendRollbackXact(buf *tdsBuffer, headers []headerStruct, name string, flags uint8, isolation uint8, newname string, resetSession bool) error { + buf.BeginPacket(packTransMgrReq, resetSession) writeAllHeaders(buf, headers) var rqtype uint16 = tmRollbackXact err := binary.Write(buf, binary.LittleEndian, &rqtype) diff --git a/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go b/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go new file mode 100644 index 0000000..64e5e21 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/tvp_go19.go @@ -0,0 +1,231 @@ +// +build go1.9 + +package mssql + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "reflect" + "strings" + "time" +) + +const ( + jsonTag = "json" + tvpTag = "tvp" + skipTagValue = "-" + sqlSeparator = "." +) + +var ( + ErrorEmptyTVPTypeName = errors.New("TypeName must not be empty") + ErrorTypeSlice = errors.New("TVP must be slice type") + ErrorTypeSliceIsEmpty = errors.New("TVP mustn't be null value") + ErrorSkip = errors.New("all fields mustn't skip") + ErrorObjectName = errors.New("wrong tvp name") + ErrorWrongTyping = errors.New("the number of elements in columnStr and tvpFieldIndexes do not align") +) + +//TVP is driver type, which allows supporting Table Valued Parameters (TVP) in SQL Server +type TVP struct { + //TypeName mustn't be default value + TypeName string + //Value must be the slice, mustn't be nil + Value interface{} +} + +func (tvp TVP) check() error { + if len(tvp.TypeName) == 0 { + return ErrorEmptyTVPTypeName + } + if !isProc(tvp.TypeName) { + return ErrorEmptyTVPTypeName + } + if sepCount := getCountSQLSeparators(tvp.TypeName); sepCount > 1 { + return ErrorObjectName + } + valueOf := reflect.ValueOf(tvp.Value) + if valueOf.Kind() != reflect.Slice { + return ErrorTypeSlice + } + if valueOf.IsNil() { + return ErrorTypeSliceIsEmpty + } + if reflect.TypeOf(tvp.Value).Elem().Kind() != reflect.Struct { + return ErrorTypeSlice + } + return nil +} + +func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldIndexes []int) ([]byte, error) { + if len(columnStr) != len(tvpFieldIndexes) { + return nil, ErrorWrongTyping + } + preparedBuffer := make([]byte, 0, 20+(10*len(columnStr))) + buf := bytes.NewBuffer(preparedBuffer) + err := writeBVarChar(buf, "") + if err != nil { + return nil, err + } + + writeBVarChar(buf, schema) + writeBVarChar(buf, name) + binary.Write(buf, binary.LittleEndian, uint16(len(columnStr))) + + for i, column := range columnStr { + binary.Write(buf, binary.LittleEndian, uint32(column.UserType)) + binary.Write(buf, binary.LittleEndian, uint16(column.Flags)) + writeTypeInfo(buf, &columnStr[i].ti) + writeBVarChar(buf, "") + } + // The returned error is always nil + buf.WriteByte(_TVP_END_TOKEN) + + conn := new(Conn) + conn.sess = new(tdsSession) + conn.sess.loginAck = loginAckStruct{TDSVersion: verTDS73} + stmt := &Stmt{ + c: conn, + } + + val := reflect.ValueOf(tvp.Value) + for i := 0; i < val.Len(); i++ { + refStr := reflect.ValueOf(val.Index(i).Interface()) + buf.WriteByte(_TVP_ROW_TOKEN) + for columnStrIdx, fieldIdx := range tvpFieldIndexes { + field := refStr.Field(fieldIdx) + tvpVal := field.Interface() + valOf := reflect.ValueOf(tvpVal) + elemKind := field.Kind() + if elemKind == reflect.Ptr && valOf.IsNil() { + switch tvpVal.(type) { + case *bool, *time.Time, *int8, *int16, *int32, *int64, *float32, *float64, *int: + binary.Write(buf, binary.LittleEndian, uint8(0)) + continue + default: + binary.Write(buf, binary.LittleEndian, uint64(_PLP_NULL)) + continue + } + } + if elemKind == reflect.Slice && valOf.IsNil() { + binary.Write(buf, binary.LittleEndian, uint64(_PLP_NULL)) + continue + } + + cval, err := convertInputParameter(tvpVal) + if err != nil { + return nil, fmt.Errorf("failed to convert tvp parameter row col: %s", err) + } + param, err := stmt.makeParam(cval) + if err != nil { + return nil, fmt.Errorf("failed to make tvp parameter row col: %s", err) + } + columnStr[columnStrIdx].ti.Writer(buf, param.ti, param.buffer) + } + } + buf.WriteByte(_TVP_END_TOKEN) + return buf.Bytes(), nil +} + +func (tvp TVP) columnTypes() ([]columnStruct, []int, error) { + val := reflect.ValueOf(tvp.Value) + var firstRow interface{} + if val.Len() != 0 { + firstRow = val.Index(0).Interface() + } else { + firstRow = reflect.New(reflect.TypeOf(tvp.Value).Elem()).Elem().Interface() + } + + tvpRow := reflect.TypeOf(firstRow) + columnCount := tvpRow.NumField() + defaultValues := make([]interface{}, 0, columnCount) + tvpFieldIndexes := make([]int, 0, columnCount) + for i := 0; i < columnCount; i++ { + field := tvpRow.Field(i) + tvpTagValue, isTvpTag := field.Tag.Lookup(tvpTag) + jsonTagValue, isJsonTag := field.Tag.Lookup(jsonTag) + if IsSkipField(tvpTagValue, isTvpTag, jsonTagValue, isJsonTag) { + continue + } + tvpFieldIndexes = append(tvpFieldIndexes, i) + if field.Type.Kind() == reflect.Ptr { + v := reflect.New(field.Type.Elem()) + defaultValues = append(defaultValues, v.Interface()) + continue + } + defaultValues = append(defaultValues, reflect.Zero(field.Type).Interface()) + } + + if columnCount-len(tvpFieldIndexes) == columnCount { + return nil, nil, ErrorSkip + } + + conn := new(Conn) + conn.sess = new(tdsSession) + conn.sess.loginAck = loginAckStruct{TDSVersion: verTDS73} + stmt := &Stmt{ + c: conn, + } + + columnConfiguration := make([]columnStruct, 0, columnCount) + for index, val := range defaultValues { + cval, err := convertInputParameter(val) + if err != nil { + return nil, nil, fmt.Errorf("failed to convert tvp parameter row %d col %d: %s", index, val, err) + } + param, err := stmt.makeParam(cval) + if err != nil { + return nil, nil, err + } + column := columnStruct{ + ti: param.ti, + } + switch param.ti.TypeId { + case typeNVarChar, typeBigVarBin: + column.ti.Size = 0 + } + columnConfiguration = append(columnConfiguration, column) + } + + return columnConfiguration, tvpFieldIndexes, nil +} + +func IsSkipField(tvpTagValue string, isTvpValue bool, jsonTagValue string, isJsonTagValue bool) bool { + if !isTvpValue && !isJsonTagValue { + return false + } else if isTvpValue && tvpTagValue != skipTagValue { + return false + } else if !isTvpValue && isJsonTagValue && jsonTagValue != skipTagValue { + return false + } + return true +} + +func getSchemeAndName(tvpName string) (string, string, error) { + if len(tvpName) == 0 { + return "", "", ErrorEmptyTVPTypeName + } + splitVal := strings.Split(tvpName, ".") + if len(splitVal) > 2 { + return "", "", errors.New("wrong tvp name") + } + if len(splitVal) == 2 { + res := make([]string, 2) + for key, value := range splitVal { + tmp := strings.Replace(value, "[", "", -1) + tmp = strings.Replace(tmp, "]", "", -1) + res[key] = tmp + } + return res[0], res[1], nil + } + tmp := strings.Replace(splitVal[0], "[", "", -1) + tmp = strings.Replace(tmp, "]", "", -1) + + return "", tmp, nil +} + +func getCountSQLSeparators(str string) int { + return strings.Count(str, sqlSeparator) +} diff --git a/vendor/github.com/denisenkom/go-mssqldb/types.go b/vendor/github.com/denisenkom/go-mssqldb/types.go index c06d6c7..b6e7fb2 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/types.go +++ b/vendor/github.com/denisenkom/go-mssqldb/types.go @@ -6,8 +6,12 @@ import ( "fmt" "io" "math" + "reflect" "strconv" "time" + + "github.com/denisenkom/go-mssqldb/internal/cp" + "github.com/denisenkom/go-mssqldb/internal/decimal" ) // fixed-length data types @@ -59,6 +63,7 @@ const ( typeNChar = 0xef typeXml = 0xf1 typeUdt = 0xf0 + typeTvp = 0xf3 // long length types typeText = 0x23 @@ -66,6 +71,13 @@ const ( typeNText = 0x63 typeVariant = 0x62 ) +const _PLP_NULL = 0xFFFFFFFFFFFFFFFF +const _UNKNOWN_PLP_LEN = 0xFFFFFFFFFFFFFFFE +const _PLP_TERMINATOR = 0x00000000 + +// TVP COLUMN FLAGS +const _TVP_END_TOKEN = 0x00 +const _TVP_ROW_TOKEN = 0x01 // TYPE_INFO rule // http://msdn.microsoft.com/en-us/library/dd358284.aspx @@ -75,11 +87,32 @@ type typeInfo struct { Scale uint8 Prec uint8 Buffer []byte - Collation collation + Collation cp.Collation + UdtInfo udtInfo + XmlInfo xmlInfo Reader func(ti *typeInfo, r *tdsBuffer) (res interface{}) Writer func(w io.Writer, ti typeInfo, buf []byte) (err error) } +// Common Language Runtime (CLR) Instances +// http://msdn.microsoft.com/en-us/library/dd357962.aspx +type udtInfo struct { + //MaxByteSize uint32 + DBName string + SchemaName string + TypeName string + AssemblyQualifiedName string +} + +// XML Values +// http://msdn.microsoft.com/en-us/library/dd304764.aspx +type xmlInfo struct { + SchemaPresent uint8 + DBName string + OwningSchema string + XmlSchemaCollection string +} + func readTypeInfo(r *tdsBuffer) (res typeInfo) { res.TypeId = r.byte() switch res.TypeId { @@ -106,6 +139,7 @@ func readTypeInfo(r *tdsBuffer) (res typeInfo) { return } +// https://msdn.microsoft.com/en-us/library/dd358284.aspx func writeTypeInfo(w io.Writer, ti *typeInfo) (err error) { err = binary.Write(w, binary.LittleEndian, ti.TypeId) if err != nil { @@ -114,7 +148,11 @@ func writeTypeInfo(w io.Writer, ti *typeInfo) (err error) { switch ti.TypeId { case typeNull, typeInt1, typeBit, typeInt2, typeInt4, typeDateTim4, typeFlt4, typeMoney, typeDateTime, typeFlt8, typeMoney4, typeInt8: - // those are fixed length types + // those are fixed length + // https://msdn.microsoft.com/en-us/library/dd341171.aspx + ti.Writer = writeFixedType + case typeTvp: + ti.Writer = writeFixedType default: // all others are VARLENTYPE err = writeVarLen(w, ti) if err != nil { @@ -124,19 +162,27 @@ func writeTypeInfo(w io.Writer, ti *typeInfo) (err error) { return } +func writeFixedType(w io.Writer, ti typeInfo, buf []byte) (err error) { + _, err = w.Write(buf) + return +} + +// https://msdn.microsoft.com/en-us/library/dd358341.aspx func writeVarLen(w io.Writer, ti *typeInfo) (err error) { switch ti.TypeId { - case typeDateN: + case typeDateN: + ti.Writer = writeByteLenType case typeTimeN, typeDateTime2N, typeDateTimeOffsetN: if err = binary.Write(w, binary.LittleEndian, ti.Scale); err != nil { return } ti.Writer = writeByteLenType - case typeGuid, typeIntN, typeDecimal, typeNumeric, + case typeIntN, typeDecimal, typeNumeric, typeBitN, typeDecimalN, typeNumericN, typeFltN, typeMoneyN, typeDateTimeN, typeChar, typeVarChar, typeBinary, typeVarBinary: + // byle len types if ti.Size > 0xff { panic("Invalid size for BYLELEN_TYPE") @@ -156,8 +202,17 @@ func writeVarLen(w io.Writer, ti *typeInfo) (err error) { } } ti.Writer = writeByteLenType + case typeGuid: + if !(ti.Size == 0x10 || ti.Size == 0x00) { + panic("Invalid size for BYLELEN_TYPE") + } + if err = binary.Write(w, binary.LittleEndian, uint8(ti.Size)); err != nil { + return + } + ti.Writer = writeByteLenType case typeBigVarBin, typeBigVarChar, typeBigBinary, typeBigChar, typeNVarChar, typeNChar, typeXml, typeUdt: + // short len types if ti.Size > 8000 || ti.Size == 0 { if err = binary.Write(w, binary.LittleEndian, uint16(0xffff)); err != nil { @@ -176,14 +231,19 @@ func writeVarLen(w io.Writer, ti *typeInfo) (err error) { return } case typeXml: - var schemapresent uint8 = 0 - if err = binary.Write(w, binary.LittleEndian, schemapresent); err != nil { + if err = binary.Write(w, binary.LittleEndian, ti.XmlInfo.SchemaPresent); err != nil { return } } case typeText, typeImage, typeNText, typeVariant: // LONGLEN_TYPE - panic("LONGLEN_TYPE not implemented") + if err = binary.Write(w, binary.LittleEndian, uint32(ti.Size)); err != nil { + return + } + if err = writeCollation(w, ti.Collation); err != nil { + return + } + ti.Writer = writeLongLenType default: panic("Invalid type") } @@ -198,6 +258,48 @@ func decodeDateTim4(buf []byte) time.Time { 0, int(mins), 0, 0, time.UTC) } +func encodeDateTim4(val time.Time) (buf []byte) { + buf = make([]byte, 4) + + ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC) + dur := val.Sub(ref) + days := dur / (24 * time.Hour) + mins := val.Hour()*60 + val.Minute() + if days < 0 { + days = 0 + mins = 0 + } + + binary.LittleEndian.PutUint16(buf[:2], uint16(days)) + binary.LittleEndian.PutUint16(buf[2:], uint16(mins)) + return +} + +// encodes datetime value +// type identifier is typeDateTimeN +func encodeDateTime(t time.Time) (res []byte) { + // base date in days since Jan 1st 1900 + basedays := gregorianDays(1900, 1) + // days since Jan 1st 1900 (same TZ as t) + days := gregorianDays(t.Year(), t.YearDay()) - basedays + tm := 300*(t.Second()+t.Minute()*60+t.Hour()*60*60) + t.Nanosecond()*300/1e9 + // minimum and maximum possible + mindays := gregorianDays(1753, 1) - basedays + maxdays := gregorianDays(9999, 365) - basedays + if days < mindays { + days = mindays + tm = 0 + } + if days > maxdays { + days = maxdays + tm = (23*60*60+59*60+59)*300 + 299 + } + res = make([]byte, 8) + binary.LittleEndian.PutUint32(res[0:4], uint32(days)) + binary.LittleEndian.PutUint32(res[4:8], uint32(tm)) + return +} + func decodeDateTime(buf []byte) time.Time { days := int32(binary.LittleEndian.Uint32(buf)) tm := binary.LittleEndian.Uint32(buf[4:]) @@ -207,7 +309,7 @@ func decodeDateTime(buf []byte) time.Time { 0, 0, secs, ns, time.UTC) } -func readFixedType(ti *typeInfo, r *tdsBuffer) (res interface{}) { +func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} { r.ReadFull(ti.Buffer) buf := ti.Buffer switch ti.TypeId { @@ -241,12 +343,7 @@ func readFixedType(ti *typeInfo, r *tdsBuffer) (res interface{}) { panic("shoulnd't get here") } -func writeFixedType(w io.Writer, ti typeInfo, buf []byte) (err error) { - _, err = w.Write(buf) - return -} - -func readByteLenType(ti *typeInfo, r *tdsBuffer) (res interface{}) { +func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} { size := r.byte() if size == 0 { return nil @@ -278,7 +375,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) (res interface{}) { case 8: return int64(binary.LittleEndian.Uint64(buf)) default: - badStreamPanicf("Invalid size for INTNTYPE") + badStreamPanicf("Invalid size for INTNTYPE: %d", len(buf)) } case typeDecimal, typeNumeric, typeDecimalN, typeNumericN: return decodeDecimal(ti.Prec, ti.Scale, buf) @@ -305,6 +402,10 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) (res interface{}) { default: badStreamPanicf("Invalid size for MONEYNTYPE") } + case typeDateTim4: + return decodeDateTim4(buf) + case typeDateTime: + return decodeDateTime(buf) case typeDateTimeN: switch len(buf) { case 4: @@ -333,7 +434,7 @@ func writeByteLenType(w io.Writer, ti typeInfo, buf []byte) (err error) { if ti.Size > 0xff { panic("Invalid size for BYTELEN_TYPE") } - err = binary.Write(w, binary.LittleEndian, uint8(ti.Size)) + err = binary.Write(w, binary.LittleEndian, uint8(len(buf))) if err != nil { return } @@ -341,7 +442,7 @@ func writeByteLenType(w io.Writer, ti typeInfo, buf []byte) (err error) { return } -func readShortLenType(ti *typeInfo, r *tdsBuffer) (res interface{}) { +func readShortLenType(ti *typeInfo, r *tdsBuffer) interface{} { size := r.uint16() if size == 0xffff { return nil @@ -384,12 +485,12 @@ func writeShortLenType(w io.Writer, ti typeInfo, buf []byte) (err error) { return } -func readLongLenType(ti *typeInfo, r *tdsBuffer) (res interface{}) { +func readLongLenType(ti *typeInfo, r *tdsBuffer) interface{} { // information about this format can be found here: // http://msdn.microsoft.com/en-us/library/dd304783.aspx // and here: // http://msdn.microsoft.com/en-us/library/dd357254.aspx - textptrsize := r.byte() + textptrsize := int(r.byte()) if textptrsize == 0 { return nil } @@ -415,10 +516,51 @@ func readLongLenType(ti *typeInfo, r *tdsBuffer) (res interface{}) { } panic("shoulnd't get here") } +func writeLongLenType(w io.Writer, ti typeInfo, buf []byte) (err error) { + //textptr + err = binary.Write(w, binary.LittleEndian, byte(0x10)) + if err != nil { + return + } + err = binary.Write(w, binary.LittleEndian, uint64(0xFFFFFFFFFFFFFFFF)) + if err != nil { + return + } + err = binary.Write(w, binary.LittleEndian, uint64(0xFFFFFFFFFFFFFFFF)) + if err != nil { + return + } + //timestamp? + err = binary.Write(w, binary.LittleEndian, uint64(0xFFFFFFFFFFFFFFFF)) + if err != nil { + return + } + + err = binary.Write(w, binary.LittleEndian, uint32(ti.Size)) + if err != nil { + return + } + _, err = w.Write(buf) + return +} + +func readCollation(r *tdsBuffer) (res cp.Collation) { + res.LcidAndFlags = r.uint32() + res.SortId = r.byte() + return +} + +func writeCollation(w io.Writer, col cp.Collation) (err error) { + if err = binary.Write(w, binary.LittleEndian, col.LcidAndFlags); err != nil { + return + } + err = binary.Write(w, binary.LittleEndian, col.SortId) + return +} // reads variant value // http://msdn.microsoft.com/en-us/library/dd303302.aspx -func readVariantType(ti *typeInfo, r *tdsBuffer) (res interface{}) { +func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} { size := r.int32() if size == 0 { return nil @@ -510,14 +652,14 @@ func readVariantType(ti *typeInfo, r *tdsBuffer) (res interface{}) { // partially length prefixed stream // http://msdn.microsoft.com/en-us/library/dd340469.aspx -func readPLPType(ti *typeInfo, r *tdsBuffer) (res interface{}) { +func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} { size := r.uint64() var buf *bytes.Buffer switch size { - case 0xffffffffffffffff: + case _PLP_NULL: // null return nil - case 0xfffffffffffffffe: + case _UNKNOWN_PLP_LEN: // size unknown buf = bytes.NewBuffer(make([]byte, 0, 1000)) default: @@ -548,15 +690,16 @@ func readPLPType(ti *typeInfo, r *tdsBuffer) (res interface{}) { } func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) { - if err = binary.Write(w, binary.LittleEndian, uint64(len(buf))); err != nil { + if err = binary.Write(w, binary.LittleEndian, uint64(_UNKNOWN_PLP_LEN)); err != nil { return } for { chunksize := uint32(len(buf)) - if err = binary.Write(w, binary.LittleEndian, chunksize); err != nil { + if chunksize == 0 { + err = binary.Write(w, binary.LittleEndian, uint32(_PLP_TERMINATOR)) return } - if chunksize == 0 { + if err = binary.Write(w, binary.LittleEndian, chunksize); err != nil { return } if _, err = w.Write(buf[:chunksize]); err != nil { @@ -606,19 +749,27 @@ func readVarLen(ti *typeInfo, r *tdsBuffer) { } ti.Reader = readByteLenType case typeXml: - schemapresent := r.byte() - if schemapresent != 0 { - // just ignore this for now + ti.XmlInfo.SchemaPresent = r.byte() + if ti.XmlInfo.SchemaPresent != 0 { // dbname - r.BVarChar() + ti.XmlInfo.DBName = r.BVarChar() // owning schema - r.BVarChar() + ti.XmlInfo.OwningSchema = r.BVarChar() // xml schema collection - r.UsVarChar() + ti.XmlInfo.XmlSchemaCollection = r.UsVarChar() } ti.Reader = readPLPType + case typeUdt: + ti.Size = int(r.uint16()) + ti.UdtInfo.DBName = r.BVarChar() + ti.UdtInfo.SchemaName = r.BVarChar() + ti.UdtInfo.TypeName = r.BVarChar() + ti.UdtInfo.AssemblyQualifiedName = r.UsVarChar() + + ti.Buffer = make([]byte, ti.Size) + ti.Reader = readPLPType case typeBigVarBin, typeBigVarChar, typeBigBinary, typeBigChar, - typeNVarChar, typeNChar, typeUdt: + typeNVarChar, typeNChar: // short len types ti.Size = int(r.uint16()) switch ti.TypeId { @@ -650,8 +801,6 @@ func readVarLen(ti *typeInfo, r *tdsBuffer) { r.UsVarChar() } ti.Reader = readLongLenType - case typeXml: - panic("XMLTYPE not implemented") case typeVariant: ti.Reader = readVariantType } @@ -670,12 +819,12 @@ func decodeMoney(buf []byte) []byte { uint64(buf[1])<<40 | uint64(buf[2])<<48 | uint64(buf[3])<<56) - return scaleBytes(strconv.FormatInt(money, 10), 4) + return decimal.ScaleBytes(strconv.FormatInt(money, 10), 4) } func decodeMoney4(buf []byte) []byte { money := int32(binary.LittleEndian.Uint32(buf[0:4])) - return scaleBytes(strconv.FormatInt(int64(money), 10), 4) + return decimal.ScaleBytes(strconv.FormatInt(int64(money), 10), 4) } func decodeGuid(buf []byte) []byte { @@ -687,15 +836,14 @@ func decodeGuid(buf []byte) []byte { func decodeDecimal(prec uint8, scale uint8, buf []byte) []byte { var sign uint8 sign = buf[0] - dec := Decimal{ - positive: sign != 0, - prec: prec, - scale: scale, - } + var dec decimal.Decimal + dec.SetPositive(sign != 0) + dec.SetPrec(prec) + dec.SetScale(scale) buf = buf[1:] l := len(buf) / 4 for i := 0; i < l; i++ { - dec.integer[i] = binary.LittleEndian.Uint32(buf[0:4]) + dec.SetInteger(binary.LittleEndian.Uint32(buf[0:4]), uint8(i)) buf = buf[4:] } return dec.Bytes() @@ -703,13 +851,23 @@ func decodeDecimal(prec uint8, scale uint8, buf []byte) []byte { // http://msdn.microsoft.com/en-us/library/ee780895.aspx func decodeDateInt(buf []byte) (days int) { - return int(buf[0]) + int(buf[1])*256 + int(buf[2])*256*256 + days = int(buf[0]) + int(buf[1])*256 + int(buf[2])*256*256 + return } func decodeDate(buf []byte) time.Time { return time.Date(1, 1, 1+decodeDateInt(buf), 0, 0, 0, 0, time.UTC) } +func encodeDate(val time.Time) (buf []byte) { + days, _, _ := dateTime2(val) + buf = make([]byte, 3) + buf[0] = byte(days) + buf[1] = byte(days >> 8) + buf[2] = byte(days >> 16) + return +} + func decodeTimeInt(scale uint8, buf []byte) (sec int, ns int) { var acc uint64 = 0 for i := len(buf) - 1; i >= 0; i-- { @@ -725,11 +883,41 @@ func decodeTimeInt(scale uint8, buf []byte) (sec int, ns int) { return } +// calculate size of time field in bytes +func calcTimeSize(scale int) int { + if scale <= 2 { + return 3 + } else if scale <= 4 { + return 4 + } else { + return 5 + } +} + +// writes time value into a field buffer +// buffer should be at least calcTimeSize long +func encodeTimeInt(seconds, ns, scale int, buf []byte) { + ns_total := int64(seconds)*1000*1000*1000 + int64(ns) + t := ns_total / int64(math.Pow10(int(scale)*-1)*1e9) + buf[0] = byte(t) + buf[1] = byte(t >> 8) + buf[2] = byte(t >> 16) + buf[3] = byte(t >> 24) + buf[4] = byte(t >> 32) +} + func decodeTime(scale uint8, buf []byte) time.Time { sec, ns := decodeTimeInt(scale, buf) return time.Date(1, 1, 1, 0, 0, sec, ns, time.UTC) } +func encodeTime(hour, minute, second, ns, scale int) (buf []byte) { + seconds := hour*3600 + minute*60 + second + buf = make([]byte, calcTimeSize(scale)) + encodeTimeInt(seconds, ns, scale, buf) + return +} + func decodeDateTime2(scale uint8, buf []byte) time.Time { timesize := len(buf) - 3 sec, ns := decodeTimeInt(scale, buf[:timesize]) @@ -737,6 +925,17 @@ func decodeDateTime2(scale uint8, buf []byte) time.Time { return time.Date(1, 1, 1+days, 0, 0, sec, ns, time.UTC) } +func encodeDateTime2(val time.Time, scale int) (buf []byte) { + days, seconds, ns := dateTime2(val) + timesize := calcTimeSize(scale) + buf = make([]byte, 3+timesize) + encodeTimeInt(seconds, ns, scale, buf) + buf[timesize] = byte(days) + buf[timesize+1] = byte(days >> 8) + buf[timesize+2] = byte(days >> 16) + return +} + func decodeDateTimeOffset(scale uint8, buf []byte) time.Time { timesize := len(buf) - 3 - 2 sec, ns := decodeTimeInt(scale, buf[:timesize]) @@ -748,29 +947,48 @@ func decodeDateTimeOffset(scale uint8, buf []byte) time.Time { time.FixedZone("", offset*60)) } -func divFloor(x int64, y int64) int64 { - q := x / y - r := x % y - if r != 0 && ((r < 0) != (y < 0)) { - q-- - } - return q +func encodeDateTimeOffset(val time.Time, scale int) (buf []byte) { + timesize := calcTimeSize(scale) + buf = make([]byte, timesize+2+3) + days, seconds, ns := dateTime2(val.In(time.UTC)) + encodeTimeInt(seconds, ns, scale, buf) + buf[timesize] = byte(days) + buf[timesize+1] = byte(days >> 8) + buf[timesize+2] = byte(days >> 16) + _, offset := val.Zone() + offset /= 60 + buf[timesize+3] = byte(offset) + buf[timesize+4] = byte(offset >> 8) + return +} + +// returns days since Jan 1st 0001 in Gregorian calendar +func gregorianDays(year, yearday int) int { + year0 := year - 1 + return year0*365 + year0/4 - year0/100 + year0/400 + yearday - 1 } -func dateTime2(t time.Time) (days int32, ns int64) { - // number of days since Jan 1 1970 UTC - days64 := divFloor(t.Unix(), 24*60*60) - // number of days since Jan 1 1 UTC - days = int32(days64) + 1969*365 + 1969/4 - 1969/100 + 1969/400 - // number of seconds within day - secs := t.Unix() - days64*24*60*60 - // number of nanoseconds within day - ns = secs*1e9 + int64(t.Nanosecond()) +func dateTime2(t time.Time) (days int, seconds int, ns int) { + // days since Jan 1 1 (in same TZ as t) + days = gregorianDays(t.Year(), t.YearDay()) + seconds = t.Second() + t.Minute()*60 + t.Hour()*60*60 + ns = t.Nanosecond() + if days < 0 { + days = 0 + seconds = 0 + ns = 0 + } + max := gregorianDays(9999, 365) + if days > max { + days = max + seconds = 59 + 59*60 + 23*60*60 + ns = 999999900 + } return } -func decodeChar(col collation, buf []byte) string { - return charset2utf8(col, buf) +func decodeChar(col cp.Collation, buf []byte) string { + return cp.CharsetToUTF8(col, buf) } func decodeUcs2(buf []byte) string { @@ -789,12 +1007,129 @@ func decodeXml(ti typeInfo, buf []byte) string { return decodeUcs2(buf) } -func decodeUdt(ti typeInfo, buf []byte) int { - panic("Not implemented") +func decodeUdt(ti typeInfo, buf []byte) []byte { + return buf +} + +// makes go/sql type instance as described below +// It should return +// the value type that can be used to scan types into. For example, the database +// column type "bigint" this should return "reflect.TypeOf(int64(0))". +func makeGoLangScanType(ti typeInfo) reflect.Type { + switch ti.TypeId { + case typeInt1: + return reflect.TypeOf(int64(0)) + case typeInt2: + return reflect.TypeOf(int64(0)) + case typeInt4: + return reflect.TypeOf(int64(0)) + case typeInt8: + return reflect.TypeOf(int64(0)) + case typeFlt4: + return reflect.TypeOf(float64(0)) + case typeIntN: + switch ti.Size { + case 1: + return reflect.TypeOf(int64(0)) + case 2: + return reflect.TypeOf(int64(0)) + case 4: + return reflect.TypeOf(int64(0)) + case 8: + return reflect.TypeOf(int64(0)) + default: + panic("invalid size of INTNTYPE") + } + case typeFlt8: + return reflect.TypeOf(float64(0)) + case typeFltN: + switch ti.Size { + case 4: + return reflect.TypeOf(float64(0)) + case 8: + return reflect.TypeOf(float64(0)) + default: + panic("invalid size of FLNNTYPE") + } + case typeBigVarBin: + return reflect.TypeOf([]byte{}) + case typeVarChar: + return reflect.TypeOf("") + case typeNVarChar: + return reflect.TypeOf("") + case typeBit, typeBitN: + return reflect.TypeOf(true) + case typeDecimalN, typeNumericN: + return reflect.TypeOf([]byte{}) + case typeMoney, typeMoney4, typeMoneyN: + switch ti.Size { + case 4: + return reflect.TypeOf([]byte{}) + case 8: + return reflect.TypeOf([]byte{}) + default: + panic("invalid size of MONEYN") + } + case typeDateTim4: + return reflect.TypeOf(time.Time{}) + case typeDateTime: + return reflect.TypeOf(time.Time{}) + case typeDateTimeN: + switch ti.Size { + case 4: + return reflect.TypeOf(time.Time{}) + case 8: + return reflect.TypeOf(time.Time{}) + default: + panic("invalid size of DATETIMEN") + } + case typeDateTime2N: + return reflect.TypeOf(time.Time{}) + case typeDateN: + return reflect.TypeOf(time.Time{}) + case typeTimeN: + return reflect.TypeOf(time.Time{}) + case typeDateTimeOffsetN: + return reflect.TypeOf(time.Time{}) + case typeBigVarChar: + return reflect.TypeOf("") + case typeBigChar: + return reflect.TypeOf("") + case typeNChar: + return reflect.TypeOf("") + case typeGuid: + return reflect.TypeOf([]byte{}) + case typeXml: + return reflect.TypeOf("") + case typeText: + return reflect.TypeOf("") + case typeNText: + return reflect.TypeOf("") + case typeImage: + return reflect.TypeOf([]byte{}) + case typeBigBinary: + return reflect.TypeOf([]byte{}) + case typeVariant: + return reflect.TypeOf(nil) + default: + panic(fmt.Sprintf("not implemented makeGoLangScanType for type %d", ti.TypeId)) + } } func makeDecl(ti typeInfo) string { switch ti.TypeId { + case typeNull: + // maybe we should use something else here + // this is tested in TestNull + return "nvarchar(1)" + case typeInt1: + return "tinyint" + case typeBigBinary: + return fmt.Sprintf("binary(%d)", ti.Size) + case typeInt2: + return "smallint" + case typeInt4: + return "int" case typeInt8: return "bigint" case typeFlt4: @@ -823,25 +1158,423 @@ func makeDecl(ti typeInfo) string { default: panic("invalid size of FLNNTYPE") } + case typeDecimal, typeDecimalN: + return fmt.Sprintf("decimal(%d, %d)", ti.Prec, ti.Scale) + case typeNumeric, typeNumericN: + return fmt.Sprintf("numeric(%d, %d)", ti.Prec, ti.Scale) + case typeMoney4: + return "smallmoney" + case typeMoney: + return "money" + case typeMoneyN: + switch ti.Size { + case 4: + return "smallmoney" + case 8: + return "money" + default: + panic("invalid size of MONEYNTYPE") + } case typeBigVarBin: if ti.Size > 8000 || ti.Size == 0 { - return fmt.Sprintf("varbinary(max)") + return "varbinary(max)" } else { return fmt.Sprintf("varbinary(%d)", ti.Size) } + case typeNChar: + return fmt.Sprintf("nchar(%d)", ti.Size/2) + case typeBigChar, typeChar: + return fmt.Sprintf("char(%d)", ti.Size) + case typeBigVarChar, typeVarChar: + if ti.Size > 8000 || ti.Size == 0 { + return fmt.Sprintf("varchar(max)") + } else { + return fmt.Sprintf("varchar(%d)", ti.Size) + } case typeNVarChar: if ti.Size > 8000 || ti.Size == 0 { - return fmt.Sprintf("nvarchar(max)") + return "nvarchar(max)" } else { return fmt.Sprintf("nvarchar(%d)", ti.Size/2) } case typeBit, typeBitN: return "bit" - case typeDateTimeN: + case typeDateN: + return "date" + case typeDateTim4: + return "smalldatetime" + case typeDateTime: return "datetime" + case typeDateTimeN: + switch ti.Size { + case 4: + return "smalldatetime" + case 8: + return "datetime" + default: + panic("invalid size of DATETIMNTYPE") + } + case typeTimeN: + return "time" + case typeDateTime2N: + return fmt.Sprintf("datetime2(%d)", ti.Scale) case typeDateTimeOffsetN: return fmt.Sprintf("datetimeoffset(%d)", ti.Scale) + case typeText: + return "text" + case typeNText: + return "ntext" + case typeUdt: + return ti.UdtInfo.TypeName + case typeGuid: + return "uniqueidentifier" + case typeTvp: + if ti.UdtInfo.SchemaName != "" { + return fmt.Sprintf("%s.%s READONLY", ti.UdtInfo.SchemaName, ti.UdtInfo.TypeName) + } + return fmt.Sprintf("%s READONLY", ti.UdtInfo.TypeName) + default: + panic(fmt.Sprintf("not implemented makeDecl for type %#x", ti.TypeId)) + } +} + +// makes go/sql type name as described below +// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the +// database system type name without the length. Type names should be uppercase. +// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT", +// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML", +// "TIMESTAMP". +func makeGoLangTypeName(ti typeInfo) string { + switch ti.TypeId { + case typeInt1: + return "TINYINT" + case typeInt2: + return "SMALLINT" + case typeInt4: + return "INT" + case typeInt8: + return "BIGINT" + case typeFlt4: + return "REAL" + case typeIntN: + switch ti.Size { + case 1: + return "TINYINT" + case 2: + return "SMALLINT" + case 4: + return "INT" + case 8: + return "BIGINT" + default: + panic("invalid size of INTNTYPE") + } + case typeFlt8: + return "FLOAT" + case typeFltN: + switch ti.Size { + case 4: + return "REAL" + case 8: + return "FLOAT" + default: + panic("invalid size of FLNNTYPE") + } + case typeBigVarBin: + return "VARBINARY" + case typeVarChar: + return "VARCHAR" + case typeNVarChar: + return "NVARCHAR" + case typeBit, typeBitN: + return "BIT" + case typeDecimalN, typeNumericN: + return "DECIMAL" + case typeMoney, typeMoney4, typeMoneyN: + switch ti.Size { + case 4: + return "SMALLMONEY" + case 8: + return "MONEY" + default: + panic("invalid size of MONEYN") + } + case typeDateTim4: + return "SMALLDATETIME" + case typeDateTime: + return "DATETIME" + case typeDateTimeN: + switch ti.Size { + case 4: + return "SMALLDATETIME" + case 8: + return "DATETIME" + default: + panic("invalid size of DATETIMEN") + } + case typeDateTime2N: + return "DATETIME2" + case typeDateN: + return "DATE" + case typeTimeN: + return "TIME" + case typeDateTimeOffsetN: + return "DATETIMEOFFSET" + case typeBigVarChar: + return "VARCHAR" + case typeBigChar: + return "CHAR" + case typeNChar: + return "NCHAR" + case typeGuid: + return "UNIQUEIDENTIFIER" + case typeXml: + return "XML" + case typeText: + return "TEXT" + case typeNText: + return "NTEXT" + case typeImage: + return "IMAGE" + case typeVariant: + return "SQL_VARIANT" + case typeBigBinary: + return "BINARY" + default: + panic(fmt.Sprintf("not implemented makeGoLangTypeName for type %d", ti.TypeId)) + } +} + +// makes go/sql type length as described below +// It should return the length +// of the column type if the column is a variable length type. If the column is +// not a variable length type ok should return false. +// If length is not limited other than system limits, it should return math.MaxInt64. +// The following are examples of returned values for various types: +// TEXT (math.MaxInt64, true) +// varchar(10) (10, true) +// nvarchar(10) (10, true) +// decimal (0, false) +// int (0, false) +// bytea(30) (30, true) +func makeGoLangTypeLength(ti typeInfo) (int64, bool) { + switch ti.TypeId { + case typeInt1: + return 0, false + case typeInt2: + return 0, false + case typeInt4: + return 0, false + case typeInt8: + return 0, false + case typeFlt4: + return 0, false + case typeIntN: + switch ti.Size { + case 1: + return 0, false + case 2: + return 0, false + case 4: + return 0, false + case 8: + return 0, false + default: + panic("invalid size of INTNTYPE") + } + case typeFlt8: + return 0, false + case typeFltN: + switch ti.Size { + case 4: + return 0, false + case 8: + return 0, false + default: + panic("invalid size of FLNNTYPE") + } + case typeBit, typeBitN: + return 0, false + case typeDecimalN, typeNumericN: + return 0, false + case typeMoney, typeMoney4, typeMoneyN: + switch ti.Size { + case 4: + return 0, false + case 8: + return 0, false + default: + panic("invalid size of MONEYN") + } + case typeDateTim4, typeDateTime: + return 0, false + case typeDateTimeN: + switch ti.Size { + case 4: + return 0, false + case 8: + return 0, false + default: + panic("invalid size of DATETIMEN") + } + case typeDateTime2N: + return 0, false + case typeDateN: + return 0, false + case typeTimeN: + return 0, false + case typeDateTimeOffsetN: + return 0, false + case typeBigVarBin: + if ti.Size == 0xffff { + return 2147483645, true + } else { + return int64(ti.Size), true + } + case typeVarChar: + return int64(ti.Size), true + case typeBigVarChar: + if ti.Size == 0xffff { + return 2147483645, true + } else { + return int64(ti.Size), true + } + case typeBigChar: + return int64(ti.Size), true + case typeNVarChar: + if ti.Size == 0xffff { + return 2147483645 / 2, true + } else { + return int64(ti.Size) / 2, true + } + case typeNChar: + return int64(ti.Size) / 2, true + case typeGuid: + return 0, false + case typeXml: + return 1073741822, true + case typeText: + return 2147483647, true + case typeNText: + return 1073741823, true + case typeImage: + return 2147483647, true + case typeVariant: + return 0, false + case typeBigBinary: + return 0, false + default: + panic(fmt.Sprintf("not implemented makeGoLangTypeLength for type %d", ti.TypeId)) + } +} + +// makes go/sql type precision and scale as described below +// It should return the length +// of the column type if the column is a variable length type. If the column is +// not a variable length type ok should return false. +// If length is not limited other than system limits, it should return math.MaxInt64. +// The following are examples of returned values for various types: +// TEXT (math.MaxInt64, true) +// varchar(10) (10, true) +// nvarchar(10) (10, true) +// decimal (0, false) +// int (0, false) +// bytea(30) (30, true) +func makeGoLangTypePrecisionScale(ti typeInfo) (int64, int64, bool) { + switch ti.TypeId { + case typeInt1: + return 0, 0, false + case typeInt2: + return 0, 0, false + case typeInt4: + return 0, 0, false + case typeInt8: + return 0, 0, false + case typeFlt4: + return 0, 0, false + case typeIntN: + switch ti.Size { + case 1: + return 0, 0, false + case 2: + return 0, 0, false + case 4: + return 0, 0, false + case 8: + return 0, 0, false + default: + panic("invalid size of INTNTYPE") + } + case typeFlt8: + return 0, 0, false + case typeFltN: + switch ti.Size { + case 4: + return 0, 0, false + case 8: + return 0, 0, false + default: + panic("invalid size of FLNNTYPE") + } + case typeBit, typeBitN: + return 0, 0, false + case typeDecimalN, typeNumericN: + return int64(ti.Prec), int64(ti.Scale), true + case typeMoney, typeMoney4, typeMoneyN: + switch ti.Size { + case 4: + return 0, 0, false + case 8: + return 0, 0, false + default: + panic("invalid size of MONEYN") + } + case typeDateTim4, typeDateTime: + return 0, 0, false + case typeDateTimeN: + switch ti.Size { + case 4: + return 0, 0, false + case 8: + return 0, 0, false + default: + panic("invalid size of DATETIMEN") + } + case typeDateTime2N: + return 0, 0, false + case typeDateN: + return 0, 0, false + case typeTimeN: + return 0, 0, false + case typeDateTimeOffsetN: + return 0, 0, false + case typeBigVarBin: + return 0, 0, false + case typeVarChar: + return 0, 0, false + case typeBigVarChar: + return 0, 0, false + case typeBigChar: + return 0, 0, false + case typeNVarChar: + return 0, 0, false + case typeNChar: + return 0, 0, false + case typeGuid: + return 0, 0, false + case typeXml: + return 0, 0, false + case typeText: + return 0, 0, false + case typeNText: + return 0, 0, false + case typeImage: + return 0, 0, false + case typeVariant: + return 0, 0, false + case typeBigBinary: + return 0, 0, false default: - panic(fmt.Sprintf("not implemented makeDecl for type %d", ti.TypeId)) + panic(fmt.Sprintf("not implemented makeGoLangTypePrecisionScale for type %d", ti.TypeId)) } } diff --git a/vendor/github.com/denisenkom/go-mssqldb/uniqueidentifier.go b/vendor/github.com/denisenkom/go-mssqldb/uniqueidentifier.go new file mode 100644 index 0000000..3ad6f86 --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/uniqueidentifier.go @@ -0,0 +1,80 @@ +package mssql + +import ( + "database/sql/driver" + "encoding/hex" + "errors" + "fmt" +) + +type UniqueIdentifier [16]byte + +func (u *UniqueIdentifier) Scan(v interface{}) error { + reverse := func(b []byte) { + for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 { + b[i], b[j] = b[j], b[i] + } + } + + switch vt := v.(type) { + case []byte: + if len(vt) != 16 { + return errors.New("mssql: invalid UniqueIdentifier length") + } + + var raw UniqueIdentifier + + copy(raw[:], vt) + + reverse(raw[0:4]) + reverse(raw[4:6]) + reverse(raw[6:8]) + *u = raw + + return nil + case string: + if len(vt) != 36 { + return errors.New("mssql: invalid UniqueIdentifier string length") + } + + b := []byte(vt) + for i, c := range b { + switch c { + case '-': + b = append(b[:i], b[i+1:]...) + } + } + + _, err := hex.Decode(u[:], []byte(b)) + return err + default: + return fmt.Errorf("mssql: cannot convert %T to UniqueIdentifier", v) + } +} + +func (u UniqueIdentifier) Value() (driver.Value, error) { + reverse := func(b []byte) { + for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 { + b[i], b[j] = b[j], b[i] + } + } + + raw := make([]byte, len(u)) + copy(raw, u[:]) + + reverse(raw[0:4]) + reverse(raw[4:6]) + reverse(raw[6:8]) + + return raw, nil +} + +func (u UniqueIdentifier) String() string { + return fmt.Sprintf("%X-%X-%X-%X-%X", u[0:4], u[4:6], u[6:8], u[8:10], u[10:]) +} + +// MarshalText converts Uniqueidentifier to bytes corresponding to the stringified hexadecimal representation of the Uniqueidentifier +// e.g., "AAAAAAAA-AAAA-AAAA-AAAA-AAAAAAAAAAAA" -> [65 65 65 65 65 65 65 65 45 65 65 65 65 45 65 65 65 65 45 65 65 65 65 65 65 65 65 65 65 65 65] +func (u UniqueIdentifier) MarshalText() []byte { + return []byte(u.String()) +} diff --git a/vendor/github.com/golang-sql/civil/CONTRIBUTING.md b/vendor/github.com/golang-sql/civil/CONTRIBUTING.md new file mode 100644 index 0000000..d0635c3 --- /dev/null +++ b/vendor/github.com/golang-sql/civil/CONTRIBUTING.md @@ -0,0 +1,73 @@ +# Contributing + +1. Sign one of the contributor license agreements below. + +#### Running + +Once you've done the necessary setup, you can run the integration tests by +running: + +``` sh +$ go test -v github.com/golang-sql/civil +``` + +## Contributor License Agreements + +Before we can accept your pull requests you'll need to sign a Contributor +License Agreement (CLA): + +- **If you are an individual writing original source code** and **you own the +intellectual property**, then you'll need to sign an [individual CLA][indvcla]. +- **If you work for a company that wants to allow you to contribute your +work**, then you'll need to sign a [corporate CLA][corpcla]. + +You can sign these electronically (just scroll to the bottom). After that, +we'll be able to accept your pull requests. + +## Contributor Code of Conduct + +As contributors and maintainers of this project, +and in the interest of fostering an open and welcoming community, +we pledge to respect all people who contribute through reporting issues, +posting feature requests, updating documentation, +submitting pull requests or patches, and other activities. + +We are committed to making participation in this project +a harassment-free experience for everyone, +regardless of level of experience, gender, gender identity and expression, +sexual orientation, disability, personal appearance, +body size, race, ethnicity, age, religion, or nationality. + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery +* Personal attacks +* Trolling or insulting/derogatory comments +* Public or private harassment +* Publishing other's private information, +such as physical or electronic +addresses, without explicit permission +* Other unethical or unprofessional conduct. + +Project maintainers have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct. +By adopting this Code of Conduct, +project maintainers commit themselves to fairly and consistently +applying these principles to every aspect of managing this project. +Project maintainers who do not follow or enforce the Code of Conduct +may be permanently removed from the project team. + +This code of conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. + +Instances of abusive, harassing, or otherwise unacceptable behavior +may be reported by opening an issue +or contacting one or more of the project maintainers. + +This Code of Conduct is adapted from the [Contributor Covenant](http://contributor-covenant.org), version 1.2.0, +available at [http://contributor-covenant.org/version/1/2/0/](http://contributor-covenant.org/version/1/2/0/) + +[gcloudcli]: https://developers.google.com/cloud/sdk/gcloud/ +[indvcla]: https://developers.google.com/open-source/cla/individual +[corpcla]: https://developers.google.com/open-source/cla/corporate \ No newline at end of file diff --git a/vendor/github.com/golang-sql/civil/LICENSE b/vendor/github.com/golang-sql/civil/LICENSE new file mode 100644 index 0000000..7a4a3ea --- /dev/null +++ b/vendor/github.com/golang-sql/civil/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/vendor/github.com/golang-sql/civil/README.md b/vendor/github.com/golang-sql/civil/README.md new file mode 100644 index 0000000..3a7956e --- /dev/null +++ b/vendor/github.com/golang-sql/civil/README.md @@ -0,0 +1,15 @@ +# Civil Date and Time + +[![GoDoc](https://godoc.org/github.com/golang-sql/civil?status.svg)](https://godoc.org/github.com/golang-sql/civil) + +Civil provides Date, Time of Day, and DateTime data types. + +While there are many uses, using specific types when working +with databases make is conceptually eaiser to understand what value +is set in the remote system. + +## Source + +This civil package was extracted and forked from `cloud.google.com/go/civil`. +As such the license and contributing requirements remain the same as that +module. diff --git a/vendor/github.com/golang-sql/civil/civil.go b/vendor/github.com/golang-sql/civil/civil.go new file mode 100644 index 0000000..29272ef --- /dev/null +++ b/vendor/github.com/golang-sql/civil/civil.go @@ -0,0 +1,277 @@ +// Copyright 2016 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package civil implements types for civil time, a time-zone-independent +// representation of time that follows the rules of the proleptic +// Gregorian calendar with exactly 24-hour days, 60-minute hours, and 60-second +// minutes. +// +// Because they lack location information, these types do not represent unique +// moments or intervals of time. Use time.Time for that purpose. +package civil + +import ( + "fmt" + "time" +) + +// A Date represents a date (year, month, day). +// +// This type does not include location information, and therefore does not +// describe a unique 24-hour timespan. +type Date struct { + Year int // Year (e.g., 2014). + Month time.Month // Month of the year (January = 1, ...). + Day int // Day of the month, starting at 1. +} + +// DateOf returns the Date in which a time occurs in that time's location. +func DateOf(t time.Time) Date { + var d Date + d.Year, d.Month, d.Day = t.Date() + return d +} + +// ParseDate parses a string in RFC3339 full-date format and returns the date value it represents. +func ParseDate(s string) (Date, error) { + t, err := time.Parse("2006-01-02", s) + if err != nil { + return Date{}, err + } + return DateOf(t), nil +} + +// String returns the date in RFC3339 full-date format. +func (d Date) String() string { + return fmt.Sprintf("%04d-%02d-%02d", d.Year, d.Month, d.Day) +} + +// IsValid reports whether the date is valid. +func (d Date) IsValid() bool { + return DateOf(d.In(time.UTC)) == d +} + +// In returns the time corresponding to time 00:00:00 of the date in the location. +// +// In is always consistent with time.Date, even when time.Date returns a time +// on a different day. For example, if loc is America/Indiana/Vincennes, then both +// time.Date(1955, time.May, 1, 0, 0, 0, 0, loc) +// and +// civil.Date{Year: 1955, Month: time.May, Day: 1}.In(loc) +// return 23:00:00 on April 30, 1955. +// +// In panics if loc is nil. +func (d Date) In(loc *time.Location) time.Time { + return time.Date(d.Year, d.Month, d.Day, 0, 0, 0, 0, loc) +} + +// AddDays returns the date that is n days in the future. +// n can also be negative to go into the past. +func (d Date) AddDays(n int) Date { + return DateOf(d.In(time.UTC).AddDate(0, 0, n)) +} + +// DaysSince returns the signed number of days between the date and s, not including the end day. +// This is the inverse operation to AddDays. +func (d Date) DaysSince(s Date) (days int) { + // We convert to Unix time so we do not have to worry about leap seconds: + // Unix time increases by exactly 86400 seconds per day. + deltaUnix := d.In(time.UTC).Unix() - s.In(time.UTC).Unix() + return int(deltaUnix / 86400) +} + +// Before reports whether d1 occurs before d2. +func (d1 Date) Before(d2 Date) bool { + if d1.Year != d2.Year { + return d1.Year < d2.Year + } + if d1.Month != d2.Month { + return d1.Month < d2.Month + } + return d1.Day < d2.Day +} + +// After reports whether d1 occurs after d2. +func (d1 Date) After(d2 Date) bool { + return d2.Before(d1) +} + +// MarshalText implements the encoding.TextMarshaler interface. +// The output is the result of d.String(). +func (d Date) MarshalText() ([]byte, error) { + return []byte(d.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// The date is expected to be a string in a format accepted by ParseDate. +func (d *Date) UnmarshalText(data []byte) error { + var err error + *d, err = ParseDate(string(data)) + return err +} + +// A Time represents a time with nanosecond precision. +// +// This type does not include location information, and therefore does not +// describe a unique moment in time. +// +// This type exists to represent the TIME type in storage-based APIs like BigQuery. +// Most operations on Times are unlikely to be meaningful. Prefer the DateTime type. +type Time struct { + Hour int // The hour of the day in 24-hour format; range [0-23] + Minute int // The minute of the hour; range [0-59] + Second int // The second of the minute; range [0-59] + Nanosecond int // The nanosecond of the second; range [0-999999999] +} + +// TimeOf returns the Time representing the time of day in which a time occurs +// in that time's location. It ignores the date. +func TimeOf(t time.Time) Time { + var tm Time + tm.Hour, tm.Minute, tm.Second = t.Clock() + tm.Nanosecond = t.Nanosecond() + return tm +} + +// ParseTime parses a string and returns the time value it represents. +// ParseTime accepts an extended form of the RFC3339 partial-time format. After +// the HH:MM:SS part of the string, an optional fractional part may appear, +// consisting of a decimal point followed by one to nine decimal digits. +// (RFC3339 admits only one digit after the decimal point). +func ParseTime(s string) (Time, error) { + t, err := time.Parse("15:04:05.999999999", s) + if err != nil { + return Time{}, err + } + return TimeOf(t), nil +} + +// String returns the date in the format described in ParseTime. If Nanoseconds +// is zero, no fractional part will be generated. Otherwise, the result will +// end with a fractional part consisting of a decimal point and nine digits. +func (t Time) String() string { + s := fmt.Sprintf("%02d:%02d:%02d", t.Hour, t.Minute, t.Second) + if t.Nanosecond == 0 { + return s + } + return s + fmt.Sprintf(".%09d", t.Nanosecond) +} + +// IsValid reports whether the time is valid. +func (t Time) IsValid() bool { + // Construct a non-zero time. + tm := time.Date(2, 2, 2, t.Hour, t.Minute, t.Second, t.Nanosecond, time.UTC) + return TimeOf(tm) == t +} + +// MarshalText implements the encoding.TextMarshaler interface. +// The output is the result of t.String(). +func (t Time) MarshalText() ([]byte, error) { + return []byte(t.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// The time is expected to be a string in a format accepted by ParseTime. +func (t *Time) UnmarshalText(data []byte) error { + var err error + *t, err = ParseTime(string(data)) + return err +} + +// A DateTime represents a date and time. +// +// This type does not include location information, and therefore does not +// describe a unique moment in time. +type DateTime struct { + Date Date + Time Time +} + +// Note: We deliberately do not embed Date into DateTime, to avoid promoting AddDays and Sub. + +// DateTimeOf returns the DateTime in which a time occurs in that time's location. +func DateTimeOf(t time.Time) DateTime { + return DateTime{ + Date: DateOf(t), + Time: TimeOf(t), + } +} + +// ParseDateTime parses a string and returns the DateTime it represents. +// ParseDateTime accepts a variant of the RFC3339 date-time format that omits +// the time offset but includes an optional fractional time, as described in +// ParseTime. Informally, the accepted format is +// YYYY-MM-DDTHH:MM:SS[.FFFFFFFFF] +// where the 'T' may be a lower-case 't'. +func ParseDateTime(s string) (DateTime, error) { + t, err := time.Parse("2006-01-02T15:04:05.999999999", s) + if err != nil { + t, err = time.Parse("2006-01-02t15:04:05.999999999", s) + if err != nil { + return DateTime{}, err + } + } + return DateTimeOf(t), nil +} + +// String returns the date in the format described in ParseDate. +func (dt DateTime) String() string { + return dt.Date.String() + "T" + dt.Time.String() +} + +// IsValid reports whether the datetime is valid. +func (dt DateTime) IsValid() bool { + return dt.Date.IsValid() && dt.Time.IsValid() +} + +// In returns the time corresponding to the DateTime in the given location. +// +// If the time is missing or ambigous at the location, In returns the same +// result as time.Date. For example, if loc is America/Indiana/Vincennes, then +// both +// time.Date(1955, time.May, 1, 0, 30, 0, 0, loc) +// and +// civil.DateTime{ +// civil.Date{Year: 1955, Month: time.May, Day: 1}}, +// civil.Time{Minute: 30}}.In(loc) +// return 23:30:00 on April 30, 1955. +// +// In panics if loc is nil. +func (dt DateTime) In(loc *time.Location) time.Time { + return time.Date(dt.Date.Year, dt.Date.Month, dt.Date.Day, dt.Time.Hour, dt.Time.Minute, dt.Time.Second, dt.Time.Nanosecond, loc) +} + +// Before reports whether dt1 occurs before dt2. +func (dt1 DateTime) Before(dt2 DateTime) bool { + return dt1.In(time.UTC).Before(dt2.In(time.UTC)) +} + +// After reports whether dt1 occurs after dt2. +func (dt1 DateTime) After(dt2 DateTime) bool { + return dt2.Before(dt1) +} + +// MarshalText implements the encoding.TextMarshaler interface. +// The output is the result of dt.String(). +func (dt DateTime) MarshalText() ([]byte, error) { + return []byte(dt.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// The datetime is expected to be a string in a format accepted by ParseDateTime +func (dt *DateTime) UnmarshalText(data []byte) error { + var err error + *dt, err = ParseDateTime(string(data)) + return err +} diff --git a/vendor/vendor.json b/vendor/vendor.json index cbb09f1..ea55cf8 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -2,6 +2,40 @@ "comment": "", "ignore": "test", "package": [ + { + "checksumSHA1": "QDDeYsNxrtjwqUR1rcMYtSTSULA=", + "path": "github.com/denisenkom/go-mssqldb", + "revision": "cfbb681360f0a7de54ae77703318f0e60d422e00", + "revisionTime": "2019-10-01T01:33:58Z" + }, + { + "checksumSHA1": "wu8t19t2rmyrrfDfdu9v7f/+iag=", + "path": "github.com/denisenkom/go-mssqldb/internal/cp", + "revision": "cfbb681360f0a7de54ae77703318f0e60d422e00", + "revisionTime": "2019-10-01T01:33:58Z" + }, + { + "checksumSHA1": "Kvwtq+/9H6FYcVoYisDdj+XrHao=", + "path": "github.com/denisenkom/go-mssqldb/internal/decimal", + "revision": "cfbb681360f0a7de54ae77703318f0e60d422e00", + "revisionTime": "2019-10-01T01:33:58Z" + }, + { + "checksumSHA1": "zKX5r8sv52kDL+qQ/qb5OobqCSo=", + "path": "github.com/denisenkom/go-mssqldb/internal/querytext", + "revision": "cfbb681360f0a7de54ae77703318f0e60d422e00", + "revisionTime": "2019-10-01T01:33:58Z" + }, + { + "checksumSHA1": "ikmefTNVlcOywmKiyDHTXmr6LL4=", + "path": "github.com/golang-sql/civil", + "revision": "cb61b32ac6fe84d34b81730175f91965e43d0f90", + "revisionTime": "2019-07-19T16:38:53Z" + }, + { + "path": "github.com/prometheus", + "revision": "" + }, { "checksumSHA1": "kO2MAzPuortudABaRd7fY+7EaoM=", "path": "github.com/prometheus/procfs", @@ -15,6 +49,7 @@ "revisionTime": "2016-03-17T12:45:02+01:00" }, { + "checksumSHA1": "KgT+peLCcuh0/m2mpoOZXuxXmwc=", "path": "gopkg.in/yaml.v2", "revision": "7ad95dd0798a40da1ccdff6dff35fd177b5edf40", "revisionTime": "2015-06-24T11:29:02+01:00" @@ -35,7 +70,6 @@ "revisionTime": "2015-06-24T11:29:02+01:00" }, { - "checksumSHA1": "KgT+peLCcuh0/m2mpoOZXuxXmwc=", "path": "gopkg.in/yaml.v2", "revision": "7ad95dd0798a40da1ccdff6dff35fd177b5edf40", "revisionTime": "2015-06-24T11:29:02+01:00" @@ -46,7 +80,6 @@ "revisionTime": "2015-06-24T11:29:02+01:00" }, { - "checksumSHA1": "KgT+peLCcuh0/m2mpoOZXuxXmwc=", "path": "gopkg.in/yaml.v2", "revision": "7ad95dd0798a40da1ccdff6dff35fd177b5edf40", "revisionTime": "2015-06-24T11:29:02+01:00" @@ -57,6 +90,7 @@ "revisionTime": "2015-06-24T11:29:02+01:00" }, { + "checksumSHA1": "KgT+peLCcuh0/m2mpoOZXuxXmwc=", "path": "gopkg.in/yaml.v2", "revision": "7ad95dd0798a40da1ccdff6dff35fd177b5edf40", "revisionTime": "2015-06-24T11:29:02+01:00" @@ -76,5 +110,6 @@ "revision": "7ad95dd0798a40da1ccdff6dff35fd177b5edf40", "revisionTime": "2015-06-24T11:29:02+01:00" } - ] + ], + "rootPath": "azure_sql_exporter" } From 07b2dce7d9498ef8fbe6293651c472ef77206457 Mon Sep 17 00:00:00 2001 From: "s.hardt" Date: Wed, 30 Oct 2019 12:19:57 +0100 Subject: [PATCH 2/3] Bumped to version 0.1.1 --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index a11b82c..4f12912 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -VERSION := 0.1.0 +VERSION := 0.1.1 LDFLAGS := -X main.Version=$(VERSION) GOFLAGS := -ldflags "$(LDFLAGS)" From 71a9dc66a6deba352eb7446a9fb595cf45094a42 Mon Sep 17 00:00:00 2001 From: "s.hardt" Date: Wed, 30 Oct 2019 16:11:42 +0100 Subject: [PATCH 3/3] Disable cgo so we can use it in an alpine docker image --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 4f12912..4ef79c3 100644 --- a/Makefile +++ b/Makefile @@ -26,4 +26,4 @@ format: .PHONY: docker docker: - docker run --rm -v "$$PWD":/go/src/github.com/iamseth/azure_sql_exporter -w /go/src/github.com/iamseth/azure_sql_exporter golang:1.8 bash -c make + docker run --rm -v "$$PWD":/go/src/github.com/iamseth/azure_sql_exporter -e CGO_ENABLED=0 -w /go/src/github.com/iamseth/azure_sql_exporter golang:1.8 bash -c make