Commit f61cb6b2 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'cfm_changes' into 'NIFTy_6'

Small changes in CorrelatedFieldMaker

See merge request !384
parents 5a83ad0a 4e27bb9e
Pipeline #64929 passed with stages
in 8 minutes and 52 seconds
...@@ -341,7 +341,6 @@ class CorrelatedFieldMaker: ...@@ -341,7 +341,6 @@ class CorrelatedFieldMaker:
def __init__(self, amplitude_offset, prefix, total_N): def __init__(self, amplitude_offset, prefix, total_N):
assert isinstance(amplitude_offset, Operator) assert isinstance(amplitude_offset, Operator)
self._a = [] self._a = []
self._spaces = []
self._position_spaces = [] self._position_spaces = []
self._azm = amplitude_offset self._azm = amplitude_offset
...@@ -421,11 +420,9 @@ class CorrelatedFieldMaker: ...@@ -421,11 +420,9 @@ class CorrelatedFieldMaker:
if index is not None: if index is not None:
self._a.insert(index, amp) self._a.insert(index, amp)
self._position_spaces.insert(index, position_space) self._position_spaces.insert(index, position_space)
self._spaces.insert(index, space)
else: else:
self._a.append(amp) self._a.append(amp)
self._position_spaces.append(position_space) self._position_spaces.append(position_space)
self._spaces.append(space)
def _finalize_from_op(self): def _finalize_from_op(self):
n_amplitudes = len(self._a) n_amplitudes = len(self._a)
...@@ -433,29 +430,29 @@ class CorrelatedFieldMaker: ...@@ -433,29 +430,29 @@ class CorrelatedFieldMaker:
hspace = makeDomain([UnstructuredDomain(self._total_N)] + hspace = makeDomain([UnstructuredDomain(self._total_N)] +
[dd.target[-1].harmonic_partner [dd.target[-1].harmonic_partner
for dd in self._a]) for dd in self._a])
spaces = list(1 + np.arange(n_amplitudes)) spaces = tuple(range(1, n_amplitudes + 1))
amp_space = 1
else: else:
hspace = makeDomain( hspace = makeDomain(
[dd.target[0].harmonic_partner for dd in self._a]) [dd.target[0].harmonic_partner for dd in self._a])
spaces = tuple(range(n_amplitudes)) spaces = tuple(range(n_amplitudes))
spaces = list(np.arange(n_amplitudes)) amp_space = 0
expander = ContractionOperator(hspace, spaces=spaces).adjoint expander = ContractionOperator(hspace, spaces=spaces).adjoint
azm = expander @ self._azm azm = expander @ self._azm
# spaces = np.array(range(n_amplitudes)) + 1 - 1//self._total_N
ht = HarmonicTransformOperator(hspace, ht = HarmonicTransformOperator(hspace,
self._position_spaces[0][self._spaces[0]], self._position_spaces[0][amp_space],
space=spaces[0]) space=spaces[0])
for i in range(1, n_amplitudes): for i in range(1, n_amplitudes):
ht = (HarmonicTransformOperator(ht.target, ht = (HarmonicTransformOperator(ht.target,
self._position_spaces[i][self._spaces[i]], self._position_spaces[i][amp_space],
space=spaces[i]) @ ht) space=spaces[i]) @ ht)
pd = PowerDistributor(hspace, self._a[0].target[self._spaces[0]], self._spaces[0]) pd = PowerDistributor(hspace, self._a[0].target[amp_space], amp_space)
for i in range(1, n_amplitudes): for i in range(1, n_amplitudes):
pd = (pd @ PowerDistributor(pd.domain, pd = (pd @ PowerDistributor(pd.domain,
self._a[i].target[self._spaces[i]], self._a[i].target[amp_space],
space=spaces[i])) space=spaces[i]))
a = ContractionOperator(pd.domain, spaces[1:]).adjoint @ self._a[0] a = ContractionOperator(pd.domain, spaces[1:]).adjoint @ self._a[0]
...@@ -523,7 +520,8 @@ class CorrelatedFieldMaker: ...@@ -523,7 +520,8 @@ class CorrelatedFieldMaker:
' no unique set of amplitudes exist because only the', ' no unique set of amplitudes exist because only the',
' relative scale is determined.') ' relative scale is determined.')
raise NotImplementedError(s) raise NotImplementedError(s)
expand = VdotOperator(full(self._a[0].target, 1)).adjoint dom = self._a[0].target
expand = ContractionOperator(dom, len(dom)-1).adjoint
return self._a[0]*(expand @ self.amplitude_total_offset) return self._a[0]*(expand @ self.amplitude_total_offset)
@property @property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment