Skip to content

Commit cce9130

Browse files
committed
Refactor XML processing
Refactored XML normalization and macro handling, including extraction of XML declaration and em-dash normalization into helper methods.
1 parent fffd7f5 commit cce9130

File tree

2 files changed

+75
-70
lines changed

2 files changed

+75
-70
lines changed

src/style_variant_builder/build.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@
99
import subprocess
1010
import sys
1111
import tempfile
12+
from contextlib import suppress
1213
from dataclasses import dataclass
14+
from itertools import chain
1315
from pathlib import Path
1416

1517
from lxml import etree
1618

1719
from style_variant_builder.prune import CSLPruner
1820

1921
logging.basicConfig(level=logging.INFO, format="%(message)s")
22+
TEMPLATE_SUFFIX = "-template.csl"
2023

2124

2225
class ColourFormatter(logging.Formatter):
@@ -99,7 +102,7 @@ def _get_diff_files(self) -> list[Path]:
99102
f"Error reading diff file {diff_file.name}: {e}",
100103
exc_info=True,
101104
)
102-
all_diffs = sorted(list(filname_diffs) + reference_diffs)
105+
all_diffs = sorted(chain(filname_diffs, reference_diffs))
103106
if not all_diffs:
104107
raise FileNotFoundError(
105108
f"No diff files found for style family '{self.style_family}' in {self.diffs_dir}"
@@ -146,8 +149,11 @@ def _insert_notice_comment_xml(self, pruner: CSLPruner) -> None:
146149
logging.error("Cannot insert notice comment: XML root is None.")
147150
return
148151

149-
comment_text = "This file was generated by the Style Variant Builder <https://github.com/citation-style-language/style-variant-builder>. To contribute changes, modify the template and regenerate variants."
150-
comment_node = etree.Comment(f" {comment_text} ")
152+
comment_node = etree.Comment(
153+
" This file was generated by the Style Variant Builder "
154+
"<https://github.com/citation-style-language/style-variant-builder>. "
155+
"To contribute changes, modify the template and regenerate variants. "
156+
)
151157
# Insert as the first child of the root <style> element
152158
pruner.root.insert(0, comment_node)
153159
# Insert a tail newline and indentation after the comment to ensure separation and correct indentation for <info>
@@ -226,8 +232,9 @@ def build_variants(self) -> tuple[int, int]:
226232
)
227233
self.failed_variants += 1
228234
finally:
229-
if patched_file is not None and patched_file.exists():
230-
patched_file.unlink(missing_ok=True)
235+
if patched_file is not None:
236+
with suppress(FileNotFoundError):
237+
patched_file.unlink()
231238

232239
return (self.successful_variants, self.failed_variants)
233240

@@ -262,7 +269,7 @@ def generate_diff_files(self) -> None:
262269
f"Error reading development file {dev_file.name}: {e}",
263270
exc_info=True,
264271
)
265-
dev_files = sorted(list(expected_dev_files) + additional_dev_files)
272+
dev_files = sorted(chain(expected_dev_files, additional_dev_files))
266273

