Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/gt4py/next/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,17 @@ def step_order(self, inp: definitions.ConcreteProgramDef) -> list[str]:
@dataclasses.dataclass(frozen=True)
class Backend(Generic[core_defs.DeviceTypeT]):
name: str
executor: workflow.Workflow[definitions.CompilableProgramDef, stages.ExecutableProgram]
executor: workflow.Workflow[definitions.CompilableProgramDef, stages.CompilationArtifact]
allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]
transforms: workflow.Workflow[definitions.ConcreteProgramDef, definitions.CompilableProgramDef]

def compile(
self, program: definitions.IRDefinitionT, compile_time_args: arguments.CompileTimeArgs
) -> stages.ExecutableProgram:
return self.executor(
artifact = self.executor(
self.transforms(definitions.ConcreteProgramDef(data=program, args=compile_time_args))
)
return artifact.load()

@property
def __gt_allocator__(
Expand Down
65 changes: 49 additions & 16 deletions src/gt4py/next/otf/compilation/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,12 @@

import factory

from gt4py._core import locking
from gt4py.next import config
from gt4py._core import definitions as core_defs, locking
from gt4py.next import config, utils as gtx_utils
from gt4py.next.otf import code_specs, definitions, stages, workflow
from gt4py.next.otf.compilation import build_data, cache, importer


T = TypeVar("T")


def is_compiled(data: build_data.BuildData) -> bool:
return data.status >= build_data.BuildStatus.COMPILED

Expand All @@ -45,27 +42,58 @@ def __call__(


@dataclasses.dataclass(frozen=True)
class Compiler(
class CPPCompilationArtifact(gtx_utils.MetadataBasedPickling):
"""On-disk result of a CPP-style compilation: a Python extension module.

The default :meth:`load` is an ``importlib`` import + entry-point lookup;
backends override to apply their own calling convention.
"""

src_dir: pathlib.Path
module: pathlib.Path
entry_point_name: str
device_type: core_defs.DeviceType

def load(self) -> stages.ExecutableProgram:
"""Import the .so and return the raw entry point.

Must run in the process that will call the returned program: the
module is registered in that process's ``sys.modules`` under the
``gt4py.__compiled_programs__.`` prefix.
"""
m = importer.import_from_path(
self.src_dir / self.module,
sys_modules_prefix="gt4py.__compiled_programs__.",
)
return getattr(m, self.entry_point_name)


@dataclasses.dataclass(frozen=True)
class CPPCompiler(
workflow.ChainableWorkflowMixin[
stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec],
stages.ExecutableProgram,
CPPCompilationArtifact,
],
workflow.ReplaceEnabledWorkflowMixin[
stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec],
stages.ExecutableProgram,
CPPCompilationArtifact,
],
definitions.CompilationStep[CPPLikeCodeSpecT, code_specs.PythonCodeSpec],
):
"""Use any build system (via configured factory) to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``."""
"""Drive a CPP-style build system into a :class:`CPPCompilationArtifact`.

Backends override :meth:`_make_artifact` to use their own artifact subclass.
"""

cache_lifetime: config.BuildCacheLifetime
builder_factory: BuildSystemProjectGenerator[CPPLikeCodeSpecT, code_specs.PythonCodeSpec]
device_type: core_defs.DeviceType
force_recompile: bool = False

def __call__(
self,
inp: stages.CompilableProject[CPPLikeCodeSpecT, code_specs.PythonCodeSpec],
) -> stages.ExecutableProgram:
) -> CPPCompilationArtifact:
src_dir = cache.get_cache_folder(inp, self.cache_lifetime)

# If we are compiling the same program at the same time (e.g. multiple MPI ranks),
Expand All @@ -83,17 +111,22 @@ def __call__(
f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'."
)

m = importer.import_from_path(
src_dir / new_data.module, sys_modules_prefix="gt4py.__compiled_programs__."
)
func = getattr(m, new_data.entry_point_name)
return self._make_artifact(src_dir, new_data.module, new_data.entry_point_name)

return func
def _make_artifact(
self, src_dir: pathlib.Path, module: pathlib.Path, entry_point_name: str
) -> CPPCompilationArtifact:
return CPPCompilationArtifact(
src_dir=src_dir,
module=module,
entry_point_name=entry_point_name,
device_type=self.device_type,
)


class CompilerFactory(factory.Factory):
class Meta:
model = Compiler
model = CPPCompiler


class CompilationError(RuntimeError): ...
11 changes: 8 additions & 3 deletions src/gt4py/next/otf/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,17 @@ def __call__(

class CompilationStep(
workflow.Workflow[
stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.ExecutableProgram
stages.CompilableProject[CodeSpecT, TargetCodeSpecT], stages.CompilationArtifact
],
Protocol[CodeSpecT, TargetCodeSpecT],
):
"""Compile program source code and bindings into a python callable (CompilableSource -> CompiledProgram)."""
"""Run the build system and produce a :class:`stages.CompilationArtifact`.

Each backend defines its own concrete artifact dataclass (frozen,
picklable, with a :meth:`stages.CompilationArtifact.load` method); they all
satisfy the :class:`stages.CompilationArtifact` Protocol structurally.
"""

def __call__(
self, source: stages.CompilableProject[CodeSpecT, TargetCodeSpecT]
) -> stages.ExecutableProgram: ...
) -> stages.CompilationArtifact: ...
7 changes: 4 additions & 3 deletions src/gt4py/next/otf/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@


@dataclasses.dataclass(frozen=True)
class OTFCompileWorkflow(workflow.NamedStepSequence):
class OTFCompileWorkflow(
workflow.NamedStepSequence[definitions.CompilableProgramDef, stages.CompilationArtifact]
):
"""The typical compiled backend steps composed into a workflow."""

translation: definitions.TranslationStep
bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableProject]
compilation: workflow.Workflow[stages.CompilableProject, stages.ExecutableProgram]
decoration: workflow.Workflow[stages.ExecutableProgram, stages.ExecutableProgram]
compilation: workflow.Workflow[stages.CompilableProject, stages.CompilationArtifact]
13 changes: 13 additions & 0 deletions src/gt4py/next/otf/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,19 @@ def build(self) -> None: ...
ExecutableProgram: TypeAlias = Callable


