diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 6123da97e0..ae599ece6d 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -147,16 +147,17 @@ def step_order(self, inp: definitions.ConcreteProgramDef) -> list[str]: @dataclasses.dataclass(frozen=True) class Backend(Generic[core_defs.DeviceTypeT]): name: str - executor: workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram] + executor: workflow.Workflow[definitions.CompilableProgramDef, stages.CompilationArtifact] allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] transforms: workflow.Workflow[definitions.ConcreteProgramDef, definitions.CompilableProgramDef] def compile( self, program: definitions.IRDefinitionT, compile_time_args: arguments.CompileTimeArgs ) -> stages.ExecutableProgram: - return self.executor( + artifact = self.executor( self.transforms(definitions.ConcreteProgramDef(data=program, args=compile_time_args)) ) + return artifact.load() @property def __gt_allocator__( diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 3748d95192..8f5da88b77 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -14,15 +14,12 @@ import factory -from gt4py._core import locking -from gt4py.next import config +from gt4py._core import definitions as core_defs, locking +from gt4py.next import config, utils as gtx_utils from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import build_data, cache, importer -T = TypeVar("T") - - def is_compiled(data: build_data.BuildData) -> bool: return data.status >= build_data.BuildStatus.COMPILED @@ -45,27 +42,58 @@ def __call__( @dataclasses.dataclass(frozen=True) -class Compiler( +class CPPCompilationArtifact(gtx_utils.MetadataBasedPickling): + """On-disk result of a CPP-style compilation: a Python extension module. + + The default :meth:`load` is an ``importlib`` import + entry-point lookup; + backends override to apply their own calling convention. + """ + + src_dir: pathlib.Path + module: pathlib.Path + entry_point_name: str + device_type: core_defs.DeviceType + + def load(self) -> stages.ExecutableProgram: + """Import the .so and return the raw entry point. + + Must run in the process that will call the returned program: the + module is registered in that process's ``sys.modules`` under the + ``gt4py.__compiled_programs__.`` prefix. + """ + m = importer.import_from_path( + self.src_dir / self.module, + sys_modules_prefix="gt4py.__compiled_programs__.", + ) + return getattr(m, self.entry_point_name) + + +@dataclasses.dataclass(frozen=True) +class CPPCompiler( workflow.ChainableWorkflowMixin[ stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - stages.ExecutableProgram, + CPPCompilationArtifact, ], workflow.ReplaceEnabledWorkflowMixin[ stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - stages.ExecutableProgram, + CPPCompilationArtifact, ], definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], ): - """Use any build system (via configured factory) to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" + """Drive a CPP-style build system into a :class:`CPPCompilationArtifact`. + + Backends override :meth:`_make_artifact` to use their own artifact subclass. + """ cache_lifetime: config.BuildCacheLifetime builder_factory: BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec] + device_type: core_defs.DeviceType force_recompile: bool = False def __call__( self, inp: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec], - ) -> stages.ExecutableProgram: + ) -> CPPCompilationArtifact: src_dir = cache.get_cache_folder(inp, self.cache_lifetime) # If we are compiling the same program at the same time (e.g. multiple MPI ranks), @@ -83,17 +111,22 @@ def __call__( f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." ) - m = importer.import_from_path( - src_dir / new_data.module, sys_modules_prefix="gt4py.__compiled_programs__." - ) - func = getattr(m, new_data.entry_point_name) + return self._make_artifact(src_dir, new_data.module, new_data.entry_point_name) - return func + def _make_artifact( + self, src_dir: pathlib.Path, module: pathlib.Path, entry_point_name: str + ) -> CPPCompilationArtifact: + return CPPCompilationArtifact( + src_dir=src_dir, + module=module, + entry_point_name=entry_point_name, + device_type=self.device_type, + ) class CompilerFactory(factory.Factory): class Meta: - model = Compiler + model = CPPCompiler class CompilationError(RuntimeError): ... diff --git a/src/gt4py/next/otf/definitions.py b/src/gt4py/next/otf/definitions.py index 11b42dc6ce..6b33465949 100644 --- a/src/gt4py/next/otf/definitions.py +++ b/src/gt4py/next/otf/definitions.py @@ -57,12 +57,17 @@ def __call__( class CompilationStep( workflow.Workflow[ - stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.ExecutableProgram + stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.CompilationArtifact ], Protocol[CodeSpecT, TargetCodeSpecT], ): - """Compile program source code and bindings into a python callable (CompilableSource -> CompiledProgram).""" + """Run the build system and produce a :class:`stages.CompilationArtifact`. + + Each backend defines its own concrete artifact dataclass (frozen, + picklable, with a :meth:`stages.CompilationArtifact.load` method); they all + satisfy the :class:`stages.CompilationArtifact` Protocol structurally. + """ def __call__( self, source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT] - ) -> stages.ExecutableProgram: ... + ) -> stages.CompilationArtifact: ... diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 79cd17162b..0b809e4731 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -14,10 +14,11 @@ @dataclasses.dataclass(frozen=True) -class OTFCompileWorkflow(workflow.NamedStepSequence): +class OTFCompileWorkflow( + workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.CompilationArtifact] +): """The typical compiled backend steps composed into a workflow.""" translation: definitions.TranslationStep bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject] - compilation: workflow.Workflow[stages.CompilableProject, stages.ExecutableProgram] - decoration: workflow.Workflow[stages.ExecutableProgram, stages.ExecutableProgram] + compilation: workflow.Workflow[stages.CompilableProject, stages.CompilationArtifact] diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index b6816b1cc3..27ee8b45a6 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -129,6 +129,19 @@ def build(self) -> None: ... ExecutableProgram: TypeAlias = Callable +class CompilationArtifact(Protocol): + """The output of an :class:`recipes.OTFCompileWorkflow`. + + Each backend defines its own concrete artifact dataclass; all share this + Protocol. Implementations are frozen dataclasses, picklable, and have no + live process-bound state — that is reconstructed by :meth:`load`, + which returns a directly-callable :class:`ExecutableProgram` taking + gt4py-shaped arguments. + """ + + def load(self) -> ExecutableProgram: ... + + def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: """ Filter out multiple occurrences of the same ``interface.LibraryDependency``. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py index e1747b7ac3..1f69f1ad71 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -9,19 +9,25 @@ from __future__ import annotations import dataclasses +import json import os +import pathlib import warnings from collections.abc import Callable, MutableSequence, Sequence from typing import Any import dace +import dace.codegen.compiler as dace_compiler import factory from gt4py._core import definitions as core_defs, locking -from gt4py.next import common, config +from gt4py.next import common, config, utils as gtx_utils from gt4py.next.otf import code_specs, definitions, stages, workflow from gt4py.next.otf.compilation import cache as gtx_cache -from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon +from gt4py.next.program_processors.runners.dace.workflow import ( + common as gtx_wfdcommon, + decoration as gtx_wfddecoration, +) class CompiledDaceProgram: @@ -55,7 +61,7 @@ def __init__( self, program: dace.CompiledSDFG, bind_func_name: str, - binding_source: stages.BindingSource[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], + binding_source_code: str, ): self.sdfg_program = program @@ -64,9 +70,10 @@ def __init__( # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. self.sdfg_argtypes = list(program.sdfg.arglist().values()) - # Note that `binding_source` contains Python code tailored to this specific SDFG. - # Here we dinamically compile this function and add it to the compiled program. - exec(binding_source.source_code, global_namespace := {}) # type: ignore[var-annotated] + # The binding source code is Python tailored to this specific SDFG. + # We dynamically compile that function and add it to the compiled program. + global_namespace: dict[str, Any] = {} + exec(binding_source_code, global_namespace) self.update_sdfg_ctype_arglist = global_namespace[bind_func_name] # For debug purpose, we set a unique module name on the compiled function. self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder) @@ -114,19 +121,75 @@ def __call__(self, **kwargs: Any) -> None: assert result is None +@dataclasses.dataclass(frozen=True) +class DaCeCompilationArtifact(gtx_utils.MetadataBasedPickling): + """Result of a DaCe compilation: build folder + SDFG bindings + the SDFG itself. + + The SDFG is carried inline as JSON because dace's load path + (:func:`get_program_handle`) needs an SDFG instance to wrap into the + returned :class:`CompiledSDFG`, and the build folder may not contain a + ``program.sdfg(z)`` dump under the upcoming minimal-build-dir mode. + """ + + build_folder: pathlib.Path + sdfg_json: str + binding_source_code: str + bind_func_name: str + device_type: core_defs.DeviceType + + # Process-local cache of the live :class:`CompiledDaceProgram`. Populated by + # ``DaCeCompiler`` to skip the disk round-trip when the artifact stays in + # the same process. Excluded from pickle (``pickle=False`` metadata) so + # receivers in other processes see ``None`` and fall through to the + # disk-based load. + _live_program: CompiledDaceProgram | None = dataclasses.field( + init=False, + default=None, + compare=False, + repr=False, + metadata=gtx_utils.gt4py_metadata(pickle=False), + ) + + def load(self) -> stages.ExecutableProgram: + """Wrap the compiled program in gt4py's calling convention. + + On a miss, loads the precompiled .so directly via + :func:`dace.codegen.compiler.get_program_handle` — no recompilation, + no ``dace.config`` re-entry. Must run in the process that will call + the returned program. + """ + program = self._live_program + if program is None: + program = self._load_compiled_program() + object.__setattr__(self, "_live_program", program) + return gtx_wfddecoration.convert_args(program, device=self.device_type) + + def _load_compiled_program(self) -> CompiledDaceProgram: + # TODO(phimuell): Drop ``sdfg_json`` from the artifact once dace + # exposes a load path that doesn't require an SDFG instance to wrap + # into the returned ``CompiledSDFG``. + sdfg = dace.SDFG.from_json(json.loads(self.sdfg_json)) + folder_version = dace_compiler.get_folder_version(self.build_folder) + library_path = dace_compiler.get_binary_name( + self.build_folder, sdfg_name=sdfg.name, folder_version=folder_version + ) + sdfg_program = dace_compiler.get_program_handle(library_path, sdfg) + return CompiledDaceProgram(sdfg_program, self.bind_func_name, self.binding_source_code) + + @dataclasses.dataclass(frozen=True) class DaCeCompiler( workflow.ChainableWorkflowMixin[ stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - CompiledDaceProgram, + DaCeCompilationArtifact, ], workflow.ReplaceEnabledWorkflowMixin[ stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - CompiledDaceProgram, + DaCeCompilationArtifact, ], definitions.CompilationStep[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], ): - """Use the dace build system to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" + """Run the DaCe build system and produce an on-disk :class:`DaCeCompilationArtifact`.""" bind_func_name: str cache_lifetime: config.BuildCacheLifetime @@ -136,25 +199,37 @@ class DaCeCompiler( def __call__( self, inp: stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec], - ) -> CompiledDaceProgram: + ) -> DaCeCompilationArtifact: with gtx_wfdcommon.dace_context( device_type=self.device_type, cmake_build_type=self.cmake_build_type, ): - sdfg_build_folder = gtx_cache.get_cache_folder(inp, self.cache_lifetime) + sdfg_build_folder = pathlib.Path(gtx_cache.get_cache_folder(inp, self.cache_lifetime)) sdfg_build_folder.mkdir(parents=True, exist_ok=True) sdfg = dace.SDFG.from_json(inp.program_source.source_code) - sdfg.build_folder = sdfg_build_folder + sdfg.build_folder = str(sdfg_build_folder) with locking.lock(sdfg_build_folder): + # Keep the handle so the artifact's load() can skip the disk + # round-trip in the same process. sdfg_program = sdfg.compile(validate=False) assert inp.binding_source is not None - return CompiledDaceProgram( - sdfg_program, - self.bind_func_name, - inp.binding_source, + artifact = DaCeCompilationArtifact( + build_folder=sdfg_build_folder, + sdfg_json=json.dumps(inp.program_source.source_code), + binding_source_code=inp.binding_source.source_code, + bind_func_name=self.bind_func_name, + device_type=self.device_type, + ) + object.__setattr__( + artifact, + "_live_program", + CompiledDaceProgram( + sdfg_program, artifact.bind_func_name, artifact.binding_source_code + ), ) + return artifact class DaCeCompilationStepFactory(factory.Factory): diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 103e7af33b..f9e9f7181b 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -9,7 +9,7 @@ from __future__ import annotations import functools -from typing import Any, Sequence +from typing import TYPE_CHECKING, Any, Sequence import numpy as np @@ -18,14 +18,16 @@ from gt4py.next.instrumentation import metrics from gt4py.next.otf import stages from gt4py.next.program_processors.runners.dace import sdfg_callable -from gt4py.next.program_processors.runners.dace.workflow import ( - common as gtx_wfdcommon, - compilation as gtx_wfdcompilation, -) +from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon + + +if TYPE_CHECKING: + # Type-only: a top-level import would cycle with ``compilation``. + from gt4py.next.program_processors.runners.dace.workflow.compilation import CompiledDaceProgram def convert_args( - fun: gtx_wfdcompilation.CompiledDaceProgram, + fun: CompiledDaceProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU, ) -> stages.ExecutableProgram: # Retieve metrics level from GT4Py environment variable. diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 62febd0965..069854a586 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -16,10 +16,7 @@ from gt4py._core import definitions as core_defs, filecache from gt4py.next import config from gt4py.next.otf import recipes, stages, workflow -from gt4py.next.program_processors.runners.dace.workflow import ( - bindings as bindings_step, - decoration as decoration_step, -) +from gt4py.next.program_processors.runners.dace.workflow import bindings as bindings_step from gt4py.next.program_processors.runners.dace.workflow.compilation import ( DaCeCompilationStepFactory, ) @@ -72,9 +69,3 @@ class Params: device_type=factory.SelfAttribute("..device_type"), cmake_build_type=factory.SelfAttribute("..cmake_build_type"), ) - decoration = factory.LazyAttribute( - lambda o: functools.partial( - decoration_step.convert_args, - device=o.device_type, - ) - ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index c1743dea6a..6a8ff1fc69 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -6,7 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import functools +import dataclasses +import pathlib from typing import Any import factory @@ -106,6 +107,30 @@ def extract_connectivity_args( return args +@dataclasses.dataclass(frozen=True) +class GTFNCompilationArtifact(compiler.CPPCompilationArtifact): + def load(self) -> stages.ExecutableProgram: + return convert_args(super().load(), device=self.device_type) + + +@dataclasses.dataclass(frozen=True) +class GTFNCompiler(compiler.CPPCompiler): + def _make_artifact( + self, src_dir: pathlib.Path, module: pathlib.Path, entry_point_name: str + ) -> GTFNCompilationArtifact: + return GTFNCompilationArtifact( + src_dir=src_dir, + module=module, + entry_point_name=entry_point_name, + device_type=self.device_type, + ) + + +class GTFNCompilerFactory(factory.Factory): + class Meta: + model = GTFNCompiler + + class GTFNCompileWorkflowFactory(factory.Factory): class Meta: model = recipes.OTFCompileWorkflow @@ -140,12 +165,10 @@ class Params: nanobind.bind_source ) compilation = factory.SubFactory( - compiler.CompilerFactory, + GTFNCompilerFactory, cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), builder_factory=factory.SelfAttribute("..builder_factory"), - ) - decoration = factory.LazyAttribute( - lambda o: functools.partial(convert_args, device=o.device_type) + device_type=factory.SelfAttribute("..device_type"), ) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 5ee0a67f25..396eecc173 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -11,11 +11,10 @@ import dataclasses import functools import importlib.util -import pathlib import tempfile import textwrap -import typing -from collections.abc import Callable, Iterable +import types +from collections.abc import Iterable from typing import Any, Optional from gt4py.eve import codegen @@ -106,28 +105,20 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: return f"{node.id} = {_create_tmp(axes, origin, shape, node.dtype)}" -_FENCIL_CACHE: dict[int, Callable] = {} +# Caches the generated source by IR hash so re-codegen is skipped within a process. +_SOURCE_CACHE: dict[int, tuple[str, str]] = {} +# Caches the loaded module by source string so re-exec is skipped within a process. +_MODULE_CACHE: dict[str, types.ModuleType] = {} -def fencil_generator( +def _generate_source( ir: itir.Program, debug: bool, use_embedded: bool, offset_provider: common.OffsetProvider, transforms: itir_transforms.GTIRTransform, -) -> stages.ExecutableProgram: - """ - Generate a directly executable fencil from an ITIR node. - - Arguments: - ir: The iterator IR (ITIR) node. - debug: Keep module source containing fencil implementation. - extract_temporaries: Extract intermediate field values into temporaries. - use_embedded: Directly use builtins from embedded backend instead of - generic dispatcher. Gives faster performance and is easier - to debug. - offset_provider: A mapping from offset names to offset providers. - """ +) -> tuple[str, str]: + """Generate the Python source for an ITIR program. Returns ``(source_code, entry_point_name)``.""" # TODO(tehrengruber): just a temporary solution until we have a proper generic # caching mechanism cache_key = hash( @@ -139,10 +130,10 @@ def fencil_generator( tuple(common.offset_provider_to_type(offset_provider).items()), ) ) - if cache_key in _FENCIL_CACHE: + if cache_key in _SOURCE_CACHE: if debug: - print(f"Using cached fencil for key {cache_key}") - return _FENCIL_CACHE[cache_key] # A CompiledProgram is just a Callable + print(f"Using cached source for key {cache_key}") + return _SOURCE_CACHE[cache_key] ir = transforms(ir, offset_provider=offset_provider) @@ -178,80 +169,113 @@ def fencil_generator( """ ) - with tempfile.NamedTemporaryFile( - mode="w", suffix=".py", encoding="utf-8", delete=False - ) as source_file: - source_file_name = source_file.name - if debug: - print(source_file_name) - offset_literals = [f'{o} = offset("{o}")' for o in offset_literals] - axis_literals = [ - f'{o.value} = gtx.Dimension("{o.value}", kind=gtx.DimensionKind("{o.kind}"))' - for o in axis_literals_set - ] - source_file.write(header) - source_file.write("\n".join(offset_literals)) - source_file.write("\n") - source_file.write("\n".join(axis_literals)) - source_file.write("\n") - source_file.write(program) - try: - spec = importlib.util.spec_from_file_location("module.name", source_file_name) - mod = importlib.util.module_from_spec(spec) # type: ignore - spec.loader.exec_module(mod) # type: ignore - finally: - if not debug: - pathlib.Path(source_file_name).unlink(missing_ok=True) + offset_literals_src = "\n".join(f'{o} = offset("{o}")' for o in offset_literals) + axis_literals_src = "\n".join( + f'{o.value} = gtx.Dimension("{o.value}", kind=gtx.DimensionKind("{o.kind}"))' + for o in axis_literals_set + ) + source_code = f"{header}{offset_literals_src}\n{axis_literals_src}\n{program}" assert isinstance(ir, itir.Program) - fencil_name = ir.id - fencil = getattr(mod, fencil_name) + entry_point_name = ir.id + + _SOURCE_CACHE[cache_key] = (source_code, entry_point_name) + return source_code, entry_point_name - _FENCIL_CACHE[cache_key] = fencil - return typing.cast(stages.ExecutableProgram, fencil) +def _load_module(source_code: str, debug: bool) -> types.ModuleType: + if source_code in _MODULE_CACHE: + return _MODULE_CACHE[source_code] + + if debug: + # Write to a real .py so debuggers/tracebacks have file/line info. + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", encoding="utf-8", delete=False + ) as source_file: + source_file.write(source_code) + source_file_name = source_file.name + print(source_file_name) + spec = importlib.util.spec_from_file_location("module.name", source_file_name) + mod = importlib.util.module_from_spec(spec) # type: ignore[arg-type] + spec.loader.exec_module(mod) # type: ignore[union-attr] + else: + mod = types.ModuleType("roundtrip_module") + exec(compile(source_code, "", "exec"), mod.__dict__) + + _MODULE_CACHE[source_code] = mod + return mod @dataclasses.dataclass(frozen=True) -class Roundtrip(workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram]): - debug: Optional[bool] = None - use_embedded: bool = True - dispatch_backend: Optional[next_backend.Backend] = None - transforms: itir_transforms.GTIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` +class RoundtripArtifact: + """Source-string artifact for the roundtrip backend. - def __call__(self, inp: definitions.CompilableProgramDef) -> stages.ExecutableProgram: - debug = config.DEBUG if self.debug is None else self.debug + The generated Python source is the artifact: picklable, re-execed on + :meth:`load`. When ``debug`` is true, ``load`` writes a temporary ``.py`` + so debuggers/tracebacks resolve to source lines. + """ - fencil = fencil_generator( - inp.data, - offset_provider=inp.args.offset_provider, - debug=debug, - use_embedded=self.use_embedded, - transforms=self.transforms, - ) + source_code: str + entry_point_name: str + column_axis: common.Dimension | None + dispatch_backend: next_backend.Backend | None + debug: bool + + def load(self) -> stages.ExecutableProgram: + mod = _load_module(self.source_code, self.debug) + fencil = getattr(mod, self.entry_point_name) + captured_column_axis = self.column_axis + dispatch_backend = self.dispatch_backend def decorated_fencil( *args: Any, offset_provider: dict[str, common.Connectivity | common.Dimension], out: Any = None, - column_axis: Optional[common.Dimension] = None, + column_axis: Optional[ + common.Dimension + ] = None, # TODO(tehrengruber): unused, kept for signature compat **kwargs: Any, ) -> None: if out is not None: args = (*args, out) - if not column_axis: # TODO(tehrengruber): This variable is never used. Bug? - column_axis = inp.args.column_axis fencil( *args, offset_provider=offset_provider, - backend=self.dispatch_backend, - column_axis=inp.args.column_axis, + backend=dispatch_backend, + column_axis=captured_column_axis, **kwargs, ) return decorated_fencil +@dataclasses.dataclass(frozen=True) +class Roundtrip(workflow.Workflow[definitions.CompilableProgramDef, RoundtripArtifact]): + debug: Optional[bool] = None + use_embedded: bool = True + dispatch_backend: Optional[next_backend.Backend] = None + transforms: itir_transforms.GTIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` + + def __call__(self, inp: definitions.CompilableProgramDef) -> RoundtripArtifact: + debug = config.DEBUG if self.debug is None else self.debug + + source_code, entry_point_name = _generate_source( + inp.data, + offset_provider=inp.args.offset_provider, + debug=debug, + use_embedded=self.use_embedded, + transforms=self.transforms, + ) + + return RoundtripArtifact( + source_code=source_code, + entry_point_name=entry_point_name, + column_axis=inp.args.column_axis, + dispatch_backend=self.dispatch_backend, + debug=debug, + ) + + # TODO(tehrengruber): introduce factory default = next_backend.Backend( name="roundtrip", diff --git a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py index 49bd7b8f87..84226c4e03 100644 --- a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py +++ b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py @@ -10,6 +10,7 @@ import numpy as np +from gt4py._core import definitions as core_defs from gt4py.next import config from gt4py.next.otf import workflow from gt4py.next.otf.binding import nanobind @@ -24,11 +25,13 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_cmake") build_the_program = workflow.make_step(nanobind.bind_source).chain( - compiler.Compiler( - cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=cmake.CMakeFactory() + compiler.CPPCompiler( + cache_lifetime=config.BuildCacheLifetime.SESSION, + builder_factory=cmake.CMakeFactory(), + device_type=core_defs.DeviceType.CPU, ) ) - compiled_program = build_the_program(example_program_source) + compiled_program = build_the_program(example_program_source).load() buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) tup = [ (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), @@ -42,12 +45,13 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): def test_gtfn_cpp_with_compiledb(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_compiledb") build_the_program = workflow.make_step(nanobind.bind_source).chain( - compiler.Compiler( + compiler.CPPCompiler( cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=compiledb.CompiledbFactory(), + device_type=core_defs.DeviceType.CPU, ) ) - compiled_program = build_the_program(example_program_source) + compiled_program = build_the_program(example_program_source).load() buf = (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)) tup = [ (np.zeros(shape=(6, 5), dtype=np.float32), (0, 0)), diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py new file mode 100644 index 0000000000..7dbaeaf719 --- /dev/null +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/test_compiler.py @@ -0,0 +1,26 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Minimal contract tests for :class:`compiler.CPPCompilationArtifact`.""" + +import pathlib +import pickle + +from gt4py._core import definitions as core_defs +from gt4py.next.otf.compilation import compiler + + +def test_cpp_compilation_artifact_pickle_round_trip(tmp_path: pathlib.Path): + artifact = compiler.CPPCompilationArtifact( + src_dir=tmp_path, + module=pathlib.Path("entry.so"), + entry_point_name="entry", + device_type=core_defs.DeviceType.CPU, + ) + restored = pickle.loads(pickle.dumps(artifact)) + assert restored == artifact diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index def8800c98..ed881c9495 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -114,12 +114,19 @@ def test_inlining_of_scalar_works_integration(testee_prog): hijacked_program = None + @dataclasses.dataclass(frozen=True) + class _NoOpArtifact: + """A trivial CompilationArtifact that loads to a no-op callable.""" + + def load(self): + return lambda *args, **kwargs: None + def pirate(program: toolchain.ConcreteArtifact): - # Replaces the gtfn otf_workflow: and steals the compilable program, - # then returns a dummy "CompiledProgram" that does nothing. + # Replaces the gtfn otf_workflow: steals the compilable program, then + # returns a dummy artifact whose materialization is a no-op callable. nonlocal hijacked_program hijacked_program = program - return lambda *args, **kwargs: None + return _NoOpArtifact() hacked_gtfn_backend = gtfn.GTFNBackendFactory(name_postfix="_custom", executor=pirate) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py new file mode 100644 index 0000000000..29d0ded9e1 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_compilation.py @@ -0,0 +1,36 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Minimal contract tests for :class:`compilation.DaCeCompilationArtifact`.""" + +import pathlib +import pickle + +import pytest + +pytest.importorskip("dace") + +from gt4py._core import definitions as core_defs # noqa: E402 +from gt4py.next.program_processors.runners.dace.workflow import compilation # noqa: E402 + + +def test_dace_compilation_artifact_pickle_round_trip_drops_live_program(tmp_path: pathlib.Path): + artifact = compilation.DaCeCompilationArtifact( + build_folder=tmp_path, + sdfg_json="{}", + binding_source_code="def update_sdfg_args(*a, **k): ...", + bind_func_name="update_sdfg_args", + device_type=core_defs.DeviceType.CPU, + ) + object.__setattr__(artifact, "_live_program", "") + + restored = pickle.loads(pickle.dumps(artifact)) + + # The data fields round-trip, the live in-process handle does not. + assert restored == artifact + assert restored._live_program is None diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index 96d8c6e27c..712e0500f5 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -37,8 +37,9 @@ def test_backend_factory_trait_device(): assert cpu_version.executor.translation.device_type is core_defs.DeviceType.CPU assert gpu_version.executor.translation.device_type is core_defs.DeviceType.CUDA - assert cpu_version.executor.decoration.keywords["device"] is core_defs.DeviceType.CPU - assert gpu_version.executor.decoration.keywords["device"] is core_defs.DeviceType.CUDA + # The compilation step now also carries device_type so it can stamp the artifact. + assert cpu_version.executor.compilation.device_type is core_defs.DeviceType.CPU + assert gpu_version.executor.compilation.device_type is core_defs.DeviceType.CUDA assert custom_layout_allocators.is_field_allocator_for( cpu_version.allocator, core_defs.DeviceType.CPU