Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions core/src/main/scala/flatgraph/storage/Deserialization.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ object Deserialization {
}

def readManifest(channel: FileChannel): ujson.Value = {
if (channel.size() < HeaderSize)
throw new DeserializationException(s"corrupt file, expected at least $HeaderSize bytes, but only found ${channel.size()}")
val fileSize = channel.size()
if (fileSize < HeaderSize)
throw new DeserializationException(s"corrupt file: expected at least $HeaderSize bytes, but only found ${channel.size()}")

val header = ByteBuffer.allocate(HeaderSize).order(ByteOrder.LITTLE_ENDIAN)
var readBytes = 0
Expand All @@ -185,21 +186,26 @@ object Deserialization {

val headerBytes = new Array[Byte](Keys.Header.length)
header.get(headerBytes)
if (!Arrays.equals(headerBytes, Keys.Header))
if (!Arrays.equals(headerBytes, Keys.Header)) {
throw new DeserializationException(
s"expected header '$MagicBytesString' (`${Keys.Header.mkString("")}`), but found '${headerBytes.mkString("")}'"
)
}

val manifestOffset = header.getLong()
val manifestSize = channel.size() - manifestOffset
val manifestBytes = ByteBuffer.allocate(manifestSize.toInt)
if (manifestSize > fileSize)
throw new DeserializationException(s"corrupt file: manifest size ($manifestSize) cannot be larger than the file's size ($fileSize)")
if (manifestSize > Int.MaxValue)
throw new DeserializationException(s"corrupt file: unreasonably large manifest size ($manifestSize)... aborting")

val manifestBytes = ByteBuffer.allocate(manifestSize.toInt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to do this check at all, then please also handle the case where manifestSize overflows the 32 bit integer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏻 doen

readBytes = 0
while (readBytes < manifestSize) {
readBytes += channel.read(manifestBytes, readBytes + manifestOffset)
}
manifestBytes.flip()
ujson.read(manifestBytes)

}

private def readPool(manifest: GraphItem, fileChannel: FileChannel, zstdCtx: ZstdWrapper.ZstdCtx): Array[String] = {
Expand Down
48 changes: 47 additions & 1 deletion core/src/test/scala/flatgraph/SerializationTests.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package flatgraph

import flatgraph.misc.DebugDump.debugDump
import flatgraph.storage.Deserialization.DeserializationException
import flatgraph.storage.{Deserialization, Serialization}
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

import java.nio.file.Files
import java.nio.file.{Files, Path}
import java.io.RandomAccessFile
import java.nio.ByteBuffer
import java.nio.ByteOrder
import scala.util.Using

class SerializationTests extends AnyWordSpec with Matchers {

Expand Down Expand Up @@ -40,4 +45,45 @@ class SerializationTests extends AnyWordSpec with Matchers {
originalDump shouldBe newDump
}

/* Show that we're no longer vulnerable to the 'denial of service attack by manipulating the manifest'
* issue filed here: https://github.com/joernio/flatgraph/security/advisories/GHSA-jqmx-3x2p-69vh
* Note that we cannot prevent all potential 'small flatgraph file leads to OOM error' attacks.
* Always treat untrusted files with precaution...
*/
"is no longer vulnerable to manifest size attack" in {
val schema = TestSchema.make(1, 0)
val graph = Graph(schema)
val diff = DiffGraphBuilder(schema).addNode(new GenericDNode(0))
DiffGraphApplier.applyDiff(graph, diff)

val storagePath = Files.createTempFile(s"flatgraph-${getClass.getSimpleName}", "fg")
Serialization.writeGraph(graph, storagePath)
patchFile(storagePath)

// when the vulnerability was reported, the following line raised a:
// `java.lang.OutOfMemoryError: Requested array size exceeds VM limit`
intercept[DeserializationException] {
Deserialization.readGraph(storagePath, Option(graph.schema))
}.getMessage should include("corrupt file: manifest size")
}

/** manipulate file as detailed in https: //github.com/joernio/flatgraph/security/advisories/GHSA-jqmx-3x2p-69vh */
private def patchFile(path: Path): Unit = {
Using.resource(new RandomAccessFile(path.toFile, "rw")) { file =>
// Seek to end and get file size
file.seek(file.length())
val fileSize = file.getFilePointer

// Calculate malicious offset
val maliciousOffset = fileSize - 2147483647L

// Seek to position 8 and write the offset as little-endian long
file.seek(8)
val buffer = ByteBuffer.allocate(8)
buffer.order(ByteOrder.LITTLE_ENDIAN)
buffer.putLong(maliciousOffset)
file.write(buffer.array())
}
}

}