From 9be06d809443f11b92ef1334cbfb3eccc5d480ff Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Thu, 23 Mar 2023 11:57:58 -0700 Subject: [PATCH 01/17] fix issues in cc_net to make it workable in spark --- cc_net/execution.py | 63 +++++++++++++++++++++++++++++++++++--- cc_net/jsonql.py | 23 +++++++++++--- cc_net/mine.py | 56 ++++++++++++++++++++++----------- cc_net/process_wet_file.py | 4 ++- config/test_spark.json | 15 +++++++++ 5 files changed, 133 insertions(+), 28 deletions(-) create mode 100644 config/test_spark.json diff --git a/cc_net/execution.py b/cc_net/execution.py index d7664d2..e0c54d2 100644 --- a/cc_net/execution.py +++ b/cc_net/execution.py @@ -16,6 +16,7 @@ import submitit from typing_extensions import Protocol +from pyspark import SparkContext, SparkConf class Executor(Protocol): @@ -42,22 +43,36 @@ def get_executor( task_parallelism: int = -1, options: dict = {}, ) -> Executor: - + execution_mode = execution.split(",")[0] options.update( {kv.split("=", 1)[0]: kv.split("=", 1)[1] for kv in execution.split(",")[1:]} ) - + print(f"===get_executor name {name}, timeout_hour: {timeout_hour}, execution_mode: {execution_mode}, task_parallelism: {task_parallelism}") if execution_mode == "mp": warnings.warn("Execution mode 'mp' is deprecated, use 'local'.") execution_mode = "local" + if execution_mode == "spark": + conf = SparkConf().setAppName(name) + #conf.set("spark.eventLog.enabled", "true") # Enable event logging + #conf.set("spark.eventLog.dir", log_dir) # S + sc = SparkContext.getOrCreate(conf) + # We are on slurm + if task_parallelism == -1: + task_parallelism = 500 + return functools.partial(map_spark_array, sc, task_parallelism) + cluster = None if execution_mode == "auto" else execution_mode # use submitit to detect which executor is available ex = submitit.AutoExecutor(log_dir, cluster=cluster) if ex.cluster == "local": # LocalExecutor doesn't respect task_parallelism + ex.update_parameters( + name=name, + timeout_min=int(timeout_hour * 60) + ) return functools.partial(custom_map_array, ex, task_parallelism) if ex.cluster == "debug": return debug_executor @@ -76,6 +91,44 @@ def get_executor( ) return functools.partial(map_array_and_wait, ex) +def get_spark_executor( + name: str, +): + #conf = SparkConf().setAppName(name).setMaster("local[5]") + #sc = SparkContext(conf=conf) + sc = SparkContext.getOrCreate() + print("===done with spark context") + return sc + + +def map_spark_array( + sc: SparkContext, + task_parallelism: int, + function: Callable[..., str], + *args: Iterable, +): + f_name = function.__name__ + print(f"===calling spark for func: {f_name}, with arg's len {len(args)}, args {args}") + + #newargs = [] + + #for x in range(len(args)): + # newargs[x].append(args[x]) + pfunc = lambda *p: p + + + newargs = list(map(pfunc, *args)) + + print(f"=====new args {newargs}") + + assert len(args) > 0, f"No arguments passed to {f_name}" + + rdd = sc.parallelize(newargs, task_parallelism) + rdd = rdd.map(lambda p: function(*p)) + #rdd = rdd.map(function) + rdd.collect() + #print(f"===result in map_spark: {tmp} ") + def map_array_and_wait( ex: submitit.AutoExecutor, function: Callable[..., str], *args: Iterable @@ -85,7 +138,7 @@ def map_array_and_wait( assert len(args) > 0, f"No arguments passed to {f_name}" approx_length = _approx_length(*args) - print(f"Submitting {f_name} in a job array ({approx_length} jobs)") + print(f"map_array_and_wait Submitting {f_name} in a job array ({approx_length} jobs)") jobs = ex.map_array(function, *args) if not jobs: return @@ -151,13 +204,13 @@ def custom_map_array( ) -> None: f_name = function.__name__ assert len(args) > 0, f"No arguments passed to {f_name}" - + jobs_args = list(zip(*args)) total = len(jobs_args) if parallelism < 0: parallelism = os.cpu_count() or 0 assert parallelism >= 0, f"Can't run any jobs with task_parallelism={parallelism}" - print(f"Submitting {total} jobs for {f_name}, with task_parallelism={parallelism}") + print(f"custom_map_array Submitting {total} jobs for {f_name}, with task_parallelism={parallelism}") enqueued = 0 done = 0 running_jobs: List[submitit.Job] = [] diff --git a/cc_net/jsonql.py b/cc_net/jsonql.py index b5ab405..1d261cc 100644 --- a/cc_net/jsonql.py +++ b/cc_net/jsonql.py @@ -7,6 +7,7 @@ """ Manipulate files containing one json per line. """ +import dill import argparse import collections import contextlib @@ -409,14 +410,17 @@ def run_pipes( if expect_json and inputs is None: fns = (JsonReader(),) + fns transformers = [] + print(f"============get fns in run_pipes {fns}, count {len(fns)}") + for t in fns: + print(f"============checking {t}, is Transformer? {isinstance(t, Transformer)}, is Parall? {t.parallelisable}") if not isinstance(t, Transformer): break if not t.parallelisable: break transformers.append(t) pipes = fns[len(transformers) :] - + log = logging.getLogger(__name__).info if inputs is None: data: Iterable = open_read(file) @@ -426,31 +430,39 @@ def run_pipes( if processes == -1: processes = os.cpu_count() or 0 + print(f"============run transformers {transformers}, process count: {processes}, total cpu count: {os.cpu_count()}") with contextlib.suppress(BrokenPipeError), contextlib.ExitStack() as stack: if transformers: log(f"preparing {transformers}") + print(f"================ prepare {transformers}, process count: {processes}") transform = stack.enter_context(compose(transformers)) if processes <= 1: data = transform.map(data) else: p = multiprocessing.current_process() log(f"Will start {processes} processes from {p.name}, Pid: {p.pid}") - pool = stack.enter_context( - multiprocessing.Pool( + cp = multiprocessing.Pool( processes=processes, initializer=_set_global_transformer, initargs=(transform,), ) + log("done with muti pool") + pool = stack.enter_context( + cp ) data = pool.imap_unordered( _global_transformer, data, chunksize=chunksize ) for fn in pipes: + print(f"======================= now handling: {fn}") if isinstance(fn, Transformer): + print(f"======================= {fn} is Transformer") data = fn.map(data) else: + print(f"======================= {fn} is not Transformer") data = fn(data) + write_jsons(data, output) @@ -505,6 +517,7 @@ def write_jsons(source: Iterable[dict], file: WritableFileLike) -> None: print(res, file=o) + class JsonReader(Transformer): def __init__(self, strict: bool = False): super().__init__() @@ -638,8 +651,10 @@ def _prepare(self): def do(self, doc: dict) -> Optional[dict]: assert self.clauses - if not doc or not all((c(doc) for c in self.clauses)): + if not doc or not all((dill.loads(c)(doc) for c in self.clauses)): + # print(f"=====not load for lan {doc.get('language')}") return None + # print(f"===== load for lan {doc.get('language')}") self.n_selected += 1 return doc diff --git a/cc_net/mine.py b/cc_net/mine.py index 4b8df55..35e8cf0 100644 --- a/cc_net/mine.py +++ b/cc_net/mine.py @@ -10,7 +10,7 @@ The pipeline parameters are described in the `Config` class. """ - +import dill import hashlib import json import time @@ -57,7 +57,7 @@ class Config(NamedTuple): num_shards: number of shards to split the dump num_segments_per_shard: allow to download a small portion of CC (eg for tests) min_len: remove documents shorter than this (in chars) - hashes_in_mem: number of shards hashes to use for dedup + hash_in_mem: number of shards hashes to use for dedup lang_whitelist: only treat those languages lang_blacklist: ignore those languages lang_threshold: remove docs whose top language score is lower than this @@ -102,9 +102,11 @@ def get_executor( self, name: str, timeout_hour: int = 1, mem_gb: int = 1, cpus: int = 1 ) -> Executor: name = "_".join((name, self.config_name, *self.experiments)) + log_dir = self.output_dir / "logs" + log_dir.mkdir(parents=True, exist_ok=True) return execution.get_executor( name, - self.output_dir / "logs", + log_dir, self.execution, timeout_hour=timeout_hour, mem_gb=mem_gb, @@ -194,15 +196,17 @@ def get_mined_dir(self, regroup: bool = False) -> Path: config_name="test", dump="2019-09", output_dir=Path("test_data"), - execution="local", - num_shards=4, + execution="spark", + num_shards=3, num_segments_per_shard=1, hash_in_mem=2, mine_num_processes=2, - lang_whitelist=["de", "it", "fr"], - target_size="32M", + #lang_whitelist=["de", "it", "fr"], + lang_whitelist=["en"], + target_size="320M", cleanup_after_regroup=False, cache_dir=Path("test_data/wet_cache"), + task_parallelism=3, ) PREDEF_CONFIGS = { @@ -259,6 +263,7 @@ def hashes(conf: Config) -> List[Path]: hashes_dir.mkdir(parents=True, exist_ok=True) # With FlatHashSet we need ~2Gb of RAM / shard, but we need to account for # overhead due to how the dynamic allocation works. + print(f"=============missing_outputs num {missing_outputs}, transpose out: {_transpose(missing_outputs)}") ex = conf.get_executor(f"hashes_{conf.dump}", mem_gb=4, timeout_hour=6, cpus=2) ex(_hashes_shard, repeat(conf), *_transpose(missing_outputs)) @@ -270,6 +275,7 @@ def hashes(conf: Config) -> List[Path]: def _hashes_shard(conf: Config, shard: int, output: Path): tmp_output = tmp(output) + print(f"===========running hash shard: {shard}, output: {output}, tmp {tmp_output}") jsonql.run_pipes( dedup.HashesCollector(field="raw_content", output=tmp_output), inputs=conf.get_cc_shard(shard), @@ -321,13 +327,6 @@ def mine(conf: Config) -> List[Path]: if not missing_outputs: return outputs - mined_dir.mkdir(parents=True, exist_ok=True) - ex = conf.get_executor( - f"mine_{conf.dump}", - mem_gb=mem_gb, - timeout_hour=timeout_hour, - cpus=conf.mine_num_processes + 1, - ) # Compute hashes firsts. if "dedup" in conf.pipeline: @@ -338,6 +337,17 @@ def mine(conf: Config) -> List[Path]: else: hashes_files = repeat([]) + #JUNJUN + #return outputs + + mined_dir.mkdir(parents=True, exist_ok=True) + ex = conf.get_executor( + f"mine_shard_{conf.dump}", + mem_gb=mem_gb, + timeout_hour=timeout_hour, + cpus=conf.mine_num_processes + 1, + ) + ex(_mine_shard, repeat(conf), hashes_files, *_transpose(missing_outputs)) assert all(o.exists() for o in outputs) @@ -350,8 +360,10 @@ def _get_segment(tmp_output: Path, doc: dict) -> str: def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> str: + print(f"==_mine_shard called: conf: {conf}, hashes: {hashes}, shard: {shard}, output: {output} ") assert conf.pipeline tmp_output = tmp(output) + print(f"===output for shard: {shard}, output: {output}, tmp: {tmp_output} ") if "hashes" in conf.experiments: # HACK: used for generating paper figures hashes_in_mem = shard @@ -377,16 +389,18 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s model=lang_id, field="raw_content", out_field="lid_after_dedup", top=5 ) + #JUNJUN + # steps["keep_lang"] = None if conf.lang_blacklist: steps["keep_lang"] = jsonql.where( - [lambda doc: doc.get("language") not in set(conf.lang_blacklist)] + [dill.dumps(lambda doc: doc.get("language") not in set(conf.lang_blacklist))] ) elif conf.lang_whitelist: steps["keep_lang"] = jsonql.where( - [lambda doc: doc.get("language") in set(conf.lang_whitelist)] + [dill.dumps(lambda doc: doc.get("language") in set(conf.lang_whitelist))] ) else: - steps["keep_lang"] = None + steps["keep_lang"] = None tok_field = "tokenized" steps["sp"] = perplexity.MultiSentencePiece( @@ -427,8 +441,14 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s split_fn=lambda doc: _get_segment(tmp_output, doc), mkdir=True ) - pipeline = filter(None, (steps[s] for s in conf.pipeline)) + remainsteps = {s: steps[s] for s in conf.pipeline} + print(f"==steps: {remainsteps}") + + pipeline = filter(None, (steps[s] for s in conf.pipeline)) + #JUNJUN + + #pipeline = list(pipeline)[0:3] jsonql.run_pipes( *pipeline, inputs=cc_shard, diff --git a/cc_net/process_wet_file.py b/cc_net/process_wet_file.py index a870f6c..955cbfb 100644 --- a/cc_net/process_wet_file.py +++ b/cc_net/process_wet_file.py @@ -20,7 +20,8 @@ from cc_net import jsonql -WET_URL_ROOT = "https://commoncrawl.s3.amazonaws.com" +#WET_URL_ROOT = "https://commoncrawl.s3.amazonaws.com" +WET_URL_ROOT = "https://data.commoncrawl.org" logger = logging.getLogger(__name__) @@ -242,6 +243,7 @@ def segments(self) -> Sequence[str]: return self._segments segments = cc_segments(self.dump, self.cache_dir) n = len(segments) + print(f">>>>>>>>total segments: {n}, called from shard: {self.shard}") if self.num_shards < 0: self.num_shards = n // self.num_segments_per_shard i_min = (self.shard * n) // self.num_shards diff --git a/config/test_spark.json b/config/test_spark.json new file mode 100644 index 0000000..9eca2f7 --- /dev/null +++ b/config/test_spark.json @@ -0,0 +1,15 @@ +{ + "config_name":"test_spark", + "dump":"2019-09", + "output_dir":"test_data", + "execution":"spark", + "num_shards":3, + "num_segments_per_shard":1, + "hash_in_mem":2, + "mine_num_processes":2, + "lang_whitelist":["en"], + "target_size":"320M", + "cleanup_after_regroup":true, + "cache_dir":"test_data/wet_cache", + "task_parallelism":3 +} From 0b73b9b907fcd263cda2f7e2566282677a033853 Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Fri, 24 Mar 2023 16:55:25 -0700 Subject: [PATCH 02/17] fix the config and path issue in Spark executors --- cc_net/mine.py | 46 +++++++++++++++++++++++++++++++++++++++++++--- setup.py | 1 + 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/cc_net/mine.py b/cc_net/mine.py index 35e8cf0..99e708f 100644 --- a/cc_net/mine.py +++ b/cc_net/mine.py @@ -63,6 +63,7 @@ class Config(NamedTuple): lang_threshold: remove docs whose top language score is lower than this keep_bucket: keep only those perplexity bucket chose from (head, middle, tail, all) lm_dir: folder containing LMs + lm_id_path: path for the language identity bin lm_languages: only use LMs for the following languages cutoff: cutoff file to use for split in head/middle/tail mine_num_processes: number of processes to use for mining @@ -88,6 +89,7 @@ class Config(NamedTuple): lang_threshold: float = 0.5 keep_bucket: Sequence[str] = [] lm_dir: Path = Path("data/lm_sp") + lm_id_path: Path = Path("bin/lid.bin") cutoff: Path = CUTOFF_CSV lm_languages: Optional[Sequence[str]] = None mine_num_processes: int = 16 @@ -206,15 +208,53 @@ def get_mined_dir(self, regroup: bool = False) -> Path: target_size="320M", cleanup_after_regroup=False, cache_dir=Path("test_data/wet_cache"), - task_parallelism=3, + task_parallelism=4, +) + +DBFS_ROOT_PATH='/dbfs/tmp/users/jun.wan/cc_net/' + +TEST_SPARK_CONFIG = BASE_CONFIG._replace( + #output_dir: Path = Path("data") + #mined_dir: str = "mined" + #lm_dir: Path = Path("data/lm_sp") + #cutoff = Path(CUTOFF_CSV) + #cache_dir: Optional[Path] = None + + #metadata: Optional[str] = None + #min_len: int = 300 + #lang_blacklist: Sequence[str] = [] + #lang_threshold: float = 0.5 + #keep_bucket: Sequence[str] = [] + #lm_languages: Optional[Sequence[str]] = None + #pipeline: Sequence[str] = DEFAULT_PIPELINE + #experiments: Sequence[str] = [] + config_name="test_spark", + dump="2019-09", + output_dir=Path(DBFS_ROOT_PATH + "test_data"), + mined_dir= DBFS_ROOT_PATH + "test_data/mined", + execution="spark", + num_shards=3, + num_segments_per_shard=1, + hash_in_mem=2, + mine_num_processes=2, + #lang_whitelist=["de", "it", "fr"], + lang_whitelist=["en"], + lm_dir=Path(DBFS_ROOT_PATH + "data/lm_sp"), + lm_id_path = Path(DBFS_ROOT_PATH + "bin/lid.bin"), + cutoff = Path(CUTOFF_CSV), + target_size="320M", + cleanup_after_regroup=False, + cache_dir=Path(DBFS_ROOT_PATH + "test_data/wet_cache"), + task_parallelism=4, ) PREDEF_CONFIGS = { "base": BASE_CONFIG, "by_lang": BYLANG_CONFIG, "test": TEST_CONFIG, + "test_spark": TEST_SPARK_CONFIG, "test_slurm": TEST_CONFIG._replace(execution="slurm,partition=dev"), - "debug": TEST_CONFIG._replace(config_name="debug", mine_num_processes=0), + "debug": TEST_CONFIG._replace(config_name="debug", mine_num_processes=0, execution="local", task_parallelism=-1), "reproduce": REPRODUCE_CONFIG, "augment": BASE_CONFIG._replace( config_name="augment", dump="2019-13", lang_blacklist=["en"] @@ -372,7 +412,7 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s cc_shard = conf.get_cc_shard(shard) steps: Dict[str, Optional[jsonql.Transformer]] = {} - lang_id = Path("bin") / "lid.bin" + lang_id = conf.lm_id_path steps["lid_before_dedup"] = split_by_lang.Classifier( model=lang_id, field="raw_content", out_field="lid_before_dedup", top=5 ) diff --git a/setup.py b/setup.py index 5ec5418..114934d 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ "sacremoses", "submitit>=1.0.0", "typing_extensions", + "dill", ], extras_require={ "dev": ["mypy==0.790", "pytest", "black==19.3b0", "isort==5.6.4"], From 0c9d57b9797931f5d29cef3b6348c0823374c8f2 Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Mon, 27 Mar 2023 14:35:17 -0700 Subject: [PATCH 03/17] remove unnecessary logging --- cc_net/execution.py | 15 +-------------- cc_net/jsonql.py | 12 ++---------- cc_net/mine.py | 21 ++++----------------- cc_net/process_wet_file.py | 2 +- 4 files changed, 8 insertions(+), 42 deletions(-) diff --git a/cc_net/execution.py b/cc_net/execution.py index e0c54d2..8d1e281 100644 --- a/cc_net/execution.py +++ b/cc_net/execution.py @@ -94,10 +94,7 @@ def get_executor( def get_spark_executor( name: str, ): - #conf = SparkConf().setAppName(name).setMaster("local[5]") - #sc = SparkContext(conf=conf) sc = SparkContext.getOrCreate() - print("===done with spark context") return sc @@ -108,26 +105,16 @@ def map_spark_array( *args: Iterable, ): f_name = function.__name__ - print(f"===calling spark for func: {f_name}, with arg's len {len(args)}, args {args}") + print(f"==calling spark for func: {f_name}, with arg's len {len(args)}, args {args}") - #newargs = [] - - #for x in range(len(args)): - # newargs[x].append(args[x]) pfunc = lambda *p: p - - newargs = list(map(pfunc, *args)) - print(f"=====new args {newargs}") - assert len(args) > 0, f"No arguments passed to {f_name}" rdd = sc.parallelize(newargs, task_parallelism) rdd = rdd.map(lambda p: function(*p)) - #rdd = rdd.map(function) rdd.collect() - #print(f"===result in map_spark: {tmp} ") def map_array_and_wait( diff --git a/cc_net/jsonql.py b/cc_net/jsonql.py index 1d261cc..4fecf91 100644 --- a/cc_net/jsonql.py +++ b/cc_net/jsonql.py @@ -410,10 +410,9 @@ def run_pipes( if expect_json and inputs is None: fns = (JsonReader(),) + fns transformers = [] - print(f"============get fns in run_pipes {fns}, count {len(fns)}") + print(f"==get fns in run_pipes {fns}, count {len(fns)}") for t in fns: - print(f"============checking {t}, is Transformer? {isinstance(t, Transformer)}, is Parall? {t.parallelisable}") if not isinstance(t, Transformer): break if not t.parallelisable: @@ -430,11 +429,9 @@ def run_pipes( if processes == -1: processes = os.cpu_count() or 0 - print(f"============run transformers {transformers}, process count: {processes}, total cpu count: {os.cpu_count()}") with contextlib.suppress(BrokenPipeError), contextlib.ExitStack() as stack: if transformers: log(f"preparing {transformers}") - print(f"================ prepare {transformers}, process count: {processes}") transform = stack.enter_context(compose(transformers)) if processes <= 1: data = transform.map(data) @@ -455,15 +452,12 @@ def run_pipes( ) for fn in pipes: - print(f"======================= now handling: {fn}") + print(f"==now handling: {fn}") if isinstance(fn, Transformer): - print(f"======================= {fn} is Transformer") data = fn.map(data) else: - print(f"======================= {fn} is not Transformer") data = fn(data) - write_jsons(data, output) @@ -652,9 +646,7 @@ def _prepare(self): def do(self, doc: dict) -> Optional[dict]: assert self.clauses if not doc or not all((dill.loads(c)(doc) for c in self.clauses)): - # print(f"=====not load for lan {doc.get('language')}") return None - # print(f"===== load for lan {doc.get('language')}") self.n_selected += 1 return doc diff --git a/cc_net/mine.py b/cc_net/mine.py index 99e708f..07cc606 100644 --- a/cc_net/mine.py +++ b/cc_net/mine.py @@ -211,15 +211,9 @@ def get_mined_dir(self, regroup: bool = False) -> Path: task_parallelism=4, ) -DBFS_ROOT_PATH='/dbfs/tmp/users/jun.wan/cc_net/' +DBFS_ROOT_PATH='/dbfs/tmp/cc_net/' TEST_SPARK_CONFIG = BASE_CONFIG._replace( - #output_dir: Path = Path("data") - #mined_dir: str = "mined" - #lm_dir: Path = Path("data/lm_sp") - #cutoff = Path(CUTOFF_CSV) - #cache_dir: Optional[Path] = None - #metadata: Optional[str] = None #min_len: int = 300 #lang_blacklist: Sequence[str] = [] @@ -303,7 +297,7 @@ def hashes(conf: Config) -> List[Path]: hashes_dir.mkdir(parents=True, exist_ok=True) # With FlatHashSet we need ~2Gb of RAM / shard, but we need to account for # overhead due to how the dynamic allocation works. - print(f"=============missing_outputs num {missing_outputs}, transpose out: {_transpose(missing_outputs)}") + # print(f"==missing_outputs num {missing_outputs}, transpose out: {_transpose(missing_outputs)}") ex = conf.get_executor(f"hashes_{conf.dump}", mem_gb=4, timeout_hour=6, cpus=2) ex(_hashes_shard, repeat(conf), *_transpose(missing_outputs)) @@ -315,7 +309,7 @@ def hashes(conf: Config) -> List[Path]: def _hashes_shard(conf: Config, shard: int, output: Path): tmp_output = tmp(output) - print(f"===========running hash shard: {shard}, output: {output}, tmp {tmp_output}") + print(f"==running hash shard: {shard}, output: {output}, tmp {tmp_output}") jsonql.run_pipes( dedup.HashesCollector(field="raw_content", output=tmp_output), inputs=conf.get_cc_shard(shard), @@ -377,8 +371,6 @@ def mine(conf: Config) -> List[Path]: else: hashes_files = repeat([]) - #JUNJUN - #return outputs mined_dir.mkdir(parents=True, exist_ok=True) ex = conf.get_executor( @@ -400,10 +392,8 @@ def _get_segment(tmp_output: Path, doc: dict) -> str: def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> str: - print(f"==_mine_shard called: conf: {conf}, hashes: {hashes}, shard: {shard}, output: {output} ") assert conf.pipeline tmp_output = tmp(output) - print(f"===output for shard: {shard}, output: {output}, tmp: {tmp_output} ") if "hashes" in conf.experiments: # HACK: used for generating paper figures hashes_in_mem = shard @@ -429,8 +419,7 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s model=lang_id, field="raw_content", out_field="lid_after_dedup", top=5 ) - #JUNJUN - # steps["keep_lang"] = None + if conf.lang_blacklist: steps["keep_lang"] = jsonql.where( [dill.dumps(lambda doc: doc.get("language") not in set(conf.lang_blacklist))] @@ -486,9 +475,7 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s print(f"==steps: {remainsteps}") pipeline = filter(None, (steps[s] for s in conf.pipeline)) - #JUNJUN - #pipeline = list(pipeline)[0:3] jsonql.run_pipes( *pipeline, inputs=cc_shard, diff --git a/cc_net/process_wet_file.py b/cc_net/process_wet_file.py index 955cbfb..a161142 100644 --- a/cc_net/process_wet_file.py +++ b/cc_net/process_wet_file.py @@ -243,7 +243,7 @@ def segments(self) -> Sequence[str]: return self._segments segments = cc_segments(self.dump, self.cache_dir) n = len(segments) - print(f">>>>>>>>total segments: {n}, called from shard: {self.shard}") + print(f"==total segments: {n}, called from shard: {self.shard}") if self.num_shards < 0: self.num_shards = n // self.num_segments_per_shard i_min = (self.shard * n) // self.num_shards From a476795919a30ff1fe3b5f6743d4e868cfb893b5 Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Mon, 27 Mar 2023 14:45:21 -0700 Subject: [PATCH 04/17] format the code --- cc_net/execution.py | 42 +++++++++++++++----------- cc_net/jsonql.py | 25 +++++++--------- cc_net/mine.py | 56 ++++++++++++++++++++--------------- cc_net/process_wet_file.py | 2 +- cc_net/regroup.py | 2 +- cc_net/tools/expand_corpus.py | 2 +- tests/test_flat_hash_set.py | 2 +- 7 files changed, 71 insertions(+), 60 deletions(-) diff --git a/cc_net/execution.py b/cc_net/execution.py index 8d1e281..5228388 100644 --- a/cc_net/execution.py +++ b/cc_net/execution.py @@ -15,8 +15,8 @@ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Sized import submitit +from pyspark import SparkConf, SparkContext from typing_extensions import Protocol -from pyspark import SparkContext, SparkConf class Executor(Protocol): @@ -43,22 +43,24 @@ def get_executor( task_parallelism: int = -1, options: dict = {}, ) -> Executor: - + execution_mode = execution.split(",")[0] options.update( {kv.split("=", 1)[0]: kv.split("=", 1)[1] for kv in execution.split(",")[1:]} ) - print(f"===get_executor name {name}, timeout_hour: {timeout_hour}, execution_mode: {execution_mode}, task_parallelism: {task_parallelism}") + print( + f"===get_executor name {name}, timeout_hour: {timeout_hour}, execution_mode: {execution_mode}, task_parallelism: {task_parallelism}" + ) if execution_mode == "mp": warnings.warn("Execution mode 'mp' is deprecated, use 'local'.") execution_mode = "local" if execution_mode == "spark": conf = SparkConf().setAppName(name) - #conf.set("spark.eventLog.enabled", "true") # Enable event logging - #conf.set("spark.eventLog.dir", log_dir) # S + # conf.set("spark.eventLog.enabled", "true") # Enable event logging + # conf.set("spark.eventLog.dir", log_dir) # S sc = SparkContext.getOrCreate(conf) - # We are on slurm + # We are on slurm if task_parallelism == -1: task_parallelism = 500 return functools.partial(map_spark_array, sc, task_parallelism) @@ -69,10 +71,7 @@ def get_executor( if ex.cluster == "local": # LocalExecutor doesn't respect task_parallelism - ex.update_parameters( - name=name, - timeout_min=int(timeout_hour * 60) - ) + ex.update_parameters(name=name, timeout_min=int(timeout_hour * 60)) return functools.partial(custom_map_array, ex, task_parallelism) if ex.cluster == "debug": return debug_executor @@ -91,6 +90,7 @@ def get_executor( ) return functools.partial(map_array_and_wait, ex) + def get_spark_executor( name: str, ): @@ -99,19 +99,21 @@ def get_spark_executor( def map_spark_array( - sc: SparkContext, + sc: SparkContext, task_parallelism: int, - function: Callable[..., str], + function: Callable[..., str], *args: Iterable, ): f_name = function.__name__ - print(f"==calling spark for func: {f_name}, with arg's len {len(args)}, args {args}") + print( + f"==calling spark for func: {f_name}, with arg's len {len(args)}, args {args}" + ) - pfunc = lambda *p: p + pfunc = lambda *p: p newargs = list(map(pfunc, *args)) assert len(args) > 0, f"No arguments passed to {f_name}" - + rdd = sc.parallelize(newargs, task_parallelism) rdd = rdd.map(lambda p: function(*p)) rdd.collect() @@ -125,7 +127,9 @@ def map_array_and_wait( assert len(args) > 0, f"No arguments passed to {f_name}" approx_length = _approx_length(*args) - print(f"map_array_and_wait Submitting {f_name} in a job array ({approx_length} jobs)") + print( + f"map_array_and_wait Submitting {f_name} in a job array ({approx_length} jobs)" + ) jobs = ex.map_array(function, *args) if not jobs: return @@ -191,13 +195,15 @@ def custom_map_array( ) -> None: f_name = function.__name__ assert len(args) > 0, f"No arguments passed to {f_name}" - + jobs_args = list(zip(*args)) total = len(jobs_args) if parallelism < 0: parallelism = os.cpu_count() or 0 assert parallelism >= 0, f"Can't run any jobs with task_parallelism={parallelism}" - print(f"custom_map_array Submitting {total} jobs for {f_name}, with task_parallelism={parallelism}") + print( + f"custom_map_array Submitting {total} jobs for {f_name}, with task_parallelism={parallelism}" + ) enqueued = 0 done = 0 running_jobs: List[submitit.Job] = [] diff --git a/cc_net/jsonql.py b/cc_net/jsonql.py index 4fecf91..ba9b138 100644 --- a/cc_net/jsonql.py +++ b/cc_net/jsonql.py @@ -7,7 +7,6 @@ """ Manipulate files containing one json per line. """ -import dill import argparse import collections import contextlib @@ -43,6 +42,7 @@ Union, ) +import dill import numpy as np import psutil # type: ignore import requests @@ -419,7 +419,7 @@ def run_pipes( break transformers.append(t) pipes = fns[len(transformers) :] - + log = logging.getLogger(__name__).info if inputs is None: data: Iterable = open_read(file) @@ -439,14 +439,12 @@ def run_pipes( p = multiprocessing.current_process() log(f"Will start {processes} processes from {p.name}, Pid: {p.pid}") cp = multiprocessing.Pool( - processes=processes, - initializer=_set_global_transformer, - initargs=(transform,), - ) - log("done with muti pool") - pool = stack.enter_context( - cp + processes=processes, + initializer=_set_global_transformer, + initargs=(transform,), ) + log("done with muti pool") + pool = stack.enter_context(cp) data = pool.imap_unordered( _global_transformer, data, chunksize=chunksize ) @@ -457,7 +455,7 @@ def run_pipes( data = fn.map(data) else: data = fn(data) - + write_jsons(data, output) @@ -511,7 +509,6 @@ def write_jsons(source: Iterable[dict], file: WritableFileLike) -> None: print(res, file=o) - class JsonReader(Transformer): def __init__(self, strict: bool = False): super().__init__() @@ -1026,7 +1023,7 @@ def open_write( def parse_size(size): - unit_map = {"B": 1, "K": 1024, "M": 1024 ** 2, "G": 1024 ** 3} + unit_map = {"B": 1, "K": 1024, "M": 1024**2, "G": 1024**3} unit = size[-1].upper() assert ( unit in unit_map @@ -1109,7 +1106,7 @@ def request_get_content(url: str, n_retry: int = 3) -> bytes: warnings.warn( f"Swallowed error {e} while downloading {url} ({i} out of {n_retry})" ) - time.sleep(10 * 2 ** i) + time.sleep(10 * 2**i) dl_time = time.time() - t0 dl_speed = len(r.content) / dl_time / 1024 logging.info( @@ -1155,7 +1152,7 @@ def sharded_file(file_pattern: Path, mode: str, max_size: str = "4G") -> MultiFi assert 0 < n < 8 assert "?" * n in name, f"The '?' need to be adjacents in {file_pattern}" assert "r" not in mode - files = (folder / name.replace("?" * n, f"%0{n}d" % i) for i in range(10 ** n)) + files = (folder / name.replace("?" * n, f"%0{n}d" % i) for i in range(10**n)) return MultiFile(files, mode, max_size) diff --git a/cc_net/mine.py b/cc_net/mine.py index 07cc606..3a66cf4 100644 --- a/cc_net/mine.py +++ b/cc_net/mine.py @@ -10,7 +10,6 @@ The pipeline parameters are described in the `Config` class. """ -import dill import hashlib import json import time @@ -21,6 +20,7 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple +import dill import func_argparse # Local scripts @@ -63,7 +63,7 @@ class Config(NamedTuple): lang_threshold: remove docs whose top language score is lower than this keep_bucket: keep only those perplexity bucket chose from (head, middle, tail, all) lm_dir: folder containing LMs - lm_id_path: path for the language identity bin + lm_id_path: path for the language identity bin lm_languages: only use LMs for the following languages cutoff: cutoff file to use for split in head/middle/tail mine_num_processes: number of processes to use for mining @@ -203,7 +203,7 @@ def get_mined_dir(self, regroup: bool = False) -> Path: num_segments_per_shard=1, hash_in_mem=2, mine_num_processes=2, - #lang_whitelist=["de", "it", "fr"], + # lang_whitelist=["de", "it", "fr"], lang_whitelist=["en"], target_size="320M", cleanup_after_regroup=False, @@ -211,31 +211,31 @@ def get_mined_dir(self, regroup: bool = False) -> Path: task_parallelism=4, ) -DBFS_ROOT_PATH='/dbfs/tmp/cc_net/' +DBFS_ROOT_PATH = "/dbfs/tmp/cc_net/" TEST_SPARK_CONFIG = BASE_CONFIG._replace( - #metadata: Optional[str] = None - #min_len: int = 300 - #lang_blacklist: Sequence[str] = [] - #lang_threshold: float = 0.5 - #keep_bucket: Sequence[str] = [] - #lm_languages: Optional[Sequence[str]] = None - #pipeline: Sequence[str] = DEFAULT_PIPELINE - #experiments: Sequence[str] = [] + # metadata: Optional[str] = None + # min_len: int = 300 + # lang_blacklist: Sequence[str] = [] + # lang_threshold: float = 0.5 + # keep_bucket: Sequence[str] = [] + # lm_languages: Optional[Sequence[str]] = None + # pipeline: Sequence[str] = DEFAULT_PIPELINE + # experiments: Sequence[str] = [] config_name="test_spark", dump="2019-09", output_dir=Path(DBFS_ROOT_PATH + "test_data"), - mined_dir= DBFS_ROOT_PATH + "test_data/mined", + mined_dir=DBFS_ROOT_PATH + "test_data/mined", execution="spark", num_shards=3, num_segments_per_shard=1, hash_in_mem=2, mine_num_processes=2, - #lang_whitelist=["de", "it", "fr"], + # lang_whitelist=["de", "it", "fr"], lang_whitelist=["en"], lm_dir=Path(DBFS_ROOT_PATH + "data/lm_sp"), - lm_id_path = Path(DBFS_ROOT_PATH + "bin/lid.bin"), - cutoff = Path(CUTOFF_CSV), + lm_id_path=Path(DBFS_ROOT_PATH + "bin/lid.bin"), + cutoff=Path(CUTOFF_CSV), target_size="320M", cleanup_after_regroup=False, cache_dir=Path(DBFS_ROOT_PATH + "test_data/wet_cache"), @@ -248,7 +248,12 @@ def get_mined_dir(self, regroup: bool = False) -> Path: "test": TEST_CONFIG, "test_spark": TEST_SPARK_CONFIG, "test_slurm": TEST_CONFIG._replace(execution="slurm,partition=dev"), - "debug": TEST_CONFIG._replace(config_name="debug", mine_num_processes=0, execution="local", task_parallelism=-1), + "debug": TEST_CONFIG._replace( + config_name="debug", + mine_num_processes=0, + execution="local", + task_parallelism=-1, + ), "reproduce": REPRODUCE_CONFIG, "augment": BASE_CONFIG._replace( config_name="augment", dump="2019-13", lang_blacklist=["en"] @@ -361,7 +366,6 @@ def mine(conf: Config) -> List[Path]: if not missing_outputs: return outputs - # Compute hashes firsts. if "dedup" in conf.pipeline: hashes_groups = list(jsonql.grouper(hashes(conf), conf.hash_in_mem)) @@ -371,7 +375,6 @@ def mine(conf: Config) -> List[Path]: else: hashes_files = repeat([]) - mined_dir.mkdir(parents=True, exist_ok=True) ex = conf.get_executor( f"mine_shard_{conf.dump}", @@ -419,17 +422,20 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s model=lang_id, field="raw_content", out_field="lid_after_dedup", top=5 ) - if conf.lang_blacklist: steps["keep_lang"] = jsonql.where( - [dill.dumps(lambda doc: doc.get("language") not in set(conf.lang_blacklist))] + [ + dill.dumps( + lambda doc: doc.get("language") not in set(conf.lang_blacklist) + ) + ] ) elif conf.lang_whitelist: steps["keep_lang"] = jsonql.where( - [dill.dumps(lambda doc: doc.get("language") in set(conf.lang_whitelist))] + [dill.dumps(lambda doc: doc.get("language") in set(conf.lang_whitelist))] ) else: - steps["keep_lang"] = None + steps["keep_lang"] = None tok_field = "tokenized" steps["sp"] = perplexity.MultiSentencePiece( @@ -475,7 +481,7 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s print(f"==steps: {remainsteps}") pipeline = filter(None, (steps[s] for s in conf.pipeline)) - + jsonql.run_pipes( *pipeline, inputs=cc_shard, @@ -690,6 +696,8 @@ def main(config: str = "base", **config_as_dict: Any) -> None: if conf.config_name == "test": _validate_test(conf, conf.get_mined_dir(regroup=True)) + print("==Completed all pipelines!") + if __name__ == "__main__": func_argparse.parse_and_call(get_main_parser()) diff --git a/cc_net/process_wet_file.py b/cc_net/process_wet_file.py index a161142..9768a65 100644 --- a/cc_net/process_wet_file.py +++ b/cc_net/process_wet_file.py @@ -20,7 +20,7 @@ from cc_net import jsonql -#WET_URL_ROOT = "https://commoncrawl.s3.amazonaws.com" +# WET_URL_ROOT = "https://commoncrawl.s3.amazonaws.com" WET_URL_ROOT = "https://data.commoncrawl.org" diff --git a/cc_net/regroup.py b/cc_net/regroup.py index 575baee..d7f2084 100644 --- a/cc_net/regroup.py +++ b/cc_net/regroup.py @@ -105,7 +105,7 @@ def fast_reshard( def determine_groups( - inputs: List[Path], target_size: int = 4 * 1024 ** 3 + inputs: List[Path], target_size: int = 4 * 1024**3 ) -> List[List[Path]]: if len(inputs) == 0: return [] diff --git a/cc_net/tools/expand_corpus.py b/cc_net/tools/expand_corpus.py index 44f3dce..46d16bc 100644 --- a/cc_net/tools/expand_corpus.py +++ b/cc_net/tools/expand_corpus.py @@ -26,7 +26,7 @@ KENLM = Path("./bin/lmplz") KENLM_BUILD = Path("./bin/build_binary") -VOCAB_SIZE = 2 ** 16 - 10 +VOCAB_SIZE = 2**16 - 10 PROCESSES = 16 diff --git a/tests/test_flat_hash_set.py b/tests/test_flat_hash_set.py index 00287df..0848fe2 100644 --- a/tests/test_flat_hash_set.py +++ b/tests/test_flat_hash_set.py @@ -69,7 +69,7 @@ def check_reload(h, dump, load, tmp_path): @pytest.mark.parametrize("hash_set_cls", [FlatHashSet, NaiveHashSet]) def test_loading(tmp_path, hash_set_cls): h = hash_set_cls() - x = np.random.randint(0, 2 ** 32, (100,), dtype=h.dtype) + x = np.random.randint(0, 2**32, (100,), dtype=h.dtype) h.add(x) check_reload(h, hash_set_cls.dump, hash_set_cls.load, tmp_path) From d33231380ada9294efc18f09b573b74e940fc895 Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Mon, 27 Mar 2023 17:22:10 -0700 Subject: [PATCH 05/17] update log to avoid long strings --- cc_net/execution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cc_net/execution.py b/cc_net/execution.py index 5228388..1e33c16 100644 --- a/cc_net/execution.py +++ b/cc_net/execution.py @@ -106,7 +106,7 @@ def map_spark_array( ): f_name = function.__name__ print( - f"==calling spark for func: {f_name}, with arg's len {len(args)}, args {args}" + f"==calling spark for func: {f_name}, with arg's len {len(args)}" ) pfunc = lambda *p: p From d46059c6bee8d18e2fb6e8980a7f401c6119d8c3 Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Mon, 27 Mar 2023 22:32:36 -0700 Subject: [PATCH 06/17] seperate hash stage out --- cc_net/jsonql.py | 10 +++++++--- cc_net/mine.py | 15 ++++++++++----- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/cc_net/jsonql.py b/cc_net/jsonql.py index ba9b138..c67af1f 100644 --- a/cc_net/jsonql.py +++ b/cc_net/jsonql.py @@ -1119,7 +1119,11 @@ def open_remote_file(url: str, cache: Path = None) -> Iterable[str]: """Download the files at the given url to memory and opens it as a file. Assumes that the file is small, and fetch it when this function is called. """ - if cache and cache.exists(): + valid_cache = False + if cache and cache.exists() and os.path.getsize(cache) > 1 : + valid_cache = True + + if cache and valid_cache: return open_read(cache) # TODO: open the remote file in streaming mode. @@ -1132,11 +1136,11 @@ def open_remote_file(url: str, cache: Path = None) -> Iterable[str]: else: f = io.TextIOWrapper(content) - if cache and not cache.exists(): + if cache and not valid_cache: # The file might have been created while downloading/writing. tmp_cache = _tmp(cache) tmp_cache.write_bytes(raw_bytes) - if not cache.exists(): + if not valid_cache: tmp_cache.replace(cache) else: tmp_cache.unlink() diff --git a/cc_net/mine.py b/cc_net/mine.py index 3a66cf4..5417a96 100644 --- a/cc_net/mine.py +++ b/cc_net/mine.py @@ -34,6 +34,7 @@ CUTOFF_CSV = FILE_DIR / "data" / "cutoff.csv" DEFAULT_PIPELINE = [ + "hash", "dedup", "lid", "keep_lang", @@ -206,7 +207,7 @@ def get_mined_dir(self, regroup: bool = False) -> Path: # lang_whitelist=["de", "it", "fr"], lang_whitelist=["en"], target_size="320M", - cleanup_after_regroup=False, + cleanup_after_regroup=True, cache_dir=Path("test_data/wet_cache"), task_parallelism=4, ) @@ -367,7 +368,7 @@ def mine(conf: Config) -> List[Path]: return outputs # Compute hashes firsts. - if "dedup" in conf.pipeline: + if "hash" in conf.pipeline: hashes_groups = list(jsonql.grouper(hashes(conf), conf.hash_in_mem)) hashes_files: Iterable[List[Path]] = [ hashes_groups[shard // conf.hash_in_mem] for shard, o in missing_outputs @@ -476,11 +477,15 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s split_fn=lambda doc: _get_segment(tmp_output, doc), mkdir=True ) - remainsteps = {s: steps[s] for s in conf.pipeline} + remainsteps = [] + pipeline = [] - print(f"==steps: {remainsteps}") + for s in conf.pipeline: + if s in steps.keys(): + remainsteps.append(s) + pipeline.append(steps[s]) - pipeline = filter(None, (steps[s] for s in conf.pipeline)) + print(f"==steps: {remainsteps}") jsonql.run_pipes( *pipeline, From 9e170827d199a85e9efd63c26b4afb1d5bbe4086 Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Mon, 27 Mar 2023 23:14:27 -0700 Subject: [PATCH 07/17] fix format --- cc_net/execution.py | 4 +--- cc_net/jsonql.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/cc_net/execution.py b/cc_net/execution.py index 1e33c16..7c5327a 100644 --- a/cc_net/execution.py +++ b/cc_net/execution.py @@ -105,9 +105,7 @@ def map_spark_array( *args: Iterable, ): f_name = function.__name__ - print( - f"==calling spark for func: {f_name}, with arg's len {len(args)}" - ) + print(f"==calling spark for func: {f_name}, with arg's len {len(args)}") pfunc = lambda *p: p newargs = list(map(pfunc, *args)) diff --git a/cc_net/jsonql.py b/cc_net/jsonql.py index c67af1f..62a7b17 100644 --- a/cc_net/jsonql.py +++ b/cc_net/jsonql.py @@ -1120,7 +1120,7 @@ def open_remote_file(url: str, cache: Path = None) -> Iterable[str]: Assumes that the file is small, and fetch it when this function is called. """ valid_cache = False - if cache and cache.exists() and os.path.getsize(cache) > 1 : + if cache and cache.exists() and os.path.getsize(cache) > 1: valid_cache = True if cache and valid_cache: From 6204dce98b1a4c3a275503652d8f65d84bf07b27 Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Tue, 28 Mar 2023 21:52:47 -0700 Subject: [PATCH 08/17] add dbfs copy to local --- cc_net/mine.py | 44 +++++++++++++++++++++++++++++--------------- initial_script.sh | 14 ++++++++++++++ setup.py | 2 +- 3 files changed, 44 insertions(+), 16 deletions(-) create mode 100644 initial_script.sh diff --git a/cc_net/mine.py b/cc_net/mine.py index 5417a96..7a6a361 100644 --- a/cc_net/mine.py +++ b/cc_net/mine.py @@ -32,6 +32,7 @@ # Constant FILE_DIR = Path(__file__).parent CUTOFF_CSV = FILE_DIR / "data" / "cutoff.csv" +LID_BIN = FILE_DIR / "bin" / "lid.bin" DEFAULT_PIPELINE = [ "hash", @@ -90,7 +91,7 @@ class Config(NamedTuple): lang_threshold: float = 0.5 keep_bucket: Sequence[str] = [] lm_dir: Path = Path("data/lm_sp") - lm_id_path: Path = Path("bin/lid.bin") + lm_id_path: Path = LID_BIN cutoff: Path = CUTOFF_CSV lm_languages: Optional[Sequence[str]] = None mine_num_processes: int = 16 @@ -100,6 +101,8 @@ class Config(NamedTuple): pipeline: Sequence[str] = DEFAULT_PIPELINE experiments: Sequence[str] = [] cache_dir: Optional[Path] = None + dbfs_lm_dir: str = "" + dbfs_lm_id_path: str = "" def get_executor( self, name: str, timeout_hour: int = 1, mem_gb: int = 1, cpus: int = 1 @@ -165,6 +168,19 @@ def get_mined_dir(self, regroup: bool = False) -> Path: return self.output_dir / f"{self.mined_dir}_split" / self.dump return self.output_dir / self.mined_dir / self.dump + def download_dbfs_file_to_local(self, source: Path, target: Path, overwrite: bool = False) -> bool: + if (not overwrite and target.exists()): + return True + + if (not source.startswith("/dbfs")): + print(f"==can not download file from dbfs to local as source path does not start from /dbfs: {str(source)}") + return False + + from databricks.sdk.runtime import dbutils + copied = dbutils.fs.cp(source.replace('/dbfs','dbfs:'), "file:" + str(target.absolute())) + print(f"==copied: {copied}, source: {str(source)}, target:{str(target)}") + return copied + BASE_CONFIG = Config() @@ -208,21 +224,13 @@ def get_mined_dir(self, regroup: bool = False) -> Path: lang_whitelist=["en"], target_size="320M", cleanup_after_regroup=True, - cache_dir=Path("test_data/wet_cache"), + cache_dir=Path("test_data2_wet_cache"), task_parallelism=4, ) DBFS_ROOT_PATH = "/dbfs/tmp/cc_net/" TEST_SPARK_CONFIG = BASE_CONFIG._replace( - # metadata: Optional[str] = None - # min_len: int = 300 - # lang_blacklist: Sequence[str] = [] - # lang_threshold: float = 0.5 - # keep_bucket: Sequence[str] = [] - # lm_languages: Optional[Sequence[str]] = None - # pipeline: Sequence[str] = DEFAULT_PIPELINE - # experiments: Sequence[str] = [] config_name="test_spark", dump="2019-09", output_dir=Path(DBFS_ROOT_PATH + "test_data"), @@ -232,14 +240,10 @@ def get_mined_dir(self, regroup: bool = False) -> Path: num_segments_per_shard=1, hash_in_mem=2, mine_num_processes=2, - # lang_whitelist=["de", "it", "fr"], lang_whitelist=["en"], - lm_dir=Path(DBFS_ROOT_PATH + "data/lm_sp"), - lm_id_path=Path(DBFS_ROOT_PATH + "bin/lid.bin"), - cutoff=Path(CUTOFF_CSV), target_size="320M", cleanup_after_regroup=False, - cache_dir=Path(DBFS_ROOT_PATH + "test_data/wet_cache"), + cache_dir=Path("test_data/wet_cache"), task_parallelism=4, ) @@ -403,6 +407,16 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s hashes_in_mem = shard hashes = hashes[: HASHES_IN_MEM[hashes_in_mem]] shard = 0 + + # ensure all needed model is ready locally: + # if (conf.execution == 'spark'): + # if (conf.dbfs_lm_id_path.startswith('/dbfs')): + # conf.download_dbfs_file_to_local(conf.dbfs_lm_id_path, conf.lm_id_path) + # if (conf.dbfs_lm_dir.startswith('/dbfs')): + # for l in conf.get_lm_languages(): + # conf.download_dbfs_file_to_local(conf.dbfs_lm_dir / f"{l}.sp.model", conf.lm_dir / f"{l}.sp.model") + # conf.download_dbfs_file_to_local(conf.dbfs_lm_dir / f"{l}.arpa.bin", conf.lm_dir / f"{l}.arpa.bin") + cc_shard = conf.get_cc_shard(shard) steps: Dict[str, Optional[jsonql.Transformer]] = {} diff --git a/initial_script.sh b/initial_script.sh new file mode 100644 index 0000000..a6a1efc --- /dev/null +++ b/initial_script.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +set -e + +if [[ -d "/root/lm_sp" ]]; then + echo "Folder /root/lm_sp already exists." +else + echo "Folder /root/lm_sp does not exist." + mkdir /root/lm_sp +fi + +cp -n /dbfs/data-mle/llm/cc_net/lm_sp/* /root/lm_sp/ + +echo "Done copy models" \ No newline at end of file diff --git a/setup.py b/setup.py index 114934d..c7c91e7 100644 --- a/setup.py +++ b/setup.py @@ -52,5 +52,5 @@ # Full version is at https://github.com/atom-moyer/getpy "getpy": ["getpy @ git+https://github.com/gwenzek/getpy.git@v0.9.10-subset"], }, - package_data={"cc_net": ["data/*"]}, + package_data={"cc_net": ["data/*", "bin/*"]}, ) From e2aaaa81aaaac4fab876d8ac9a36bae730fe2cdb Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Wed, 29 Mar 2023 00:20:19 -0700 Subject: [PATCH 09/17] workaround the dbfs mount failure issue --- cc_net/mine.py | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/cc_net/mine.py b/cc_net/mine.py index 7a6a361..d63cfa2 100644 --- a/cc_net/mine.py +++ b/cc_net/mine.py @@ -14,6 +14,7 @@ import json import time import warnings +import traceback from argparse import ArgumentParser from collections import defaultdict from itertools import repeat @@ -101,8 +102,8 @@ class Config(NamedTuple): pipeline: Sequence[str] = DEFAULT_PIPELINE experiments: Sequence[str] = [] cache_dir: Optional[Path] = None - dbfs_lm_dir: str = "" - dbfs_lm_id_path: str = "" + # dbfs_lm_dir: str = "" + # dbfs_lm_id_path: str = "" def get_executor( self, name: str, timeout_hour: int = 1, mem_gb: int = 1, cpus: int = 1 @@ -168,18 +169,18 @@ def get_mined_dir(self, regroup: bool = False) -> Path: return self.output_dir / f"{self.mined_dir}_split" / self.dump return self.output_dir / self.mined_dir / self.dump - def download_dbfs_file_to_local(self, source: Path, target: Path, overwrite: bool = False) -> bool: - if (not overwrite and target.exists()): - return True + # def download_dbfs_file_to_local(self, source: Path, target: Path, overwrite: bool = False) -> bool: + # if (not overwrite and target.exists()): + # return True - if (not source.startswith("/dbfs")): - print(f"==can not download file from dbfs to local as source path does not start from /dbfs: {str(source)}") - return False + # if (not source.startswith("/dbfs")): + # print(f"==can not download file from dbfs to local as source path does not start from /dbfs: {str(source)}") + # return False - from databricks.sdk.runtime import dbutils - copied = dbutils.fs.cp(source.replace('/dbfs','dbfs:'), "file:" + str(target.absolute())) - print(f"==copied: {copied}, source: {str(source)}, target:{str(target)}") - return copied + # from databricks.sdk.runtime import dbutils + # copied = dbutils.fs.cp(source.replace('/dbfs','dbfs:'), "file:" + str(target.absolute())) + # print(f"==copied: {copied}, source: {str(source)}, target:{str(target)}") + # return copied BASE_CONFIG = Config() @@ -221,9 +222,9 @@ def download_dbfs_file_to_local(self, source: Path, target: Path, overwrite: boo hash_in_mem=2, mine_num_processes=2, # lang_whitelist=["de", "it", "fr"], - lang_whitelist=["en"], + lang_whitelist=["de"], target_size="320M", - cleanup_after_regroup=True, + cleanup_after_regroup=False, cache_dir=Path("test_data2_wet_cache"), task_parallelism=4, ) @@ -400,6 +401,18 @@ def _get_segment(tmp_output: Path, doc: dict) -> str: def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> str: + print(f"==calling_mine_shard, with shard: {shard}, output: {output}, hashes: {hashes}") + + # Workaround DBFS super unstable /dbfs mnt issue. + # check if the worknode can access the hashes, if not just skip + # then we can keep minding shards instead of error out the whole cluster for single failure + if conf.execution == 'spark': + try: + can_access = all(h.exists() for h in hashes) + except Exception as ex: + print(f"==Failed to access hashes! error type:{type(ex).__name__}, details: {traceback.format_exc()}") + return "skipped" + assert conf.pipeline tmp_output = tmp(output) if "hashes" in conf.experiments: From 41e7bde974c886a145769817290243efd4cfd6b3 Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Wed, 29 Mar 2023 00:21:56 -0700 Subject: [PATCH 10/17] fix format --- cc_net/mine.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/cc_net/mine.py b/cc_net/mine.py index d63cfa2..87534f8 100644 --- a/cc_net/mine.py +++ b/cc_net/mine.py @@ -13,8 +13,8 @@ import hashlib import json import time -import warnings import traceback +import warnings from argparse import ArgumentParser from collections import defaultdict from itertools import repeat @@ -172,11 +172,11 @@ def get_mined_dir(self, regroup: bool = False) -> Path: # def download_dbfs_file_to_local(self, source: Path, target: Path, overwrite: bool = False) -> bool: # if (not overwrite and target.exists()): # return True - + # if (not source.startswith("/dbfs")): # print(f"==can not download file from dbfs to local as source path does not start from /dbfs: {str(source)}") # return False - + # from databricks.sdk.runtime import dbutils # copied = dbutils.fs.cp(source.replace('/dbfs','dbfs:'), "file:" + str(target.absolute())) # print(f"==copied: {copied}, source: {str(source)}, target:{str(target)}") @@ -401,16 +401,20 @@ def _get_segment(tmp_output: Path, doc: dict) -> str: def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> str: - print(f"==calling_mine_shard, with shard: {shard}, output: {output}, hashes: {hashes}") + print( + f"==calling_mine_shard, with shard: {shard}, output: {output}, hashes: {hashes}" + ) # Workaround DBFS super unstable /dbfs mnt issue. - # check if the worknode can access the hashes, if not just skip + # check if the worknode can access the hashes, if not just skip # then we can keep minding shards instead of error out the whole cluster for single failure - if conf.execution == 'spark': + if conf.execution == "spark": try: can_access = all(h.exists() for h in hashes) except Exception as ex: - print(f"==Failed to access hashes! error type:{type(ex).__name__}, details: {traceback.format_exc()}") + print( + f"==Failed to access hashes! error type:{type(ex).__name__}, details: {traceback.format_exc()}" + ) return "skipped" assert conf.pipeline @@ -420,7 +424,7 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s hashes_in_mem = shard hashes = hashes[: HASHES_IN_MEM[hashes_in_mem]] shard = 0 - + # ensure all needed model is ready locally: # if (conf.execution == 'spark'): # if (conf.dbfs_lm_id_path.startswith('/dbfs')): From 708795eb6539f4f3d4431fb57d4e59d81a96734f Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Wed, 29 Mar 2023 10:41:53 -0700 Subject: [PATCH 11/17] adding more trace inside worker node --- cc_net/mine.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/cc_net/mine.py b/cc_net/mine.py index 87534f8..ecb149d 100644 --- a/cc_net/mine.py +++ b/cc_net/mine.py @@ -310,6 +310,7 @@ def hashes(conf: Config) -> List[Path]: # overhead due to how the dynamic allocation works. # print(f"==missing_outputs num {missing_outputs}, transpose out: {_transpose(missing_outputs)}") ex = conf.get_executor(f"hashes_{conf.dump}", mem_gb=4, timeout_hour=6, cpus=2) + print(f"==calling _hashes_shard to continue with {len(missing_outputs)} missing output") ex(_hashes_shard, repeat(conf), *_transpose(missing_outputs)) # Wait a bit so that files appears on the disk. @@ -388,7 +389,7 @@ def mine(conf: Config) -> List[Path]: timeout_hour=timeout_hour, cpus=conf.mine_num_processes + 1, ) - + print(f"==calling _mine_shard to continue with {len(missing_outputs)} missing output") ex(_mine_shard, repeat(conf), hashes_files, *_transpose(missing_outputs)) assert all(o.exists() for o in outputs) @@ -416,7 +417,10 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s f"==Failed to access hashes! error type:{type(ex).__name__}, details: {traceback.format_exc()}" ) return "skipped" - + + print( + f"==continue _mine_shard, with shard: {shard}, output: {output}" + ) assert conf.pipeline tmp_output = tmp(output) if "hashes" in conf.experiments: @@ -516,7 +520,7 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s remainsteps.append(s) pipeline.append(steps[s]) - print(f"==steps: {remainsteps}") + print(f"==remaining steps: {remainsteps}") jsonql.run_pipes( *pipeline, From 56f477845941d793c0c14a4642ab2a61e338e7d0 Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Wed, 29 Mar 2023 10:43:08 -0700 Subject: [PATCH 12/17] format code --- cc_net/mine.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/cc_net/mine.py b/cc_net/mine.py index ecb149d..b811f83 100644 --- a/cc_net/mine.py +++ b/cc_net/mine.py @@ -310,7 +310,9 @@ def hashes(conf: Config) -> List[Path]: # overhead due to how the dynamic allocation works. # print(f"==missing_outputs num {missing_outputs}, transpose out: {_transpose(missing_outputs)}") ex = conf.get_executor(f"hashes_{conf.dump}", mem_gb=4, timeout_hour=6, cpus=2) - print(f"==calling _hashes_shard to continue with {len(missing_outputs)} missing output") + print( + f"==calling _hashes_shard to continue with {len(missing_outputs)} missing output" + ) ex(_hashes_shard, repeat(conf), *_transpose(missing_outputs)) # Wait a bit so that files appears on the disk. @@ -389,7 +391,9 @@ def mine(conf: Config) -> List[Path]: timeout_hour=timeout_hour, cpus=conf.mine_num_processes + 1, ) - print(f"==calling _mine_shard to continue with {len(missing_outputs)} missing output") + print( + f"==calling _mine_shard to continue with {len(missing_outputs)} missing output" + ) ex(_mine_shard, repeat(conf), hashes_files, *_transpose(missing_outputs)) assert all(o.exists() for o in outputs) @@ -417,10 +421,8 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s f"==Failed to access hashes! error type:{type(ex).__name__}, details: {traceback.format_exc()}" ) return "skipped" - - print( - f"==continue _mine_shard, with shard: {shard}, output: {output}" - ) + + print(f"==continue _mine_shard, with shard: {shard}, output: {output}") assert conf.pipeline tmp_output = tmp(output) if "hashes" in conf.experiments: From 1b64dcaa7694476d462838909682f55215aff6ad Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Wed, 29 Mar 2023 12:04:41 -0700 Subject: [PATCH 13/17] add random sleep to avoid race condition in /dbfs --- cc_net/mine.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/cc_net/mine.py b/cc_net/mine.py index b811f83..14909ac 100644 --- a/cc_net/mine.py +++ b/cc_net/mine.py @@ -12,6 +12,7 @@ """ import hashlib import json +import random import time import traceback import warnings @@ -413,14 +414,24 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s # Workaround DBFS super unstable /dbfs mnt issue. # check if the worknode can access the hashes, if not just skip # then we can keep minding shards instead of error out the whole cluster for single failure - if conf.execution == "spark": - try: - can_access = all(h.exists() for h in hashes) - except Exception as ex: - print( - f"==Failed to access hashes! error type:{type(ex).__name__}, details: {traceback.format_exc()}" - ) - return "skipped" + if conf.execution == "spark" and conf.num_shards > 10: + retried = 0 + while True: + try: + retried += 1 + random_float = random.uniform(0.01, float(conf.num_shards) / 100.0) + print( + f"==sleep for {random_float} seconds in _mine_shard, tried: {retried}" + ) + time.sleep(random_float) + can_access = all(h.exists() for h in hashes) + break + except Exception as ex: + print( + f"==Failed to access hashes! error type:{type(ex).__name__}, details: {traceback.format_exc()}" + ) + if retried > 10: + return "skipped as too many failures" print(f"==continue _mine_shard, with shard: {shard}, output: {output}") assert conf.pipeline From 16c2f9e7d8f18037eccc66a3886501c7c87067a4 Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Wed, 29 Mar 2023 13:00:47 -0700 Subject: [PATCH 14/17] ignore spark worker node return values --- cc_net/execution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cc_net/execution.py b/cc_net/execution.py index 7c5327a..1b3b19b 100644 --- a/cc_net/execution.py +++ b/cc_net/execution.py @@ -113,8 +113,8 @@ def map_spark_array( assert len(args) > 0, f"No arguments passed to {f_name}" rdd = sc.parallelize(newargs, task_parallelism) - rdd = rdd.map(lambda p: function(*p)) - rdd.collect() + rdd = rdd.foreach(lambda p: function(*p)) + # rdd.collect() def map_array_and_wait( From 38cd0416851572bfb842315ba0bad229eddc206e Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Wed, 29 Mar 2023 16:36:24 -0700 Subject: [PATCH 15/17] add getpy package to improve memory usage in hahsing and deduplication --- cc_net/jsonql.py | 2 +- cc_net/mine.py | 4 ++-- initial_script.sh | 7 ++++++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/cc_net/jsonql.py b/cc_net/jsonql.py index 62a7b17..5f9df9b 100644 --- a/cc_net/jsonql.py +++ b/cc_net/jsonql.py @@ -1086,7 +1086,7 @@ def close(self): _session = functools.lru_cache()(requests.Session) -def request_get_content(url: str, n_retry: int = 3) -> bytes: +def request_get_content(url: str, n_retry: int = 5) -> bytes: """Retrieve the binary content at url. Retry on connection errors. diff --git a/cc_net/mine.py b/cc_net/mine.py index 14909ac..9e582e6 100644 --- a/cc_net/mine.py +++ b/cc_net/mine.py @@ -215,7 +215,7 @@ def get_mined_dir(self, regroup: bool = False) -> Path: TEST_CONFIG = BASE_CONFIG._replace( config_name="test", - dump="2019-09", + dump="2020-10", output_dir=Path("test_data"), execution="spark", num_shards=3, @@ -223,7 +223,7 @@ def get_mined_dir(self, regroup: bool = False) -> Path: hash_in_mem=2, mine_num_processes=2, # lang_whitelist=["de", "it", "fr"], - lang_whitelist=["de"], + lang_whitelist=["en"], target_size="320M", cleanup_after_regroup=False, cache_dir=Path("test_data2_wet_cache"), diff --git a/initial_script.sh b/initial_script.sh index a6a1efc..282e3f7 100644 --- a/initial_script.sh +++ b/initial_script.sh @@ -9,6 +9,11 @@ else mkdir /root/lm_sp fi -cp -n /dbfs/data-mle/llm/cc_net/lm_sp/* /root/lm_sp/ +# cp -n /dbfs/data-mle/llm/cc_net/lm_sp/* /root/lm_sp/ +cp -n /dbfs/data-mle/llm/cc_net/lm_sp/en.arpa.bin /root/lm_sp/ +cp -n /dbfs/data-mle/llm/cc_net/lm_sp/en.sp.model /root/lm_sp/ + +pip install /dbfs/data-mle/llm/cc_net/dist/cc_net-1.0.0-py3-none-any.whl +pip install cc_net[getpy] echo "Done copy models" \ No newline at end of file From 7888ca33b7207a4246297db33db000840f5d26d6 Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Wed, 29 Mar 2023 18:23:54 -0700 Subject: [PATCH 16/17] increase spark retry for 10 --- cc_net/execution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cc_net/execution.py b/cc_net/execution.py index 1b3b19b..81b060e 100644 --- a/cc_net/execution.py +++ b/cc_net/execution.py @@ -56,7 +56,7 @@ def get_executor( execution_mode = "local" if execution_mode == "spark": - conf = SparkConf().setAppName(name) + conf = SparkConf().setAppName(name).set("spark.task.maxFailures", "10") # conf.set("spark.eventLog.enabled", "true") # Enable event logging # conf.set("spark.eventLog.dir", log_dir) # S sc = SparkContext.getOrCreate(conf) From f053e3e045be216460bf7fb2039cc62bbda86535 Mon Sep 17 00:00:00 2001 From: Jun Wan Date: Wed, 29 Mar 2023 21:35:27 -0700 Subject: [PATCH 17/17] fix the header parser bug for the new data format after 2020 --- cc_net/mine.py | 2 +- cc_net/process_wet_file.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/cc_net/mine.py b/cc_net/mine.py index 9e582e6..c667c19 100644 --- a/cc_net/mine.py +++ b/cc_net/mine.py @@ -215,7 +215,7 @@ def get_mined_dir(self, regroup: bool = False) -> Path: TEST_CONFIG = BASE_CONFIG._replace( config_name="test", - dump="2020-10", + dump="2023-06", output_dir=Path("test_data"), execution="spark", num_shards=3, diff --git a/cc_net/process_wet_file.py b/cc_net/process_wet_file.py index 9768a65..844e045 100644 --- a/cc_net/process_wet_file.py +++ b/cc_net/process_wet_file.py @@ -64,9 +64,11 @@ def parse_doc(headers: List[str], doc: List[str]) -> Optional[dict]: WARC-Record-ID: WARC-Refers-To: WARC-Block-Digest: sha1:S3DTWCONT2L6ORTGCY2KXEZ37LNBB7V2 + WARC-Identified-Content-Language: rus NOTE: this is newly added in late 2020 Content-Type: text/plain - Content-Length: 7743 + Content-Length: 6724 """ + if not headers or not doc: return None @@ -77,7 +79,7 @@ def parse_doc(headers: List[str], doc: List[str]) -> Optional[dict]: url = headers[2].split()[1] date = headers[3].split()[1] digest = headers[6].split()[1] - length = int(headers[8].split()[1]) + length = int(headers[len(headers) - 2].split()[1]) except Exception as e: logger.warning("Can't parse header:", e, headers, doc) return None