diff --git a/cc_net/execution.py b/cc_net/execution.py index d7664d2..81b060e 100644 --- a/cc_net/execution.py +++ b/cc_net/execution.py @@ -15,6 +15,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Sized import submitit +from pyspark import SparkConf, SparkContext from typing_extensions import Protocol @@ -47,17 +48,30 @@ def get_executor( options.update( {kv.split("=", 1)[0]: kv.split("=", 1)[1] for kv in execution.split(",")[1:]} ) - + print( + f"===get_executor name {name}, timeout_hour: {timeout_hour}, execution_mode: {execution_mode}, task_parallelism: {task_parallelism}" + ) if execution_mode == "mp": warnings.warn("Execution mode 'mp' is deprecated, use 'local'.") execution_mode = "local" + if execution_mode == "spark": + conf = SparkConf().setAppName(name).set("spark.task.maxFailures", "10") + # conf.set("spark.eventLog.enabled", "true") # Enable event logging + # conf.set("spark.eventLog.dir", log_dir) # S + sc = SparkContext.getOrCreate(conf) + # We are on slurm + if task_parallelism == -1: + task_parallelism = 500 + return functools.partial(map_spark_array, sc, task_parallelism) + cluster = None if execution_mode == "auto" else execution_mode # use submitit to detect which executor is available ex = submitit.AutoExecutor(log_dir, cluster=cluster) if ex.cluster == "local": # LocalExecutor doesn't respect task_parallelism + ex.update_parameters(name=name, timeout_min=int(timeout_hour * 60)) return functools.partial(custom_map_array, ex, task_parallelism) if ex.cluster == "debug": return debug_executor @@ -77,6 +91,32 @@ def get_executor( return functools.partial(map_array_and_wait, ex) +def get_spark_executor( + name: str, +): + sc = SparkContext.getOrCreate() + return sc + + +def map_spark_array( + sc: SparkContext, + task_parallelism: int, + function: Callable[..., str], + *args: Iterable, +): + f_name = function.__name__ + print(f"==calling spark for func: {f_name}, with arg's len {len(args)}") + + pfunc = lambda *p: p + newargs = list(map(pfunc, *args)) + + assert len(args) > 0, f"No arguments passed to {f_name}" + + rdd = sc.parallelize(newargs, task_parallelism) + rdd = rdd.foreach(lambda p: function(*p)) + # rdd.collect() + + def map_array_and_wait( ex: submitit.AutoExecutor, function: Callable[..., str], *args: Iterable ): @@ -85,7 +125,9 @@ def map_array_and_wait( assert len(args) > 0, f"No arguments passed to {f_name}" approx_length = _approx_length(*args) - print(f"Submitting {f_name} in a job array ({approx_length} jobs)") + print( + f"map_array_and_wait Submitting {f_name} in a job array ({approx_length} jobs)" + ) jobs = ex.map_array(function, *args) if not jobs: return @@ -157,7 +199,9 @@ def custom_map_array( if parallelism < 0: parallelism = os.cpu_count() or 0 assert parallelism >= 0, f"Can't run any jobs with task_parallelism={parallelism}" - print(f"Submitting {total} jobs for {f_name}, with task_parallelism={parallelism}") + print( + f"custom_map_array Submitting {total} jobs for {f_name}, with task_parallelism={parallelism}" + ) enqueued = 0 done = 0 running_jobs: List[submitit.Job] = [] diff --git a/cc_net/jsonql.py b/cc_net/jsonql.py index b5ab405..5f9df9b 100644 --- a/cc_net/jsonql.py +++ b/cc_net/jsonql.py @@ -42,6 +42,7 @@ Union, ) +import dill import numpy as np import psutil # type: ignore import requests @@ -409,6 +410,8 @@ def run_pipes( if expect_json and inputs is None: fns = (JsonReader(),) + fns transformers = [] + print(f"==get fns in run_pipes {fns}, count {len(fns)}") + for t in fns: if not isinstance(t, Transformer): break @@ -435,18 +438,19 @@ def run_pipes( else: p = multiprocessing.current_process() log(f"Will start {processes} processes from {p.name}, Pid: {p.pid}") - pool = stack.enter_context( - multiprocessing.Pool( - processes=processes, - initializer=_set_global_transformer, - initargs=(transform,), - ) + cp = multiprocessing.Pool( + processes=processes, + initializer=_set_global_transformer, + initargs=(transform,), ) + log("done with muti pool") + pool = stack.enter_context(cp) data = pool.imap_unordered( _global_transformer, data, chunksize=chunksize ) for fn in pipes: + print(f"==now handling: {fn}") if isinstance(fn, Transformer): data = fn.map(data) else: @@ -638,7 +642,7 @@ def _prepare(self): def do(self, doc: dict) -> Optional[dict]: assert self.clauses - if not doc or not all((c(doc) for c in self.clauses)): + if not doc or not all((dill.loads(c)(doc) for c in self.clauses)): return None self.n_selected += 1 return doc @@ -1019,7 +1023,7 @@ def open_write( def parse_size(size): - unit_map = {"B": 1, "K": 1024, "M": 1024 ** 2, "G": 1024 ** 3} + unit_map = {"B": 1, "K": 1024, "M": 1024**2, "G": 1024**3} unit = size[-1].upper() assert ( unit in unit_map @@ -1082,7 +1086,7 @@ def close(self): _session = functools.lru_cache()(requests.Session) -def request_get_content(url: str, n_retry: int = 3) -> bytes: +def request_get_content(url: str, n_retry: int = 5) -> bytes: """Retrieve the binary content at url. Retry on connection errors. @@ -1102,7 +1106,7 @@ def request_get_content(url: str, n_retry: int = 3) -> bytes: warnings.warn( f"Swallowed error {e} while downloading {url} ({i} out of {n_retry})" ) - time.sleep(10 * 2 ** i) + time.sleep(10 * 2**i) dl_time = time.time() - t0 dl_speed = len(r.content) / dl_time / 1024 logging.info( @@ -1115,7 +1119,11 @@ def open_remote_file(url: str, cache: Path = None) -> Iterable[str]: """Download the files at the given url to memory and opens it as a file. Assumes that the file is small, and fetch it when this function is called. """ - if cache and cache.exists(): + valid_cache = False + if cache and cache.exists() and os.path.getsize(cache) > 1: + valid_cache = True + + if cache and valid_cache: return open_read(cache) # TODO: open the remote file in streaming mode. @@ -1128,11 +1136,11 @@ def open_remote_file(url: str, cache: Path = None) -> Iterable[str]: else: f = io.TextIOWrapper(content) - if cache and not cache.exists(): + if cache and not valid_cache: # The file might have been created while downloading/writing. tmp_cache = _tmp(cache) tmp_cache.write_bytes(raw_bytes) - if not cache.exists(): + if not valid_cache: tmp_cache.replace(cache) else: tmp_cache.unlink() @@ -1148,7 +1156,7 @@ def sharded_file(file_pattern: Path, mode: str, max_size: str = "4G") -> MultiFi assert 0 < n < 8 assert "?" * n in name, f"The '?' need to be adjacents in {file_pattern}" assert "r" not in mode - files = (folder / name.replace("?" * n, f"%0{n}d" % i) for i in range(10 ** n)) + files = (folder / name.replace("?" * n, f"%0{n}d" % i) for i in range(10**n)) return MultiFile(files, mode, max_size) diff --git a/cc_net/mine.py b/cc_net/mine.py index 4b8df55..c667c19 100644 --- a/cc_net/mine.py +++ b/cc_net/mine.py @@ -10,10 +10,11 @@ The pipeline parameters are described in the `Config` class. """ - import hashlib import json +import random import time +import traceback import warnings from argparse import ArgumentParser from collections import defaultdict @@ -21,6 +22,7 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple +import dill import func_argparse # Local scripts @@ -32,8 +34,10 @@ # Constant FILE_DIR = Path(__file__).parent CUTOFF_CSV = FILE_DIR / "data" / "cutoff.csv" +LID_BIN = FILE_DIR / "bin" / "lid.bin" DEFAULT_PIPELINE = [ + "hash", "dedup", "lid", "keep_lang", @@ -57,12 +61,13 @@ class Config(NamedTuple): num_shards: number of shards to split the dump num_segments_per_shard: allow to download a small portion of CC (eg for tests) min_len: remove documents shorter than this (in chars) - hashes_in_mem: number of shards hashes to use for dedup + hash_in_mem: number of shards hashes to use for dedup lang_whitelist: only treat those languages lang_blacklist: ignore those languages lang_threshold: remove docs whose top language score is lower than this keep_bucket: keep only those perplexity bucket chose from (head, middle, tail, all) lm_dir: folder containing LMs + lm_id_path: path for the language identity bin lm_languages: only use LMs for the following languages cutoff: cutoff file to use for split in head/middle/tail mine_num_processes: number of processes to use for mining @@ -88,6 +93,7 @@ class Config(NamedTuple): lang_threshold: float = 0.5 keep_bucket: Sequence[str] = [] lm_dir: Path = Path("data/lm_sp") + lm_id_path: Path = LID_BIN cutoff: Path = CUTOFF_CSV lm_languages: Optional[Sequence[str]] = None mine_num_processes: int = 16 @@ -97,14 +103,18 @@ class Config(NamedTuple): pipeline: Sequence[str] = DEFAULT_PIPELINE experiments: Sequence[str] = [] cache_dir: Optional[Path] = None + # dbfs_lm_dir: str = "" + # dbfs_lm_id_path: str = "" def get_executor( self, name: str, timeout_hour: int = 1, mem_gb: int = 1, cpus: int = 1 ) -> Executor: name = "_".join((name, self.config_name, *self.experiments)) + log_dir = self.output_dir / "logs" + log_dir.mkdir(parents=True, exist_ok=True) return execution.get_executor( name, - self.output_dir / "logs", + log_dir, self.execution, timeout_hour=timeout_hour, mem_gb=mem_gb, @@ -160,6 +170,19 @@ def get_mined_dir(self, regroup: bool = False) -> Path: return self.output_dir / f"{self.mined_dir}_split" / self.dump return self.output_dir / self.mined_dir / self.dump + # def download_dbfs_file_to_local(self, source: Path, target: Path, overwrite: bool = False) -> bool: + # if (not overwrite and target.exists()): + # return True + + # if (not source.startswith("/dbfs")): + # print(f"==can not download file from dbfs to local as source path does not start from /dbfs: {str(source)}") + # return False + + # from databricks.sdk.runtime import dbutils + # copied = dbutils.fs.cp(source.replace('/dbfs','dbfs:'), "file:" + str(target.absolute())) + # print(f"==copied: {copied}, source: {str(source)}, target:{str(target)}") + # return copied + BASE_CONFIG = Config() @@ -192,25 +215,52 @@ def get_mined_dir(self, regroup: bool = False) -> Path: TEST_CONFIG = BASE_CONFIG._replace( config_name="test", - dump="2019-09", + dump="2023-06", output_dir=Path("test_data"), - execution="local", - num_shards=4, + execution="spark", + num_shards=3, + num_segments_per_shard=1, + hash_in_mem=2, + mine_num_processes=2, + # lang_whitelist=["de", "it", "fr"], + lang_whitelist=["en"], + target_size="320M", + cleanup_after_regroup=False, + cache_dir=Path("test_data2_wet_cache"), + task_parallelism=4, +) + +DBFS_ROOT_PATH = "/dbfs/tmp/cc_net/" + +TEST_SPARK_CONFIG = BASE_CONFIG._replace( + config_name="test_spark", + dump="2019-09", + output_dir=Path(DBFS_ROOT_PATH + "test_data"), + mined_dir=DBFS_ROOT_PATH + "test_data/mined", + execution="spark", + num_shards=3, num_segments_per_shard=1, hash_in_mem=2, mine_num_processes=2, - lang_whitelist=["de", "it", "fr"], - target_size="32M", + lang_whitelist=["en"], + target_size="320M", cleanup_after_regroup=False, cache_dir=Path("test_data/wet_cache"), + task_parallelism=4, ) PREDEF_CONFIGS = { "base": BASE_CONFIG, "by_lang": BYLANG_CONFIG, "test": TEST_CONFIG, + "test_spark": TEST_SPARK_CONFIG, "test_slurm": TEST_CONFIG._replace(execution="slurm,partition=dev"), - "debug": TEST_CONFIG._replace(config_name="debug", mine_num_processes=0), + "debug": TEST_CONFIG._replace( + config_name="debug", + mine_num_processes=0, + execution="local", + task_parallelism=-1, + ), "reproduce": REPRODUCE_CONFIG, "augment": BASE_CONFIG._replace( config_name="augment", dump="2019-13", lang_blacklist=["en"] @@ -259,7 +309,11 @@ def hashes(conf: Config) -> List[Path]: hashes_dir.mkdir(parents=True, exist_ok=True) # With FlatHashSet we need ~2Gb of RAM / shard, but we need to account for # overhead due to how the dynamic allocation works. + # print(f"==missing_outputs num {missing_outputs}, transpose out: {_transpose(missing_outputs)}") ex = conf.get_executor(f"hashes_{conf.dump}", mem_gb=4, timeout_hour=6, cpus=2) + print( + f"==calling _hashes_shard to continue with {len(missing_outputs)} missing output" + ) ex(_hashes_shard, repeat(conf), *_transpose(missing_outputs)) # Wait a bit so that files appears on the disk. @@ -270,6 +324,7 @@ def hashes(conf: Config) -> List[Path]: def _hashes_shard(conf: Config, shard: int, output: Path): tmp_output = tmp(output) + print(f"==running hash shard: {shard}, output: {output}, tmp {tmp_output}") jsonql.run_pipes( dedup.HashesCollector(field="raw_content", output=tmp_output), inputs=conf.get_cc_shard(shard), @@ -321,16 +376,8 @@ def mine(conf: Config) -> List[Path]: if not missing_outputs: return outputs - mined_dir.mkdir(parents=True, exist_ok=True) - ex = conf.get_executor( - f"mine_{conf.dump}", - mem_gb=mem_gb, - timeout_hour=timeout_hour, - cpus=conf.mine_num_processes + 1, - ) - # Compute hashes firsts. - if "dedup" in conf.pipeline: + if "hash" in conf.pipeline: hashes_groups = list(jsonql.grouper(hashes(conf), conf.hash_in_mem)) hashes_files: Iterable[List[Path]] = [ hashes_groups[shard // conf.hash_in_mem] for shard, o in missing_outputs @@ -338,6 +385,16 @@ def mine(conf: Config) -> List[Path]: else: hashes_files = repeat([]) + mined_dir.mkdir(parents=True, exist_ok=True) + ex = conf.get_executor( + f"mine_shard_{conf.dump}", + mem_gb=mem_gb, + timeout_hour=timeout_hour, + cpus=conf.mine_num_processes + 1, + ) + print( + f"==calling _mine_shard to continue with {len(missing_outputs)} missing output" + ) ex(_mine_shard, repeat(conf), hashes_files, *_transpose(missing_outputs)) assert all(o.exists() for o in outputs) @@ -350,6 +407,33 @@ def _get_segment(tmp_output: Path, doc: dict) -> str: def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> str: + print( + f"==calling_mine_shard, with shard: {shard}, output: {output}, hashes: {hashes}" + ) + + # Workaround DBFS super unstable /dbfs mnt issue. + # check if the worknode can access the hashes, if not just skip + # then we can keep minding shards instead of error out the whole cluster for single failure + if conf.execution == "spark" and conf.num_shards > 10: + retried = 0 + while True: + try: + retried += 1 + random_float = random.uniform(0.01, float(conf.num_shards) / 100.0) + print( + f"==sleep for {random_float} seconds in _mine_shard, tried: {retried}" + ) + time.sleep(random_float) + can_access = all(h.exists() for h in hashes) + break + except Exception as ex: + print( + f"==Failed to access hashes! error type:{type(ex).__name__}, details: {traceback.format_exc()}" + ) + if retried > 10: + return "skipped as too many failures" + + print(f"==continue _mine_shard, with shard: {shard}, output: {output}") assert conf.pipeline tmp_output = tmp(output) if "hashes" in conf.experiments: @@ -357,10 +441,20 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s hashes_in_mem = shard hashes = hashes[: HASHES_IN_MEM[hashes_in_mem]] shard = 0 + + # ensure all needed model is ready locally: + # if (conf.execution == 'spark'): + # if (conf.dbfs_lm_id_path.startswith('/dbfs')): + # conf.download_dbfs_file_to_local(conf.dbfs_lm_id_path, conf.lm_id_path) + # if (conf.dbfs_lm_dir.startswith('/dbfs')): + # for l in conf.get_lm_languages(): + # conf.download_dbfs_file_to_local(conf.dbfs_lm_dir / f"{l}.sp.model", conf.lm_dir / f"{l}.sp.model") + # conf.download_dbfs_file_to_local(conf.dbfs_lm_dir / f"{l}.arpa.bin", conf.lm_dir / f"{l}.arpa.bin") + cc_shard = conf.get_cc_shard(shard) steps: Dict[str, Optional[jsonql.Transformer]] = {} - lang_id = Path("bin") / "lid.bin" + lang_id = conf.lm_id_path steps["lid_before_dedup"] = split_by_lang.Classifier( model=lang_id, field="raw_content", out_field="lid_before_dedup", top=5 ) @@ -379,11 +473,15 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s if conf.lang_blacklist: steps["keep_lang"] = jsonql.where( - [lambda doc: doc.get("language") not in set(conf.lang_blacklist)] + [ + dill.dumps( + lambda doc: doc.get("language") not in set(conf.lang_blacklist) + ) + ] ) elif conf.lang_whitelist: steps["keep_lang"] = jsonql.where( - [lambda doc: doc.get("language") in set(conf.lang_whitelist)] + [dill.dumps(lambda doc: doc.get("language") in set(conf.lang_whitelist))] ) else: steps["keep_lang"] = None @@ -427,7 +525,15 @@ def _mine_shard(conf: Config, hashes: List[Path], shard: int, output: Path) -> s split_fn=lambda doc: _get_segment(tmp_output, doc), mkdir=True ) - pipeline = filter(None, (steps[s] for s in conf.pipeline)) + remainsteps = [] + pipeline = [] + + for s in conf.pipeline: + if s in steps.keys(): + remainsteps.append(s) + pipeline.append(steps[s]) + + print(f"==remaining steps: {remainsteps}") jsonql.run_pipes( *pipeline, @@ -643,6 +749,8 @@ def main(config: str = "base", **config_as_dict: Any) -> None: if conf.config_name == "test": _validate_test(conf, conf.get_mined_dir(regroup=True)) + print("==Completed all pipelines!") + if __name__ == "__main__": func_argparse.parse_and_call(get_main_parser()) diff --git a/cc_net/process_wet_file.py b/cc_net/process_wet_file.py index a870f6c..844e045 100644 --- a/cc_net/process_wet_file.py +++ b/cc_net/process_wet_file.py @@ -20,7 +20,8 @@ from cc_net import jsonql -WET_URL_ROOT = "https://commoncrawl.s3.amazonaws.com" +# WET_URL_ROOT = "https://commoncrawl.s3.amazonaws.com" +WET_URL_ROOT = "https://data.commoncrawl.org" logger = logging.getLogger(__name__) @@ -63,9 +64,11 @@ def parse_doc(headers: List[str], doc: List[str]) -> Optional[dict]: WARC-Record-ID: WARC-Refers-To: WARC-Block-Digest: sha1:S3DTWCONT2L6ORTGCY2KXEZ37LNBB7V2 + WARC-Identified-Content-Language: rus NOTE: this is newly added in late 2020 Content-Type: text/plain - Content-Length: 7743 + Content-Length: 6724 """ + if not headers or not doc: return None @@ -76,7 +79,7 @@ def parse_doc(headers: List[str], doc: List[str]) -> Optional[dict]: url = headers[2].split()[1] date = headers[3].split()[1] digest = headers[6].split()[1] - length = int(headers[8].split()[1]) + length = int(headers[len(headers) - 2].split()[1]) except Exception as e: logger.warning("Can't parse header:", e, headers, doc) return None @@ -242,6 +245,7 @@ def segments(self) -> Sequence[str]: return self._segments segments = cc_segments(self.dump, self.cache_dir) n = len(segments) + print(f"==total segments: {n}, called from shard: {self.shard}") if self.num_shards < 0: self.num_shards = n // self.num_segments_per_shard i_min = (self.shard * n) // self.num_shards diff --git a/cc_net/regroup.py b/cc_net/regroup.py index 575baee..d7f2084 100644 --- a/cc_net/regroup.py +++ b/cc_net/regroup.py @@ -105,7 +105,7 @@ def fast_reshard( def determine_groups( - inputs: List[Path], target_size: int = 4 * 1024 ** 3 + inputs: List[Path], target_size: int = 4 * 1024**3 ) -> List[List[Path]]: if len(inputs) == 0: return [] diff --git a/cc_net/tools/expand_corpus.py b/cc_net/tools/expand_corpus.py index 44f3dce..46d16bc 100644 --- a/cc_net/tools/expand_corpus.py +++ b/cc_net/tools/expand_corpus.py @@ -26,7 +26,7 @@ KENLM = Path("./bin/lmplz") KENLM_BUILD = Path("./bin/build_binary") -VOCAB_SIZE = 2 ** 16 - 10 +VOCAB_SIZE = 2**16 - 10 PROCESSES = 16 diff --git a/config/test_spark.json b/config/test_spark.json new file mode 100644 index 0000000..9eca2f7 --- /dev/null +++ b/config/test_spark.json @@ -0,0 +1,15 @@ +{ + "config_name":"test_spark", + "dump":"2019-09", + "output_dir":"test_data", + "execution":"spark", + "num_shards":3, + "num_segments_per_shard":1, + "hash_in_mem":2, + "mine_num_processes":2, + "lang_whitelist":["en"], + "target_size":"320M", + "cleanup_after_regroup":true, + "cache_dir":"test_data/wet_cache", + "task_parallelism":3 +} diff --git a/initial_script.sh b/initial_script.sh new file mode 100644 index 0000000..282e3f7 --- /dev/null +++ b/initial_script.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +set -e + +if [[ -d "/root/lm_sp" ]]; then + echo "Folder /root/lm_sp already exists." +else + echo "Folder /root/lm_sp does not exist." + mkdir /root/lm_sp +fi + +# cp -n /dbfs/data-mle/llm/cc_net/lm_sp/* /root/lm_sp/ +cp -n /dbfs/data-mle/llm/cc_net/lm_sp/en.arpa.bin /root/lm_sp/ +cp -n /dbfs/data-mle/llm/cc_net/lm_sp/en.sp.model /root/lm_sp/ + +pip install /dbfs/data-mle/llm/cc_net/dist/cc_net-1.0.0-py3-none-any.whl +pip install cc_net[getpy] + +echo "Done copy models" \ No newline at end of file diff --git a/setup.py b/setup.py index 5ec5418..c7c91e7 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ "sacremoses", "submitit>=1.0.0", "typing_extensions", + "dill", ], extras_require={ "dev": ["mypy==0.790", "pytest", "black==19.3b0", "isort==5.6.4"], @@ -51,5 +52,5 @@ # Full version is at https://github.com/atom-moyer/getpy "getpy": ["getpy @ git+https://github.com/gwenzek/getpy.git@v0.9.10-subset"], }, - package_data={"cc_net": ["data/*"]}, + package_data={"cc_net": ["data/*", "bin/*"]}, ) diff --git a/tests/test_flat_hash_set.py b/tests/test_flat_hash_set.py index 00287df..0848fe2 100644 --- a/tests/test_flat_hash_set.py +++ b/tests/test_flat_hash_set.py @@ -69,7 +69,7 @@ def check_reload(h, dump, load, tmp_path): @pytest.mark.parametrize("hash_set_cls", [FlatHashSet, NaiveHashSet]) def test_loading(tmp_path, hash_set_cls): h = hash_set_cls() - x = np.random.randint(0, 2 ** 32, (100,), dtype=h.dtype) + x = np.random.randint(0, 2**32, (100,), dtype=h.dtype) h.add(x) check_reload(h, hash_set_cls.dump, hash_set_cls.load, tmp_path)