Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.

Update CC_net code to make it can be run in Spark cluster #43

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions cc_net/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Sized

import submitit
from pyspark import SparkConf, SparkContext
from typing_extensions import Protocol


Expand Down Expand Up @@ -47,17 +48,30 @@ def get_executor(
options.update(
{kv.split("=", 1)[0]: kv.split("=", 1)[1] for kv in execution.split(",")[1:]}
)

print(
f"===get_executor name {name}, timeout_hour: {timeout_hour}, execution_mode: {execution_mode}, task_parallelism: {task_parallelism}"
)
if execution_mode == "mp":
warnings.warn("Execution mode 'mp' is deprecated, use 'local'.")
execution_mode = "local"

if execution_mode == "spark":
conf = SparkConf().setAppName(name).set("spark.task.maxFailures", "10")
# conf.set("spark.eventLog.enabled", "true") # Enable event logging
# conf.set("spark.eventLog.dir", log_dir) # S
sc = SparkContext.getOrCreate(conf)
# We are on slurm
if task_parallelism == -1:
task_parallelism = 500
return functools.partial(map_spark_array, sc, task_parallelism)

cluster = None if execution_mode == "auto" else execution_mode
# use submitit to detect which executor is available
ex = submitit.AutoExecutor(log_dir, cluster=cluster)

if ex.cluster == "local":
# LocalExecutor doesn't respect task_parallelism
ex.update_parameters(name=name, timeout_min=int(timeout_hour * 60))
return functools.partial(custom_map_array, ex, task_parallelism)
if ex.cluster == "debug":
return debug_executor
Expand All @@ -77,6 +91,32 @@ def get_executor(
return functools.partial(map_array_and_wait, ex)


def get_spark_executor(
name: str,
):
sc = SparkContext.getOrCreate()
return sc


def map_spark_array(
sc: SparkContext,
task_parallelism: int,
function: Callable[..., str],
*args: Iterable,
):
f_name = function.__name__
print(f"==calling spark for func: {f_name}, with arg's len {len(args)}")

pfunc = lambda *p: p
newargs = list(map(pfunc, *args))

assert len(args) > 0, f"No arguments passed to {f_name}"

rdd = sc.parallelize(newargs, task_parallelism)
rdd = rdd.foreach(lambda p: function(*p))
# rdd.collect()


def map_array_and_wait(
ex: submitit.AutoExecutor, function: Callable[..., str], *args: Iterable
):
Expand All @@ -85,7 +125,9 @@ def map_array_and_wait(
assert len(args) > 0, f"No arguments passed to {f_name}"
approx_length = _approx_length(*args)

print(f"Submitting {f_name} in a job array ({approx_length} jobs)")
print(
f"map_array_and_wait Submitting {f_name} in a job array ({approx_length} jobs)"
)
jobs = ex.map_array(function, *args)
if not jobs:
return
Expand Down Expand Up @@ -157,7 +199,9 @@ def custom_map_array(
if parallelism < 0:
parallelism = os.cpu_count() or 0
assert parallelism >= 0, f"Can't run any jobs with task_parallelism={parallelism}"
print(f"Submitting {total} jobs for {f_name}, with task_parallelism={parallelism}")
print(
f"custom_map_array Submitting {total} jobs for {f_name}, with task_parallelism={parallelism}"
)
enqueued = 0
done = 0
running_jobs: List[submitit.Job] = []
Expand Down
36 changes: 22 additions & 14 deletions cc_net/jsonql.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
Union,
)

import dill
import numpy as np
import psutil # type: ignore
import requests
Expand Down Expand Up @@ -409,6 +410,8 @@ def run_pipes(
if expect_json and inputs is None:
fns = (JsonReader(),) + fns
transformers = []
print(f"==get fns in run_pipes {fns}, count {len(fns)}")

for t in fns:
if not isinstance(t, Transformer):
break
Expand All @@ -435,18 +438,19 @@ def run_pipes(
else:
p = multiprocessing.current_process()
log(f"Will start {processes} processes from {p.name}, Pid: {p.pid}")
pool = stack.enter_context(
multiprocessing.Pool(
processes=processes,
initializer=_set_global_transformer,
initargs=(transform,),
)
cp = multiprocessing.Pool(
processes=processes,
initializer=_set_global_transformer,
initargs=(transform,),
)
log("done with muti pool")
pool = stack.enter_context(cp)
data = pool.imap_unordered(
_global_transformer, data, chunksize=chunksize
)

for fn in pipes:
print(f"==now handling: {fn}")
if isinstance(fn, Transformer):
data = fn.map(data)
else:
Expand Down Expand Up @@ -638,7 +642,7 @@ def _prepare(self):

def do(self, doc: dict) -> Optional[dict]:
assert self.clauses
if not doc or not all((c(doc) for c in self.clauses)):
if not doc or not all((dill.loads(c)(doc) for c in self.clauses)):
return None
self.n_selected += 1
return doc
Expand Down Expand Up @@ -1019,7 +1023,7 @@ def open_write(


def parse_size(size):
unit_map = {"B": 1, "K": 1024, "M": 1024 ** 2, "G": 1024 ** 3}
unit_map = {"B": 1, "K": 1024, "M": 1024**2, "G": 1024**3}
unit = size[-1].upper()
assert (
unit in unit_map
Expand Down Expand Up @@ -1082,7 +1086,7 @@ def close(self):
_session = functools.lru_cache()(requests.Session)


def request_get_content(url: str, n_retry: int = 3) -> bytes:
def request_get_content(url: str, n_retry: int = 5) -> bytes:
"""Retrieve the binary content at url.

Retry on connection errors.
Expand All @@ -1102,7 +1106,7 @@ def request_get_content(url: str, n_retry: int = 3) -> bytes:
warnings.warn(
f"Swallowed error {e} while downloading {url} ({i} out of {n_retry})"
)
time.sleep(10 * 2 ** i)
time.sleep(10 * 2**i)
dl_time = time.time() - t0
dl_speed = len(r.content) / dl_time / 1024
logging.info(
Expand All @@ -1115,7 +1119,11 @@ def open_remote_file(url: str, cache: Path = None) -> Iterable[str]:
"""Download the files at the given url to memory and opens it as a file.
Assumes that the file is small, and fetch it when this function is called.
"""
if cache and cache.exists():
valid_cache = False
if cache and cache.exists() and os.path.getsize(cache) > 1:
valid_cache = True

if cache and valid_cache:
return open_read(cache)

# TODO: open the remote file in streaming mode.
Expand All @@ -1128,11 +1136,11 @@ def open_remote_file(url: str, cache: Path = None) -> Iterable[str]:
else:
f = io.TextIOWrapper(content)

if cache and not cache.exists():
if cache and not valid_cache:
# The file might have been created while downloading/writing.
tmp_cache = _tmp(cache)
tmp_cache.write_bytes(raw_bytes)
if not cache.exists():
if not valid_cache:
tmp_cache.replace(cache)
else:
tmp_cache.unlink()
Expand All @@ -1148,7 +1156,7 @@ def sharded_file(file_pattern: Path, mode: str, max_size: str = "4G") -> MultiFi
assert 0 < n < 8
assert "?" * n in name, f"The '?' need to be adjacents in {file_pattern}"
assert "r" not in mode
files = (folder / name.replace("?" * n, f"%0{n}d" % i) for i in range(10 ** n))
files = (folder / name.replace("?" * n, f"%0{n}d" % i) for i in range(10**n))

return MultiFile(files, mode, max_size)

Expand Down
Loading