diff --git a/sumpy/codegen.py b/sumpy/codegen.py index 3b19fcb9e..2f9be4bc8 100644 --- a/sumpy/codegen.py +++ b/sumpy/codegen.py @@ -32,7 +32,7 @@ from pytools import memoize_method -from sumpy.symbolic import (SympyToPymbolicMapper as SympyToPymbolicMapperBase) +from sumpy.symbolic import SympyToPymbolicMapper import logging logger = logging.getLogger(__name__) @@ -49,27 +49,6 @@ """ -# {{{ sympy -> pymbolic mapper - -import sumpy.symbolic as sym -_SPECIAL_FUNCTION_NAMES = frozenset(dir(sym.functions)) - - -class SympyToPymbolicMapper(SympyToPymbolicMapperBase): - - def not_supported(self, expr): - if isinstance(expr, int): - return expr - elif getattr(expr, "is_Function", False): - func_name = SympyToPymbolicMapperBase.function_name(self, expr) - return prim.Variable(func_name)( - *tuple(self.rec(arg) for arg in expr.args)) - else: - return SympyToPymbolicMapperBase.not_supported(self, expr) - -# }}} - - # {{{ bessel -> loopy codegen BESSEL_PREAMBLE = """//CL// diff --git a/sumpy/kernel.py b/sumpy/kernel.py index 1be1dec93..444eb4293 100644 --- a/sumpy/kernel.py +++ b/sumpy/kernel.py @@ -519,7 +519,15 @@ def __init__(self, dim, helmholtz_k_name="k", scaling = var("I")/4 elif dim == 3: r = pymbolic_real_norm_2(make_sym_vector("d", dim)) - expr = var("exp")(var("I")*k*r)/r + if allow_evanescent: + expr = var("exp")(var("I")*k*r)/r + else: + # expi is a function that takes in a real and returns a + # complex number such that + # expi(x) = exp(I * x) + # Retaining the information that the input is real leads + # to better code generation + expr = var("expi")(k*r)/r scaling = 1/(4*var("pi")) else: raise RuntimeError("unsupported dimensionality") diff --git a/sumpy/symbolic.py b/sumpy/symbolic.py index 48d43a5ef..90a80fa06 100644 --- a/sumpy/symbolic.py +++ b/sumpy/symbolic.py @@ -313,6 +313,17 @@ def map_Mul(self, expr): # noqa: N802 return math.prod(num_args) / math.prod(den_args) + def not_supported(self, expr): + if getattr(expr, "is_Function", False): + if self.function_name(expr) == "ExpI": + arg = self.rec(expr.args[0]) + return prim.Variable("cos")(arg) + 1j * prim.Variable("sin")(arg) + else: + return prim.Variable(self.function_name(expr))( + *[self.rec(arg) for arg in expr.args]) + else: + return SympyToPymbolicMapperBase.not_supported(self, expr) + class PymbolicToSympyMapperWithSymbols(PymbolicToSympyMapper): def map_variable(self, expr): @@ -338,6 +349,9 @@ def map_call(self, expr): args = [self.rec(param) for param in expr.parameters] args.append(0) return BesselJ(*args) + elif expr.function.name == "expi": + args = [self.rec(param) for param in expr.parameters] + return ExpI(*args) else: return PymbolicToSympyMapper.map_call(self, expr) @@ -369,8 +383,20 @@ class Hankel1(_BesselOrHankel): pass +class ExpI(sympy.Function): + """A symbolic function that takes a real value as an + input and returns a complex number such that + expi(x) = exp(i*x). + """ + nargs = (1,) + + def fdiff(self, argindex=1): + return self.func(self.args[0]) * sympy.I + + _SympyBesselJ = BesselJ _SympyHankel1 = Hankel1 +_SympyExpI = ExpI if USE_SYMENGINE: def BesselJ(*args): # noqa: N802 # pylint: disable=function-redefined @@ -379,4 +405,7 @@ def BesselJ(*args): # noqa: N802 # pylint: disable=function-redefined def Hankel1(*args): # noqa: N802 # pylint: disable=function-redefined return sym.sympify(_SympyHankel1(*args)) + def ExpI(*args): # noqa: N802 # pylint: disable=function-redefined + return sym.sympify(_SympyExpI(*args)) + # vim: fdm=marker