diff --git a/.travis.yml b/.travis.yml index cd44e11..1d8ae75 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,9 +9,9 @@ matrix: env: SCHEME="PostgreSQL-Package" before_install: + - brew update - gem install xcpretty - brew tap vapor/tap - - brew update - brew install vapor install: diff --git a/Sources/PostgreSQL/Connection.swift b/Sources/PostgreSQL/Connection.swift index a95dda8..9a18cb0 100644 --- a/Sources/PostgreSQL/Connection.swift +++ b/Sources/PostgreSQL/Connection.swift @@ -10,8 +10,11 @@ public final class Connection: ConnInfoInitializable { // MARK: - CConnection public typealias CConnection = OpaquePointer - - public let cConnection: CConnection + + @available(*, deprecated: 2.2, message: "needs to be optional or could cause runtime crash passing invalid reference to C") + public var cConnection: CConnection { return pgConnection! } + + public private(set) var pgConnection: CConnection? // MARK: - Init @@ -27,14 +30,14 @@ public final class Connection: ConnInfoInitializable { string = "host='\(hostname)' port='\(port)' dbname='\(database)' user='\(user)' password='\(password)' client_encoding='UTF8'" } - cConnection = PQconnectdb(string) + pgConnection = PQconnectdb(string) try validateConnection() } // MARK: - Deinit deinit { - try? close() + close() } // MARK: - Execute @@ -68,7 +71,7 @@ public final class Connection: ConnInfoInitializable { } let resultPointer: Result.Pointer? = PQexecParams( - cConnection, + pgConnection, query, Int32(binds.count), types, @@ -85,27 +88,35 @@ public final class Connection: ConnInfoInitializable { // MARK: - Connection Status public var isConnected: Bool { - return PQstatus(cConnection) == CONNECTION_OK + return pgConnection != nil && PQstatus(pgConnection) == CONNECTION_OK } public var status: ConnStatusType { - return PQstatus(cConnection) + guard pgConnection != nil else { return CONNECTION_BAD } + return PQstatus(pgConnection) } - private func validateConnection() throws { + func validateConnection() throws { + guard pgConnection != nil else { + throw PostgreSQLError(code: .connectionDoesNotExist, connection: self) + } guard isConnected else { throw PostgreSQLError(code: .connectionFailure, connection: self) } } public func reset() throws { - try validateConnection() - PQreset(cConnection) + guard let connection = pgConnection else { return } + PQreset(connection) + guard status == CONNECTION_OK else { + throw PostgreSQLError(code: .connectionFailure, connection: self) + } } - public func close() throws { - try validateConnection() - PQfinish(cConnection) + public func close() { + guard pgConnection != nil else { return } + PQfinish(pgConnection) + pgConnection = nil } // MARK: - Transaction @@ -152,6 +163,7 @@ public final class Connection: ConnInfoInitializable { public let channel: String public let payload: String? + /// internal initializer init(pgNotify: PGnotify) { channel = String(cString: pgNotify.relname) pid = Int(pgNotify.be_pid) @@ -160,8 +172,7 @@ public final class Connection: ConnInfoInitializable { let string = String(cString: pgNotify.extra) if !string.isEmpty { payload = string - } - else { + } else { payload = nil } } @@ -171,12 +182,45 @@ public final class Connection: ConnInfoInitializable { } } + /// Creates a dispatch read source for this connection that will call `callback` on `queue` when a notification is received + /// + /// - Parameter channel: the channel to register for + /// - Parameter queue: the queue to create the DispatchSource on + /// - Parameter callback: the callback + /// - Parameter notification: The notification received from the database + /// - Parameter error: Any error while reading the notification. If not nil, the source will have been canceled + /// - Returns: the dispatch socket to activate + /// - Throws: if fails to get the socket for the connection + public func listen(toChannel channel: String, queue: DispatchQueue, callback: @escaping (_ notification: Notification?, _ error: Error?) -> Void) throws -> DispatchSourceRead { + let sock = PQsocket(self.pgConnection) + guard sock >= 0 else { + throw PostgreSQLError(code: .ioError, reason: "failed to get socket for connection") + } + let src = DispatchSource.makeReadSource(fileDescriptor: sock, queue: queue) + src.setEventHandler { [weak self] in + guard let strongSelf = self else { return } + guard strongSelf.pgConnection != nil else { + callback(nil, PostgreSQLError(code: .connectionDoesNotExist, reason: "connection does not exist")) + return + } + PQconsumeInput(strongSelf.pgConnection) + while let pgNotify = PQnotifies(strongSelf.pgConnection) { + let notification = Notification(pgNotify: pgNotify.pointee) + callback(notification, nil) + PQfreemem(pgNotify) + } + } + try self.execute("LISTEN \(channel)") + return src + } + /// Registers as a listener on a specific notification channel. /// /// - Parameters: /// - channel: The channel to register for. /// - queue: The queue to perform the listening on. /// - callback: Callback containing any received notification or error and a boolean which can be set to true to stop listening. + @available(*, deprecated: 2.2, message: "replaced with version using DispatchSource") public func listen(toChannel channel: String, on queue: DispatchQueue = DispatchQueue.global(), callback: @escaping (Notification?, Error?, inout Bool) -> Void) { queue.async { var stop: Bool = false @@ -190,9 +234,9 @@ public final class Connection: ConnInfoInitializable { // Sleep to avoid looping continuously on cpu sleep(1) - PQconsumeInput(self.cConnection) + PQconsumeInput(self.pgConnection) - while !stop, let pgNotify = PQnotifies(self.cConnection) { + while !stop, let pgNotify = PQnotifies(self.pgConnection) { let notification = Notification(pgNotify: pgNotify.pointee) callback(notification, nil, &stop) @@ -234,7 +278,7 @@ public final class Connection: ConnInfoInitializable { } private func getBooleanParameterStatus(key: String, `default` defaultValue: Bool = false) -> Bool { - guard let value = PQparameterStatus(cConnection, "integer_datetimes") else { + guard let value = PQparameterStatus(pgConnection, "integer_datetimes") else { return defaultValue } return String(cString: value) == "on" diff --git a/Sources/PostgreSQL/Error.swift b/Sources/PostgreSQL/Error.swift index 82e29d8..1fbbbb3 100644 --- a/Sources/PostgreSQL/Error.swift +++ b/Sources/PostgreSQL/Error.swift @@ -305,7 +305,7 @@ extension PostgreSQLError { extension PostgreSQLError { public init(code: Code, connection: Connection) { let reason: String - if let error = PQerrorMessage(connection.cConnection) { + if let error = PQerrorMessage(connection.pgConnection) { reason = String(cString: error) } else { diff --git a/Tests/PostgreSQLTests/ConnectionTests.swift b/Tests/PostgreSQLTests/ConnectionTests.swift index cf2e3f9..5791469 100644 --- a/Tests/PostgreSQLTests/ConnectionTests.swift +++ b/Tests/PostgreSQLTests/ConnectionTests.swift @@ -9,6 +9,7 @@ class ConnectionTests: XCTestCase { ("testConnInfoRaw", testConnInfoRaw), ("testConnectionFailure", testConnectionFailure), ("testConnectionSuccess", testConnectionSuccess), + ("testInvalidConnection", testInvalidConnection), ] var postgreSQL: PostgreSQL.Database! @@ -18,11 +19,12 @@ class ConnectionTests: XCTestCase { let conn = try postgreSQL.makeConnection() let connection = try postgreSQL.makeConnection() - XCTAssert(conn.status == CONNECTION_OK) + let status = conn.status + XCTAssert(status == CONNECTION_OK) XCTAssertTrue(connection.isConnected) try connection.reset() - try connection.close() + connection.close() XCTAssertFalse(connection.isConnected) } @@ -86,4 +88,17 @@ class ConnectionTests: XCTestCase { XCTFail("Could not connect to database") } } + + func testInvalidConnection() throws { + postgreSQL = PostgreSQL.Database.makeTest() + let connection = try postgreSQL.makeConnection() + try connection.validateConnection() + connection.close() + do { + try connection.validateConnection() + XCTFail("connection was valid after close") + } catch { + // connection was invalid + } + } } diff --git a/Tests/PostgreSQLTests/PostgreSQLTests.swift b/Tests/PostgreSQLTests/PostgreSQLTests.swift index 55d35a2..9efa2c7 100644 --- a/Tests/PostgreSQLTests/PostgreSQLTests.swift +++ b/Tests/PostgreSQLTests/PostgreSQLTests.swift @@ -1,6 +1,7 @@ import XCTest @testable import PostgreSQL import Foundation +import Dispatch class PostgreSQLTests: XCTestCase { static let allTests = [ @@ -30,6 +31,9 @@ class PostgreSQLTests: XCTestCase { ("testUnsupportedObject", testUnsupportedObject), ("testNotification", testNotification), ("testNotificationWithPayload", testNotificationWithPayload), + ("testDispatchNotification", testDispatchNotification), + ("testDispatchNotificationInvalidConnection", testDispatchNotificationInvalidConnection), + ("testDispatchNotificationWithPayload", testDispatchNotificationWithPayload), ("testQueryToNode", testQueryToNode) ] @@ -779,6 +783,45 @@ class PostgreSQLTests: XCTestCase { waitForExpectations(timeout: 5) } + func testDispatchNotification() throws { + let conn1 = try postgreSQL.makeConnection() + let conn2 = try postgreSQL.makeConnection() + + let testExpectation = expectation(description: "Receive notification") + + let queue = DispatchQueue.global() + var source: DispatchSourceRead? + source = try! conn1.listen(toChannel: "test_channel1", queue: queue) { (notification, error) in + XCTAssertEqual(notification?.channel, "test_channel1") + XCTAssertNil(notification?.payload) + XCTAssertNil(error) + + testExpectation.fulfill() + source?.cancel() + } + source?.resume() + + sleep(1) + + try conn2.notify(channel: "test_channel1", payload: nil) + + waitForExpectations(timeout: 5) + } + + func testDispatchNotificationInvalidConnection() throws { + let conn1 = try postgreSQL.makeConnection() + conn1.close() + do { + _ = try conn1.listen(toChannel: "test_channel1", queue: .global()) { (notification, error) in + XCTFail("callback should never be called") + } + XCTFail("exception should have been thrown because connection was not open") + } catch { + guard let pgerror = error as? PostgreSQLError else { XCTFail("incorrect error type"); return } + XCTAssertEqual(pgerror.code, .ioError) + } + } + func testNotificationWithPayload() throws { let conn1 = try postgreSQL.makeConnection() let conn2 = try postgreSQL.makeConnection() @@ -801,6 +844,32 @@ class PostgreSQLTests: XCTestCase { waitForExpectations(timeout: 5) } + func testDispatchNotificationWithPayload() throws { + let conn1 = try postgreSQL.makeConnection() + let conn2 = try postgreSQL.makeConnection() + + let testExpectation = expectation(description: "Receive notification with payload") + + let queue = DispatchQueue.global() + var source: DispatchSourceRead? + source = try! conn1.listen(toChannel: "test_channel2", queue: queue) { (notification, error) in + XCTAssertEqual(notification?.channel, "test_channel2") + XCTAssertEqual(notification?.payload, "test_payload") + XCTAssertNotNil(notification?.payload) + XCTAssertNil(error) + + testExpectation.fulfill() + source?.cancel() + } + source?.resume() + + sleep(1) + + try conn2.notify(channel: "test_channel2", payload: "test_payload") + + waitForExpectations(timeout: 5) + } + func testQueryToNode() throws { let conn = try postgreSQL.makeConnection() diff --git a/Tests/PostgreSQLTests/Utilities.swift b/Tests/PostgreSQLTests/Utilities.swift index 1ec5374..0624520 100644 --- a/Tests/PostgreSQLTests/Utilities.swift +++ b/Tests/PostgreSQLTests/Utilities.swift @@ -8,7 +8,7 @@ extension PostgreSQL.Database { let postgreSQL = try PostgreSQL.Database( hostname: "127.0.0.1", port: 5432, - database: "postgres", + database: "test", user: "postgres", password: "" )