Skip to content

Commit 8582910

Browse files
authored
Merge pull request #419 from Titas-Ghosh/fix-zmq-specialization-shared-sources
Fix ZMQ specialization collision for shared node source scripts in mkconcore
2 parents 543d169 + 176cc94 commit 8582910

3 files changed

Lines changed: 143 additions & 52 deletions

File tree

copy_with_port_portname.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,40 @@
55
import logging
66
import json
77

8-
def run_specialization_script(template_script_path, output_dir, edge_params_list, python_exe, copy_script_path):
8+
def _normalize_output_relpath(template_script_path, output_relpath=None):
9+
if output_relpath:
10+
relpath = output_relpath.replace("\\", "/").lstrip("/")
11+
else:
12+
relpath = os.path.basename(template_script_path)
13+
if not relpath:
14+
raise ValueError("Output relative path cannot be empty.")
15+
return relpath
16+
17+
18+
def _join_output_path(output_dir, output_relpath):
19+
return os.path.join(output_dir, *output_relpath.split("/"))
20+
21+
22+
def run_specialization_script(
23+
template_script_path,
24+
output_dir,
25+
edge_params_list,
26+
python_exe,
27+
copy_script_path,
28+
output_relpath=None
29+
):
930
"""
1031
Calls the copy script to generate a specialized version of a node's script.
1132
Returns the basename of the generated script on success, None on failure.
1233
"""
13-
# The new copy script generates a standardized filename, e.g., "original.py"
1434
base_template_name = os.path.basename(template_script_path)
15-
template_root, template_ext = os.path.splitext(base_template_name)
16-
output_filename = f"{template_root}{template_ext}"
17-
expected_output_path = os.path.join(output_dir, output_filename)
35+
output_relpath = _normalize_output_relpath(template_script_path, output_relpath)
36+
expected_output_path = _join_output_path(output_dir, output_relpath)
1837

1938
# If the specialized file already exists, we don't need to regenerate it.
2039
if os.path.exists(expected_output_path):
2140
logging.info(f"Specialized script '{expected_output_path}' already exists. Using existing.")
22-
return output_filename
41+
return output_relpath
2342