class CompilationArtifact(Protocol):
"""The output of an :class:`recipes.OTFCompileWorkflow`.

Each backend defines its own concrete artifact dataclass; all share this
Protocol. Implementations are frozen dataclasses, picklable, and have no
live process-bound state — that is reconstructed by :meth:`load`,
which returns a directly-callable :class:`ExecutableProgram` taking
gt4py-shaped arguments.
Comment thread
havogt marked this conversation as resolved.
Comment on lines +136 to +139
"""

def load(self) -> ExecutableProgram: ...


def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]:
"""
Filter out multiple occurrences of the same ``interface.LibraryDependency``.
Expand Down
107 changes: 91 additions & 16 deletions src/gt4py/next/program_processors/runners/dace/workflow/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,25 @@
from __future__ import annotations

import dataclasses
import json
import os
import pathlib
import warnings
from collections.abc import Callable, MutableSequence, Sequence
from typing import Any

import dace
import dace.codegen.compiler as dace_compiler
import factory

from gt4py._core import definitions as core_defs, locking
from gt4py.next import common, config
from gt4py.next import common, config, utils as gtx_utils
from gt4py.next.otf import code_specs, definitions, stages, workflow
from gt4py.next.otf.compilation import cache as gtx_cache
from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon
from gt4py.next.program_processors.runners.dace.workflow import (
common as gtx_wfdcommon,
decoration as gtx_wfddecoration,
)


class CompiledDaceProgram:
Expand Down Expand Up @@ -55,7 +61,7 @@ def __init__(
self,
program: dace.CompiledSDFG,
bind_func_name: str,
binding_source: stages.BindingSource[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec],
binding_source_code: str,
):
self.sdfg_program = program

Expand All @@ -64,9 +70,10 @@ def __init__(
# This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`.
self.sdfg_argtypes = list(program.sdfg.arglist().values())

# Note that `binding_source` contains Python code tailored to this specific SDFG.
# Here we dinamically compile this function and add it to the compiled program.
exec(binding_source.source_code, global_namespace := {}) # type: ignore[var-annotated]
# The binding source code is Python tailored to this specific SDFG.
# We dynamically compile that function and add it to the compiled program.
global_namespace: dict[str, Any] = {}
exec(binding_source_code, global_namespace)
self.update_sdfg_ctype_arglist = global_namespace[bind_func_name]
# For debug purpose, we set a unique module name on the compiled function.
self.update_sdfg_ctype_arglist.__module__ = os.path.basename(program.sdfg.build_folder)
Expand Down Expand Up @@ -114,19 +121,75 @@ def __call__(self, **kwargs: Any) -> None:
assert result is None


@dataclasses.dataclass(frozen=True)
class DaCeCompilationArtifact(gtx_utils.MetadataBasedPickling):
"""Result of a DaCe compilation: build folder + SDFG bindings + the SDFG itself.

The SDFG is carried inline as JSON because dace's load path
(:func:`get_program_handle`) needs an SDFG instance to wrap into the
returned :class:`CompiledSDFG`, and the build folder may not contain a
``program.sdfg(z)`` dump under the upcoming minimal-build-dir mode.
"""

build_folder: pathlib.Path
sdfg_json: str
binding_source_code: str
bind_func_name: str
device_type: core_defs.DeviceType

# Process-local cache of the live :class:`CompiledDaceProgram`. Populated by
# ``DaCeCompiler`` to skip the disk round-trip when the artifact stays in
# the same process. Excluded from pickle (``pickle=False`` metadata) so
# receivers in other processes see ``None`` and fall through to the
# disk-based load.
_live_program: CompiledDaceProgram | None = dataclasses.field(
init=False,
default=None,
compare=False,
repr=False,
metadata=gtx_utils.gt4py_metadata(pickle=False),
)

def load(self) -> stages.ExecutableProgram:
"""Wrap the compiled program in gt4py's calling convention.

