Skip to content

Commit 119b33a

Browse files
committed
Add parallel processing to variant generation
Refactored CSLBuilder to use ProcessPoolExecutor for parallel processing of diff file generation and variant building, improving performance on multi-core systems. Added static methods for single diff generation and processing, and introduced a --max-workers CLI argument to control parallelism.
1 parent ccd9cfb commit 119b33a

File tree

2 files changed

+186
-103
lines changed

2 files changed

+186
-103
lines changed

src/style_variant_builder/build.py

Lines changed: 178 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import subprocess
1010
import sys
1111
import tempfile
12+
from concurrent.futures import ProcessPoolExecutor, as_completed
1213
from contextlib import suppress
1314
from dataclasses import dataclass
1415
from itertools import chain
@@ -61,6 +62,8 @@ def filter(self, record: logging.LogRecord) -> bool:
6162

6263
@dataclass(slots=True)
6364
class CSLBuilder:
65+
"""Builder for CSL style variants with parallel processing support."""
66+
6467
templates_dir: Path
6568
diffs_dir: Path
6669
output_dir: Path
@@ -69,9 +72,133 @@ class CSLBuilder:
6972
export_development: bool = False
7073
generate_diffs: bool = False
7174
group_by_family: bool = True
75+
max_workers: int | None = None
7276
successful_variants: int = 0
7377
failed_variants: int = 0
7478

79+
@staticmethod
80+
def _generate_single_diff(
81+
dev_file: Path,
82+
template_lines: list[str],
83+
template_path: Path,
84+
diffs_dir: Path,
85+
) -> tuple[str, bool, str]:
86+
"""
87+
Generate a diff file for a single development file.
88+
89+
Returns: (filename, success, message)
90+
"""
91+
try:
92+
with dev_file.open("r", encoding="utf-8") as df:
93+
dev_lines = df.readlines()
94+
95+
diff = list(
96+
difflib.unified_diff(
97+
template_lines,
98+
dev_lines,
99+
fromfile=str(template_path),
100+
tofile=str(dev_file),
101+
lineterm="\n",
102+
)
103+
)
104+
105+
if not diff:
106+
return (
107+
dev_file.name,
108+
True,
109+
f"No differences found for {dev_file.name}.",
110+
)
111+
112+
diff_path = diffs_dir / dev_file.with_suffix(".diff").name
113+
diffs_dir.mkdir(parents=True, exist_ok=True)
114+
diff_path.write_text("".join(diff), encoding="utf-8")
115+
return (dev_file.name, True, f"Generated diff file: {diff_path}")
116+
117+
except Exception as e:
118+
return (dev_file.name, False, f"Error generating diff: {e}")
119+
120+
@staticmethod
121+
def _process_single_diff(
122+
diff_path: Path,
123+
template_path: Path,
124+
target_output_dir: Path,
125+
development_dir: Path | None,
126+
export_development: bool,
127+
) -> tuple[str, bool, str]:
128+
"""
129+
Process a single diff file in a worker process.
130+
131+
Returns: (diff_name, success, message)
132+
"""
133+
patched_file = None
134+
try:
135+
# Apply the patch
136+
if not shutil.which("patch"):
137+
return (
138+
diff_path.name,
139+
False,
140+
"Required command 'patch' not found in PATH.",
141+
)
142+
143+
with tempfile.NamedTemporaryFile(
144+
delete=False, mode="w+", encoding="utf-8"
145+
) as tmp_file:
146+
shutil.copy(template_path, tmp_file.name)
147+
tmp_file_path = Path(tmp_file.name)
148+
149+
patched_file = tmp_file_path
150+
result = subprocess.run(
151+
["patch", "-s", "-N", str(patched_file), str(diff_path)],
152+
stdout=subprocess.PIPE,
153+
stderr=subprocess.PIPE,
154+
text=True,
155+
)
156+
157+
if result.returncode != 0:
158+
return (
159+
diff_path.name,
160+
False,
161+
f"Failed to apply patch: {result.stdout}",
162+
)
163+
164+
# Export or prune
165+
if export_development and development_dir is not None:
166+
dev_variant = (
167+
development_dir / diff_path.with_suffix(".csl").name
168+
)
169+
shutil.copy(patched_file, dev_variant)
170+
return (
171+
diff_path.name,
172+
True,
173+
f"Exported development variant to {dev_variant}",
174+
)
175+
else:
176+
output_variant = (
177+
target_output_dir / diff_path.with_suffix(".csl").name
178+
)
179+
# Prune the variant
180+
pruner = CSLPruner(
181+
input_path=patched_file,
182+
output_path=output_variant,
183+
)
184+
pruner.parse_xml()
185+
pruner.flatten_layout_macros()
186+
pruner.prune_macros()
187+
pruner.save()
188+
return (
189+
diff_path.name,
190+
True,
191+
f"Generated variant: {output_variant}",
192+
)
193+
194+
except Exception as e:
195+
return (diff_path.name, False, f"Error processing diff: {e}")
196+
197+
finally:
198+
if patched_file is not None:
199+
with suppress(FileNotFoundError):
200+
patched_file.unlink()
201+
75202
def _get_template_path(self) -> Path:
76203
template = self.templates_dir / f"{self.style_family}-template.csl"
77204
if not template.exists():
@@ -80,11 +207,11 @@ def _get_template_path(self) -> Path:
80207

