diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 098fa7974b87b..8f1107d8a796f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -35,6 +35,7 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; @@ -106,7 +107,7 @@ protected void handleMessage( } else if (msgObj instanceof RegisterExecutor) { final Timer.Context responseDelayContext = - metrics.registerExecutorRequestLatencyMillis.time(); + metrics.registerExecutorRequestLatencyMillis.time(); try { RegisterExecutor msg = (RegisterExecutor) msgObj; checkAuth(client, msg.appId); @@ -116,9 +117,49 @@ protected void handleMessage( responseDelayContext.stop(); } + } else if (msgObj instanceof RegisterExecutorForBackupsOnly) { + final Timer.Context responseDelayContext = + metrics.registerExecutorRequestLatencyMillis.time(); + try { + RegisterExecutorForBackupsOnly msg = (RegisterExecutorForBackupsOnly) msgObj; + checkAuth(client, msg.appId); + blockManager.registerExecutorForBackups(msg.appId, msg.execId, msg.shuffleManager); + callback.onSuccess(ByteBuffer.wrap(new byte[0])); + } finally { + responseDelayContext.stop(); + } + } else { throw new UnsupportedOperationException("Unexpected message: " + msgObj); } + + } + + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer messageHeader, + RpcResponseCallback callback) { + BlockTransferMessage header = BlockTransferMessage.Decoder.fromByteBuffer(messageHeader); + if (header instanceof UploadShuffleFileStream) { + UploadShuffleFileStream msg = (UploadShuffleFileStream) header; + checkAuth(client, msg.appId); + return blockManager.openShuffleFileForBackup( + msg.appId, + msg.execId, + msg.shuffleId, + msg.mapId); + } else if (header instanceof UploadShuffleIndexFileStream) { + UploadShuffleIndexFileStream msg = (UploadShuffleIndexFileStream) header; + checkAuth(client, msg.appId); + return blockManager.openShuffleIndexFileForBackup( + msg.appId, + msg.execId, + msg.shuffleId, + msg.mapId); + } else { + throw new UnsupportedOperationException("Unexpected message header: " + header); + } } public MetricSet getAllMetrics() { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 0b7a27402369d..4304a421e7167 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -18,7 +18,12 @@ package org.apache.spark.network.shuffle; import java.io.*; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; import java.util.*; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; @@ -44,6 +49,7 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.LevelDBProvider; import org.apache.spark.network.util.LevelDBProvider.StoreVersion; @@ -75,6 +81,8 @@ public class ExternalShuffleBlockResolver { @VisibleForTesting final ConcurrentMap executors; + private final ConcurrentMap backupExecutors; + /** * Caches index file information so that we can avoid open/close the index files * for each block fetch. @@ -95,6 +103,7 @@ public class ExternalShuffleBlockResolver { "org.apache.spark.shuffle.sort.SortShuffleManager", "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager"); + public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorFile) throws IOException { this(conf, registeredExecutorFile, Executors.newSingleThreadExecutor( @@ -144,6 +153,12 @@ public void registerExecutor( String execId, ExecutorShuffleInfo executorInfo) { AppExecId fullId = new AppExecId(appId, execId); + if (backupExecutors.containsKey(fullId)) { + throw new UnsupportedOperationException( + String.format( + "Executor %s cannot be registered for both primary shuffle management and backup" + + " shuffle management.", fullId)); + } logger.info("Registered executor {} with {}", fullId, executorInfo); if (!knownManagers.contains(executorInfo.shuffleManager)) { throw new UnsupportedOperationException( @@ -161,6 +176,32 @@ public void registerExecutor( executors.put(fullId, executorInfo); } + private StreamCallbackWithID getFileWriterStreamCallback( + String appId, + String execId, + int shuffleId, + int mapId, + String extension, + FileWriterStreamCallback.BackupFileType backupFileType) { + AppExecId fullId = new AppExecId(appId, execId); + ExecutorShuffleInfo executor = backupExecutors.get(fullId); + if (executor == null) { + throw new RuntimeException( + String.format("Executor is not registered for shuffle file backups" + + " (appId=%s, execId=%s)", appId, execId)); + } + File backedUpFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0." + extension); + FileWriterStreamCallback streamCallback = new FileWriterStreamCallback( + fullId, + shuffleId, + mapId, + backedUpFile, + backupFileType); + streamCallback.open(); + return streamCallback; + } + /** * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId). We make assumptions * about how the hash and sort based shuffles store their data. @@ -173,6 +214,13 @@ public ManagedBuffer getBlockData( int reduceId) { ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); if (executor == null) { + logger.info("application's shuffle data isn't in main file system, checking backups..." + + "app id: {}, executor id: {}, shuffle id: {}, map id: {}, reduce id: {}", + appId, execId, shuffleId, mapId, reduceId); + executor = backupExecutors.get(new AppExecId(appId, execId)); + } + if (executor == null) { + logger.warn("Executor is not registered (appId: {}, execId: {}", appId, execId); throw new RuntimeException( String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index e49e27ab5aa79..06a8b51f85190 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -34,7 +34,9 @@ import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.RegisterDriver; import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.RegisterExecutorForBackupsOnly; import org.apache.spark.network.util.TransportConf; /** @@ -43,7 +45,7 @@ * BlockTransferService), which has the downside of losing the shuffle data if we lose the * executors. */ -public class ExternalShuffleClient extends ShuffleClient { +public class ExternalShuffleClient extends ShuffleClient{ private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); private final TransportConf conf; @@ -90,6 +92,7 @@ public void fetchBlocks( int port, String execId, String[] blockIds, + boolean isRemote, BlockFetchingListener listener, DownloadFileManager downloadFileManager) { checkInit(); @@ -145,6 +148,20 @@ public void registerWithShuffleServer( } } + public void registerWithRemoteShuffleServer( + String driverHostPort, + String host, + int port, + String execId, + String shuffleManager) throws IOException, InterruptedException{ + checkInit(); + try (TransportClient client = clientFactory.createUnmanagedClient(host, port)) { + ByteBuffer registerMessage = new RegisterExecutorForBackupsOnly( + driverHostPort, appId, execId, shuffleManager).toByteBuffer(); + client.sendRpcSync(registerMessage, registrationTimeoutMs); + } + } + @Override public void close() { checkInit(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java new file mode 100644 index 0000000000000..56ec8283ce735 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java @@ -0,0 +1,149 @@ +package org.apache.spark.network.shuffle; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; + +import org.apache.spark.network.client.StreamCallbackWithID; + +final class FileWriterStreamCallback implements StreamCallbackWithID { + + private static final Logger logger = LoggerFactory.getLogger(FileWriterStreamCallback.class); + + public enum BackupFileType { + DATA("shuffle-data"), + INDEX("shuffle-index"); + + private final String typeString; + + BackupFileType(String typeString) { + this.typeString = typeString; + } + + @Override + public String toString() { + return typeString; + } + } + private final ExternalShuffleBlockResolver.AppExecId fullExecId; + private final int shuffleId; + private final int mapId; + private final File file; + private final BackupFileType fileType; + private WritableByteChannel fileOutputChannel = null; + + FileWriterStreamCallback( + ExternalShuffleBlockResolver.AppExecId fullExecId, + int shuffleId, + int mapId, + File file, + BackupFileType fileType) { + this.fullExecId = fullExecId; + this.shuffleId = shuffleId; + this.mapId = mapId; + this.file = file; + this.fileType = fileType; + } + + public void open() { + logger.info( + "Opening {} for backup writing. File type: {}", file.getAbsolutePath(), fileType); + if (fileOutputChannel != null) { + throw new IllegalStateException( + String.format( + "File %s for is already open for writing (type: %s).", + file.getAbsolutePath(), + fileType)); + } + if (!file.exists()) { + try { + if (!file.getParentFile().isDirectory() && !file.getParentFile().mkdirs()) { + throw new IOException( + String.format( + "Failed to create shuffle file directory at" + + file.getParentFile().getAbsolutePath() + "(type: %s).", fileType)); + } + + if (!file.createNewFile()) { + throw new IOException( + String.format( + "Failed to create shuffle file (type: %s).", fileType)); + } + } catch (IOException e) { + throw new RuntimeException( + String.format( + "Failed to create shuffle file at %s for backup (type: %s).", + file.getAbsolutePath(), + fileType), + e); + } + } + try { + // TODO encryption + fileOutputChannel = Channels.newChannel(new FileOutputStream(file)); + } catch (FileNotFoundException e) { + throw new RuntimeException( + String.format( + "Failed to find file for writing at %s (type: %s).", + file.getAbsolutePath(), + fileType), + e); + } + } + + @Override + public String getID() { + return String.format("%s-%s-%d-%d-%s", + fullExecId.appId, + fullExecId.execId, + shuffleId, + mapId, + fileType); + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + verifyShuffleFileOpenForWriting(); + while (buf.hasRemaining()) { + fileOutputChannel.write(buf); + } + } + + @Override + public void onComplete(String streamId) throws IOException { + fileOutputChannel.close(); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + logger.warn("Failed to back up shuffle file at {} (type: %s).", + file.getAbsolutePath(), + fileType, + cause); + fileOutputChannel.close(); + // TODO delete parent dirs too + if (!file.delete()) { + logger.warn( + "Failed to delete incomplete backup shuffle file at %s (type: %s)", + file.getAbsolutePath(), + fileType); + } + } + + private void verifyShuffleFileOpenForWriting() { + if (fileOutputChannel == null) { + throw new RuntimeException( + String.format( + "Shuffle file at %s not open for writing (type: %s).", + file.getAbsolutePath(), + fileType)); + } + } +} \ No newline at end of file diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 62b99c40f61f9..5f42cefab9080 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -53,6 +53,7 @@ public abstract void fetchBlocks( int port, String execId, String[] blockIds, + boolean isRemote, BlockFetchingListener listener, DownloadFileManager downloadFileManager); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index 60179f126bc44..8d3d86698974f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -24,7 +24,6 @@ import java.util.concurrent.TimeUnit; import com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,7 +31,8 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.shuffle.ExternalShuffleClient; -import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; +import org.apache.spark.network.shuffle.protocol.RegisterDriver; +import org.apache.spark.network.shuffle.protocol.ExternalServiceHeartbeat; import org.apache.spark.network.util.TransportConf; /** @@ -117,7 +117,7 @@ private Heartbeater(TransportClient client) { @Override public void run() { // TODO: Stop sending heartbeats if the shuffle service has lost the app due to timeout - client.send(new ShuffleServiceHeartbeat(appId).toByteBuffer()); + client.send(new ExternalServiceHeartbeat(appId).toByteBuffer()); } } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index a68a297519b66..dfdbde0eac1c6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -23,8 +23,6 @@ import io.netty.buffer.Unpooled; import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; -import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat; /** * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or @@ -42,7 +40,8 @@ public abstract class BlockTransferMessage implements Encodable { /** Preceding every serialized message is its type, which allows us to deserialize it. */ public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), - HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6); + HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), UPLOAD_SHUFFLE_FILE_STREAM(7), UPLOAD_SHUFFLE_INDEX_STREAM(8), + REGISTER_EXECUTOR_FOR_BACKUPS(9); private final byte id; @@ -66,8 +65,11 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 2: return RegisterExecutor.decode(buf); case 3: return StreamHandle.decode(buf); case 4: return RegisterDriver.decode(buf); - case 5: return ShuffleServiceHeartbeat.decode(buf); + case 5: return ExternalServiceHeartbeat.decode(buf); case 6: return UploadBlockStream.decode(buf); + case 7: return UploadShuffleFileStream.decode(buf); + case 8: return UploadShuffleIndexFileStream.decode(buf); + case 9: return RegisterExecutorForBackupsOnly.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java index 93758bdc58fb0..8daaf6b772b64 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java @@ -89,6 +89,9 @@ public static ExecutorShuffleInfo decode(ByteBuf buf) { String[] localDirs = Encoders.StringArrays.decode(buf); int subDirsPerLocalDir = buf.readInt(); String shuffleManager = Encoders.Strings.decode(buf); - return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager); + return new ExecutorShuffleInfo( + localDirs, + subDirsPerLocalDir, + shuffleManager); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExternalServiceHeartbeat.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExternalServiceHeartbeat.java new file mode 100644 index 0000000000000..d4403e8b94aa9 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExternalServiceHeartbeat.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Encoders; + +/** + * A heartbeat sent from the driver to some ExternalService + */ +public class ExternalServiceHeartbeat extends BlockTransferMessage { + private final String appId; + + public ExternalServiceHeartbeat(String appId) { + this.appId = appId; + } + + public String getAppId() { return appId; } + + @Override + protected Type type() { return Type.HEARTBEAT; } + + @Override + public int encodedLength() { return Encoders.Strings.encodedLength(appId); } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + } + + public static ExternalServiceHeartbeat decode(ByteBuf buf) { + return new ExternalServiceHeartbeat(Encoders.Strings.decode(buf)); + } +} \ No newline at end of file diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java similarity index 50% rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java index d5f53ccb7f741..1a3918a9b36aa 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.shuffle.protocol.mesos; +package org.apache.spark.network.shuffle.protocol; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; @@ -30,48 +30,48 @@ * A message sent from the driver to register with the MesosExternalShuffleService. */ public class RegisterDriver extends BlockTransferMessage { - private final String appId; - private final long heartbeatTimeoutMs; + private final String appId; + private final long heartbeatTimeoutMs; - public RegisterDriver(String appId, long heartbeatTimeoutMs) { - this.appId = appId; - this.heartbeatTimeoutMs = heartbeatTimeoutMs; - } + public RegisterDriver(String appId, long heartbeatTimeoutMs) { + this.appId = appId; + this.heartbeatTimeoutMs = heartbeatTimeoutMs; + } - public String getAppId() { return appId; } + public String getAppId() { return appId; } - public long getHeartbeatTimeoutMs() { return heartbeatTimeoutMs; } + public long getHeartbeatTimeoutMs() { return heartbeatTimeoutMs; } - @Override - protected Type type() { return Type.REGISTER_DRIVER; } + @Override + protected Type type() { return Type.REGISTER_DRIVER; } - @Override - public int encodedLength() { - return Encoders.Strings.encodedLength(appId) + Long.SIZE / Byte.SIZE; - } + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + Long.SIZE / Byte.SIZE; + } - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, appId); - buf.writeLong(heartbeatTimeoutMs); - } + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + buf.writeLong(heartbeatTimeoutMs); + } - @Override - public int hashCode() { - return Objects.hashCode(appId, heartbeatTimeoutMs); - } + @Override + public int hashCode() { + return Objects.hashCode(appId, heartbeatTimeoutMs); + } - @Override - public boolean equals(Object o) { - if (!(o instanceof RegisterDriver)) { - return false; + @Override + public boolean equals(Object o) { + if (!(o instanceof RegisterDriver)) { + return false; + } + return Objects.equal(appId, ((RegisterDriver) o).appId); } - return Objects.equal(appId, ((RegisterDriver) o).appId); - } - public static RegisterDriver decode(ByteBuf buf) { - String appId = Encoders.Strings.decode(buf); - long heartbeatTimeout = buf.readLong(); - return new RegisterDriver(appId, heartbeatTimeout); - } + public static RegisterDriver decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + long heartbeatTimeout = buf.readLong(); + return new RegisterDriver(appId, heartbeatTimeout); + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorForBackupsOnly.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorForBackupsOnly.java new file mode 100644 index 0000000000000..7e2806a7b014e --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorForBackupsOnly.java @@ -0,0 +1,80 @@ +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +public class RegisterExecutorForBackupsOnly extends BlockTransferMessage { + public final String driverHostPort; + public final String appId; + public final String execId; + public final String shuffleManager; + + public RegisterExecutorForBackupsOnly( + String driverHostPort, + String appId, + String execId, + String shuffleManager) { + this.driverHostPort = driverHostPort; + this.appId = appId; + this.execId = execId; + this.shuffleManager = shuffleManager; + } + + @Override + protected Type type() { + return Type.REGISTER_EXECUTOR_FOR_BACKUPS; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(driverHostPort) + + Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.Strings.encodedLength(shuffleManager); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, driverHostPort); + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.Strings.encode(buf, shuffleManager); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RegisterExecutorForBackupsOnly) { + RegisterExecutorForBackupsOnly o = (RegisterExecutorForBackupsOnly) other; + return Objects.equal(driverHostPort, o.driverHostPort) + && Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(shuffleManager, o.shuffleManager); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hashCode(driverHostPort, appId, execId, shuffleManager); + } + + @Override + public String toString() { + return Objects.toStringHelper(RegisterExecutorForBackupsOnly.class) + .add("driverHostPort", driverHostPort) + .add("appId", appId) + .add("execId", execId) + .add("shuffleManager", shuffleManager) + .toString(); + } + + public static RegisterExecutorForBackupsOnly decode(ByteBuf buf) { + String driverHostPort = Encoders.Strings.decode(buf); + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String shuffleManager = Encoders.Strings.decode(buf); + return new RegisterExecutorForBackupsOnly(driverHostPort, appId, execId, shuffleManager); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleFileStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleFileStream.java new file mode 100644 index 0000000000000..409a00c1d89ac --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleFileStream.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +public class UploadShuffleFileStream extends BlockTransferMessage { + public final String appId; + public final String execId; + public final int shuffleId; + public final int mapId; + + public UploadShuffleFileStream( + String appId, + String execId, + int shuffleId, + int mapId) { + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; + this.mapId = mapId; + } + + @Override + protected Type type() { + return Type.UPLOAD_SHUFFLE_FILE_STREAM; + } + + @Override + public int hashCode() { + return Objects.hashCode( + appId, + execId, + shuffleId, + mapId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("shuffleId", shuffleId) + .add("mapId", mapId) + .toString(); + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + 8; + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + buf.writeInt(shuffleId); + buf.writeInt(mapId); + } + + public static UploadShuffleFileStream decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int mapId = buf.readInt(); + return new UploadShuffleFileStream(appId, execId, shuffleId, mapId); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexFileStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexFileStream.java new file mode 100644 index 0000000000000..0bd6301517716 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexFileStream.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +public class UploadShuffleIndexFileStream extends BlockTransferMessage { + public final String appId; + public final String execId; + public final int shuffleId; + public final int mapId; + + public UploadShuffleIndexFileStream( + String appId, + String execId, + int shuffleId, + int mapId) { + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; + this.mapId = mapId; + } + + @Override + protected Type type() { + return Type.UPLOAD_SHUFFLE_INDEX_STREAM; + } + + @Override + public int hashCode() { + return Objects.hashCode( + appId, + execId, + shuffleId, + mapId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("shuffleId", shuffleId) + .add("mapId", mapId) + .toString(); + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + 8; + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + buf.writeInt(shuffleId); + buf.writeInt(mapId); + } + + public static UploadShuffleIndexFileStream decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int mapId = buf.readInt(); + return new UploadShuffleIndexFileStream(appId, execId, shuffleId, mapId); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java deleted file mode 100644 index b30bb9aed55b6..0000000000000 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.protocol.mesos; - -import io.netty.buffer.ByteBuf; -import org.apache.spark.network.protocol.Encoders; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; - -// Needed by ScalaDoc. See SPARK-7726 -import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; - -/** - * A heartbeat sent from the driver to the MesosExternalShuffleService. - */ -public class ShuffleServiceHeartbeat extends BlockTransferMessage { - private final String appId; - - public ShuffleServiceHeartbeat(String appId) { - this.appId = appId; - } - - public String getAppId() { return appId; } - - @Override - protected Type type() { return Type.HEARTBEAT; } - - @Override - public int encodedLength() { return Encoders.Strings.encodedLength(appId); } - - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, appId); - } - - public static ShuffleServiceHeartbeat decode(ByteBuf buf) { - return new ShuffleServiceHeartbeat(Encoders.Strings.decode(buf)); - } -} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 526b96b364473..3e7bc8107924e 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -135,7 +135,7 @@ private FetchResult fetchBlocks( try (ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false, 5000)) { client.init(APP_ID); - client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, + client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, false, new BlockFetchingListener() { @Override public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 4839d04522f10..5b6f39c37e335 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -424,6 +424,65 @@ private long[] mergeSpillsWithFileStream( return partitionLengths; } + /** + * Merges spill files using the ShufflePartitionWriter API. + */ + private long[] mergeSpillsWithPluggableWriter( + SpillInfo[] spills, + @Nullable CompressionCodec compressionCodec) throws IOException { + assert (spills.length >= 2); + assert(blockManager. != null); + final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; + final InputStream[] spillInputStreams = new InputStream[spills.length]; + boolean threwException = true; + try (ShufflePartitionWriter writer = writeSupport.newPartitionWriter( + sparkConf.getAppId(), shuffleId, mapId)) { + try { + for (int i = 0; i < spills.length; i++) { + spillInputStreams[i] = new NioBufferedFileInputStream( + spills[i].file, + inputBufferSizeInBytes); + } + for (int partition = 0; partition < numPartitions; partition++) { + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], + partitionLengthInSpill, false); + try { + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + } + partitionLengths[partition] = writer.appendPartition(partition, partitionInputStream); + } finally { + partitionInputStream.close(); + } + } + } + } + } catch (Exception e) { + try { + writer.abort(); + } catch (Exception e2) { + logger.warn("Failed to close shuffle writer upon aborting.", e2); + } + } + threwException = false; + } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. + for (InputStream stream : spillInputStreams) { + Closeables.close(stream, threwException); + } + } + return partitionLengths; + } + + + /** * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes. * This is only safe when the IO compression codec and serializer support concatenation of diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 1c4fa4bc6541f..925947a0c6fd2 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -22,18 +22,19 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.{HashMap, HashSet, ListBuffer, Map} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.reflect.ClassTag import scala.util.control.NonFatal - import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException +import org.apache.spark.shuffle.external.ShuffleServiceAddressProvider import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ @@ -213,9 +214,12 @@ private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage +private[spark] case object GetRemoteShuffleServiceAddresses extends MapOutputTrackerMessage +private[spark] sealed trait BackupMessage +private[spark] case class HeartbeaterMessage(appId: String) extends BackupMessage private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext) - + extends BackupMessage /** RpcEndpoint class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterEndpoint( override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf) @@ -227,7 +231,11 @@ private[spark] class MapOutputTrackerMasterEndpoint( case GetMapOutputStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) - val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context)) + val message = GetMapOutputMessage(shuffleId, context) + tracker.post[GetMapOutputMessage](message) + + case GetRemoteShuffleServiceAddresses => + context.reply(tracker.getRemoteShuffleServiceAddresses) case StopMapOutputTracker => logInfo("MapOutputTrackerMasterEndpoint stopped!") @@ -295,7 +303,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * and the second item is a sequence of (shuffle block id, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + def getMapSizesByExecutorId( + shuffleId: Int, startPartition: Int, endPartition: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] /** @@ -318,9 +327,20 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging private[spark] class MapOutputTrackerMaster( conf: SparkConf, broadcastManager: BroadcastManager, - isLocal: Boolean) + isLocal: Boolean, + backupMaster: Option[MapOutputTrackerMaster], + shuffleServiceAddressProvider: ShuffleServiceAddressProvider) extends MapOutputTracker(conf) { + def this( + conf: SparkConf, + broadcastManager: BroadcastManager, + isLocal: Boolean, + shuffleServiceAddressProvider: ShuffleServiceAddressProvider) = this( + conf, broadcastManager, isLocal, Some(new MapOutputTrackerMaster( + conf, broadcastManager, isLocal, None, shuffleServiceAddressProvider)), + shuffleServiceAddressProvider) + // The size at which we use Broadcast to send the map output statuses to the executors private val minSizeForBroadcast = conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt @@ -348,7 +368,7 @@ private[spark] class MapOutputTrackerMaster( private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) // requests for map output statuses - private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage] + private val mapOutputRequests = new LinkedBlockingQueue[BackupMessage] // Thread pool used for handling map output status requests. This is a separate thread pool // to ensure we don't block the normal dispatcher threads. @@ -371,10 +391,15 @@ private[spark] class MapOutputTrackerMaster( throw new IllegalArgumentException(msg) } - def post(message: GetMapOutputMessage): Unit = { + def post[T <: BackupMessage](message: T): Unit = { mapOutputRequests.offer(message) } + def postToBackup[T <: BackupMessage](message: T): Unit = { + require(backupMaster.isDefined, "No backup master available") + backupMaster.foreach(_.post(message)) + } + /** Message loop used for dispatching messages. */ private class MessageLoop extends Runnable { override def run(): Unit = { @@ -382,19 +407,20 @@ private[spark] class MapOutputTrackerMaster( while (true) { try { val data = mapOutputRequests.take() - if (data == PoisonPill) { - // Put PoisonPill back so that other MessageLoops can see it. - mapOutputRequests.offer(PoisonPill) - return + data match { + case PoisonPill => + // Put PoisonPill back so that other MessageLoops can see it. + mapOutputRequests.offer(PoisonPill) + return + case GetMapOutputMessage(shuffleId, context) => + val hostPort = context.senderAddress.hostPort + // TODO: Change back to debug + logDebug("Handling request to send map output locations for shuffle " + shuffleId + + " to " + hostPort) + val shuffleStatus = shuffleStatuses.get(shuffleId).head + context.reply( + shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast)) } - val context = data.context - val shuffleId = data.shuffleId - val hostPort = context.senderAddress.hostPort - logDebug("Handling request to send map output locations for shuffle " + shuffleId + - " to " + hostPort) - val shuffleStatus = shuffleStatuses.get(shuffleId).head - context.reply( - shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast)) } catch { case NonFatal(e) => logError(e.getMessage, e) } @@ -417,6 +443,7 @@ private[spark] class MapOutputTrackerMaster( if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } + backupMaster.foreach(_.registerShuffle(shuffleId, numMaps)) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { @@ -644,13 +671,16 @@ private[spark] class MapOutputTrackerMaster( } } + def getRemoteShuffleServiceAddresses: List[(String, Int)] = + shuffleServiceAddressProvider.getShuffleServiceAddresses() + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { - case Some (shuffleStatus) => + case Some(shuffleStatus) => shuffleStatus.withMapStatuses { statuses => MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } @@ -671,7 +701,7 @@ private[spark] class MapOutputTrackerMaster( /** * Executor-side client for fetching map output info from the driver's MapOutputTrackerMaster. * Note that this is not used in local-mode; instead, local-mode Executors access the - * MapOutputTrackerMaster directly (which is possible because the master and worker share a comon + * MapOutputTrackerMaster directly (which is possible because the master and worker share a common * superclass). */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { @@ -683,10 +713,11 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr private val fetching = new HashSet[Int] // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. - override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + override def getMapSizesByExecutorId( + shuffleId: Int, startPartition: Int, endPartition: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") - val statuses = getStatuses(shuffleId) + val statuses = getStatuses(shuffleId, mapStatuses, fetching) try { MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } catch { @@ -698,22 +729,26 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } /** - * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize - * on this array when reading it, because on the driver, we may be changing it in place. - * - * (It would be nice to remove this restriction in the future.) - */ - private def getStatuses(shuffleId: Int): Array[MapStatus] = { - val statuses = mapStatuses.get(shuffleId).orNull + * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * on this array when reading it, because on the driver, we may be changing it in place. + * + * (It would be nice to remove this restriction in the future.) + */ + private def getStatuses( + shuffleId: Int, + statusesToInspect: Map[Int, Array[MapStatus]], + statusesBeingFetched: mutable.HashSet[Int]) + : Array[MapStatus] = { + val statuses = statusesToInspect.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") val startTime = System.currentTimeMillis var fetchedStatuses: Array[MapStatus] = null - fetching.synchronized { + statusesBeingFetched.synchronized { // Someone else is fetching it; wait for them to be done - while (fetching.contains(shuffleId)) { + while (statusesBeingFetched.contains(shuffleId)) { try { - fetching.wait() + statusesBeingFetched.wait() } catch { case e: InterruptedException => } @@ -721,10 +756,10 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr // Either while we waited the fetch happened successfully, or // someone fetched it in between the get and the fetching.synchronized. - fetchedStatuses = mapStatuses.get(shuffleId).orNull + fetchedStatuses = statusesToInspect.get(shuffleId).orNull if (fetchedStatuses == null) { // We have to do the fetch, get others to wait for us. - fetching += shuffleId + statusesBeingFetched += shuffleId } } @@ -736,11 +771,11 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) + statusesToInspect.put(shuffleId, fetchedStatuses) } finally { - fetching.synchronized { - fetching -= shuffleId - fetching.notifyAll() + statusesBeingFetched.synchronized { + statusesBeingFetched -= shuffleId + statusesBeingFetched.notifyAll() } } } @@ -759,7 +794,6 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } } - /** Unregister shuffle data. */ def unregisterShuffle(shuffleId: Int): Unit = { mapStatuses.remove(shuffleId) @@ -835,7 +869,6 @@ private[spark] object MapOutputTracker extends Logging { objIn.close() } } - bytes(0) match { case DIRECT => deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[MapStatus]] diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 66038eeaea54f..b3587326caae1 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -19,13 +19,13 @@ package org.apache.spark import java.io.File import java.net.Socket -import java.util.Locale - -import scala.collection.mutable -import scala.util.Properties +import java.util.{Locale, ServiceLoader} import com.google.common.collect.MapMaker +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.Properties import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager @@ -39,6 +39,7 @@ import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} +import org.apache.spark.shuffle.external.{DefaultShuffleServiceAddressProvider, ShuffleServiceAddressProviderFactory} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ import org.apache.spark.util.{RpcUtils, Utils} @@ -302,7 +303,20 @@ object SparkEnv extends Logging { val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) val mapOutputTracker = if (isDriver) { - new MapOutputTrackerMaster(conf, broadcastManager, isLocal) + val loader = Utils.getContextOrSparkClassLoader + val master = conf.get("spark.master") + val serviceLoaders = + ServiceLoader.load(classOf[ShuffleServiceAddressProviderFactory], loader) + .asScala.filter(_.canCreate(conf.get("spark.master"))) + if (serviceLoaders.size > 1) { + throw new SparkException( + s"Multiple external cluster managers registered for the url $master: $serviceLoaders") + } + val shuffleServiceAddressProvider = serviceLoaders.headOption + .map(_.create(conf)) + .getOrElse(DefaultShuffleServiceAddressProvider) + shuffleServiceAddressProvider.start() + new MapOutputTrackerMaster(conf, broadcastManager, isLocal, shuffleServiceAddressProvider) } else { new MapOutputTrackerWorker(conf) } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index d34601358d896..25064ea6e496e 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -103,6 +103,17 @@ package object config { private[spark] val EXECUTOR_HEARTBEAT_MAX_FAILURES = ConfigBuilder("spark.executor.heartbeat.maxFailures").internal().intConf.createWithDefault(60) + private[spark] val SHUFFLE_BACKUP_HEARTBEAT_INTERVAL = + ConfigBuilder("spark.driver.externalShuffleBackup.heartbeatInterval") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("10s") + + private[spark] val SHUFFLE_REMOTE_READ_OVERRIDE = + ConfigBuilder("spark.shuffle.externalShuffleBackup.remote") + .booleanConf + .createWithDefault(false) + + private[spark] val EXECUTOR_JAVA_OPTIONS = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS).stringConf.createOptional diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index a58c8fa2e763f..df44cc26a8cc2 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -67,6 +67,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo port: Int, execId: String, blockIds: Array[String], + isRemote: Boolean, listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit @@ -92,10 +93,11 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo port: Int, execId: String, blockId: String, + isRemote: Boolean, tempFileManager: DownloadFileManager): ManagedBuffer = { // A monitor for the thread to wait on. val result = Promise[ManagedBuffer]() - fetchBlocks(host, port, execId, Array(blockId), + fetchBlocks(host, port, execId, Array(blockId), isRemote, new BlockFetchingListener { override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { result.failure(exception) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index dc55685b1e7bd..3089b94576063 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -105,6 +105,7 @@ private[spark] class NettyBlockTransferService( port: Int, execId: String, blockIds: Array[String], + isRemote: Boolean, listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 64f0a060a247c..6ae8554e7bad5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -243,3 +243,31 @@ private[spark] object HighlyCompressedMapStatus { hugeBlockSizes) } } + +private[spark] class RelocatedMapStatus private( + private[this] var original: MapStatus, + private[this] var newLocation: BlockManagerId) + extends MapStatus with Externalizable { + + protected def this() = this(null, null) + + override def location: BlockManagerId = newLocation + + override def getSizeForBlock(reduceId: Int): Long = original.getSizeForBlock(reduceId) + + override def writeExternal(out: ObjectOutput): Unit = { + out.writeObject(original) + out.writeObject(newLocation) + } + + override def readExternal(in: ObjectInput): Unit = { + this.original = in.readObject().asInstanceOf[MapStatus] + this.newLocation = in.readObject().asInstanceOf[BlockManagerId] + } +} + +private[spark] object RelocatedMapStatus { + def apply(original: MapStatus, newLocation: BlockManagerId): MapStatus = { + new RelocatedMapStatus(original, newLocation) + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 74b0e0b3a741a..ab2e0c7fec303 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -20,7 +20,8 @@ package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.shuffle.external.ShuffleReadSupport +import org.apache.spark.storage._ import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -33,6 +34,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( startPartition: Int, endPartition: Int, context: TaskContext, + appId: String, + shuffleReadSupport: ShuffleReadSupport = null, serializerManager: SerializerManager = SparkEnv.get.serializerManager, blockManager: BlockManager = SparkEnv.get.blockManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) @@ -42,19 +45,35 @@ private[spark] class BlockStoreShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val wrappedStreams = new ShuffleBlockFetcherIterator( - context, - blockManager.shuffleClient, - blockManager, - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), - serializerManager.wrapStream, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, - SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), - SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), - SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), - SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) - + val wrappedStreams = if (shuffleReadSupport != null) { + getBlockIdsGroupedByMapIds(handle.shuffleId, startPartition, endPartition) + .flatMap { case (mapId, blockIds) => + val reader = shuffleReadSupport.newPartitionReader( + appId, handle.shuffleId, mapId) + blockIds.map { + case ShuffleBlockId(_, _, reduceId) => reader.fetchPartition(reduceId) + case ShuffleDataBlockId(_, _, reduceId) => reader.fetchPartition(reduceId) + case invalid => + throw new IllegalArgumentException(s"Invalid block id $invalid") + } + } + } else { + val mapSizesByExecId = + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition) + new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + mapSizesByExecId, + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, + SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true), + readMetrics) + } val serializerInstance = dep.serializer.newInstance() // Create a key/value iterator for each stream @@ -120,4 +139,25 @@ private[spark] class BlockStoreShuffleReader[K, C]( new InterruptibleIterator[Product2[K, C]](context, resultIter) } } + private def getBlockIdsGroupedByMapIds( + shuffleId: Int, startPartition: Int, endPartition: Int): Iterator[(Int, Seq[BlockId])] = { + mapOutputTracker.getMapSizesByExecutorId(shuffleId, startPartition, endPartition) + .flatMap(_._2) + .map(_._1) + .toStream + .filter { blockId => + blockId match { + case ShuffleBlockId(_, _, _) => true + case ShuffleDataBlockId(_, _, _) => true + case _ => false + } + } + .groupBy { + case ShuffleBlockId(_, mapId, _) => mapId + case ShuffleDataBlockId(_, mapId, _) => mapId + case blockId => + throw new IllegalArgumentException(s"Invalid block id: $blockId") + }.mapValues(_.toSeq) + .iterator + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d3f1c7ec1bbee..0a5e12c36c444 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -27,6 +27,7 @@ import org.apache.spark.io.NioBufferedFileInputStream import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID +import org.apache.spark.shuffle.external._ import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -43,19 +44,37 @@ import org.apache.spark.util.Utils // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData(). private[spark] class IndexShuffleBlockResolver( conf: SparkConf, - _blockManager: BlockManager = null) + _blockManager: BlockManager = null, + _shuffleDataIO: ShuffleDataIO = null) extends ShuffleBlockResolver with Logging { + private lazy val appId = conf.getAppId private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager) private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") + private var shuffleWriteSupport: ShuffleWriteSupport = _ + + private var shuffleReadSupport: ShuffleReadSupport = _ + + private var isRemote_ = false + + if (shuffleDataIO != null) { + shuffleWriteSupport = _shuffleDataIO.writeSupport() + shuffleReadSupport = _shuffleDataIO.readSupport() + isRemote_ = true + } + + def shuffleDataIO(): ShuffleDataIO = _shuffleDataIO + def isRemote(): Boolean = isRemote_ + + def getDataFile(shuffleId: Int, mapId: Int): File = { blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } - private def getIndexFile(shuffleId: Int, mapId: Int): File = { + def getIndexFile(shuffleId: Int, mapId: Int): File = { blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } @@ -219,6 +238,15 @@ private[spark] class IndexShuffleBlockResolver( offset, nextOffset - offset) } finally { + if (isRemote()) { + val writer = shuffleWriteSupport.newPartitionWriter(appId, blockId.shuffleId, blockId.mapId) + try { + writer.appendIndexFile(blockId.reduceId, in) + } catch { + case e: Exception => + writer.abort(e) + } + } in.close() } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ExternalRemoteShuffleClient.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ExternalRemoteShuffleClient.scala new file mode 100644 index 0000000000000..c6988901a2259 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ExternalRemoteShuffleClient.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.external + +import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.shuffle._ + +private[spark] class ExternalRemoteShuffleClient( + externalShuffleClient: ExternalShuffleClient, + baseBlockTransferService: BlockTransferService) extends ShuffleClient { + + override def init(appId: String): Unit = { + externalShuffleClient.init(appId) + baseBlockTransferService.init(appId) + } + + override def fetchBlocks( + host: String, + port: Int, + execId: String, + blockIds: Array[String], + isRemote: Boolean, + listener: BlockFetchingListener, + downloadFileManager: DownloadFileManager): Unit = { + if (isRemote) { + externalShuffleClient.fetchBlocks( + host, + port, + execId, + blockIds, + isRemote, + listener, + downloadFileManager) + } else { + baseBlockTransferService.fetchBlocks( + host, port, execId, blockIds, isRemote, listener, downloadFileManager) + } + } + + override def close(): Unit = { + baseBlockTransferService.close() + externalShuffleClient.close() + } + + def registerWithRemoteShuffleServer( + driverHostPort: String, + host: String, + port: Int, + execId: String, + shuffleManager: String) : Unit = { + externalShuffleClient.registerWithRemoteShuffleServer( + driverHostPort, host, port, execId, shuffleManager) + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleDataIO.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleDataIO.scala new file mode 100644 index 0000000000000..fd4b35e42ed79 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleDataIO.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.external + +private[spark] trait ShuffleDataIO { + def writeSupport(): ShuffleWriteSupport + def readSupport(): ShuffleReadSupport +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionReader.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionReader.scala new file mode 100644 index 0000000000000..44a1c74de5458 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionReader.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.external + +import java.io.InputStream + +// TODO: Support batch +private[spark] trait ShufflePartitionReader { + def fetchPartition(reduceId: Long): InputStream +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionWriter.scala new file mode 100644 index 0000000000000..2479f98f04c54 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShufflePartitionWriter.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.external + +import java.io.{Closeable, InputStream, OutputStream} + +private[spark] trait ShufflePartitionWriter extends Closeable { + def appendPartition(partitionId: Long, partitionOutput: OutputStream): Unit + def appendIndexFile(partitionId: Long, indexInput: InputStream): Unit + def abort(exception: Throwable): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleReadSupport.scala new file mode 100644 index 0000000000000..87772d73d7211 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleReadSupport.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.external + +private[spark] trait ShuffleReadSupport { + def newPartitionReader(appId: String, shuffleId: Int, mapId: Int): ShufflePartitionReader +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleServiceAddressProvider.scala similarity index 60% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala rename to core/src/main/scala/org/apache/spark/shuffle/external/ShuffleServiceAddressProvider.scala index 83daddf714489..8f5f7634f8161 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleServiceAddressProvider.scala @@ -14,24 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.scheduler.cluster.k8s -import io.fabric8.kubernetes.api.model.Pod +package org.apache.spark.shuffle.external -sealed trait ExecutorPodState { - def pod: Pod -} - -case class PodRunning(pod: Pod) extends ExecutorPodState - -case class PodPending(pod: Pod) extends ExecutorPodState +trait ShuffleServiceAddressProvider { -sealed trait FinalPodState extends ExecutorPodState + def start(): Unit = {} -case class PodSucceeded(pod: Pod) extends FinalPodState + def getShuffleServiceAddresses(): List[(String, Int)] -case class PodFailed(pod: Pod) extends FinalPodState - -case class PodDeleted(pod: Pod) extends FinalPodState + def stop(): Unit = {} +} -case class PodUnknown(pod: Pod) extends ExecutorPodState +private[spark] object DefaultShuffleServiceAddressProvider extends ShuffleServiceAddressProvider { + override def getShuffleServiceAddresses(): List[(String, Int)] = List.empty[(String, Int)] +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleServiceAddressProviderFactory.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleServiceAddressProviderFactory.scala new file mode 100644 index 0000000000000..7379d9948cce9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleServiceAddressProviderFactory.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.external + +import org.apache.spark.SparkConf + +trait ShuffleServiceAddressProviderFactory { + def canCreate(masterUrl: String): Boolean + + def create(conf: SparkConf): ShuffleServiceAddressProvider +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleWriteSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleWriteSupport.scala new file mode 100644 index 0000000000000..ffc0a75be4764 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/external/ShuffleWriteSupport.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.external + +private[spark] trait ShuffleWriteSupport { + def newPartitionWriter(appId: String, shuffleId: Int, mapId: Int): ShufflePartitionWriter +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/external/default/DefaultShuffleDataIO.scala b/core/src/main/scala/org/apache/spark/shuffle/external/default/DefaultShuffleDataIO.scala new file mode 100644 index 0000000000000..676c2e083f396 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/external/default/DefaultShuffleDataIO.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.external.default + +import java.io.{DataInputStream, File, InputStream, OutputStream} + +import org.apache.spark.io.NioBufferedFileInputStream +import org.apache.spark.shuffle.external._ + +private[spark] object DefaultShuffleDataIO extends ShuffleDataIO { + override def writeSupport(): ShuffleWriteSupport = { + + val shuffleWriteSupport = new ShuffleWriteSupport { + override def newPartitionWriter(appId: String, shuffleId: Int, mapId: Int): ShufflePartitionWriter = ?? + + new ShufflePartitionWriter { + override def appendPartition(partitionId: Long, partitionInput: OutputStream): Unit = ??? + + override def close(): Unit = ??? + } + } + } + + override def readSupport(): ShuffleReadSupport = new ShuffleReadSupport { + override def newPartitionReader(appId: String, shuffleId: Int, mapId: Int): ShufflePartitionReader = + new ShufflePartitionReader { + override def fetchPartition(reduceId: Long): InputStream = ??? + + override def getDataFile(shuffleId: Int, mapId: Int): File = ??? + + override def getIndexFile(shuffleId: Int, mapId: Int): File = ??? + + override def deleteDataFile(shuffleId: Int, mapId: Int): Unit = ??? + + override def deleteIndexFile(shuffleId: Int, mapId: Int): Unit = ??? + } + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 0caf84c6050a8..cd203b5bf38e9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -17,11 +17,18 @@ package org.apache.spark.shuffle.sort +import java.net.URI import java.util.concurrent.ConcurrentHashMap +import scala.util.Random import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.server.NoOpRpcHandler import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.external.ShuffleDataIO +import org.apache.spark.util.ThreadUtils /** * In sort-based shuffle, incoming records are sorted according to their target partition ids, then @@ -81,6 +88,26 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + private val backupShuffleTransportConf = SparkTransportConf.fromSparkConf( + conf, "shuffle", 2) + + private lazy val remoteShuffleTransportClients = SparkEnv + .get + .blockManager. + .blockManager + .getRemoteShuffleServiceAddresses() + .map(address => { + val addressAsUri = URI.create(s"spark://${address._1}:${address._2}") + val transportContext = new TransportContext( + backupShuffleTransportConf, + new NoOpRpcHandler(), + false) + (address, addressAsUri, transportContext.createClientFactory()) + }) + + // TODO: Fill out defaults + private val shuffleDataIO: ShuffleDataIO = null + /** * Obtains a [[ShuffleHandle]] to pass to tasks. */ @@ -116,7 +143,12 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + startPartition, + endPartition, + context, + conf.getAppId, + shuffleDataIO.readSupport()) } /** Get a writer for a given partition. Called on executors by map tasks. */ @@ -174,7 +206,7 @@ private[spark] object SortShuffleManager extends Logging { * buffering map outputs in a serialized form. This is an extreme defensive programming measure, * since it's extremely unlikely that a single shuffle produces over 16 million output partitions. * */ - val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE = + val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE: Int = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 274399b9cc1f3..49972c6540d59 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.sort +import java.io.File + import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.MapStatus @@ -64,11 +66,15 @@ private[spark] class SortShuffleWriter[K, V, C]( // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). + var tmp: File = null val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) - val tmp = Utils.tempFileWith(output) + tmp = Utils.tempFileWith(output) try { val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val partitionLengths = sorter.writePartitionedFile(blockId, tmp) + val partitionLengths = sorter.writePartitionedFile( + blockId, + tmp, + shuffleBlockResolver.shuffleDataIO) shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } finally { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index edae2f95fce33..cef9046207ceb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -24,6 +24,7 @@ import java.nio.channels.Channels import java.util.Collections import java.util.concurrent.ConcurrentHashMap +import com.codahale.metrics.{MetricRegistry, MetricSet} import scala.collection.mutable import scala.collection.mutable.HashMap import scala.concurrent.{ExecutionContext, Future} @@ -32,9 +33,6 @@ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal -import com.codahale.metrics.{MetricRegistry, MetricSet} -import com.google.common.io.CountingOutputStream - import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.internal.{config, Logging} @@ -45,12 +43,13 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.StreamCallbackWithID import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle._ -import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo +import org.apache.spark.network.shuffle.protocol.{ExecutorShuffleInfo, ExternalServiceHeartbeat, RegisterDriver} import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.external.{ExternalRemoteShuffleClient, ShuffleDataIO} import org.apache.spark.storage.memory._ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ @@ -64,7 +63,7 @@ private[spark] class BlockResult( /** * Abstracts away how blocks are stored and provides different ways to read the underlying block - * data. Callers should call [[dispose()]] when they're done with the block. + * data. Callers should call [[ dispose() ]] when they're done with the block. */ private[spark] trait BlockData { @@ -128,11 +127,16 @@ private[spark] class BlockManager( shuffleManager: ShuffleManager, val blockTransferService: BlockTransferService, securityManager: SecurityManager, - numUsableCores: Int) + numUsableCores: Int, + shuffleDataIO: ShuffleDataIO = null) extends BlockDataManager with BlockEvictionHandler with Logging { private[spark] val externalShuffleServiceEnabled = conf.get(config.SHUFFLE_SERVICE_ENABLED) + + private[spark] val readFromRemote = + conf.get(config.SHUFFLE_REMOTE_READ_OVERRIDE) + private val remoteReadNioBufferConversion = conf.getBoolean("spark.network.remoteReadNioBufferConversion", false) @@ -143,6 +147,13 @@ private[spark] class BlockManager( new DiskBlockManager(conf, deleteFilesOnStop) } + val blockMapper: BlockMapper = + if (readFromRemote) { + new RemoteBlockManager(conf, shuffleDataIO.readSupport()) + } else { + diskBlockManager + } + // Visible for testing private[storage] val blockInfoManager = new BlockInfoManager @@ -152,7 +163,7 @@ private[spark] class BlockManager( // Actual storage of where blocks are kept private[spark] val memoryStore = new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this) - private[spark] val diskStore = new DiskStore(conf, diskBlockManager, securityManager) + private[spark] val blockStore = new BlockStore(conf, blockMapper, securityManager) memoryManager.setMemoryStore(memoryStore) // Note: depending on the memory manager, `maxMemory` may actually vary over time. @@ -183,15 +194,11 @@ private[spark] class BlockManager( // service, or just our own Executor's BlockManager. private[spark] var shuffleServerId: BlockManagerId = _ + private var remoteShuffleServiceAddresses: List[(String, Int)] = _ + // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTransferService to directly connect to other Executors. - private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) - new ExternalShuffleClient(transConf, securityManager, - securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) - } else { - blockTransferService - } + private[spark] var shuffleClient: ShuffleClient = _ // Max number of failures before this block manager refreshes the block locations from the driver private val maxFailuresBeforeLocationRefresh = @@ -220,6 +227,7 @@ private[spark] class BlockManager( new BlockManager.RemoteBlockDownloadFileManager(this) private val maxRemoteBlockToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -241,7 +249,6 @@ private[spark] class BlockManager( logInfo(s"Using $priorityClass for block replication policy") ret } - val id = BlockManagerId(executorId, blockTransferService.hostName, blockTransferService.port, None) @@ -253,6 +260,48 @@ private[spark] class BlockManager( blockManagerId = if (idFromMaster != null) idFromMaster else id + remoteShuffleServiceAddresses = if (blockManagerId.isDriver) { + List.empty[(String, Int)] + } else { + Random.shuffle(mapOutputTracker + .trackerEndpoint + .askSync[List[(String, Int)]](GetRemoteShuffleServiceAddresses)) + .take(3) + } + + if (remoteShuffleServiceAddresses.nonEmpty) { + require(!externalShuffleServiceEnabled, "Cannot use external shuffle service with remote" + + " shuffle services.") + } + shuffleClient = if (externalShuffleServiceEnabled) { + require(remoteShuffleServiceAddresses.isEmpty, + "Cannot use the external shuffle service while using remote shuffle services.") + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) + new ExternalShuffleClient(transConf, + securityManager, + securityManager.isAuthenticationEnabled(), + conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) + } else if (remoteShuffleServiceAddresses.nonEmpty) { + logInfo("Using RemoteShuffleServices") + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) + val externalShuffleClient = new ExternalShuffleClient( + transConf, + securityManager, + securityManager.isAuthenticationEnabled(), + conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) + new ExternalRemoteShuffleClient(externalShuffleClient, blockTransferService) + } else blockTransferService + shuffleClient.init(appId) + + blockReplicationPolicy = { + val priorityClass = conf.get( + "spark.storage.replication.policy", classOf[RandomBlockReplicationPolicy].getName) + val clazz = Utils.classForName(priorityClass) + val ret = clazz.newInstance.asInstanceOf[BlockReplicationPolicy] + logInfo(s"Using $priorityClass for block replication policy") + ret + } + shuffleServerId = if (externalShuffleServiceEnabled) { logInfo(s"external shuffle service port = $externalShuffleServicePort") BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) @@ -265,6 +314,13 @@ private[spark] class BlockManager( registerWithExternalShuffleServer() } + remoteShuffleServiceAddresses.foreach(address => { + registerWithRemoteShuffleServer( + address._1, + address._2, + appId) + }) + logInfo(s"Initialized BlockManager: $blockManagerId") } @@ -306,6 +362,37 @@ private[spark] class BlockManager( } } + private def registerWithRemoteShuffleServer( + shuffleServerHost: String, + shuffleServerPort: Int, + appId: String) + : Unit = { + val MAX_ATTEMPTS = conf.get(config.SHUFFLE_REGISTRATION_MAX_ATTEMPTS) + val SLEEP_TIME_SECS = 5 + for (i <- 1 to MAX_ATTEMPTS) { + try { + // Synchronous and will throw an exception if we cannot connect. + shuffleClient + .asInstanceOf[ExternalRemoteShuffleClient] + .registerWithRemoteShuffleServer( + master.driverEndpoint.address.hostPort, + shuffleServerHost, + shuffleServerPort, + shuffleServerId.executorId, + shuffleManager.getClass.getName) + return + } catch { + case e: Exception if i < MAX_ATTEMPTS => + logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}" + + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) + Thread.sleep(SLEEP_TIME_SECS * 1000L) + case NonFatal(e) => + throw new SparkException("Unable to register with external shuffle server due to : " + + e.getMessage, e) + } + } + } + /** * Report all blocks to the BlockManager again. This may be necessary if we are dropped * by the BlockManager and come back or if we become capable of recovering blocks on disk after @@ -470,7 +557,7 @@ private[spark] class BlockManager( def getStatus(blockId: BlockId): Option[BlockStatus] = { blockInfoManager.get(blockId).map { info => val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L - val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L + val diskSize = if (blockStore.contains(blockId)) blockStore.getSize(blockId) else 0L BlockStatus(info.level, memSize = memSize, diskSize = diskSize) } } @@ -538,7 +625,7 @@ private[spark] class BlockManager( BlockStatus.empty case level => val inMem = level.useMemory && memoryStore.contains(blockId) - val onDisk = level.useDisk && diskStore.contains(blockId) + val onDisk = level.useDisk && blockStore.contains(blockId) val deserialized = if (inMem) level.deserialized else false val replication = if (inMem || onDisk) level.replication else 1 val storageLevel = StorageLevel( @@ -548,7 +635,7 @@ private[spark] class BlockManager( deserialized = deserialized, replication = replication) val memSize = if (inMem) memoryStore.getSize(blockId) else 0L - val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L + val diskSize = if (onDisk) blockStore.getSize(blockId) else 0L BlockStatus(storageLevel, memSize, diskSize) } } @@ -602,8 +689,8 @@ private[spark] class BlockManager( releaseLock(blockId, taskAttemptId) }) Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) - } else if (level.useDisk && diskStore.contains(blockId)) { - val diskData = diskStore.getBytes(blockId) + } else if (level.useDisk && blockStore.contains(blockId)) { + val diskData = blockStore.getBytes(blockId) val iterToReturn: Iterator[Any] = { if (level.deserialized) { val diskValues = serializerManager.dataDeserializeStream( @@ -659,12 +746,12 @@ private[spark] class BlockManager( // serializing in-memory objects, and, finally, throw an exception if the block does not exist. if (level.deserialized) { // Try to avoid expensive serialization by reading a pre-serialized copy from disk: - if (level.useDisk && diskStore.contains(blockId)) { + if (level.useDisk && blockStore.contains(blockId)) { // Note: we purposely do not try to put the block back into memory here. Since this branch // handles deserialized blocks, this block may only be cached in memory as objects, not // serialized bytes. Because the caller only requested bytes, it doesn't make sense to // cache the block's deserialized objects since that caching may not have a payoff. - diskStore.getBytes(blockId) + blockStore.getBytes(blockId) } else if (level.useMemory && memoryStore.contains(blockId)) { // The block was not found on disk, so serialize an in-memory copy: new ByteBufferBlockData(serializerManager.dataSerializeWithExplicitClassTag( @@ -675,8 +762,8 @@ private[spark] class BlockManager( } else { // storage level is serialized if (level.useMemory && memoryStore.contains(blockId)) { new ByteBufferBlockData(memoryStore.getBytes(blockId).get, false) - } else if (level.useDisk && diskStore.contains(blockId)) { - val diskData = diskStore.getBytes(blockId) + } else if (level.useDisk && blockStore.contains(blockId)) { + val diskData = blockStore.getBytes(blockId) maybeCacheDiskBytesInMemory(info, blockId, level, diskData) .map(new ByteBufferBlockData(_, false)) .getOrElse(diskData) @@ -755,7 +842,7 @@ private[spark] class BlockManager( logDebug(s"Getting remote block $blockId from $loc") val data = try { blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager) + loc.host, loc.port, loc.executorId, blockId.toString, loc.isRemote, tempFileManager) } catch { case NonFatal(e) => runningFailureCount += 1 @@ -1020,10 +1107,10 @@ private[spark] class BlockManager( } if (!putSucceeded && level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.putBytes(blockId, bytes) + blockStore.putBytes(blockId, bytes) } } else if (level.useDisk) { - diskStore.putBytes(blockId, bytes) + blockStore.putBytes(blockId, bytes) } val putBlockStatus = getCurrentBlockStatus(blockId, info) @@ -1169,11 +1256,11 @@ private[spark] class BlockManager( // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.put(blockId) { channel => + blockStore.put(blockId) { channel => val out = Channels.newOutputStream(channel) serializerManager.dataSerializeStream(blockId, out, iter)(classTag) } - size = diskStore.getSize(blockId) + size = blockStore.getSize(blockId) } else { iteratorFromFailedMemoryStorePut = Some(iter) } @@ -1186,11 +1273,11 @@ private[spark] class BlockManager( // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.put(blockId) { channel => + blockStore.put(blockId) { channel => val out = Channels.newOutputStream(channel) partiallySerializedValues.finishWritingToStream(out) } - size = diskStore.getSize(blockId) + size = blockStore.getSize(blockId) } else { iteratorFromFailedMemoryStorePut = Some(partiallySerializedValues.valuesIterator) } @@ -1198,11 +1285,11 @@ private[spark] class BlockManager( } } else if (level.useDisk) { - diskStore.put(blockId) { channel => + blockStore.put(blockId) { channel => val out = Channels.newOutputStream(channel) serializerManager.dataSerializeStream(blockId, out, iterator())(classTag) } - size = diskStore.getSize(blockId) + size = blockStore.getSize(blockId) } val putBlockStatus = getCurrentBlockStatus(blockId, info) @@ -1501,11 +1588,11 @@ private[spark] class BlockManager( val level = info.level // Drop to disk, if storage level requires - if (level.useDisk && !diskStore.contains(blockId)) { + if (level.useDisk && !blockStore.contains(blockId)) { logInfo(s"Writing block $blockId to disk") data() match { case Left(elements) => - diskStore.put(blockId) { channel => + blockStore.put(blockId) { channel => val out = Channels.newOutputStream(channel) serializerManager.dataSerializeStream( blockId, @@ -1513,7 +1600,7 @@ private[spark] class BlockManager( elements.toIterator)(info.classTag.asInstanceOf[ClassTag[T]]) } case Right(bytes) => - diskStore.putBytes(blockId, bytes) + blockStore.putBytes(blockId, bytes) } blockIsUpdated = true } @@ -1585,7 +1672,7 @@ private[spark] class BlockManager( private def removeBlockInternal(blockId: BlockId, tellMaster: Boolean): Unit = { // Removals are idempotent in disk store and memory store. At worst, we get a warning. val removedFromMemory = memoryStore.remove(blockId) - val removedFromDisk = diskStore.remove(blockId) + val removedFromDisk = blockStore.remove(blockId) if (!removedFromMemory && !removedFromDisk) { logWarning(s"Block $blockId could not be removed as it was not found on disk or in memory") } @@ -1603,6 +1690,8 @@ private[spark] class BlockManager( } } + def getRemoteShuffleServiceAddresses(): List[(String, Int)] = remoteShuffleServiceAddresses + def releaseLockAndDispose( blockId: BlockId, data: BlockData, @@ -1618,7 +1707,7 @@ private[spark] class BlockManager( shuffleClient.close() } remoteBlockTempFileManager.stop() - diskBlockManager.stop() + blockMapper.stop() rpcEnv.stop(slaveEndpoint) blockInfoManager.clear() memoryStore.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index d4a59c33b974c..1b92b41fedda5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -39,7 +39,8 @@ class BlockManagerId private ( private var executorId_ : String, private var host_ : String, private var port_ : Int, - private var topologyInfo_ : Option[String]) + private var topologyInfo_ : Option[String], + private var isRemote_ : Boolean = false) extends Externalizable { private def this() = this(null, null, 0, None) // For deserialization only @@ -62,6 +63,8 @@ class BlockManagerId private ( def port: Int = port_ + def isRemote: Boolean = isRemote_ + def topologyInfo: Option[String] = topologyInfo_ def isDriver: Boolean = { @@ -69,10 +72,12 @@ class BlockManagerId private ( executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER } + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeUTF(executorId_) out.writeUTF(host_) out.writeInt(port_) + out.writeBoolean(isRemote) out.writeBoolean(topologyInfo_.isDefined) // we only write topologyInfo if we have it topologyInfo.foreach(out.writeUTF(_)) @@ -82,6 +87,7 @@ class BlockManagerId private ( executorId_ = in.readUTF() host_ = in.readUTF() port_ = in.readInt() + isRemote_ = in.readBoolean() val isTopologyInfoAvailable = in.readBoolean() topologyInfo_ = if (isTopologyInfoAvailable) Option(in.readUTF()) else None } @@ -124,8 +130,10 @@ private[spark] object BlockManagerId { execId: String, host: String, port: Int, - topologyInfo: Option[String] = None): BlockManagerId = - getCachedBlockManagerId(new BlockManagerId(execId, host, port, topologyInfo)) + topologyInfo: Option[String] = None, + isRemote: Boolean = false): BlockManagerId = + getCachedBlockManagerId(new BlockManagerId( + execId, host, port, topologyInfo, isRemote)) def apply(in: ObjectInput): BlockManagerId = { val obj = new BlockManagerId() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMapper.scala b/core/src/main/scala/org/apache/spark/storage/BlockMapper.scala new file mode 100644 index 0000000000000..094ab8487cb0b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockMapper.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +/** + * Creates and maintains the logical mapping between logical blocks and physical on-disk + * locations. One block is mapped to one file with a name given by its BlockId. + */ +private[spark] trait BlockMapper { + def containsBlock(blockId: BlockId): Boolean +} diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala similarity index 70% rename from core/src/main/scala/org/apache/spark/storage/DiskStore.scala rename to core/src/main/scala/org/apache/spark/storage/BlockStore.scala index 29963a95cb074..547d0219249fb 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala @@ -28,7 +28,7 @@ import scala.collection.mutable.ListBuffer import com.google.common.io.Closeables import io.netty.channel.DefaultFileRegion -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.internal.{config, Logging} import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils} @@ -38,11 +38,11 @@ import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBuffer /** - * Stores BlockManager blocks on disk. + * Stores BlockManager blocks. */ -private[spark] class DiskStore( +private[spark] class BlockStore( conf: SparkConf, - diskManager: DiskBlockManager, + blockMapper: BlockMapper, securityManager: SecurityManager) extends Logging { private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") @@ -60,35 +60,40 @@ private[spark] class DiskStore( if (contains(blockId)) { throw new IllegalStateException(s"Block $blockId is already present in the disk store") } - logDebug(s"Attempting to put block $blockId") - val startTime = System.currentTimeMillis - val file = diskManager.getFile(blockId) - val out = new CountingWritableChannel(openForWrite(file)) - var threwException: Boolean = true - try { - writeFunc(out) - blockSizes.put(blockId, out.getCount) - threwException = false - } finally { - try { - out.close() - } catch { - case ioe: IOException => - if (!threwException) { - threwException = true - throw ioe + blockMapper match { + case _: RemoteBlockManager => + throw new IllegalAccessError("Remote BlockMapper does not support this writing feature") + case d: DiskBlockManager => + logDebug(s"Attempting to put block $blockId") + val startTime = System.currentTimeMillis + val file = d.getFile(blockId) + val out = new CountingWritableChannel(openForWrite(file)) + var threwException: Boolean = true + try { + writeFunc(out) + blockSizes.put(blockId, out.getCount) + threwException = false + } finally { + try { + out.close() + } catch { + case ioe: IOException => + if (!threwException) { + threwException = true + throw ioe + } + } finally { + if (threwException) { + remove(blockId) + } } - } finally { - if (threwException) { - remove(blockId) } - } + val finishTime = System.currentTimeMillis + logDebug("Block %s stored as %s file on disk in %d ms".format( + file.getName, + Utils.bytesToString(file.length()), + finishTime - startTime)) } - val finishTime = System.currentTimeMillis - logDebug("Block %s stored as %s file on disk in %d ms".format( - file.getName, - Utils.bytesToString(file.length()), - finishTime - startTime)) } def putBytes(blockId: BlockId, bytes: ChunkedByteBuffer): Unit = { @@ -98,53 +103,60 @@ private[spark] class DiskStore( } def getBytes(blockId: BlockId): BlockData = { - val file = diskManager.getFile(blockId.name) val blockSize = getSize(blockId) securityManager.getIOEncryptionKey() match { case Some(key) => // Encrypted blocks cannot be memory mapped; return a special object that does decryption // and provides InputStream / FileRegion implementations for reading the data. - new EncryptedBlockData(file, blockSize, conf, key) - + blockMapper match { + case d: DiskBlockManager => + new EncryptedBlockData(d.getFile(blockId), blockSize, conf, key) + case r: RemoteBlockManager => + new EncryptedBlockData(null, blockSize, conf, key, r.getInputStream(blockId)) + } case _ => - new DiskBlockData(minMemoryMapBytes, maxMemoryMapBytes, file, blockSize) + blockMapper match { + case d: DiskBlockManager => + new DiskBlockData(minMemoryMapBytes, maxMemoryMapBytes, d.getFile(blockId), blockSize) + case _: RemoteBlockManager => + throw new SparkException("Cant read from non-encrypted remote block") + } } } def remove(blockId: BlockId): Boolean = { - blockSizes.remove(blockId) - val file = diskManager.getFile(blockId.name) - if (file.exists()) { - val ret = file.delete() - if (!ret) { - logWarning(s"Error deleting ${file.getPath()}") - } - ret - } else { - false + blockMapper match { + case d: DiskBlockManager => + blockSizes.remove(blockId) + d.removeBlock(blockId) + case _: RemoteBlockManager => + throw new IllegalAccessError("Remote Block Mapper does not support this writing feature") } } def contains(blockId: BlockId): Boolean = { - val file = diskManager.getFile(blockId.name) - file.exists() + blockMapper.containsBlock(blockId) } private def openForWrite(file: File): WritableByteChannel = { - val out = new FileOutputStream(file).getChannel() - try { - securityManager.getIOEncryptionKey().map { key => - CryptoStreamUtils.createWritableChannel(out, conf, key) - }.getOrElse(out) - } catch { - case e: Exception => - Closeables.close(out, true) - file.delete() - throw e + blockMapper match { + case _: DiskBlockManager => + val out = new FileOutputStream(file).getChannel() + try { + securityManager.getIOEncryptionKey().map { key => + CryptoStreamUtils.createWritableChannel(out, conf, key) + }.getOrElse(out) + } catch { + case e: Exception => + Closeables.close(out, true) + file.delete() + throw e + } + case _: RemoteBlockManager => + throw new IllegalAccessError("Remote Block Mapper does not support this writing feature") } } - } private class DiskBlockData( @@ -156,10 +168,10 @@ private class DiskBlockData( override def toInputStream(): InputStream = new FileInputStream(file) /** - * Returns a Netty-friendly wrapper for the block's data. - * - * Please see `ManagedBuffer.convertToNetty()` for more details. - */ + * Returns a Netty-friendly wrapper for the block's data. + * + * Please see `ManagedBuffer.convertToNetty()` for more details. + */ override def toNetty(): AnyRef = new DefaultFileRegion(file, 0, size) override def toChunkedByteBuffer(allocator: (Int) => ByteBuffer): ChunkedByteBuffer = { @@ -181,7 +193,7 @@ private class DiskBlockData( override def toByteBuffer(): ByteBuffer = { require(blockSize < maxMemoryMapBytes, s"can't create a byte buffer of size $blockSize" + - s" since it exceeds ${Utils.bytesToString(maxMemoryMapBytes)}.") + s" since it exceeds ${Utils.bytesToString(maxMemoryMapBytes)}.") Utils.tryWithResource(open()) { channel => if (blockSize < minMemoryMapBytes) { // For small files, directly read rather than memory map. @@ -206,7 +218,8 @@ private[spark] class EncryptedBlockData( file: File, blockSize: Long, conf: SparkConf, - key: Array[Byte]) extends BlockData { + key: Array[Byte], + iStream: InputStream = null) extends BlockData { override def toInputStream(): InputStream = Channels.newInputStream(open()) @@ -254,13 +267,17 @@ private[spark] class EncryptedBlockData( override def dispose(): Unit = { } private def open(): ReadableByteChannel = { - val channel = new FileInputStream(file).getChannel() - try { - CryptoStreamUtils.createReadableChannel(channel, conf, key) - } catch { - case e: Exception => - Closeables.close(channel, true) - throw e + if (iStream != null) { + Channels.newChannel(iStream) + } else { + val channel = new FileInputStream(file).getChannel() + try { + CryptoStreamUtils.createReadableChannel(channel, conf, key) + } catch { + case e: Exception => + Closeables.close(channel, true) + throw e + } } } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index a69bcc9259995..8d765910023fd 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -32,7 +32,8 @@ import org.apache.spark.util.{ShutdownHookManager, Utils} * Block files are hashed among the directories listed in spark.local.dir (or in * SPARK_LOCAL_DIRS, if it's set). */ -private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolean) extends Logging { +private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolean) + extends Logging with BlockMapper { private[spark] val subDirsPerLocalDir = conf.getInt("spark.diskStore.subDirectories", 64) @@ -80,10 +81,23 @@ private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolea def getFile(blockId: BlockId): File = getFile(blockId.name) /** Check if disk block manager has a block. */ - def containsBlock(blockId: BlockId): Boolean = { + override def containsBlock(blockId: BlockId): Boolean = { getFile(blockId.name).exists() } + def removeBlock(blockId: BlockId): Boolean = { + val file = getFile(blockId) + if (file.exists()) { + val ret = file.delete() + if (!ret) { + logWarning(s"Error deleting ${file.getPath()}") + } + ret + } else { + false + } + } + /** List all the files currently stored on disk by the disk manager. */ def getAllFiles(): Seq[File] = { // Get all the files inside the array of array of directories diff --git a/core/src/main/scala/org/apache/spark/storage/RemoteBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/RemoteBlockManager.scala new file mode 100644 index 0000000000000..7bff63248d635 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/RemoteBlockManager.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.InputStream + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.external.ShuffleReadSupport + + /** + * Creates and maintains the logical mapping between logical blocks and physical on-disk + * locations. One block is mapped to one file with a name given by its BlockId. + * + * Block files are hashed among the directories listed in spark.local.dir (or in + * SPARK_LOCAL_DIRS, if it's set). + */ +private[spark] class RemoteBlockManager( + conf: SparkConf, + shuffleReadSupport: ShuffleReadSupport) + extends Logging with BlockMapper { + + def getInputStream(blockId: BlockId): InputStream = { + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + val reader = shuffleReadSupport.newPartitionReader( + conf.getAppId, shufId, mapId) + reader.fetchPartition(reduceId) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block") + } + } + + override def containsBlock(blockId: BlockId): Boolean = + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + val reader = shuffleReadSupport.newPartitionReader( + conf.getAppId, shufId, mapId) + reader.fetchPartition(reduceId).available() > 0 + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block") + } + } \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/storage/RemoteBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/RemoteBlockObjectWriter.scala new file mode 100644 index 0000000000000..001a2f59b848b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/RemoteBlockObjectWriter.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io._ +import java.nio.channels.FileChannel + +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} +import org.apache.spark.shuffle.external.ShuffleWriteSupport +import org.apache.spark.util.Utils + +/** + * A class for writing JVM objects directly to a file on disk. This class allows data to be appended + * to an existing block. For efficiency, it retains the underlying file channel across + * multiple commits. This channel is kept open until close() is called. In case of faults, + * callers should instead close with revertPartialWritesAndClose() to atomically revert the + * uncommitted partial writes. + * + * This class does not support concurrent writes. Also, once the writer has been opened it cannot be + * reopened again. + */ +private[spark] class RemoteBlockObjectWriter( + shuffleWriteSupport: ShuffleWriteSupport, + serializerManager: SerializerManager, + serializerInstance: SerializerInstance, + bufferSize: Int, + syncWrites: Boolean, + // These write metrics concurrently shared with other active BlockObjectWriters who + // are themselves performing writes. All updates must be relative. + writeMetrics: ShuffleWriteMetrics, + val blockId: BlockId = null) extends Logging { + + private var byteArrayOutput: ByteArrayOutputStream = null + private var bs: OutputStream = null + private var objectOutputStream: ObjectOutputStream = null + private var ts: TimeTrackingOutputStream = null + private var bufferedOS: BufferedOutputStream = null + private var serializationStream: SerializationStream = null + private var mcOS: ManualCloseOutputStream = null + private var initialized = false + private var streamOpen = false + private var hasBeenClosed = false + + private def initialize(): Unit = { + byteArrayOutput = new ByteArrayOutputStream() + objectOutputStream = new ObjectOutputStream(byteArrayOutput) + ts = new TimeTrackingOutputStream(writeMetrics, objectOutputStream) + mcOS: ManualCloseOutputStream = + } + + /** + * Guards against close calls, e.g. from a wrapping stream. + * Call manualClose to close the stream that was extended by this trait. + * Commit uses this trait to close object streams without paying the + * cost of closing and opening the underlying file. + */ + private trait ManualCloseOutputStream extends OutputStream { + abstract override def close(): Unit = { + flush() + } + + def manualClose(): Unit = { + super.close() + } + } + + /** + * Keep track of number of records written and also use this to periodically + * output bytes written since the latter is expensive to do for each record. + * And we reset it after every commitAndGet called. + */ + private var numRecordsWritten = 0 + + def open(): RemoteBlockObjectWriter = { + if (hasBeenClosed) { + throw new IllegalStateException("Writer already closed. Cannot be reopened.") + } + if (!initialized) { + initialize() + initialized = true + } + + bs = serializerManager.wrapStream(blockId, new BufferedOutputStream(ts, bufferSize)) + serializationStream = serializerInstance.serializeStream(bs) + streamOpen = true + this + } + +} diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index aecc2284a9588..d91d86a12c6f1 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -254,11 +254,22 @@ final class ShuffleBlockFetcherIterator( // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch // the data and write it to file directly. if (req.size > maxReqSizeShuffleToMem) { - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, this) + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + address.isRemote, + blockFetchingListener, this) } else { - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, null) + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + address.isRemote, + blockFetchingListener, + null) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index b159200d79222..567fb935100c0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -22,13 +22,12 @@ import java.util.Comparator import scala.collection.mutable import scala.collection.mutable.ArrayBuffer - import com.google.common.io.ByteStreams - import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging import org.apache.spark.serializer._ +import org.apache.spark.shuffle.external.ShuffleDataIO import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** @@ -682,13 +681,13 @@ private[spark] class ExternalSorter[K, V, C]( */ def writePartitionedFile( blockId: BlockId, - outputFile: File): Array[Long] = { + outputFile: File, + _shuffleDataIO: ShuffleDataIO = null): Array[Long] = { // Track location of each range in the output file val lengths = new Array[Long](numPartitions) val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, context.taskMetrics().shuffleWriteMetrics) - if (spills.isEmpty) { // Case where we only have in-memory data val collection = if (aggregator.isDefined) map else buffer diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index 5232c2bd8d6f6..f061eb57bf6c2 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -57,6 +57,8 @@ private[spark] trait WritablePartitionedPairCollection[K, V] { cur = if (it.hasNext) it.next() else null } + + def hasNext(): Boolean = cur != null def nextPartition(): Int = cur._1._1 diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 629a323042ff2..462769188a13f 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -189,8 +189,13 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(locations.size === storageLevel.replication, s"; got ${locations.size} replicas instead of ${storageLevel.replication}") locations.foreach { cmId => - val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, - blockId.toString, null) + val bytes = blockTransfer.fetchBlockSync( + cmId.host, + cmId.port, + cmId.executorId, + blockId.toString, + cmId.isRemote, + null) val deserialized = serializerManager.dataDeserializeStream(blockId, new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList assert(deserialized === (1 to 100).toList) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 21f481d477242..119a443a15e15 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -18,15 +18,14 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer - import org.mockito.Matchers.any import org.mockito.Mockito._ - import org.apache.spark.LocalSparkContext._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.external.DefaultShuffleServiceAddressProvider import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { @@ -35,7 +34,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { private def newTrackerMaster(sparkConf: SparkConf = conf) = { val broadcastManager = new BroadcastManager(true, sparkConf, new SecurityManager(sparkConf)) - new MapOutputTrackerMaster(sparkConf, broadcastManager, true) + new MapOutputTrackerMaster( + sparkConf, + broadcastManager, + true, + DefaultShuffleServiceAddressProvider) } def createRpcEnv(name: String, host: String = "localhost", port: Int = 0, @@ -186,7 +189,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) - masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10)) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10, false)) // Default size for broadcast in this testsuite is set to -1 so should not cause broadcast // to be used. verify(rpcCallContext, timeout(30000)).reply(any()) @@ -265,7 +268,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) - masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20, false)) // should succeed since majority of data is broadcast and actual serialized // message size is small verify(rpcCallContext, timeout(30000)).reply(any()) @@ -313,7 +316,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(size10000, size0, size1000, size0))) assert(tracker.containsShuffle(10)) - assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === + assert(tracker.getMapSizesByExecutorId(10, 0, 4, false).toSeq === Seq( (BlockManagerId("a", "hostA", 1000), Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 21138bd4a16ba..8c7065d5728bd 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -156,7 +156,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi val promise = Promise[ManagedBuffer]() - self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString), + self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString), false, new BlockFetchingListener { override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { promise.failure(exception) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 5f4ffa151d19b..eeec7f684c668 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -24,16 +24,15 @@ import scala.annotation.meta.param import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal - import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ - import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.config import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.shuffle.external.DefaultShuffleServiceAddressProvider import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, LongAccumulator, Utils} @@ -250,7 +249,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi results.clear() securityMgr = new SecurityManager(conf) broadcastManager = new BroadcastManager(true, conf, securityMgr) - mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true) { + mapOutputTracker = new MapOutputTrackerMaster( + conf, broadcastManager, true, DefaultShuffleServiceAddressProvider) { override def sendTracker(message: Any): Unit = { // no-op, just so we can stop this to avoid leaking threads } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 2d8a83c6fabed..f0096e60841e6 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -101,7 +101,8 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn { + when(mapOutputTracker.getMapSizesByExecutorId( + shuffleId, reduceId, reduceId + 1, false)).thenReturn { // Test a scenario where all data is local, to avoid creating a bunch of additional mocks // for the code to read data over the network. val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 3962bdc27d22c..6f81a3c7482a2 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -23,11 +23,9 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps - import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ - import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.Logging @@ -38,6 +36,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{KryoSerializer, SerializerManager} +import org.apache.spark.shuffle.external.DefaultShuffleServiceAddressProvider import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.StorageLevel._ import org.apache.spark.util.Utils @@ -53,7 +52,8 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite protected var master: BlockManagerMaster = null protected lazy val securityMgr = new SecurityManager(conf) protected lazy val bcastManager = new BroadcastManager(true, conf, securityMgr) - protected lazy val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true) + protected lazy val mapOutputTracker = new MapOutputTrackerMaster( + conf, bcastManager, true, DefaultShuffleServiceAddressProvider) protected lazy val shuffleManager = new SortShuffleManager(conf) // List of block manager created during an unit test, so that all of the them can be stopped diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 32d6e8b94e1a2..5cd1954531e4e 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -25,14 +25,12 @@ import scala.concurrent.Future import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} import scala.reflect.ClassTag - import org.apache.commons.lang3.RandomUtils import org.mockito.{Matchers => mc} import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest._ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.concurrent.Eventually._ - import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod @@ -50,6 +48,7 @@ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} +import org.apache.spark.shuffle.external.DefaultShuffleServiceAddressProvider import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util._ @@ -72,7 +71,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE var master: BlockManagerMaster = null val securityMgr = new SecurityManager(new SparkConf(false)) val bcastManager = new BroadcastManager(true, new SparkConf(false), securityMgr) - val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false), bcastManager, true) + val mapOutputTracker = new MapOutputTrackerMaster( + new SparkConf(false), bcastManager, true, DefaultShuffleServiceAddressProvider) val shuffleManager = new SortShuffleManager(new SparkConf(false)) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test @@ -1017,7 +1017,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE case _ => fail("Updated block is neither list2 nor list4") } } - assert(store.diskStore.contains("list2"), "list2 was not in disk store") + assert(store.blockStore.contains("list2"), "list2 was not in disk store") assert(store.memoryStore.contains("list4"), "list4 was not in memory store") // No updated blocks - list5 is too big to fit in store and nothing is kicked out @@ -1035,11 +1035,11 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!store.memoryStore.contains("list5"), "list5 was in memory store") // disk store contains only list2 - assert(!store.diskStore.contains("list1"), "list1 was in disk store") - assert(store.diskStore.contains("list2"), "list2 was not in disk store") - assert(!store.diskStore.contains("list3"), "list3 was in disk store") - assert(!store.diskStore.contains("list4"), "list4 was in disk store") - assert(!store.diskStore.contains("list5"), "list5 was in disk store") + assert(!store.blockStore.contains("list1"), "list1 was in disk store") + assert(store.blockStore.contains("list2"), "list2 was not in disk store") + assert(!store.blockStore.contains("list3"), "list3 was in disk store") + assert(!store.blockStore.contains("list4"), "list4 was in disk store") + assert(!store.blockStore.contains("list5"), "list5 was in disk store") // remove block - list2 should be removed from disk val updatedBlocks6 = getUpdatedBlocks { @@ -1049,7 +1049,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(updatedBlocks6.size === 1) assert(updatedBlocks6.head._1 === TestBlockId("list2")) assert(updatedBlocks6.head._2.storageLevel == StorageLevel.NONE) - assert(!store.diskStore.contains("list2"), "list2 was in disk store") + assert(!store.blockStore.contains("list2"), "list2 was in disk store") } test("query block statuses") { @@ -1158,7 +1158,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("safely unroll blocks through putIterator (disk)") { store = makeBlockManager(12000) val memoryStore = store.memoryStore - val diskStore = store.diskStore + val diskStore = store.blockStore val smallList = List.fill(40)(new Array[Byte](100)) val bigList = List.fill(40)(new Array[Byte](1000)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] @@ -1446,6 +1446,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE port: Int, execId: String, blockIds: Array[String], + isRemote: Boolean, listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) @@ -1474,13 +1475,14 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE port: Int, execId: String, blockId: String, + isRemote: Boolean, tempFileManager: DownloadFileManager): ManagedBuffer = { numCalls += 1 this.tempFileManager = tempFileManager if (numCalls <= maxFailures) { throw new RuntimeException("Failing block fetch in the mock block transfer service") } - super.fetchBlockSync(host, port, execId, blockId, tempFileManager) + super.fetchBlockSync(host, port, execId, blockId, isRemote, tempFileManager) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index eec961a491101..461716585f49f 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -48,13 +48,13 @@ class DiskStoreSuite extends SparkFunSuite { val blockId = BlockId("rdd_1_2") val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) - val diskStoreMapped = new DiskStore(conf.clone().set(confKey, "0"), diskBlockManager, + val diskStoreMapped = new BlockStore(conf.clone().set(confKey, "0"), diskBlockManager, securityManager) diskStoreMapped.putBytes(blockId, byteBuffer) val mapped = diskStoreMapped.getBytes(blockId).toByteBuffer() assert(diskStoreMapped.remove(blockId)) - val diskStoreNotMapped = new DiskStore(conf.clone().set(confKey, "1m"), diskBlockManager, + val diskStoreNotMapped = new BlockStore(conf.clone().set(confKey, "1m"), diskBlockManager, securityManager) diskStoreNotMapped.putBytes(blockId, byteBuffer) val notMapped = diskStoreNotMapped.getBytes(blockId).toByteBuffer() @@ -78,7 +78,7 @@ class DiskStoreSuite extends SparkFunSuite { test("block size tracking") { val conf = new SparkConf() val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) - val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) + val diskStore = new BlockStore(conf, diskBlockManager, new SecurityManager(conf)) val blockId = BlockId("rdd_1_2") diskStore.put(blockId) { chan => @@ -97,7 +97,7 @@ class DiskStoreSuite extends SparkFunSuite { val conf = new SparkConf() .set(config.MEMORY_MAP_LIMIT_FOR_TESTS.key, "10k") val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) - val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) + val diskStore = new BlockStore(conf, diskBlockManager, new SecurityManager(conf)) val blockId = BlockId("rdd_1_2") diskStore.put(blockId) { chan => @@ -139,7 +139,7 @@ class DiskStoreSuite extends SparkFunSuite { val conf = new SparkConf() val securityManager = new SecurityManager(conf, Some(CryptoStreamUtils.createKey(conf))) val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) - val diskStore = new DiskStore(conf, diskBlockManager, securityManager) + val diskStore = new BlockStore(conf, diskBlockManager, securityManager) val blockId = BlockId("rdd_1_2") diskStore.put(blockId) { chan => diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index b268195e09a5b..2e694d725441e 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -46,7 +46,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] @@ -140,7 +140,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -159,7 +159,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -227,7 +227,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -297,7 +297,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -337,7 +337,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -415,7 +415,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -479,7 +479,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) var tempFileManager: DownloadFileManager = null - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] diff --git a/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.shuffle.external.ShuffleServiceAddressProviderFactory b/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.shuffle.external.ShuffleServiceAddressProviderFactory new file mode 100644 index 0000000000000..c39f2ec633ea3 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/resources/META-INF/services/org.apache.spark.shuffle.external.ShuffleServiceAddressProviderFactory @@ -0,0 +1 @@ +org.apache.spark.shuffle.k8s.KubernetesShuffleServiceAddressProviderFactory \ No newline at end of file diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index a32bd93bb65bc..48aff1aa5bb8a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -276,6 +276,25 @@ private[spark] object Config extends Logging { .stringConf .createOptional + val KUBERNETES_BACKUP_SHUFFLE_SERVICE_ENABLED = + ConfigBuilder("spark.kubernetes.shuffle.service.backups.enabled") + .doc("Use shuffle service to back up shuffle data in Kubernetes applications.") + .booleanConf + .createWithDefault(false) + + val KUBERNETES_BACKUP_SHUFFLE_SERVICE_PODS_NAMESPACE = + ConfigBuilder("spark.kubernetes.shuffle.service.backups.pods.namespace") + .doc("Namespace of the pods that are running the shuffle service instances for backing up" + + " shuffle data.") + .stringConf + .createOptional + + val KUBERNETES_BACKUP_SHUFFLE_SERVICE_PORT = + ConfigBuilder("spark.kubernetes.shuffle.service.backups.port") + .doc("Port of the shuffle services that will back up the application's shuffle data.") + .intConf + .createWithDefault(7337) + val KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX = "spark.kubernetes.authenticate.submission" @@ -304,4 +323,7 @@ private[spark] object Config extends Logging { val KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY = "options.sizeLimit" val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." + + val KUBERNETES_BACKUP_SHUFFLE_SERVICE_LABELS = + "spark.kubernetes.shuffle.service.backups.label." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala index 77bd66b608e7c..86212f9cce2af 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala @@ -21,11 +21,13 @@ import java.io.File import com.google.common.base.Charsets import com.google.common.io.Files import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient, KubernetesClient} +import io.fabric8.kubernetes.client.Config._ import io.fabric8.kubernetes.client.utils.HttpClientUtils import okhttp3.Dispatcher import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.util.ThreadUtils /** @@ -35,6 +37,36 @@ import org.apache.spark.util.ThreadUtils */ private[spark] object SparkKubernetesClientFactory { + def getDriverKubernetesClient(conf: SparkConf, masterURL: String): KubernetesClient = { + val wasSparkSubmittedInClusterMode = conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK) + val (authConfPrefix, + apiServerUri, + defaultServiceAccountToken, + defaultServiceAccountCaCrt) = if (wasSparkSubmittedInClusterMode) { + require(conf.get(KUBERNETES_DRIVER_POD_NAME).isDefined, + "If the application is deployed using spark-submit in cluster mode, the driver pod name " + + "must be provided.") + (KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, + KUBERNETES_MASTER_INTERNAL_URL, + Some(new File(KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), + Some(new File(KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) + } else { + (KUBERNETES_AUTH_CLIENT_MODE_PREFIX, + KubernetesUtils.parseMasterUrl(masterURL), + None, + None) + } + + val kubernetesClient = createKubernetesClient( + apiServerUri, + Some(conf.get(KUBERNETES_NAMESPACE)), + authConfPrefix, + conf, + defaultServiceAccountToken, + defaultServiceAccountCaCrt) + kubernetesClient + } + def createKubernetesClient( master: String, namespace: Option[String], diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala index 435a5f1461c92..afd97240255bc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging /** * An immutable view of the current executor pods that are running in the cluster. */ -private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, ExecutorPodState]) { +private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, SparkPodState]) { import ExecutorPodsSnapshot._ @@ -42,15 +42,15 @@ object ExecutorPodsSnapshot extends Logging { ExecutorPodsSnapshot(toStatesByExecutorId(executorPods)) } - def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, ExecutorPodState]) + def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, SparkPodState]) - private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, ExecutorPodState] = { + private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, SparkPodState] = { executorPods.map { pod => - (pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL).toLong, toState(pod)) + (pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL).toLong, SparkPodState.toState(pod)) }.toMap } - private def toState(pod: Pod): ExecutorPodState = { + private def toState(pod: Pod): SparkPodState = { if (isDeleted(pod)) { PodDeleted(pod) } else { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index ce10f766334ff..c9bbc51a6bde7 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -16,16 +16,13 @@ */ package org.apache.spark.scheduler.cluster.k8s -import java.io.File +import java.lang import java.util.concurrent.TimeUnit import com.google.common.cache.CacheBuilder -import io.fabric8.kubernetes.client.Config import org.apache.spark.SparkContext -import org.apache.spark.deploy.k8s.{KubernetesUtils, SparkKubernetesClientFactory} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} import org.apache.spark.util.{SystemClock, ThreadUtils} @@ -42,32 +39,8 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { - val wasSparkSubmittedInClusterMode = sc.conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK) - val (authConfPrefix, - apiServerUri, - defaultServiceAccountToken, - defaultServiceAccountCaCrt) = if (wasSparkSubmittedInClusterMode) { - require(sc.conf.get(KUBERNETES_DRIVER_POD_NAME).isDefined, - "If the application is deployed using spark-submit in cluster mode, the driver pod name " + - "must be provided.") - (KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, - KUBERNETES_MASTER_INTERNAL_URL, - Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), - Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - } else { - (KUBERNETES_AUTH_CLIENT_MODE_PREFIX, - KubernetesUtils.parseMasterUrl(masterURL), - None, - None) - } - - val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( - apiServerUri, - Some(sc.conf.get(KUBERNETES_NAMESPACE)), - authConfPrefix, - sc.conf, - defaultServiceAccountToken, - defaultServiceAccountCaCrt) + val kubernetesClient = SparkKubernetesClientFactory.getDriverKubernetesClient( + sc.conf, masterURL) if (sc.conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).isDefined) { KubernetesUtils.loadPodFromTemplate( @@ -85,7 +58,7 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit val snapshotsStore = new ExecutorPodsSnapshotsStoreImpl(subscribersExecutor) val removedExecutorsCache = CacheBuilder.newBuilder() .expireAfterWrite(3, TimeUnit.MINUTES) - .build[java.lang.Long, java.lang.Long]() + .build[lang.Long, lang.Long]() val executorPodsLifecycleEventHandler = new ExecutorPodsLifecycleManager( sc.conf, kubernetesClient, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/SparkPodStates.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/SparkPodStates.scala new file mode 100644 index 0000000000000..0910b787b7f4b --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/SparkPodStates.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.Pod + +import org.apache.spark.internal.Logging + +sealed trait SparkPodState { + def pod: Pod +} + +case class PodRunning(pod: Pod) extends SparkPodState + +case class PodPending(pod: Pod) extends SparkPodState + +sealed trait FinalPodState extends SparkPodState + +case class PodSucceeded(pod: Pod) extends FinalPodState + +case class PodFailed(pod: Pod) extends FinalPodState + +case class PodDeleted(pod: Pod) extends FinalPodState + +case class PodUnknown(pod: Pod) extends SparkPodState + +object SparkPodState extends Logging { + def toState(pod: Pod): SparkPodState = { + if (isDeleted(pod)) { + PodDeleted(pod) + } else { + val phase = pod.getStatus.getPhase.toLowerCase + phase match { + case "pending" => + PodPending(pod) + case "running" => + PodRunning(pod) + case "failed" => + PodFailed(pod) + case "succeeded" => + PodSucceeded(pod) + case _ => + logWarning(s"Received unknown phase $phase for executor pod with name" + + s" ${pod.getMetadata.getName} in namespace ${pod.getMetadata.getNamespace}") + PodUnknown(pod) + } + } + } + + private def isDeleted(pod: Pod): Boolean = pod.getMetadata.getDeletionTimestamp != null +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala new file mode 100644 index 0000000000000..959b77f40a14c --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.k8s + +import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture, TimeUnit} +import java.util.concurrent.locks.ReentrantReadWriteLock + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watch, Watcher} + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.cluster.k8s._ +import org.apache.spark.shuffle.external.ShuffleServiceAddressProvider +import org.apache.spark.util.Utils + +class KubernetesShuffleServiceAddressProvider( + kubernetesClient: KubernetesClient, + pollForPodsExecutor: ScheduledExecutorService, + podLabels: Map[String, String], + namespace: String, + portNumber: Int) + extends ShuffleServiceAddressProvider with Logging { + + // General implementation remark: this bears a strong resemblance to ExecutorPodsSnapshotsStore, + // but we don't need all "in-between" lists of all executor pods, just the latest known list + // when we query in getShuffleServiceAddresses. + + private val podsUpdateLock = new ReentrantReadWriteLock() + + private val shuffleServicePods = mutable.HashMap.empty[String, Pod] + + private var shuffleServicePodsWatch: Watch = _ + private var pollForPodsTask: ScheduledFuture[_] = _ + + override def start(): Unit = { + pollForPods() + pollForPodsTask = pollForPodsExecutor.scheduleWithFixedDelay( + () => pollForPods(), 0, 10, TimeUnit.SECONDS) + shuffleServicePodsWatch = kubernetesClient + .pods() + .inNamespace(namespace) + .withLabels(podLabels.asJava).watch(new PutPodsInCacheWatcher()) + } + + override def stop(): Unit = { + Utils.tryLogNonFatalError { + if (pollForPodsTask != null) { + pollForPodsTask.cancel(false) + } + } + + Utils.tryLogNonFatalError { + if (shuffleServicePodsWatch != null) { + shuffleServicePodsWatch.close() + } + } + + Utils.tryLogNonFatalError { + kubernetesClient.close() + } + } + + override def getShuffleServiceAddresses(): List[(String, Int)] = { + val readLock = podsUpdateLock.readLock() + readLock.lock() + try { + val addresses = shuffleServicePods.values.map(pod => { + (pod.getStatus.getPodIP, portNumber) + }).toList + logInfo(s"Found backup shuffle service addresses at $addresses.") + addresses + } finally { + readLock.unlock() + } + } + + private def pollForPods(): Unit = { + val writeLock = podsUpdateLock.writeLock() + writeLock.lock() + try { + val allPods = kubernetesClient + .pods() + .inNamespace(namespace) + .withLabels(podLabels.asJava) + .list() + shuffleServicePods.clear() + allPods.getItems.asScala.foreach(updatePod) + } finally { + writeLock.unlock() + } + } + + private def updatePod(pod: Pod): Unit = { + require(podsUpdateLock.isWriteLockedByCurrentThread, "Should only update pods under lock.") + val state = SparkPodState.toState(pod) + state match { + case PodPending(_) | PodFailed(_) | PodSucceeded(_) | PodDeleted(_) => + shuffleServicePods.remove(pod.getMetadata.getName) + case PodRunning(_) => + shuffleServicePods.put(pod.getMetadata.getName, pod) + case _ => + logWarning(s"Unknown state $state for pod named ${pod.getMetadata.getName}") + } + } + + private def deletePod(pod: Pod): Unit = { + require(podsUpdateLock.isWriteLockedByCurrentThread, "Should only delete under lock.") + shuffleServicePods.remove(pod.getMetadata.getName) + } + + private class PutPodsInCacheWatcher extends Watcher[Pod] { + override def eventReceived(action: Watcher.Action, pod: Pod): Unit = { + val writeLock = podsUpdateLock.writeLock() + writeLock.lock() + try { + updatePod(pod) + } finally { + writeLock.unlock() + } + } + + override def onClose(e: KubernetesClientException): Unit = {} + } + + private implicit def toRunnable(func: () => Unit): Runnable = { + new Runnable { + override def run(): Unit = func() + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala new file mode 100644 index 0000000000000..5bd0fc410f4ba --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.k8s + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory +import org.apache.spark.shuffle.external.{DefaultShuffleServiceAddressProvider, ShuffleServiceAddressProvider, ShuffleServiceAddressProviderFactory} +import org.apache.spark.shuffle.DefaultShuffleServiceAddressProvider +import org.apache.spark.util.ThreadUtils + +class KubernetesShuffleServiceAddressProviderFactory extends ShuffleServiceAddressProviderFactory { + override def canCreate(masterUrl: String): Boolean = masterUrl.startsWith("k8s://") + + override def create(conf: SparkConf): ShuffleServiceAddressProvider = { + if (conf.get(KUBERNETES_BACKUP_SHUFFLE_SERVICE_ENABLED)) { + val kubernetesClient = SparkKubernetesClientFactory.getDriverKubernetesClient( + conf, conf.get("spark.master")) + val pollForPodsExecutor = ThreadUtils.newDaemonThreadPoolScheduledExecutor( + "poll-shuffle-service-pods", 1) + val shuffleServiceLabels = conf.getAllWithPrefix(KUBERNETES_BACKUP_SHUFFLE_SERVICE_LABELS) + val shuffleServicePodsNamespace = conf.get(KUBERNETES_BACKUP_SHUFFLE_SERVICE_PODS_NAMESPACE) + require(shuffleServicePodsNamespace.isDefined, "Namespace for the pods running the backup" + + s" shuffle service must be defined by" + + s" ${KUBERNETES_BACKUP_SHUFFLE_SERVICE_PODS_NAMESPACE.key}") + require(shuffleServiceLabels.nonEmpty, "Requires labels for the backup shuffle service pods.") + + val port: Int = conf.get(KUBERNETES_BACKUP_SHUFFLE_SERVICE_PORT) + new KubernetesShuffleServiceAddressProvider( + kubernetesClient, + pollForPodsExecutor, + shuffleServiceLabels.toMap, + shuffleServicePodsNamespace.get, + port) + } else DefaultShuffleServiceAddressProvider + } +} diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 859aa836a3157..918606e40112e 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -29,7 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler import org.apache.spark.network.shuffle.protocol.BlockTransferMessage -import org.apache.spark.network.shuffle.protocol.mesos.{RegisterDriver, ShuffleServiceHeartbeat} +import org.apache.spark.network.shuffle.protocol.{RegisterDriver, ShuffleServiceHeartbeat} import org.apache.spark.network.util.TransportConf import org.apache.spark.util.ThreadUtils diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala index da71f8f9e407c..9f7336af3f8ea 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala @@ -20,6 +20,8 @@ package org.apache.spark.scheduler.cluster.mesos import org.apache.spark.SparkContext import org.apache.spark.internal.config._ import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.DefaultShuffleServiceAddressProvider +import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleServiceAddressProvider} /** * Cluster Manager for creation of Mesos scheduler and backend @@ -60,5 +62,8 @@ private[spark] class MesosClusterManager extends ExternalClusterManager { override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) } + + override def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = + DefaultShuffleServiceAddressProvider } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala index 64cd1bd088001..f924088913fa4 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler.cluster import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} +import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleServiceAddressProvider} /** * Cluster Manager for creation of Yarn scheduler and backend @@ -53,4 +54,7 @@ private[spark] class YarnClusterManager extends ExternalClusterManager { override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) } + + override def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = + DefaultShuffleServiceAddressProvider } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index fe65353b9d502..324f7e2e16069 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -24,11 +24,9 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps import scala.reflect.ClassTag - import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ - import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.Logging @@ -39,6 +37,7 @@ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{KryoSerializer, SerializerManager} +import org.apache.spark.shuffle.external.DefaultShuffleServiceAddressProvider import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ @@ -70,7 +69,8 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) val streamId = 1 val securityMgr = new SecurityManager(conf, encryptionKey) val broadcastManager = new BroadcastManager(true, conf, securityMgr) - val mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true) + val mapOutputTracker = new MapOutputTrackerMaster( + conf, broadcastManager, true, None, DefaultShuffleServiceAddressProvider) val shuffleManager = new SortShuffleManager(conf) val serializer = new KryoSerializer(conf) var serializerManager = new SerializerManager(serializer, conf, encryptionKey)