Skip to content
This repository was archived by the owner on Dec 20, 2022. It is now read-only.

Commit 616834d

Browse files
committed
Use single QP between executors and driver.
Fixes the issue of instantiating wrong QP type (RPC requestor/responder) between executors that's running on the same host as driver. Executor instantiates with a driver QP of type RPC, while driver reuse passive channel to send announce messages from other executors. Communication between executors still uses RDMA_READ_REQUESTOR/RDMA_READ_RESPONDER QPs. Change-Id: I8acc0ac796ab1a2f16bc1e3c987fadc5bee7a110
1 parent 47c0856 commit 616834d

File tree

5 files changed

+78
-78
lines changed

5 files changed

+78
-78
lines changed

src/main/java/org/apache/spark/shuffle/rdma/RdmaChannel.java

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public class RdmaChannel {
4343
private final ConcurrentHashMap<Integer, ConcurrentLinkedDeque<SVCPostSend>> svcPostSendCache =
4444
new ConcurrentHashMap();
4545

46-
enum RdmaChannelType { RPC_REQUESTOR, RPC_RESPONDER, RDMA_READ_REQUESTOR, RDMA_READ_RESPONDER }
46+
enum RdmaChannelType { RPC, RDMA_READ_REQUESTOR, RDMA_READ_RESPONDER }
4747
private final RdmaChannelType rdmaChannelType;
4848

4949
private final RdmaCompletionListener receiveListener;
@@ -130,6 +130,7 @@ private class CompletionInfo {
130130
// NOOP_RESERVED_INDEX is used for send operations that do not require a callback
131131
private static final int NOOP_RESERVED_INDEX = 0;
132132
private final AtomicInteger completionInfoIndex = new AtomicInteger(NOOP_RESERVED_INDEX);
133+
private final RdmaShuffleConf conf;
133134

134135
RdmaChannel(
135136
RdmaChannelType rdmaChannelType,
@@ -152,32 +153,20 @@ private class CompletionInfo {
152153
this.receiveListener = receiveListener;
153154
this.rdmaBufferManager = rdmaBufferManager;
154155
this.cpuVector = cpuVector;
156+
this.conf = conf;
155157

156158
switch (rdmaChannelType) {
157-
case RPC_REQUESTOR:
158-
// Requires full-size sends, and receives for credit reports only
159+
case RPC:
160+
// Single bidirectional QP between executors and driver.
159161
if (conf.swFlowControl()) {
160-
this.recvDepth = RECV_CREDIT_REPORT_RATIO;
161-
this.remoteRecvCredits = new Semaphore(conf.recvQueueDepth(), false);
162-
} else {
163-
this.recvDepth = 0;
162+
this.remoteRecvCredits = new Semaphore(
163+
conf.recvQueueDepth() - RECV_CREDIT_REPORT_RATIO, false);
164164
}
165-
this.recvWrSize = 0;
166-
this.sendDepth = conf.sendQueueDepth();
167-
this.sendBudgetSemaphore = new Semaphore(sendDepth, false);
168-
break;
169-
170-
case RPC_RESPONDER:
171-
// Requires full-size receives and sends for credit reports only
172165
this.recvDepth = conf.recvQueueDepth();
173166
this.recvWrSize = conf.recvWrSize();
174-
if (conf.swFlowControl()) {
175-
this.sendDepth = RECV_CREDIT_REPORT_RATIO;
176-
} else {
177-
this.sendDepth = 0;
178-
}
167+
this.sendDepth = conf.sendQueueDepth();
168+
this.sendBudgetSemaphore = new Semaphore(sendDepth - RECV_CREDIT_REPORT_RATIO, false);
179169
break;
180-
181170
case RDMA_READ_REQUESTOR:
182171
// Requires sends only, no need for any receives
183172
this.recvDepth = 0;
@@ -322,6 +311,10 @@ void connect(InetSocketAddress socketAddress) throws IOException {
322311
setRdmaChannelState(RdmaChannelState.CONNECTED);
323312
}
324313

314+
InetSocketAddress getSourceSocketAddress() throws IOException {
315+
return (InetSocketAddress)cmId.getSource();
316+
}
317+
325318
void accept() throws IOException {
326319
RdmaConnParam connParams = new RdmaConnParam();
327320

@@ -778,7 +771,7 @@ private void exhaustCq() throws IOException {
778771
}
779772
}
780773

781-
if (sendDepth == RECV_CREDIT_REPORT_RATIO) {
774+
if (conf.swFlowControl() && rdmaChannelType == RdmaChannelType.RPC) {
782775
// Software-level flow control is enabled
783776
localRecvCreditsPendingReport += reclaimedRecvWrs;
784777
if (localRecvCreditsPendingReport > (recvDepth / RECV_CREDIT_REPORT_RATIO)) {
@@ -895,7 +888,7 @@ void stop() throws InterruptedException, IOException {
895888
int ret = cmId.disconnect();
896889
if (ret != 0) {
897890
logger.error("disconnect failed with errno: " + ret);
898-
} else if (rdmaChannelType.equals(RdmaChannelType.RPC_REQUESTOR) ||
891+
} else if (rdmaChannelType.equals(RdmaChannelType.RPC) ||
899892
rdmaChannelType.equals(RdmaChannelType.RDMA_READ_REQUESTOR)) {
900893
try {
901894
processRdmaCmEvent(RdmaCmEvent.EventType.RDMA_CM_EVENT_DISCONNECTED.ordinal(),

src/main/java/org/apache/spark/shuffle/rdma/RdmaNode.java

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,12 @@ class RdmaNode {
5252
private InetAddress driverInetAddress;
5353
private final ArrayList<Integer> cpuArrayList = new ArrayList<>();
5454
private int cpuIndex = 0;
55+
private final RdmaCompletionListener receiveListener;
5556

5657
RdmaNode(String hostName, boolean isExecutor, final RdmaShuffleConf conf,
5758
final RdmaCompletionListener receiveListener) throws Exception {
5859
this.conf = conf;
59-
60+
this.receiveListener = receiveListener;
6061
try {
6162
driverInetAddress = InetAddress.getByName(conf.driverHost());
6263

@@ -147,10 +148,8 @@ class RdmaNode {
147148
}
148149

149150
RdmaChannel.RdmaChannelType rdmaChannelType;
150-
if (driverInetAddress.equals(inetSocketAddress.getAddress()) ||
151-
driverInetAddress.equals(localInetSocketAddress.getAddress())) {
152-
// RPC communication is limited to driver<->executor only
153-
rdmaChannelType = RdmaChannel.RdmaChannelType.RPC_RESPONDER;
151+
if (!isExecutor) {
152+
rdmaChannelType = RdmaChannel.RdmaChannelType.RPC;
154153
} else {
155154
rdmaChannelType = RdmaChannel.RdmaChannelType.RDMA_READ_RESPONDER;
156155
}
@@ -162,6 +161,12 @@ class RdmaNode {
162161
rdmaChannel.stop();
163162
continue;
164163
}
164+
if (!isExecutor) {
165+
RdmaChannel previous = activeRdmaChannelMap.put(inetSocketAddress, rdmaChannel);
166+
if (previous != null) {
167+
previous.stop();
168+
}
169+
}
165170

166171
try {
167172
rdmaChannel.accept();
@@ -275,8 +280,8 @@ private int getNextCpuVector() {
275280

276281
public RdmaBufferManager getRdmaBufferManager() { return rdmaBufferManager; }
277282

278-
public RdmaChannel getRdmaChannel(InetSocketAddress remoteAddr, boolean mustRetry)
279-
throws IOException, InterruptedException {
283+
public RdmaChannel getRdmaChannel(InetSocketAddress remoteAddr, boolean mustRetry,
284+
RdmaChannel.RdmaChannelType rdmaChannelType) throws IOException, InterruptedException {
280285
final long startTime = System.nanoTime();
281286
final int maxConnectionAttempts = conf.maxConnectionAttempts();
282287
final long connectionTimeout = maxConnectionAttempts * conf.rdmaCmEventTimeout();
@@ -287,16 +292,12 @@ public RdmaChannel getRdmaChannel(InetSocketAddress remoteAddr, boolean mustRetr
287292
do {
288293
rdmaChannel = activeRdmaChannelMap.get(remoteAddr);
289294
if (rdmaChannel == null) {
290-
RdmaChannel.RdmaChannelType rdmaChannelType;
291-
if (driverInetAddress.equals(remoteAddr.getAddress()) ||
292-
driverInetAddress.equals(localInetSocketAddress.getAddress())) {
293-
// RPC communication is limited to driver<->executor only
294-
rdmaChannelType = RdmaChannel.RdmaChannelType.RPC_REQUESTOR;
295-
} else {
296-
rdmaChannelType = RdmaChannel.RdmaChannelType.RDMA_READ_REQUESTOR;
295+
RdmaCompletionListener listener = null;
296+
if (rdmaChannelType == RdmaChannel.RdmaChannelType.RPC) {
297+
// Executor <-> Driver rdma channels need listener on both sides.
298+
listener = receiveListener;
297299
}
298-
299-
rdmaChannel = new RdmaChannel(rdmaChannelType, conf, rdmaBufferManager, null,
300+
rdmaChannel = new RdmaChannel(rdmaChannelType, conf, rdmaBufferManager, listener,
300301
getNextCpuVector());
301302

302303
RdmaChannel actualRdmaChannel = activeRdmaChannelMap.putIfAbsent(remoteAddr, rdmaChannel);

src/main/scala/org/apache/spark/shuffle/rdma/RdmaRpcMsg.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ import scala.collection.mutable.ArrayBuffer
2525

2626
import org.apache.spark.internal.Logging
2727
import org.apache.spark.shuffle.rdma.RdmaRpcMsgType.RdmaRpcMsgType
28-
import org.apache.spark.storage.BlockManagerId
29-
3028

3129
object RdmaRpcMsgType extends Enumeration {
3230
type RdmaRpcMsgType = Value
@@ -80,14 +78,14 @@ object RdmaRpcMsg extends Logging {
8078
}
8179
}
8280

83-
class RdmaShuffleManagerHelloRpcMsg(var rdmaShuffleManagerId: RdmaShuffleManagerId)
84-
extends RdmaRpcMsg {
85-
private def this() = this(null) // For deserialization only
81+
class RdmaShuffleManagerHelloRpcMsg(var rdmaShuffleManagerId: RdmaShuffleManagerId,
82+
var channelPort: Int) extends RdmaRpcMsg {
83+
private def this() = this(null, 0) // For deserialization only
8684

8785
override protected def msgType: RdmaRpcMsgType = RdmaRpcMsgType.RdmaShuffleManagerHello
8886

8987
override protected def getLengthInSegments(segmentSize: Int): Array[Int] = {
90-
val serializedLength = rdmaShuffleManagerId.serializedLength
88+
val serializedLength = rdmaShuffleManagerId.serializedLength + 4
9189
require(serializedLength <= segmentSize, "RdmaBuffer RPC segment size is too small")
9290

9391
Array.fill(1) { serializedLength }
@@ -96,10 +94,12 @@ class RdmaShuffleManagerHelloRpcMsg(var rdmaShuffleManagerId: RdmaShuffleManager
9694
override protected def writeSegments(outs: Iterator[(DataOutputStream, Int)]): Unit = {
9795
val out = outs.next()._1
9896
rdmaShuffleManagerId.write(out)
97+
out.writeInt(channelPort)
9998
}
10099

101100
override protected def read(in: DataInputStream): Unit = {
102101
rdmaShuffleManagerId = RdmaShuffleManagerId(in)
102+
channelPort = in.readInt()
103103
}
104104
}
105105

src/main/scala/org/apache/spark/shuffle/rdma/RdmaShuffleFetcherIterator.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ private[spark] final class RdmaShuffleFetcherIterator(
169169
}
170170

171171
try {
172-
val rdmaChannel = rdmaShuffleManager.getRdmaChannel(pendingFetch.rdmaShuffleManagerId,
173-
mustRetry = true)
172+
val rdmaChannel = rdmaShuffleManager.getRdmaChannel(
173+
pendingFetch.rdmaShuffleManagerId, mustRetry = true)
174174
rdmaChannel.rdmaReadInQueue(listener, rdmaRegisteredBuffer.getRegisteredAddress,
175175
rdmaRegisteredBuffer.getLkey, pendingFetch.rdmaBlockLocations.map(_.length).toArray,
176176
pendingFetch.rdmaBlockLocations.map(_.address).toArray,

src/main/scala/org/apache/spark/shuffle/rdma/RdmaShuffleManager.scala

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -82,36 +82,36 @@ private[spark] class RdmaShuffleManager(val conf: SparkConf, isDriver: Boolean)
8282
// Book keep mapping from BlockManagerId to RdmaShuffleManagerId
8383
blockManagerIdToRdmaShuffleManagerId.put(helloMsg.rdmaShuffleManagerId.blockManagerId,
8484
helloMsg.rdmaShuffleManagerId)
85-
Future {
86-
getRdmaChannel(helloMsg.rdmaShuffleManagerId, mustRetry = false)
87-
}.onSuccess { case rdmaChannel =>
88-
rdmaShuffleManagersMap.put(helloMsg.rdmaShuffleManagerId, rdmaChannel)
89-
val buffers = new RdmaAnnounceRdmaShuffleManagersRpcMsg(
90-
rdmaShuffleManagersMap.keys.toSeq).toRdmaByteBufferManagedBuffers(
91-
getRdmaByteBufferManagedBuffer, rdmaShuffleConf.recvWrSize)
92-
93-
for ((dstRdmaShuffleManagerId, dstRdmaChannel) <- rdmaShuffleManagersMap) {
94-
buffers.foreach(_.retain())
95-
96-
val listener = new RdmaCompletionListener {
97-
override def onSuccess(buf: ByteBuffer): Unit = buffers.foreach(_.release())
98-
override def onFailure(e: Throwable): Unit = {
99-
buffers.foreach(_.release())
100-
logError("Failed to send RdmaAnnounceExecutorsRpcMsg to executor: " +
101-
dstRdmaShuffleManagerId + ", Exception: " + e)
102-
}
85+
// Since we're reusing executor <-> driver QP - whis will be taken from cache.
86+
val rdmaChannel = getRdmaChannel(helloMsg.rdmaShuffleManagerId.host,
87+
helloMsg.channelPort, false, RdmaChannel.RdmaChannelType.RPC)
88+
rdmaShuffleManagersMap.put(helloMsg.rdmaShuffleManagerId, rdmaChannel)
89+
val buffers = new RdmaAnnounceRdmaShuffleManagersRpcMsg(
90+
rdmaShuffleManagersMap.keys.toSeq).toRdmaByteBufferManagedBuffers(
91+
getRdmaByteBufferManagedBuffer, rdmaShuffleConf.recvWrSize)
92+
93+
for ((dstRdmaShuffleManagerId, dstRdmaChannel) <- rdmaShuffleManagersMap) {
94+
buffers.foreach(_.retain())
95+
96+
val listener = new RdmaCompletionListener {
97+
override def onSuccess(buf: ByteBuffer): Unit = buffers.foreach(_.release())
98+
99+
override def onFailure(e: Throwable): Unit = {
100+
buffers.foreach(_.release())
101+
logError("Failed to send RdmaAnnounceExecutorsRpcMsg to executor: " +
102+
dstRdmaShuffleManagerId + ", Exception: " + e)
103103
}
104+
}
104105

105-
try {
106-
dstRdmaChannel.rdmaSendInQueue(listener, buffers.map(_.getAddress),
107-
buffers.map(_.getLkey), buffers.map(_.getLength.toInt))
108-
} catch {
109-
case e: Exception => listener.onFailure(e)
110-
}
106+
try {
107+
dstRdmaChannel.rdmaSendInQueue(listener, buffers.map(_.getAddress),
108+
buffers.map(_.getLkey), buffers.map(_.getLength.toInt))
109+
} catch {
110+
case e: Exception => listener.onFailure(e)
111111
}
112-
// Release the reference taken by the allocation
113-
buffers.foreach(_.release())
114112
}
113+
// Release the reference taken by the allocation
114+
buffers.foreach(_.release())
115115
}
116116

117117
case announceMsg: RdmaAnnounceRdmaShuffleManagersRpcMsg =>
@@ -205,7 +205,8 @@ private[spark] class RdmaShuffleManager(val conf: SparkConf, isDriver: Boolean)
205205
Future {
206206
getRdmaChannelToDriver(mustRetry = true)
207207
}.onSuccess { case rdmaChannel =>
208-
val buffers = new RdmaShuffleManagerHelloRpcMsg(localRdmaShuffleManagerId.get).
208+
val port = rdmaChannel.getSourceSocketAddress.getPort
209+
val buffers = new RdmaShuffleManagerHelloRpcMsg(localRdmaShuffleManagerId.get, port).
209210
toRdmaByteBufferManagedBuffers(getRdmaByteBufferManagedBuffer, rdmaShuffleConf.recvWrSize)
210211

211212
val listener = new RdmaCompletionListener {
@@ -308,14 +309,19 @@ private[spark] class RdmaShuffleManager(val conf: SparkConf, isDriver: Boolean)
308309
}
309310
}
310311

311-
private def getRdmaChannel(host: String, port: Int, mustRetry: Boolean): RdmaChannel =
312-
rdmaNode.get.getRdmaChannel(new InetSocketAddress(host, port), mustRetry)
312+
private def getRdmaChannel(host: String, port: Int, mustRetry: Boolean,
313+
rdmaChannelType: RdmaChannel.RdmaChannelType): RdmaChannel =
314+
rdmaNode.get.getRdmaChannel(new InetSocketAddress(host, port), mustRetry, rdmaChannelType)
313315

314-
def getRdmaChannel(rdmaShuffleManagerId: RdmaShuffleManagerId, mustRetry: Boolean): RdmaChannel =
315-
getRdmaChannel(rdmaShuffleManagerId.host, rdmaShuffleManagerId.port, mustRetry)
316+
def getRdmaChannel(rdmaShuffleManagerId: RdmaShuffleManagerId,
317+
mustRetry: Boolean): RdmaChannel = {
318+
getRdmaChannel(rdmaShuffleManagerId.host, rdmaShuffleManagerId.port, mustRetry,
319+
RdmaChannel.RdmaChannelType.RDMA_READ_REQUESTOR)
320+
}
316321

317322
def getRdmaChannelToDriver(mustRetry: Boolean): RdmaChannel = getRdmaChannel(
318-
rdmaShuffleConf.driverHost, rdmaShuffleConf.driverPort, mustRetry)
323+
rdmaShuffleConf.driverHost, rdmaShuffleConf.driverPort, mustRetry,
324+
RdmaChannel.RdmaChannelType.RPC)
319325

320326
def getRdmaBufferManager: RdmaBufferManager = rdmaNode.get.getRdmaBufferManager
321327

0 commit comments

Comments
 (0)