From 70a698c2bb15e6fc3e5e27ca726149d06d4c1cd8 Mon Sep 17 00:00:00 2001 From: Philipp Arras <parras@mpa-garching.mpg.de> Date: Fri, 6 Aug 2021 11:52:26 +0200 Subject: [PATCH] Cosmetics --- src/operators/jax_operator.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/operators/jax_operator.py b/src/operators/jax_operator.py index 9db0e68b0..5aa621c97 100644 --- a/src/operators/jax_operator.py +++ b/src/operators/jax_operator.py @@ -15,11 +15,10 @@ # Author: Philipp Arras import numpy as np -from .operator import Operator -from .energy_operators import EnergyOperator, LikelihoodEnergyOperator -from .linear_operator import LinearOperator -from .endomorphic_operator import EndomorphicOperator +from .energy_operators import LikelihoodEnergyOperator +from .linear_operator import LinearOperator +from .operator import Operator try: import jax @@ -60,8 +59,8 @@ class JaxOperator(Operator): self._fwd = jax.jit(lambda x, y: jax.jvp(self._func, (x,), (y,))[1]) def apply(self, x): - from ..sugar import is_linearization, makeField from ..multi_domain import MultiDomain + from ..sugar import is_linearization, makeField self._check_input(x) if is_linearization(x): res, bwd = self._vjp(x.val.val) @@ -137,9 +136,9 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator): return self._dt, self._trafo def apply(self, x): + from ..linearization import Linearization from ..sugar import is_linearization, makeField from .simple_linear_operators import VdotOperator - from ..linearization import Linearization self._check_input(x) lin = is_linearization(x) val = x.val.val if lin else x.val -- GitLab