|
18 | 18 | package org.apache.spark.broadcast
|
19 | 19 |
|
20 | 20 | import java.io._
|
| 21 | +import java.lang.ref.SoftReference |
21 | 22 | import java.nio.ByteBuffer
|
22 | 23 | import java.util.zip.Adler32
|
23 | 24 |
|
@@ -63,9 +64,11 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
|
63 | 64 | * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]],
|
64 | 65 | * which builds this value by reading blocks from the driver and/or other executors.
|
65 | 66 | *
|
66 |
| - * On the driver, if the value is required, it is read lazily from the block manager. |
| 67 | + * On the driver, if the value is required, it is read lazily from the block manager. We hold |
| 68 | + * a soft reference so that it can be garbage collected if required, as we can always reconstruct |
| 69 | + * in the future. |
67 | 70 | */
|
68 |
| - @transient private lazy val _value: T = readBroadcastBlock() |
| 71 | + @transient private var _value: SoftReference[T] = _ |
69 | 72 |
|
70 | 73 | /** The compression codec to use, or None if compression is disabled */
|
71 | 74 | @transient private var compressionCodec: Option[CompressionCodec] = _
|
@@ -94,8 +97,15 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
|
94 | 97 | /** The checksum for all the blocks. */
|
95 | 98 | private var checksums: Array[Int] = _
|
96 | 99 |
|
97 |
| - override protected def getValue() = { |
98 |
| - _value |
| 100 | + override protected def getValue() = synchronized { |
| 101 | + val memoized: T = if (_value == null) null.asInstanceOf[T] else _value.get |
| 102 | + if (memoized != null) { |
| 103 | + memoized |
| 104 | + } else { |
| 105 | + val newlyRead = readBroadcastBlock() |
| 106 | + _value = new SoftReference[T](newlyRead) |
| 107 | + newlyRead |
| 108 | + } |
99 | 109 | }
|
100 | 110 |
|
101 | 111 | private def calcChecksum(block: ByteBuffer): Int = {
|
@@ -209,8 +219,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
|
209 | 219 | }
|
210 | 220 |
|
211 | 221 | private def readBroadcastBlock(): T = Utils.tryOrIOException {
|
212 |
| - TorrentBroadcast.synchronized { |
213 |
| - val broadcastCache = SparkEnv.get.broadcastManager.cachedValues |
| 222 | + val broadcastCache = SparkEnv.get.broadcastManager.cachedValues |
| 223 | + broadcastCache.synchronized { |
214 | 224 |
|
215 | 225 | Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse {
|
216 | 226 | setConf(SparkEnv.get.conf)
|
|
0 commit comments