Skip to content

Commit edabc69

Browse files
authored
performance: parallelize diffgraph application, store and load (#288)
1 parent a64ce01 commit edabc69

File tree

8 files changed

+699
-402
lines changed

8 files changed

+699
-402
lines changed

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,21 @@ Code formatting is maintained via
66
sbt scalafmt Test/scalafmt
77
```
88

9+
## Diverse notes
10+
By default, diffgraph application, deserialization from storage, and serialization to storage are all multi-threaded.
11+
12+
This can be globally disabled via `flatgraph.misc.Misc.force_singlethreaded()`, for easier debugging.
13+
14+
In order to quickly glance the input of flatgraph files, you can extract the manifest json with `tail`, e.g. `tail someGraph.fg | jless`:
15+
Our output writer always places the manifest at the end, with a bunch of preceding newlines, such that this will not contain binary garbage.
16+
17+
This is suitable for quick command-line debugging. However, that approach will fail if e.g. somebody appended two flatgraph files -- deserialization will
18+
read the file from the beginning, and find the offset of the true manifest from the header, and ignore trailing garbage like an appended fake manifest.
19+
So don't dare to do security checks with that!
20+
21+
22+
23+
924
## Core Features
1025
- [x] Access nodes and neighbors
1126
- [x] Add nodes and edges

core/src/main/scala/flatgraph/DiffGraphApplier.scala

Lines changed: 261 additions & 197 deletions
Large diffs are not rendered by default.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package flatgraph.misc
2+
3+
import java.util.concurrent
4+
5+
object Misc {
6+
7+
@volatile var _overrideExecutor: Option[concurrent.ExecutorService] => concurrent.ExecutorService = defaultExecutorProvider
8+
9+
def force_singlethreaded(): Unit = {
10+
// This one is magic -- it can get garbage collected, no manual shutdown required!
11+
// Unfortunately this behavior is apparently not documented officially :(
12+
// But that is the behavior on java8 and java23 and, presumably, everywhere we care about:
13+
// https://github.com/openjdk/jdk8u/blob/1a6e3a5ea32d5c671cb46a590046f16426089921/jdk/src/share/classes/java/util/concurrent/Executors.java#L170
14+
// https://github.com/openjdk/jdk23u/blob/9101cc14972ce6bdeb966e67bcacc8b693c37d0a/src/java.base/share/classes/java/util/concurrent/Executors.java#L192
15+
this._overrideExecutor = (something: Option[concurrent.ExecutorService]) => concurrent.Executors.newSingleThreadExecutor()
16+
}
17+
18+
def defaultExecutorProvider(requested: Option[concurrent.ExecutorService]): concurrent.ExecutorService = requested.getOrElse {
19+
java.lang.Thread.currentThread() match {
20+
case fjt: concurrent.ForkJoinWorkerThread => fjt.getPool
21+
case _ => concurrent.ForkJoinPool.commonPool()
22+
}
23+
}
24+
25+
def maybeOverrideExecutor(requested: Option[concurrent.ExecutorService]): concurrent.ExecutorService =
26+
this._overrideExecutor.apply(requested)
27+
}

core/src/main/scala/flatgraph/storage/Deserialization.scala

Lines changed: 65 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package flatgraph.storage
22

3-
import com.github.luben.zstd.Zstd
4-
import flatgraph.*
3+
import flatgraph.{AccessHelpers, FreeSchema, GNode, Graph, Schema}
54
import flatgraph.Edge.Direction
5+
import flatgraph.misc.Misc
66
import flatgraph.storage.Manifest.{GraphItem, OutlineStorage}
77

88
import java.nio.channels.FileChannel
@@ -11,29 +11,42 @@ import java.nio.file.Path
1111
import java.nio.{ByteBuffer, ByteOrder}
1212
import java.util.Arrays
1313
import scala.collection.mutable
14+
import java.util.concurrent
1415

1516
object Deserialization {
1617

17-
def readGraph(storagePath: Path, schemaMaybe: Option[Schema], persistOnClose: Boolean = true): Graph = {
18+
def readGraph(
19+
storagePath: Path,
20+
schemaMaybe: Option[Schema],
21+
persistOnClose: Boolean = true,
22+
requestedExecutor: Option[concurrent.ExecutorService] = None
23+
): Graph = {
24+
val executor = Misc.maybeOverrideExecutor(requestedExecutor)
1825
val fileChannel = new java.io.RandomAccessFile(storagePath.toAbsolutePath.toFile, "r").getChannel
26+
val queue = mutable.ArrayBuffer[concurrent.Future[Any]]()
27+
val zstdCtx = new ZstdWrapper.ZstdCtx
28+
def submitJob[T](block: => T): concurrent.Future[T] = {
29+
val res = executor.submit((() => block))
30+
queue.addOne(res.asInstanceOf[concurrent.Future[Any]])
31+
res
32+
}
33+
1934
try {
2035
// fixme: Use convenience methods from schema to translate string->id. Fix after we get strict schema checking.
2136
val manifest = GraphItem.read(readManifest(fileChannel))
22-
val pool = readPool(manifest, fileChannel)
37+
val pool = submitJob { readPool(manifest, fileChannel, zstdCtx) }
2338
val schema = schemaMaybe.getOrElse(freeSchemaFromManifest(manifest))
2439
val storagePathMaybe =
2540
if (persistOnClose) Option(storagePath)
2641
else None
2742
val g = new Graph(schema, storagePathMaybe)
2843
val nodekinds = mutable.HashMap[String, Short]()
2944
for (nodeKind <- g.schema.nodeKinds) nodekinds(g.schema.getNodeLabel(nodeKind)) = nodeKind.toShort
30-
val kindRemapper = Array.fill(manifest.nodes.size)(-1.toShort)
3145
val nodeRemapper = new Array[Array[GNode]](manifest.nodes.length)
3246
for {
3347
(nodeItem, idx) <- manifest.nodes.zipWithIndex
3448
nodeKind <- nodekinds.get(nodeItem.nodeLabel)
3549
} {
36-
kindRemapper(idx) = nodeKind
3750
val nodes = new Array[GNode](nodeItem.nnodes)
3851
for (seq <- Range(0, nodes.length)) nodes(seq) = g.schema.makeNode(g, nodeKind, seq)
3952
g.nodesArray(nodeKind) = nodes
@@ -66,11 +79,17 @@ object Deserialization {
6679
val direction = Direction.fromOrdinal(edgeItem.inout)
6780
if (nodeKind.isDefined && edgeKind.isDefined) {
6881
val pos = g.schema.neighborOffsetArrayIndex(nodeKind.get, direction, edgeKind.get)
69-
g.neighbors(pos) = deltaDecode(readArray(fileChannel, edgeItem.qty, nodeRemapper, pool).asInstanceOf[Array[Int]])
70-
g.neighbors(pos + 1) = readArray(fileChannel, edgeItem.neighbors, nodeRemapper, pool)
71-
val property = readArray(fileChannel, edgeItem.property, nodeRemapper, pool)
72-
if (property != null)
73-
g.neighbors(pos + 2) = property
82+
submitJob {
83+
g.neighbors(pos) = deltaDecode(readArray(fileChannel, edgeItem.qty, nodeRemapper, pool, zstdCtx).asInstanceOf[Array[Int]])
84+
}
85+
submitJob {
86+
g.neighbors(pos + 1) = readArray(fileChannel, edgeItem.neighbors, nodeRemapper, pool, zstdCtx)
87+
}
88+
submitJob {
89+
val property = readArray(fileChannel, edgeItem.property, nodeRemapper, pool, zstdCtx)
90+
if (property != null)
91+
g.neighbors(pos + 2) = property
92+
}
7493
}
7594
}
7695

@@ -91,12 +110,18 @@ object Deserialization {
91110
val propertyKind = propertykinds.get((property.nodeLabel, property.propertyLabel))
92111
if (nodeKind.isDefined && propertyKind.isDefined) {
93112
val pos = g.schema.propertyOffsetArrayIndex(nodeKind.get, propertyKind.get)
94-
g.properties(pos) = deltaDecode(readArray(fileChannel, property.qty, nodeRemapper, pool).asInstanceOf[Array[Int]])
95-
g.properties(pos + 1) = readArray(fileChannel, property.property, nodeRemapper, pool)
113+
submitJob {
114+
g.properties(pos) = deltaDecode(readArray(fileChannel, property.qty, nodeRemapper, pool, zstdCtx).asInstanceOf[Array[Int]])
115+
}
116+
submitJob { g.properties(pos + 1) = readArray(fileChannel, property.property, nodeRemapper, pool, zstdCtx) }
96117
}
97118
}
119+
queue.foreach { _.get() }
98120
g
99-
} finally fileChannel.close()
121+
} catch {
122+
case ex: java.util.concurrent.ExecutionException =>
123+
throw ex.getCause()
124+
} finally { fileChannel.close(); zstdCtx.close(); }
100125
}
101126

102127
private def freeSchemaFromManifest(manifest: Manifest.GraphItem): FreeSchema = {
@@ -171,23 +196,17 @@ object Deserialization {
171196

172197
}
173198

174-
private def readPool(manifest: GraphItem, fileChannel: FileChannel): Array[String] = {
175-
val stringPoolLength = ZstdWrapper(
176-
Zstd
177-
.decompress(
178-
fileChannel.map(FileChannel.MapMode.READ_ONLY, manifest.stringPoolLength.startOffset, manifest.stringPoolLength.compressedLength),
179-
manifest.stringPoolLength.decompressedLength
180-
)
181-
.order(ByteOrder.LITTLE_ENDIAN)
182-
)
183-
val stringPoolBytes = ZstdWrapper(
184-
Zstd
185-
.decompress(
186-
fileChannel.map(FileChannel.MapMode.READ_ONLY, manifest.stringPoolBytes.startOffset, manifest.stringPoolBytes.compressedLength),
187-
manifest.stringPoolBytes.decompressedLength
188-
)
189-
.order(ByteOrder.LITTLE_ENDIAN)
190-
)
199+
private def readPool(manifest: GraphItem, fileChannel: FileChannel, zstdCtx: ZstdWrapper.ZstdCtx): Array[String] = {
200+
val stringPoolLength = zstdCtx
201+
.decompress(
202+
fileChannel.map(FileChannel.MapMode.READ_ONLY, manifest.stringPoolLength.startOffset, manifest.stringPoolLength.compressedLength),
203+
manifest.stringPoolLength.decompressedLength
204+
)
205+
val stringPoolBytes = zstdCtx
206+
.decompress(
207+
fileChannel.map(FileChannel.MapMode.READ_ONLY, manifest.stringPoolBytes.startOffset, manifest.stringPoolBytes.compressedLength),
208+
manifest.stringPoolBytes.decompressedLength
209+
)
191210
val poolBytes = new Array[Byte](manifest.stringPoolBytes.decompressedLength)
192211
stringPoolBytes.get(poolBytes)
193212
val pool = new Array[String](manifest.stringPoolLength.decompressedLength >> 2)
@@ -215,11 +234,18 @@ object Deserialization {
215234
a
216235
}
217236

218-
private def readArray(channel: FileChannel, ptr: OutlineStorage, nodes: Array[Array[GNode]], stringPool: Array[String]): Array[?] = {
237+
private def readArray(
238+
channel: FileChannel,
239+
ptr: OutlineStorage,
240+
nodes: Array[Array[GNode]],
241+
stringPoolFuture: concurrent.Future[Array[String]],
242+
zstdCtx: ZstdWrapper.ZstdCtx
243+
): Array[?] = {
219244
if (ptr == null) return null
220-
val dec = ZstdWrapper(
221-
Zstd.decompress(channel.map(FileChannel.MapMode.READ_ONLY, ptr.startOffset, ptr.compressedLength), ptr.decompressedLength)
222-
).order(ByteOrder.LITTLE_ENDIAN)
245+
if (ptr.typ == StorageType.String) stringPoolFuture.get()
246+
247+
val dec =
248+
zstdCtx.decompress(channel.map(FileChannel.MapMode.READ_ONLY, ptr.startOffset, ptr.compressedLength), ptr.decompressedLength)
223249
ptr.typ match {
224250
case StorageType.Bool =>
225251
val bytes = new Array[Byte](dec.limit())
@@ -253,9 +279,10 @@ object Deserialization {
253279
dec.asDoubleBuffer().get(res)
254280
res
255281
case StorageType.String =>
256-
val res = new Array[String](dec.limit() >> 2)
257-
val intbuf = dec.asIntBuffer()
258-
var idx = 0
282+
val stringPool = stringPoolFuture.get()
283+
val res = new Array[String](dec.limit() >> 2)
284+
val intbuf = dec.asIntBuffer()
285+
var idx = 0
259286
while (idx < res.length) {
260287
val offset = intbuf.get(idx)
261288
if (offset >= 0) res(idx) = stringPool(offset)

core/src/main/scala/flatgraph/storage/Manifest.scala

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ object Manifest {
3434
var nodes: Array[NodeItem],
3535
var edges: Array[EdgeItem],
3636
var properties: Array[PropertyItem],
37-
val stringPoolLength: OutlineStorage,
38-
val stringPoolBytes: OutlineStorage
37+
val stringPoolLength: OutlineStorage = new OutlineStorage(StorageType.Int),
38+
val stringPoolBytes: OutlineStorage = new OutlineStorage(StorageType.Byte)
3939
) {
4040
var version = 0
4141
}
@@ -96,9 +96,9 @@ object Manifest {
9696
val nodeLabel: String,
9797
val edgeLabel: String,
9898
val inout: Byte, // 0: Incoming, 1: Outgoing; see Edge.Direction enum
99-
var qty: OutlineStorage,
100-
var neighbors: OutlineStorage,
101-
var property: OutlineStorage
99+
var qty: OutlineStorage = new OutlineStorage,
100+
var neighbors: OutlineStorage = new OutlineStorage,
101+
var property: OutlineStorage = new OutlineStorage
102102
) {
103103
Edge.Direction.verifyEncodingRange(inout)
104104
}
@@ -122,11 +122,20 @@ object Manifest {
122122
}
123123
}
124124

125-
class PropertyItem(val nodeLabel: String, val propertyLabel: String, var qty: OutlineStorage, var property: OutlineStorage)
125+
class PropertyItem(
126+
val nodeLabel: String,
127+
val propertyLabel: String,
128+
var qty: OutlineStorage = new OutlineStorage,
129+
var property: OutlineStorage = new OutlineStorage
130+
)
126131

127132
object OutlineStorage {
128133
def write(item: OutlineStorage): ujson.Value = {
129134
if (item == null) return ujson.Null
135+
if (item.typ == null) {
136+
assert(item.startOffset == -1L && item.compressedLength == -1 && item.decompressedLength == -1, s"bad OutlineStorage ${item}")
137+
return ujson.Null
138+
}
130139
val res = ujson.Obj()
131140
res(Keys.Type) = item.typ
132141
res(Keys.StartOffset) = ujson.Num(item.startOffset.toDouble)
@@ -143,17 +152,25 @@ object Manifest {
143152

144153
def read(item: ujson.Value): OutlineStorage = {
145154
if (item.isNull) return null
146-
val res = new OutlineStorage(item.obj(Keys.Type).str)
155+
val res = new OutlineStorage
156+
res.typ = item.obj(Keys.Type).str
147157
res.startOffset = item.obj(Keys.StartOffset).num.toLong
148158
res.compressedLength = item.obj(Keys.CompressedLength).num.toInt
149159
res.decompressedLength = item.obj(Keys.DecompressedLength).num.toInt
150160
res
151161
}
152162
}
153163

154-
class OutlineStorage(var typ: String) {
164+
class OutlineStorage {
165+
var typ: String = null
155166
var startOffset: Long = -1L
156167
var compressedLength: Int = -1
157168
var decompressedLength: Int = -1
169+
def this(_typ: String) = {
170+
this()
171+
this.typ = _typ
172+
}
173+
174+
override def toString: String = super.toString + s"($typ, $startOffset, $compressedLength, $decompressedLength)"
158175
}
159176
}

0 commit comments

Comments
 (0)