Source code for ufl.algorithms.apply_coefficient_split

"""Apply coefficient split.

This module contains the apply_coefficient_split function that
decomposes mixed coefficients in the given Expr into components.

"""

from functools import singledispatchmethod

import numpy as np

from ufl.classes import (
    Coefficient,
    ComponentTensor,
    Expr,
    MultiIndex,
    NegativeRestricted,
    PositiveRestricted,
    ReferenceGrad,
    ReferenceValue,
    Restricted,
    Terminal,
    Zero,
)
from ufl.core.multiindex import indices
from ufl.corealg.dag_traverser import DAGTraverser
from ufl.tensors import as_tensor


[docs] class CoefficientSplitter(DAGTraverser): """DAGTraverser to split mixed coefficients.""" def __init__( self, coefficient_split: dict, compress: bool | None = True, visited_cache: dict[tuple, Expr] | None = None, result_cache: dict[Expr, Expr] | None = None, ) -> None: """Initialise. Args: coefficient_split: `dict` that maps mixed coefficients to their components. compress: If True, ``result_cache`` will be used. visited_cache: cache of intermediate results; expr -> r = self.process(expr, ...). result_cache: cache of result objects for memory reuse, r -> r. """ super().__init__(compress=compress, visited_cache=visited_cache, result_cache=result_cache) self._coefficient_split = coefficient_split
[docs] @singledispatchmethod def process( self, o: Expr, reference_value: bool | None = False, reference_grad: int | None = 0, restricted: str | None = None, ) -> Expr: """Split mixed coefficients. Args: o: `Expr` to be processed. reference_value: Whether `ReferenceValue` has been applied or not. reference_grad: Number of `ReferenceGrad`s that have been applied. restricted: '+', '-', or None. Returns: This ``o`` wrapped with `ReferenceValue` (if ``reference_value``), `ReferenceGrad` (``reference_grad`` times), and `Restricted` (if ``restricted`` is '+' or '-'). The underlying terminal will be decomposed into components according to ``self._coefficient_split``. """ return super().process(o)
@process.register(Expr) def _( self, o: Expr, reference_value: bool | None = False, reference_grad: int | None = 0, restricted: str | None = None, ) -> Expr: """Handle Expr.""" return self.reuse_if_untouched( o, reference_value=reference_value, reference_grad=reference_grad, restricted=restricted, ) @process.register(ReferenceValue) def _( self, o: Expr, reference_value: bool | None = False, reference_grad: int | None = 0, restricted: str | None = None, ) -> Expr: """Handle ReferenceValue.""" if reference_value: raise RuntimeError(f"Can not apply ReferenceValue on a ReferenceValue: got {o}") (op,) = o.ufl_operands if not op._ufl_terminal_modifiers_: raise ValueError(f"Must be a terminal modifier: {op!r}.") return self( op, reference_value=True, reference_grad=reference_grad, restricted=restricted, ) @process.register(ReferenceGrad) def _( self, o: Expr, reference_value: bool = False, reference_grad: int = 0, restricted: str | None = None, ) -> Expr: """Handle ReferenceGrad.""" (op,) = o.ufl_operands if not op._ufl_terminal_modifiers_: raise ValueError(f"Must be a terminal modifier: {op!r}.") return self( op, reference_value=reference_value, reference_grad=reference_grad + 1, restricted=restricted, ) @process.register(Restricted) def _( self, o: Restricted, reference_value: bool | None = False, reference_grad: int | None = 0, restricted: str | None = None, ) -> Expr: """Handle Restricted.""" if restricted is not None: raise RuntimeError(f"Can not apply Restricted on a Restricted: got {o}") (op,) = o.ufl_operands if not op._ufl_terminal_modifiers_: raise ValueError(f"Must be a terminal modifier: {op!r}.") return self( op, reference_value=reference_value, reference_grad=reference_grad, restricted=o._side, ) @process.register(Terminal) def _( self, o: Expr, reference_value: bool | None = False, reference_grad: int = 0, restricted: str | None = None, ) -> Expr: """Handle Terminal.""" return self._handle_terminal( o, reference_value=reference_value, reference_grad=reference_grad, restricted=restricted, ) @process.register(Coefficient) def _( self, o: Expr, reference_value: bool | None = False, reference_grad: int = 0, restricted: str | None = None, ) -> Expr: """Handle Coefficient.""" if o not in self._coefficient_split: return self._handle_terminal( o, reference_value=reference_value, reference_grad=reference_grad, restricted=restricted, ) if not reference_value: raise RuntimeError("ReferenceValue expected") beta = indices(reference_grad) components = [] for coeff in self._coefficient_split[o]: c = self._handle_terminal( coeff, reference_value=reference_value, reference_grad=reference_grad, restricted=restricted, ) for alpha in np.ndindex(coeff.ufl_element().reference_value_shape): components.append(c[alpha + beta]) (i,) = indices(1) return ComponentTensor(as_tensor(components)[i], MultiIndex((i,) + beta)) def _handle_terminal( self, o: Expr, reference_value: bool | None = False, reference_grad: int = 0, restricted: str | None = None, ) -> Expr: """Wrap terminal as needed.""" c = o if reference_value: c = ReferenceValue(c) for k in range(reference_grad): c = ReferenceGrad(c) if isinstance(c, Zero): gdim = c.ufl_shape[-1] c = Zero( c.ufl_shape + (gdim,) * (reference_grad - k - 1), c.ufl_free_indices, c.ufl_index_dimensions, ) break if restricted == "+": c = PositiveRestricted(c) elif restricted == "-": c = NegativeRestricted(c) elif restricted is not None: raise RuntimeError(f"Got unknown restriction: {restricted}") return c
[docs] def apply_coefficient_split(expr: Expr, coefficient_split: dict) -> Expr: """Split mixed coefficients. Args: expr: UFL expression. coefficient_split: `dict` that maps mixed coefficients to their components. Returns: ``expr`` with uderlying mixed coefficients split according to ``coefficient_split``. """ if not coefficient_split: return expr return CoefficientSplitter(coefficient_split)(expr)