2443
# Convert the list of parameters to a JSON string for command line argument
2544
edge_params_json_str = json.dumps(edge_params_list)
@@ -31,13 +50,15 @@ def run_specialization_script(template_script_path, output_dir, edge_params_list
3150
output_dir,
3251
edge_params_json_str # Pass the JSON string as the last argument
3352
]
53+
if output_relpath:
54+
cmd.append(output_relpath)
3455
logging.info(f"Running specialization for '{base_template_name}': {' '.join(cmd)}")
3556
try:
3657
result = subprocess.run(cmd, capture_output=True, text=True, check=True, encoding='utf-8')
37-
logging.info(f"Successfully generated specialized script '{output_filename}'.")
58+
logging.info(f"Successfully generated specialized script '{output_relpath}'.")
3859
if result.stdout: logging.debug(f"copy_with_port_portname.py stdout:\n{result.stdout.strip()}")
3960
if result.stderr: logging.warning(f"copy_with_port_portname.py stderr:\n{result.stderr.strip()}")
40-
return output_filename
61+
return output_relpath
4162
except subprocess.CalledProcessError as e:
4263
logging.error(f"Error calling specialization script for '{template_script_path}':")
4364
logging.error(f"Command: {' '.join(e.cmd)}")
@@ -50,7 +71,7 @@ def run_specialization_script(template_script_path, output_dir, edge_params_list
5071
return None
5172

5273

53-
def create_modified_script(template_script_path, output_dir, edge_params_json_str):
74+
def create_modified_script(template_script_path, output_dir, edge_params_json_str, output_relpath=None):
5475
"""
5576
Creates a modified Python script by injecting ZMQ port and port name
5677
definitions from a JSON object.
@@ -121,17 +142,16 @@ def create_modified_script(template_script_path, output_dir, edge_params_json_st
121142
modified_lines = lines[:insert_index] + definitions + lines[insert_index:]
122143

123144
# --- Determine and create output file ---
124-
base_template_name = os.path.basename(template_script_path)
125-
template_root, template_ext = os.path.splitext(base_template_name)
126-
127-
# Standardized output filename for a node with one or more specializations
128-
output_filename = f"{template_root}{template_ext}"
129-
output_script_path = os.path.join(output_dir, output_filename)
145+
output_relpath = _normalize_output_relpath(template_script_path, output_relpath)
146+
output_script_path = _join_output_path(output_dir, output_relpath)
130147

131148
try:
132149
if not os.path.exists(output_dir):
133150
os.makedirs(output_dir)
134151
print(f"Created output directory: {output_dir}")
152+
output_parent = os.path.dirname(output_script_path)
153+
if output_parent and not os.path.exists(output_parent):
154+
os.makedirs(output_parent, exist_ok=True)
135155

136156
with open(output_script_path, 'w') as f:
137157
f.writelines(modified_lines)
@@ -149,14 +169,15 @@ def create_modified_script(template_script_path, output_dir, edge_params_json_st
149169
datefmt='%Y-%m-%d %H:%M:%S'
150170
)
151171

152-
if len(sys.argv) != 4:
153-
print("\nUsage: python3 copy_with_port_portname.py <TEMPLATE_SCRIPT_PATH> <OUTPUT_DIRECTORY> '<JSON_PARAMETERS>'\n")
172+
if len(sys.argv) not in [4, 5]:
173+
print("\nUsage: python3 copy_with_port_portname.py <TEMPLATE_SCRIPT_PATH> <OUTPUT_DIRECTORY> '<JSON_PARAMETERS>' [OUTPUT_RELATIVE_PATH]\n")
154174
print("Example JSON: '[{\"port\": \"2355\", \"port_name\": \"FUNBODY_REP_1\", \"source_node_label\": \"nodeA\", \"target_node_label\": \"nodeB\"}]'")
155175
print("Note: The JSON string must be enclosed in single quotes in shell.\n")
156176
sys.exit(1)
157177

158178
template_script_path_arg = sys.argv[1]
159179
output_directory_arg = sys.argv[2]
160180
json_params_arg = sys.argv[3]
181+
output_relpath_arg = sys.argv[4] if len(sys.argv) == 5 else None
161182

162-
create_modified_script(template_script_path_arg, output_directory_arg, json_params_arg)
183+
create_modified_script(template_script_path_arg, output_directory_arg, json_params_arg, output_relpath_arg)

mkconcore.py

Lines changed: 63 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -420,40 +420,69 @@ def cleanup_script_files():
420420
logging.warning(f"Error processing edge for parameter aggregation: {e}")
421421

422422
# --- Now, run the specialization for each node that has aggregated parameters ---
423-
if node_edge_params:
424-
logging.info("Running script specialization process...")
425-
specialized_scripts_output_dir = os.path.abspath(os.path.join(outdir, "src"))
426-
os.makedirs(specialized_scripts_output_dir, exist_ok=True)
427-
428-
for node_id, params_list in node_edge_params.items():
429-
current_node_full_label = nodes_dict[node_id]
430-
try:
431-
container_name, original_script = current_node_full_label.split(':', 1)
432-
except ValueError:
433-
continue # Skip if label format is wrong
434-
435-
if not original_script or "." not in original_script:
436-
continue # Skip if not a script file
437-
438-
template_script_full_path = os.path.join(sourcedir, original_script)
439-
if not os.path.exists(template_script_full_path):
440-
logging.error(f"Cannot specialize: Original script '{template_script_full_path}' not found in '{sourcedir}'.")
441-
continue
442-
443-
new_script_basename = copy_with_port_portname.run_specialization_script(
444-
template_script_full_path,
445-
specialized_scripts_output_dir,
446-
params_list,
447-
python_executable,
448-
copy_script_py_path
449-
)
450-
451-
if new_script_basename:
452-
# Update nodes_dict to point to the new comprehensive specialized script
453-
nodes_dict[node_id] = f"{container_name}:{new_script_basename}"
454-
logging.info(f"Node ID '{node_id}' ('{container_name}') updated to use specialized script '{new_script_basename}'.")
455-
else:
456-
logging.error(f"Failed to generate specialized script for node ID '{node_id}'. It will retain its original script.")
423+
if node_edge_params:
424+
logging.info("Running script specialization process...")
425+
specialized_scripts_output_dir = os.path.abspath(os.path.join(outdir, "src"))
426+
os.makedirs(specialized_scripts_output_dir, exist_ok=True)
427+
428+
# Build one specialization plan per source script. This avoids collisions
429+
# when multiple nodes reference the same script and need different ZMQ params.
430+
script_edge_params = {}
431+
script_nodes = {}
432+
for node_id, params_list in node_edge_params.items():
433+
current_node_full_label = nodes_dict.get(node_id, "")
434+
try:
435+
container_name, original_script = current_node_full_label.split(':', 1)
436+
except ValueError:
437+
continue
438+
439+
if not original_script or "." not in original_script:
440+
continue
441+
442+
script_nodes.setdefault(original_script, []).append((node_id, container_name))
443+
script_edge_params.setdefault(original_script, [])
444+
seen_keys = {
445+
(
446+
p.get("port"),
447+
p.get("port_name"),
448+
p.get("source_node_label"),
449+
p.get("target_node_label")
450+
)
451+
for p in script_edge_params[original_script]
452+
}
453+
for edge_param in params_list:
454+
edge_key = (
455+
edge_param.get("port"),
456+
edge_param.get("port_name"),
457+
edge_param.get("source_node_label"),
458+
edge_param.get("target_node_label")
459+
)
460+
if edge_key not in seen_keys:
461+
script_edge_params[original_script].append(edge_param)
462+
seen_keys.add(edge_key)
463+
464+
for original_script, merged_params in script_edge_params.items():
465+
template_script_full_path = os.path.join(sourcedir, original_script)
466+
if not os.path.exists(template_script_full_path):
467+
logging.error(f"Cannot specialize: Original script '{template_script_full_path}' not found in '{sourcedir}'.")
468+
continue
469+
470+
new_script_relpath = copy_with_port_portname.run_specialization_script(
471+
template_script_full_path,
472+
specialized_scripts_output_dir,
473+
merged_params,
474+
python_executable,
475+
copy_script_py_path,
476+
output_relpath=original_script
477+
)
478+
479+
if not new_script_relpath:
480+
logging.error(f"Failed to generate specialized script for source '{original_script}'.")
481+
continue
482+
483+
for node_id, container_name in script_nodes.get(original_script, []):
484+
nodes_dict[node_id] = f"{container_name}:{new_script_relpath}"
485+
logging.info(f"Node ID '{node_id}' ('{container_name}') updated to use specialized script '{new_script_relpath}'.")
457486

458487
#not right for PM2_1_1 and PM2_1_2
459488
volswr = len(nodes_dict)*['']

tests/test_cli.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,47 @@ def test_run_command_subdir_source(self):
153153
self.assertEqual(result.exit_code, 0)
154154
self.assertTrue(Path('out/src/subdir/script.py').exists())
155155

156+
def test_run_command_shared_source_specialization_merges_edge_params(self):
157+
with self.runner.isolated_filesystem(temp_dir=self.temp_dir):
158+
Path('src').mkdir()
159+
Path('src/common.py').write_text(
160+
"import concore\n\n"
161+
"def step():\n"
162+
" return None\n"
163+
)
164+
165+
workflow = """<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
166+
<graphml xmlns="http://graphml.graphdrawing.org/xmlns" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns http://www.yworks.com/xml/schema/graphml/1.1/ygraphml.xsd" xmlns:y="http://www.yworks.com/xml/graphml">
167+
<key for="node" id="d6" yfiles.type="nodegraphics"/>
168+
<key for="edge" id="d10" yfiles.type="edgegraphics"/>
169+
<graph edgedefault="directed" id="G">
170+
<node id="n1"><data key="d6"><y:ShapeNode><y:NodeLabel>A:common.py</y:NodeLabel></y:ShapeNode></data></node>
171+
<node id="n2"><data key="d6"><y:ShapeNode><y:NodeLabel>B:common.py</y:NodeLabel></y:ShapeNode></data></node>
172+
<node id="n3"><data key="d6"><y:ShapeNode><y:NodeLabel>C:common.py</y:NodeLabel></y:ShapeNode></data></node>
173+
<edge source="n1" target="n2"><data key="d10"><y:PolyLineEdge><y:EdgeLabel>0x1000_AB</y:EdgeLabel></y:PolyLineEdge></data></edge>
174+
<edge source="n2" target="n3"><data key="d10"><y:PolyLineEdge><y:EdgeLabel>0x1001_BC</y:EdgeLabel></y:PolyLineEdge></data></edge>
175+
</graph>
176+
</graphml>
177+
"""
178+
Path('workflow.graphml').write_text(workflow)
179+
180+
result = self.runner.invoke(cli, [
181+
'run',
182+
'workflow.graphml',
183+
'--source', 'src',
184+
'--output', 'out',
185+
'--type', 'posix'
186+
])
187+
self.assertEqual(result.exit_code, 0)
188+
189+
specialized_script = Path('out/src/common.py')
190+
self.assertTrue(specialized_script.exists())
191+
content = specialized_script.read_text()
192+
self.assertIn('PORT_NAME_A_B', content)
193+
self.assertIn('PORT_A_B', content)
194+
self.assertIn('PORT_NAME_B_C', content)
195+
self.assertIn('PORT_B_C', content)
196+
156197
def test_run_command_existing_output(self):
157198
with self.runner.isolated_filesystem(temp_dir=self.temp_dir):
158199
result = self.runner.invoke(cli, ['init', 'test-project'])

0 commit comments

Comments
 (0)