267274
if not dev_files:
268275
logging.warning(
@@ -291,8 +298,7 @@ def generate_diff_files(self) -> None:
291298
continue
292299
diff_path = self.diffs_dir / dev_file.with_suffix(".diff").name
293300
self.diffs_dir.mkdir(parents=True, exist_ok=True)
294-
with diff_path.open("w", encoding="utf-8") as dfile:
295-
dfile.write("".join(diff))
301+
diff_path.write_text("".join(diff), encoding="utf-8")
296302
logging.info(f"Generated diff file: {diff_path}")
297303
except Exception as e:
298304
logging.error(
@@ -356,12 +362,13 @@ def main() -> int:
356362
args = parser.parse_args()
357363

358364
# Automatically determine style families by scanning template files.
359-
template_files = list(args.templates_path.glob("*-template.csl"))
365+
template_files = list(args.templates_path.glob(f"*{TEMPLATE_SUFFIX}"))
360366
if not template_files:
361367
logging.error(f"No template files found in {args.templates_path}.")
362368
return 1
363369
style_families = [
364-
template.stem.replace("-template", "") for template in template_files
370+
template.stem.removesuffix(TEMPLATE_SUFFIX.removesuffix(".csl"))
371+
for template in template_files
365372
]
366373

367374
overall_success = True

src/style_variant_builder/prune.py

Lines changed: 58 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
from lxml import etree
1414

1515
logging.basicConfig(level=logging.INFO, format="%(message)s")
16-
NS = "{http://purl.org/net/xbiblio/csl}"
16+
17+
NSMAP = {"csl": "http://purl.org/net/xbiblio/csl"}
18+
19+
20+
def _tag(local_name: str) -> str:
21+
"""Helper to construct fully-qualified tag names."""
22+
return f"{{{NSMAP['csl']}}}{local_name}"
1723

1824

1925
@dataclass(slots=True)
@@ -69,9 +75,8 @@ def collect_macro_definitions(self) -> None:
6975
logging.error(msg)
7076
raise ValueError(msg)
7177
self.macro_defs = {}
72-
for elem in self.root.iter(f"{NS}macro"):
73-
name = elem.attrib.get("name")
74-
if name:
78+
for elem in self.root.iter(_tag("macro")):
79+
if name := elem.attrib.get("name"):
7580
self.macro_defs[name] = elem
7681

7782
def flatten_layout_macros(self) -> int:
@@ -90,30 +95,27 @@ def flatten_layout_macros(self) -> int:
9095
raise ValueError(msg)
9196

9297
updated = 0
93-
for layout in self.root.iter(f"{NS}layout"):
98+
for layout in self.root.iter(_tag("layout")):
9499
# Consider only real element children, ignore comments/whitespace
95100
children = [ch for ch in list(layout) if isinstance(ch.tag, str)]
96101
if len(children) != 1:
97102
continue
98103
only_child = children[0]
99-
if only_child.tag != f"{NS}text":
104+
if only_child.tag != _tag("text"):
100105
continue
101106

102107
# Require a pure macro call to avoid changing semantics if other attributes exist
103108
# Build a concrete dict[str, str] of attributes for type safety
104109
attrs: dict[str, str] = {
105110
str(k): str(v) for k, v in only_child.attrib.items()
106111
}
107-
macro_attr = attrs.get("macro")
108112
# Ensure the attribute is a string (lxml can expose AnyStr)
109-
if not isinstance(macro_attr, str):
113+
if not isinstance(macro_attr := attrs.get("macro"), str):
110114
continue
111-
macro_name: str = macro_attr
112-
if not macro_name or len(attrs) != 1:
115+
if not macro_attr or len(attrs) != 1:
113116
continue
114117

115-
macro_def = self.macro_defs.get(macro_name)
116-
if macro_def is None:
118+
if (macro_def := self.macro_defs.get(macro_attr)) is None:
117119
# Unknown macro; skip
118120
continue
119121

@@ -131,46 +133,46 @@ def flatten_layout_macros(self) -> int:
131133
self.build_parent_map()
132134
return updated
133135

134-
def gather_macro_refs(self, element: etree._Element) -> set:
136+
def gather_macro_refs(self, element: etree._Element) -> set[str]:
135137
"""Recursively collect macro names to which the element refers."""
136-
refs = set()
137-
macro_attr = element.attrib.get("macro")
138-
if macro_attr:
138+
refs: set[str] = set()
139+
if macro_attr := element.attrib.get("macro"):
139140
refs.add(macro_attr)
140141
for child in element:
141142
refs.update(self.gather_macro_refs(child))
142143
return refs
143144

144-
def build_used_macros(self) -> set:
145+
def build_used_macros(self) -> set[str]:
145146
if self.root is None:
146147
msg = "Root is None. Ensure parse_xml() is called successfully."
147148
logging.error(msg)
148149
raise ValueError(msg)
149-
used_macros = set()
150+
used_macros: set[str] = set()
150151

151152
def add_macro_and_deps(macro_name: str) -> None:
152153
if macro_name in used_macros:
153154
return
154155
used_macros.add(macro_name)
155-
macro = self.macro_defs.get(macro_name)
156-
if macro is not None:
156+
if (macro := self.macro_defs.get(macro_name)) is not None:
157157
refs = self.gather_macro_refs(macro)
158158
for ref in refs:
159159
add_macro_and_deps(ref)
160160

161161
# Entry points: <citation> and <bibliography>
162-
entry_tags = [f"{NS}citation", f"{NS}bibliography"]
163-
entry_macro_refs = set()
164-
for tag in entry_tags:
165-
for entry in self.root.iter(tag):
166-
entry_macro_refs.update(self.gather_macro_refs(entry))
162+
entry_tags = [_tag("citation"), _tag("bibliography")]
163+
entry_macro_refs = {
164+
ref
165+
for tag in entry_tags
166+
for entry in self.root.iter(tag)
167+
for ref in self.gather_macro_refs(entry)
168+
}
167169
entry_macro_refs.update(self.gather_macro_refs(self.root))
168170
for ref in entry_macro_refs:
169171
add_macro_and_deps(ref)
170172
return used_macros
171173

172174
def prune_macros(self) -> None:
173-
total_removed = []
175+
total_removed_count = 0
174176
while True:
175177
used_macros = self.build_used_macros()
176178
removed = []
@@ -183,15 +185,15 @@ def prune_macros(self) -> None:
183185
logging.debug(f"Removed macro: {name}")
184186
if removed:
185187
self.modified = True
186-
total_removed.extend(removed)
188+
total_removed_count += len(removed)
187189
self.collect_macro_definitions()
188190
self.build_parent_map()
189191
else:
190192
logging.debug("No unused macros found.")
191193
break
192-
if total_removed:
194+
if total_removed_count > 0:
193195
logging.info(
194-
f"Removed a total of {len(total_removed)} unused macros."
196+
f"Removed a total of {total_removed_count} unused macros."
195197
)
196198
else:
197199
logging.info("No macros pruned.")
@@ -200,16 +202,24 @@ def remove_xml_model_declarations(self, xml_data: str) -> str:
200202
"""Remove any xml-model declarations from the XML content."""
201203
return re.sub(r"<\?xml-model [^>]+>\n?", "", xml_data)
202204

203-
def normalize_xml_content(self, xml_data: bytes) -> bytes:
204-
"""Revert Python changes to XML content and reorder the default-locale attribute in <style> tags."""
205-
text = xml_data.decode("utf-8")
206-
text = self.remove_xml_model_declarations(text)
207-
text = re.sub(
205+
def _normalize_xml_declaration(self, text: str) -> str:
206+
"""Ensure XML declaration uses double quotes."""
207+
return re.sub(
208208
r"<\?xml version='1\.0' encoding='utf-8'\?>",
209209
'<?xml version="1.0" encoding="utf-8"?>',
210210
text,
211211
)
212-
text = re.sub(r"—+", lambda m: "&#8212;" * len(m.group(0)), text)
212+
213+
def _escape_em_dashes(self, text: str) -> str:
214+
"""Replace em-dashes with numeric entities."""
215+
return re.sub(r"—+", lambda m: "&#8212;" * len(m.group(0)), text)
216+
217+
def normalize_xml_content(self, xml_data: bytes) -> bytes:
218+
"""Revert Python changes to XML content and reorder the default-locale attribute in <style> tags."""
219+
text = xml_data.decode("utf-8")
220+
text = self.remove_xml_model_declarations(text)
221+
text = self._normalize_xml_declaration(text)
222+
text = self._escape_em_dashes(text)
213223

214224
# Collapse XML fragments inside multi-line comments to match CSL repository indentation
215225
# Pattern captures the entire comment body (non-greedy) so we can post-process line breaks
@@ -232,12 +242,8 @@ def collapse_comment(match: re.Match) -> str:
232242
seq.append(lines[i])
233243
i += 1
234244
# Collapse sequence
235-
first_indent_match = re.match(r"^([ \t]*)", seq[0])
236-
indent = (
237-
first_indent_match.group(1)
238-
if first_indent_match
239-
else ""
240-
)
245+
indent_match = re.match(r"^([ \t]*)", seq[0])
246+
indent = indent_match.group(1) if indent_match else ""
241247
collapsed = indent + "".join(s.strip() for s in seq)
242248
new_lines.append(collapsed)
243249
else:
@@ -251,18 +257,16 @@ def collapse_comment(match: re.Match) -> str:
251257
# Reorder the default-locale attribute to the end in <style ...> tags.
252258
def reorder_default_locale(match: re.Match) -> str:
253259
attribs = match.group(1)
254-
# Find all attributes in the tag.
255260
attrs = re.findall(r'(\S+="[^"]*")', attribs)
256-
new_attrs: list[str] = []
257-
default_locale_attr: str | None = None
258-
for attr in attrs:
259-
if attr.startswith("default-locale="):
260-
default_locale_attr = attr
261-
else:
262-
new_attrs.append(attr)
263-
if default_locale_attr is not None:
264-
new_attrs.append(default_locale_attr)
265-
new_attribs = " ".join(new_attrs)
261+
262+
# Separate default-locale from other attributes
263+
other_attrs = [
264+
a for a in attrs if not a.startswith("default-locale=")
265+
]
266+
locale_attrs = [a for a in attrs if a.startswith("default-locale=")]
267+
268+
# Reorder: other attributes first, then locale
269+
new_attribs = " ".join(other_attrs + locale_attrs)
266270
return f"<style {new_attribs}>"
267271

268272
text = re.sub(r"<style\s+([^>]+)>", reorder_default_locale, text)
@@ -306,15 +310,9 @@ def save(self) -> None:
306310
xml_data = self.reindent_xml_bytes(xml_data)
307311
# Ensure XML declaration uses double quotes
308312
xml_text = xml_data.decode("utf-8")
309-
xml_text = re.sub(
310-
r"<\?xml version='1\.0' encoding='utf-8'\?>",
311-
'<?xml version="1.0" encoding="utf-8"?>',
312-
xml_text,
313-
)
313+
xml_text = self._normalize_xml_declaration(xml_text)
314314
# lxml will decode character entities during reparse; restore em-dashes as numeric entities
315-
xml_text = re.sub(
316-
r"—+", lambda m: "&#8212;" * len(m.group(0)), xml_text
317-
)
315+
xml_text = self._escape_em_dashes(xml_text)
318316
self.output_path.write_text(xml_text, encoding="utf-8")
319317
except Exception as e:
320318
logging.error(

0 commit comments

Comments
 (0)