diff --git a/summarize/nnsum/summarization-datasets/preprocess_ami.py b/summarize/nnsum/summarization-datasets/preprocess_ami.py index a412c9f..dbbfb03 100644 --- a/summarize/nnsum/summarization-datasets/preprocess_ami.py +++ b/summarize/nnsum/summarization-datasets/preprocess_ami.py @@ -42,7 +42,29 @@ def main(): print("Extracting data... ", end="", flush=True) f.seek(0) with tarfile.open(fileobj=f, mode='r:gz') as tf: - tf.extractall(args.data_dir) + + import os + + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tf, args.data_dir) print(" done!") if __name__ == "__main__": diff --git a/summarize/nnsum/summarization-datasets/preprocess_pubmed.py b/summarize/nnsum/summarization-datasets/preprocess_pubmed.py index 3372ab7..4c5c223 100644 --- a/summarize/nnsum/summarization-datasets/preprocess_pubmed.py +++ b/summarize/nnsum/summarization-datasets/preprocess_pubmed.py @@ -41,7 +41,29 @@ def main(): print("Extracting data... ", end="", flush=True) f.seek(0) with tarfile.open(fileobj=f, mode='r:gz') as tf: - tf.extractall(args.data_dir) + + import os + + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tf, args.data_dir) print(" done!") if __name__ == "__main__":