Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions nemo_curator/stages/deduplication/semantic/identify_duplicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,22 @@ def process_batch(self, tasks: list[FileGroupTask]) -> list[FileGroupTask]:

all_files = [file for task in tasks for file in task.data]
# Read using filters
df: pd.DataFrame = pd.read_parquet(
all_files,
storage_options=self.input_storage_options,
**self.read_kwargs,
filters=[("cosine_sim_score", ">=", 1.0 - self.eps)],
engine="pyarrow",
)[["id"]] # TODO: If we want we can add other columns
# We read file by file since list[files] when files are remote urls can fail
# See https://github.com/pandas-dev/pandas/issues/62922
df: pd.DataFrame = pd.concat(
[
pd.read_parquet(
f,
storage_options=self.input_storage_options,
**self.read_kwargs,
filters=[("cosine_sim_score", ">=", 1.0 - self.eps)],
columns=["id"],
engine="pyarrow",
)
for f in all_files
],
ignore_index=True,
)
# Write out sorted and with multiple row groups
df.sort_values("id", inplace=True) # noqa: PD002

Expand Down
5 changes: 4 additions & 1 deletion nemo_curator/stages/deduplication/semantic/pairwise_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nemo_curator.stages.base import ProcessingStage
from nemo_curator.stages.resources import Resources
from nemo_curator.tasks import FileGroupTask, _EmptyTask
from nemo_curator.utils.client_utils import is_remote_url
from nemo_curator.utils.file_utils import get_all_file_paths_under, get_fs, infer_dataset_name_from_path

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
self._name = "pairwise_file_partitioning"
self._resources = Resources(cpus=0.5)
self.fs: AbstractFileSystem | None = None
self.path_normalizer = lambda x: x

def inputs(self) -> tuple[list[str], list[str]]:
return ["data"], []
Expand All @@ -61,6 +63,7 @@ def outputs(self) -> tuple[list[str], list[str]]:

def setup(self, _: WorkerMetadata | None = None) -> None:
self.fs = get_fs(self.input_path, storage_options=self.storage_options)
self.path_normalizer = self.fs.unstrip_protocol if is_remote_url(self.input_path) else (lambda x: x)

def ray_stage_spec(self) -> dict[str, Any]:
"""Ray stage specification for this stage."""
Expand All @@ -83,7 +86,7 @@ def process(self, _: _EmptyTask) -> list[FileGroupTask]:
# Extract centroid ID from directory name (e.g., "centroid=0" -> 0)
if "centroid=" in entry:
centroid_id = int(entry.split("centroid=")[-1])
centroid_dirs[centroid_id] = entry
centroid_dirs[centroid_id] = self.path_normalizer(entry)

