Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 1 addition & 22 deletions sumpy/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from pytools import memoize_method

from sumpy.symbolic import (SympyToPymbolicMapper as SympyToPymbolicMapperBase)
from sumpy.symbolic import SympyToPymbolicMapper

import logging
logger = logging.getLogger(__name__)
Expand All @@ -49,27 +49,6 @@
"""


# {{{ sympy -> pymbolic mapper

import sumpy.symbolic as sym
_SPECIAL_FUNCTION_NAMES = frozenset(dir(sym.functions))


class SympyToPymbolicMapper(SympyToPymbolicMapperBase):

def not_supported(self, expr):
if isinstance(expr, int):
return expr
elif getattr(expr, "is_Function", False):
func_name = SympyToPymbolicMapperBase.function_name(self, expr)
return prim.Variable(func_name)(
*tuple(self.rec(arg) for arg in expr.args))
else:
return SympyToPymbolicMapperBase.not_supported(self, expr)

# }}}


# {{{ bessel -> loopy codegen

BESSEL_PREAMBLE = """//CL//
Expand Down
10 changes: 9 additions & 1 deletion sumpy/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,15 @@ def __init__(self, dim, helmholtz_k_name="k",
scaling = var("I")/4
elif dim == 3:
r = pymbolic_real_norm_2(make_sym_vector("d", dim))
expr = var("exp")(var("I")*k*r)/r
if allow_evanescent:
expr = var("exp")(var("I")*k*r)/r
else:
# expi is a function that takes in a real and returns a
# complex number such that
# expi(x) = exp(I * x)
# Retaining the information that the input is real leads
# to better code generation
expr = var("expi")(k*r)/r
scaling = 1/(4*var("pi"))
else:
raise RuntimeError("unsupported dimensionality")
Expand Down
29 changes: 29 additions & 0 deletions sumpy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,17 @@ def map_Mul(self, expr): # noqa: N802

return math.prod(num_args) / math.prod(den_args)

def not_supported(self, expr):
if getattr(expr, "is_Function", False):
if self.function_name(expr) == "ExpI":
arg = self.rec(expr.args[0])
return prim.Variable("cos")(arg) + 1j * prim.Variable("sin")(arg)
else:
return prim.Variable(self.function_name(expr))(
*[self.rec(arg) for arg in expr.args])
else:
return SympyToPymbolicMapperBase.not_supported(self, expr)


class PymbolicToSympyMapperWithSymbols(PymbolicToSympyMapper):
def map_variable(self, expr):
Expand All @@ -338,6 +349,9 @@ def map_call(self, expr):
args = [self.rec(param) for param in expr.parameters]
args.append(0)
return BesselJ(*args)
elif expr.function.name == "expi":
args = [self.rec(param) for param in expr.parameters]
return ExpI(*args)
else:
return PymbolicToSympyMapper.map_call(self, expr)

Expand Down Expand Up @@ -369,8 +383,20 @@ class Hankel1(_BesselOrHankel):
pass


class ExpI(sympy.Function):
"""A symbolic function that takes a real value as an
input and returns a complex number such that
expi(x) = exp(i*x).
"""
nargs = (1,)

def fdiff(self, argindex=1):
return self.func(self.args[0]) * sympy.I


_SympyBesselJ = BesselJ
_SympyHankel1 = Hankel1
_SympyExpI = ExpI

if USE_SYMENGINE:
def BesselJ(*args): # noqa: N802 # pylint: disable=function-redefined
Expand All @@ -379,4 +405,7 @@ def BesselJ(*args): # noqa: N802 # pylint: disable=function-redefined
def Hankel1(*args): # noqa: N802 # pylint: disable=function-redefined
return sym.sympify(_SympyHankel1(*args))

def ExpI(*args): # noqa: N802 # pylint: disable=function-redefined
return sym.sympify(_SympyExpI(*args))

# vim: fdm=marker