On a miss, loads the precompiled .so directly via
:func:`dace.codegen.compiler.get_program_handle` — no recompilation,
no ``dace.config`` re-entry. Must run in the process that will call
the returned program.
"""
program = self._live_program
if program is None:
program = self._load_compiled_program()
object.__setattr__(self, "_live_program", program)
return gtx_wfddecoration.convert_args(program, device=self.device_type)

def _load_compiled_program(self) -> CompiledDaceProgram:
# TODO(phimuell): Drop ``sdfg_json`` from the artifact once dace
# exposes a load path that doesn't require an SDFG instance to wrap
# into the returned ``CompiledSDFG``.
sdfg = dace.SDFG.from_json(json.loads(self.sdfg_json))
folder_version = dace_compiler.get_folder_version(self.build_folder)
library_path = dace_compiler.get_binary_name(
self.build_folder, sdfg_name=sdfg.name, folder_version=folder_version
)
sdfg_program = dace_compiler.get_program_handle(library_path, sdfg)
return CompiledDaceProgram(sdfg_program, self.bind_func_name, self.binding_source_code)


@dataclasses.dataclass(frozen=True)
class DaCeCompiler(
workflow.ChainableWorkflowMixin[
stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec],
CompiledDaceProgram,
DaCeCompilationArtifact,
],
workflow.ReplaceEnabledWorkflowMixin[
stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec],
CompiledDaceProgram,
DaCeCompilationArtifact,
],
definitions.CompilationStep[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec],
):
"""Use the dace build system to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``."""
"""Run the DaCe build system and produce an on-disk :class:`DaCeCompilationArtifact`."""

bind_func_name: str
cache_lifetime: config.BuildCacheLifetime
Expand All @@ -136,25 +199,37 @@ class DaCeCompiler(
def __call__(
self,
inp: stages.CompilableProject[code_specs.SDFGCodeSpec, code_specs.PythonCodeSpec],
) -> CompiledDaceProgram:
) -> DaCeCompilationArtifact:
with gtx_wfdcommon.dace_context(
device_type=self.device_type,
cmake_build_type=self.cmake_build_type,
):
sdfg_build_folder = gtx_cache.get_cache_folder(inp, self.cache_lifetime)
sdfg_build_folder = pathlib.Path(gtx_cache.get_cache_folder(inp, self.cache_lifetime))
sdfg_build_folder.mkdir(parents=True, exist_ok=True)

sdfg = dace.SDFG.from_json(inp.program_source.source_code)
sdfg.build_folder = sdfg_build_folder
sdfg.build_folder = str(sdfg_build_folder)
with locking.lock(sdfg_build_folder):
# Keep the handle so the artifact's load() can skip the disk
# round-trip in the same process.
sdfg_program = sdfg.compile(validate=False)

assert inp.binding_source is not None
return CompiledDaceProgram(
sdfg_program,
self.bind_func_name,
inp.binding_source,
artifact = DaCeCompilationArtifact(
build_folder=sdfg_build_folder,
sdfg_json=json.dumps(inp.program_source.source_code),
binding_source_code=inp.binding_source.source_code,
bind_func_name=self.bind_func_name,
device_type=self.device_type,
)
object.__setattr__(
artifact,
"_live_program",
CompiledDaceProgram(
sdfg_program, artifact.bind_func_name, artifact.binding_source_code
),
)
return artifact


class DaCeCompilationStepFactory(factory.Factory):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from __future__ import annotations

import functools
from typing import Any, Sequence
from typing import TYPE_CHECKING, Any, Sequence

import numpy as np

Expand All @@ -18,14 +18,16 @@
from gt4py.next.instrumentation import metrics
from gt4py.next.otf import stages
from gt4py.next.program_processors.runners.dace import sdfg_callable
from gt4py.next.program_processors.runners.dace.workflow import (
common as gtx_wfdcommon,
compilation as gtx_wfdcompilation,
)
from gt4py.next.program_processors.runners.dace.workflow import common as gtx_wfdcommon


if TYPE_CHECKING:
# Type-only: a top-level import would cycle with ``compilation``.
from gt4py.next.program_processors.runners.dace.workflow.compilation import CompiledDaceProgram


def convert_args(
fun: gtx_wfdcompilation.CompiledDaceProgram,
fun: CompiledDaceProgram,
device: core_defs.DeviceType = core_defs.DeviceType.CPU,
) -> stages.ExecutableProgram:
# Retieve metrics level from GT4Py environment variable.
Expand Down
Loading