81208
def _get_diff_files(self) -> list[Path]:
82209
# Collect diff files that match the expected naming convention.
83-
filname_diffs = set(self.diffs_dir.glob(f"{self.style_family}*.diff"))
210+
filename_diffs = set(self.diffs_dir.glob(f"{self.style_family}*.diff"))
84211
reference_diffs = []
85212
# Also examine all diff files for an internal reference to the template.
86213
for diff_file in self.diffs_dir.glob("*.diff"):
87-
if diff_file in filname_diffs:
214+
if diff_file in filename_diffs:
88215
continue
89216
try:
90217
with diff_file.open("r", encoding="utf-8") as f:
@@ -100,52 +227,13 @@ def _get_diff_files(self) -> list[Path]:
100227
f"Error reading diff file {diff_file.name}: {e}",
101228
exc_info=True,
102229
)
103-
all_diffs = sorted(chain(filname_diffs, reference_diffs))
230+
all_diffs = sorted(chain(filename_diffs, reference_diffs))
104231
if not all_diffs:
105232
raise FileNotFoundError(
106233
f"No diff files found for style family '{self.style_family}' in {self.diffs_dir}"
107234
)
108235
return all_diffs
109236

110-
def _apply_patch(self, template_path: Path, diff_path: Path) -> Path | None:
111-
# Ensure the 'patch' command is available
112-
if not shutil.which("patch"):
113-
raise EnvironmentError(
114-
"Required command 'patch' not found in PATH."
115-
)
116-
with tempfile.NamedTemporaryFile(
117-
delete=False, mode="w+", encoding="utf-8"
118-
) as tmp_file:
119-
shutil.copy(template_path, tmp_file.name)
120-
tmp_file_path = Path(tmp_file.name)
121-
result = subprocess.run(
122-
["patch", "-s", "-N", str(tmp_file_path), str(diff_path)],
123-
stdout=subprocess.PIPE,
124-
stderr=subprocess.PIPE,
125-
text=True,
126-
)
127-
if result.returncode != 0:
128-
logging.error(
129-
f"Failed to apply patch {diff_path}:\n {result.stdout}"
130-
)
131-
tmp_file_path.unlink(missing_ok=True)
132-
return None # Fail gracefully by returning None
133-
return tmp_file_path
134-
135-
def _prune_variant(self, patched_file: Path, output_variant: Path) -> None:
136-
pruner = CSLPruner(patched_file, output_variant)
137-
pruner.parse_xml()
138-
# Flatten single-macro layouts before pruning so wrapper macros can be removed
139-
pruner.flatten_layout_macros()
140-
pruner.prune_macros()
141-
# Set notice comment to be inserted during save
142-
pruner.notice_comment = (
143-
"This file was generated by the Style Variant Builder "
144-
"<https://github.com/citation-style-language/style-variant-builder>. "
145-
"To contribute changes, modify the template and regenerate variants."
146-
)
147-
pruner.save()
148-
149237
def build_variants(self) -> tuple[int, int]:
150238
try:
151239
template_path = self._get_template_path()
@@ -157,6 +245,7 @@ def build_variants(self) -> tuple[int, int]:
157245
except FileNotFoundError as e:
158246
logging.warning(f"Skipping style family '{self.style_family}': {e}")
159247
return (0, 0)
248+
160249
# Prepare output directory (optionally group by family)
161250
target_output_dir = (
162251
self.output_dir / self.style_family
@@ -166,44 +255,29 @@ def build_variants(self) -> tuple[int, int]:
166255
target_output_dir.mkdir(parents=True, exist_ok=True)
167256
if self.export_development:
168257
self.development_dir.mkdir(parents=True, exist_ok=True)
169-
for diff_path in diff_files:
170-
patched_file = None
171-
try:
172-
logging.debug(f"Processing diff: {diff_path.name}")
173-
patched_file = self._apply_patch(template_path, diff_path)
174-
if patched_file is None:
175-
logging.error(
176-
f"Skipping diff {diff_path.name} due to patch failure."
177-
)
178-
self.failed_variants += 1
179-
continue
180-
if self.export_development:
181-
dev_variant = (
182-
self.development_dir
183-
/ diff_path.with_suffix(".csl").name
184-
)
185-
shutil.copy(patched_file, dev_variant)
186-
logging.info(
187-
f"Exported development variant to {dev_variant}"
188-
)
258+
259+
# Process diff files in parallel
260+
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
261+
futures = [
262+
executor.submit(
263+
CSLBuilder._process_single_diff,
264+
diff_path,
265+
template_path,
266+
target_output_dir,
267+
self.development_dir if self.export_development else None,
268+
self.export_development,
269+
)
270+
for diff_path in diff_files
271+
]
272+
273+
for future in as_completed(futures):
274+
diff_name, success, message = future.result()
275+
if success:
276+
logging.info(message)
189277
self.successful_variants += 1
190278
else:
191-
output_variant = (
192-
target_output_dir / diff_path.with_suffix(".csl").name
193-
)
194-
self._prune_variant(patched_file, output_variant)
195-
logging.info(f"Generated variant: {output_variant}")
196-
self.successful_variants += 1
197-
except Exception as e:
198-
logging.error(
199-
f"Error processing diff {diff_path.name}: {e}",
200-
exc_info=True,
201-
)
202-
self.failed_variants += 1
203-
finally:
204-
if patched_file is not None:
205-
with suppress(FileNotFoundError):
206-
patched_file.unlink()
279+
logging.error(f"Failed {diff_name}: {message}")
280+
self.failed_variants += 1
207281

208282
return (self.successful_variants, self.failed_variants)
209283

@@ -249,31 +323,25 @@ def generate_diff_files(self) -> None:
249323
with template_path.open("r", encoding="utf-8") as tf:
250324
template_lines = tf.readlines()
251325

252-
for dev_file in dev_files:
253-
try:
254-
with dev_file.open("r", encoding="utf-8") as df:
255-
dev_lines = df.readlines()
256-
diff = list(
257-
difflib.unified_diff(
258-
template_lines,
259-
dev_lines,
260-
fromfile=str(template_path),
261-
tofile=str(dev_file),
262-
lineterm="\n",
263-
)
264-
)
265-
if not diff:
266-
logging.info(f"No differences found for {dev_file.name}.")
267-
continue
268-
diff_path = self.diffs_dir / dev_file.with_suffix(".diff").name
269-
self.diffs_dir.mkdir(parents=True, exist_ok=True)
270-
diff_path.write_text("".join(diff), encoding="utf-8")
271-
logging.info(f"Generated diff file: {diff_path}")
272-
except Exception as e:
273-
logging.error(
274-
f"Error generating diff for {dev_file.name}: {e}",
275-
exc_info=True,
326+
# Process diff generation in parallel
327+
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
328+
futures = [
329+
executor.submit(
330+
CSLBuilder._generate_single_diff,
331+
dev_file,
332+
template_lines,
333+
template_path,
334+
self.diffs_dir,
276335
)
336+
for dev_file in dev_files
337+
]
338+
339+
for future in as_completed(futures):
340+
filename, success, message = future.result()
341+
if success:
342+
logging.info(message)
343+
else:
344+
logging.error(f"Failed {filename}: {message}")
277345

278346

279347
def main() -> int:
@@ -327,6 +395,13 @@ def main() -> int:
327395
action="store_true",
328396
help="Write pruned output styles into a flat output directory (no per-family subfolders).",
329397
)
398+
parser.add_argument(
399+
"--max-workers",
400+
"-w",
401+
type=int,
402+
default=None,
403+
help="Maximum number of parallel workers. Default is the number of CPU cores.",
404+
)
330405

331406
args = parser.parse_args()
332407

@@ -356,6 +431,7 @@ def main() -> int:
356431
export_development=args.development,
357432
generate_diffs=args.diffs,
358433
group_by_family=(not args.flat_output),
434+
max_workers=args.max_workers,
359435
)
360436
try:
361437
if args.diffs:

src/style_variant_builder/prune.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,14 @@ class CSLPruner:
3434
modified: bool = field(
3535
default=False, init=False
3636
) # Track whether changes have been made
37-
notice_comment: str | None = field(default=None, init=False)
37+
notice_comment: str | None = field(
38+
default=(
39+
"This file was generated by the Style Variant Builder "
40+
"<https://github.com/citation-style-language/style-variant-builder>. "
41+
"To contribute changes, modify the template and regenerate variants."
42+
),
43+
init=False,
44+
)
3845

3946
def parse_xml(self) -> None:
4047
try:

0 commit comments

Comments
 (0)