logger.debug(
f"Found {len(centroid_dirs)} centroid directories e.g. {next(iter(centroid_dirs.values())) if centroid_dirs else None}"
Expand Down
4 changes: 3 additions & 1 deletion nemo_curator/stages/text/deduplication/removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,14 @@ def process(self, task: DocumentBatch) -> DocumentBatch:
input_df_min_max_time = time.perf_counter() - input_df_t0
# Filter the parquet files for IDs to remove within this range
read_dupes_t0 = time.perf_counter()

removal_df = pd.read_parquet(
self.ids_to_remove_path,
filters=[(self.duplicate_id_field, ">=", min_id), (self.duplicate_id_field, "<=", max_id)],
columns=[self.duplicate_id_field],
**self.read_kwargs,
**self.read_kwargs, # this might fail if filesystem exists in read_kwargs
)

read_dupes_time = time.perf_counter() - read_dupes_t0

# Filter out documents with IDs in the removal set using pandas
Expand Down
5 changes: 4 additions & 1 deletion nemo_curator/stages/text/io/reader/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ def read_data(
if "dtype_backend" not in read_kwargs:
update_kwargs["dtype_backend"] = "pyarrow"
read_kwargs.update(update_kwargs)
return pd.read_parquet(paths, **read_kwargs)
# We read file by file since list[files] when files are remote urls can fail
# See https://github.com/pandas-dev/pandas/issues/62922
# TODO: We can benchmark pq.read_table but it might have edge-cases with dtype_backend and long strings
return pd.concat([pd.read_parquet(path, **read_kwargs) for path in paths], ignore_index=True)


@dataclass
Expand Down
38 changes: 17 additions & 21 deletions nemo_curator/stages/text/io/writer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from dataclasses import dataclass, field
from typing import Any, Literal

import fsspec
from fsspec.utils import infer_storage_options
from fsspec.core import url_to_fs
from loguru import logger

import nemo_curator.stages.text.io.writer.utils as writer_utils
from nemo_curator.stages.base import ProcessingStage
from nemo_curator.tasks import DocumentBatch, FileGroupTask
from nemo_curator.utils.client_utils import is_remote_url
from nemo_curator.utils.file_utils import check_output_mode


Expand All @@ -41,25 +41,16 @@ class BaseWriter(ProcessingStage[DocumentBatch, FileGroupTask], ABC):
fields: list[str] | None = None
mode: Literal["ignore", "overwrite", "append", "error"] = "ignore"
_name: str = "BaseWriter"
_fs_path: str = field(init=False, repr=False, default="")
_protocol: str = field(init=False, repr=False, default="file")
_has_explicit_protocol: bool = field(init=False, repr=False, default=False)
append_mode_implemented: bool = False

def __post_init__(self):
# Determine protocol and normalized filesystem path
path_opts = infer_storage_options(self.path)
protocol = path_opts.get("protocol", "file")
self._protocol = protocol or "file"
# Track if the user provided an explicit URL-style protocol in the path
self._has_explicit_protocol = "://" in self.path
# Use the filesystem-native path (no protocol) for fs operations
self._fs_path = path_opts.get("path", self.path)

# Only pass user-provided storage options to fsspec
# Use fsspec's url_to_fs to get both filesystem and normalized path
self.storage_options = (self.write_kwargs or {}).get("storage_options", {})
self.fs = fsspec.filesystem(protocol, **self.storage_options)
self.fs, self._fs_path = url_to_fs(self.path, **self.storage_options)
check_output_mode(self.mode, self.fs, self._fs_path, append_mode_implemented=self.append_mode_implemented)
logger.info(
f"Initialized writer for {self.path} with filesystem {self.fs} and storage_options {self.storage_options}"
)

def inputs(self) -> tuple[list[str], list[str]]:
return ["data"], []
Expand Down Expand Up @@ -95,17 +86,22 @@ def process(self, task: DocumentBatch) -> FileGroupTask:
file_extension = self.get_file_extension()
file_path = self.fs.sep.join([self._fs_path, f"{filename}.{file_extension}"])

# For remote URLs, restore the protocol prefix so downstream code can infer the filesystem
file_path_with_protocol = self.fs.unstrip_protocol(file_path) if is_remote_url(self.path) else file_path

logger.info(f"Writing {task.num_items} records to {file_path_with_protocol} with filesystem {self.fs}")

if self.fs.exists(file_path):
logger.debug(f"File {file_path} already exists, overwriting it")
logger.debug(f"File {file_path_with_protocol} already exists, overwriting it")

self.write_data(task, file_path)
logger.debug(f"Written {task.num_items} records to {file_path}")
self.write_data(task, file_path_with_protocol)
logger.debug(f"Written {task.num_items} records to {file_path_with_protocol}")

# Create FileGroupTask with written files
# Create FileGroupTask with written files using the full protocol-prefixed path
return FileGroupTask(
task_id=task.task_id,
dataset_name=task.dataset_name,
data=[file_path],
data=[file_path_with_protocol],
_metadata={
**task._metadata,
"format": self.get_file_extension(),
Expand Down
2 changes: 1 addition & 1 deletion nemo_curator/stages/text/io/writer/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ def write_data(self, task: DocumentBatch, file_path: str) -> None:

# Add any additional kwargs, allowing them to override defaults
write_kwargs.update(self.write_kwargs)
df.to_json(file_path, **write_kwargs)
df.to_json(file_path, **write_kwargs) # TODO: test this for cloud path
2 changes: 1 addition & 1 deletion nemo_curator/stages/text/io/writer/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ def write_data(self, task: DocumentBatch, file_path: str) -> None:

# Add any additional kwargs, allowing them to override defaults
write_kwargs.update(self.write_kwargs)
df.to_parquet(file_path, **write_kwargs)
df.to_parquet(file_path, **write_kwargs) # TODO: debug this for cloud path