Commit db4db3c2 authored by Philipp Haim's avatar Philipp Haim

New behaviour of domain tuples

parent 7453ea72
...@@ -33,21 +33,21 @@ from ..operators.operator import Operator ...@@ -33,21 +33,21 @@ from ..operators.operator import Operator
from ..operators.simple_linear_operators import VdotOperator, ducktape from ..operators.simple_linear_operators import VdotOperator, ducktape
from ..operators.value_inserter import ValueInserter from ..operators.value_inserter import ValueInserter
from ..probing import StatCalculator from ..probing import StatCalculator
from ..sugar import from_global_data, full, makeDomain, get_default_codomain from ..sugar import from_global_data, from_random, full, makeDomain, get_default_codomain
def _reshaper(x, shape): def _reshaper(x, N):
x = np.array(x) x = np.asfarray(x)
if x.shape == shape: if x.shape in [(), (1,)]:
return np.asfarray(x) return np.full(N, x) if N != 1 else x.reshape(())
elif x.shape in [(), (1,)]: elif x.shape == (N,):
return np.full(shape, x, dtype=np.float) return x
else: else:
raise TypeError("Shape of parameters cannot be interpreted") raise TypeError("Shape of parameters cannot be interpreted")
def _lognormal_moments(mean, sig, shape = ()): def _lognormal_moments(mean, sig, N = 1):
mean, sig = (_reshaper(param, shape) for param in (mean, sig)) mean, sig = (_reshaper(param, N) for param in (mean, sig))
assert np.all(mean > 0 ) assert np.all(mean > 0 )
assert np.all(sig > 0) assert np.all(sig > 0)
logsig = np.sqrt(np.log((sig/mean)**2 + 1)) logsig = np.sqrt(np.log((sig/mean)**2 + 1))
...@@ -55,9 +55,12 @@ def _lognormal_moments(mean, sig, shape = ()): ...@@ -55,9 +55,12 @@ def _lognormal_moments(mean, sig, shape = ()):
return logmean, logsig return logmean, logsig
def _normal(mean, sig, key, domain = DomainTuple.scalar_domain()): def _normal(mean, sig, key, N = 1):
domain = makeDomain(domain) if N == 1:
mean, sig = (_reshaper(param, domain.shape) for param in (mean, sig)) domain = DomainTuple.scalar_domain()
else:
domain = UnstructuredDomain(N)
mean, sig = (_reshaper(param, N) for param in (mean, sig))
return Adder(from_global_data(domain, mean)) @ ( return Adder(from_global_data(domain, mean)) @ (
DiagonalOperator(from_global_data(domain,sig)) DiagonalOperator(from_global_data(domain,sig))
@ ducktape(domain, None, key)) @ ducktape(domain, None, key))
...@@ -102,13 +105,12 @@ def _stats(op, samples): ...@@ -102,13 +105,12 @@ def _stats(op, samples):
class _LognormalMomentMatching(Operator): class _LognormalMomentMatching(Operator):
def __init__(self, mean, sig, key, def __init__(self, mean, sig, key, N_copies):
domain = DomainTuple.scalar_domain()):
key = str(key) key = str(key)
logmean, logsig = _lognormal_moments(mean, sig, domain.shape) logmean, logsig = _lognormal_moments(mean, sig, N_copies)
self._mean = mean self._mean = mean
self._sig = sig self._sig = sig
op = _normal(logmean, logsig, key, domain).exp() op = _normal(logmean, logsig, key, N_copies).exp()
self._domain, self._target = op.domain, op.target self._domain, self._target = op.domain, op.target
self.apply = op.apply self.apply = op.apply
...@@ -241,11 +243,32 @@ class _slice_extractor(LinearOperator): ...@@ -241,11 +243,32 @@ class _slice_extractor(LinearOperator):
res = np.zeros(self._domain.shape) res = np.zeros(self._domain.shape)
res[self._sl] = x res[self._sl] = x
return from_global_data(self._tgt(mode), res) return from_global_data(self._tgt(mode), res)
class _Distributor(LinearOperator):
def __init__(self, dofdex, domain, target, space = 0):
self._dofdex = dofdex
self._target = makeDomain(target)
self._domain = makeDomain(domain)
self._sl = (slice(None),)*space
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
x = x.to_global_data()
if mode == self.TIMES:
res = x[self._dofdex]
else:
res = np.empty(self._tgt(mode).shape)
res[self._dofdex] = x
return from_global_data(self._tgt(mode), res)
class _Amplitude(Operator): class _Amplitude(Operator):
def __init__(self, target, fluctuations, flexibility, asperity, def __init__(self, target, fluctuations, flexibility, asperity,
loglogavgslope, key, space = 0): loglogavgslope, azm, key, dofdex):
""" """
fluctuations > 0 fluctuations > 0
flexibility > 0 flexibility > 0
...@@ -256,10 +279,19 @@ class _Amplitude(Operator): ...@@ -256,10 +279,19 @@ class _Amplitude(Operator):
assert isinstance(flexibility, Operator) assert isinstance(flexibility, Operator)
assert isinstance(asperity, Operator) assert isinstance(asperity, Operator)
assert isinstance(loglogavgslope, Operator) assert isinstance(loglogavgslope, Operator)
target = makeDomain(target)
assert isinstance(target[space], PowerSpace)
target = makeDomain(target) N_copies = max(dofdex) + 1
assert N_copies > 0
if N_copies > 1:
space = 1
distributed_tgt = makeDomain((UnstructuredDomain(len(dofdex)), target))
target = makeDomain((UnstructuredDomain(N_copies), target))
Distributor = _Distributor(dofdex, target, distributed_tgt, 0)
else:
space = 0
target = makeDomain(target)
assert isinstance(target[space], PowerSpace)
twolog = _TwoLogIntegrations(target, space) twolog = _TwoLogIntegrations(target, space)
dom = twolog.domain dom = twolog.domain
shp = dom[space].shape shp = dom[space].shape
...@@ -299,12 +331,18 @@ class _Amplitude(Operator): ...@@ -299,12 +331,18 @@ class _Amplitude(Operator):
sig_flex = vflex @ expander @ flexibility sig_flex = vflex @ expander @ flexibility
sig_asp = vasp @ expander @ asperity sig_asp = vasp @ expander @ asperity
sig_fluc = vol1 @ ps_expander @ fluctuations sig_fluc = vol1 @ ps_expander @ fluctuations
sig_fluc = vol1 @ ps_expander @ fluctuations
xi = ducktape(dom, None, key) xi = ducktape(dom, None, key)
sigma = sig_flex*(Adder(shift) @ sig_asp).sqrt() sigma = sig_flex*(Adder(shift) @ sig_asp).sqrt()
smooth = _SlopeRemover(target, space) @ twolog @ (sigma*xi) smooth = _SlopeRemover(target, space) @ twolog @ (sigma*xi)
op = _Normalization(target, space) @ (slope + smooth) op = _Normalization(target, space) @ (slope + smooth)
op = Adder(vol0) @ (sig_fluc*op) if space == 1:
op = Distributor @ op
sig_fluc = Distributor @ sig_fluc
op = (Distributor @ Adder(vol0)) @ (sig_fluc*(ps_expander @ azm.one_over())*op)
else:
op = (Adder(vol0)) @ (sig_fluc*(ps_expander @ azm.one_over())*op)
self.apply = op.apply self.apply = op.apply
self._fluc = fluctuations self._fluc = fluctuations
...@@ -317,24 +355,26 @@ class _Amplitude(Operator): ...@@ -317,24 +355,26 @@ class _Amplitude(Operator):
class CorrelatedFieldMaker: class CorrelatedFieldMaker:
def __init__(self, amplitude_offset, prefix): def __init__(self, amplitude_offset, prefix, total_N):
self._a = [] self._a = []
self._spaces = [] self._spaces = []
self._position_spaces = [] self._position_spaces = []
self._azm = amplitude_offset self._azm = amplitude_offset
self._prefix = prefix self._prefix = prefix
self._total_N = total_N
@staticmethod @staticmethod
def make(offset_amplitude_mean, offset_amplitude_stddev, prefix): def make(offset_amplitude_mean, offset_amplitude_stddev, prefix, total_N = 1):
offset_amplitude_stddev = float(offset_amplitude_stddev) offset_amplitude_stddev = float(offset_amplitude_stddev)
offset_amplitude_mean = float(offset_amplitude_mean) offset_amplitude_mean = float(offset_amplitude_mean)
assert offset_amplitude_stddev > 0 assert offset_amplitude_stddev > 0
assert offset_amplitude_mean > 0 assert offset_amplitude_mean > 0
zm = _LognormalMomentMatching(offset_amplitude_mean, zm = _LognormalMomentMatching(offset_amplitude_mean,
offset_amplitude_stddev, offset_amplitude_stddev,
prefix + 'zeromode') prefix + 'zeromode',
return CorrelatedFieldMaker(zm, prefix) total_N)
return CorrelatedFieldMaker(zm, prefix, total_N)
def add_fluctuations(self, def add_fluctuations(self,
position_space, position_space,
...@@ -346,36 +386,49 @@ class CorrelatedFieldMaker: ...@@ -346,36 +386,49 @@ class CorrelatedFieldMaker:
asperity_stddev, asperity_stddev,
loglogavgslope_mean, loglogavgslope_mean,
loglogavgslope_stddev, loglogavgslope_stddev,
prefix='', prefix = '',
index=None, index = None,
space=0): dofdex = None):
position_space = makeDomain(position_space) if dofdex is None:
power_space = list(position_space) dofdex = np.full(self._total_N, 0)
power_space[space] = PowerSpace(position_space[space].get_default_codomain()) else:
power_space = makeDomain(power_space) assert len(dofdex) == self._total_N
N = max(dofdex)
if self._total_N > 1:
space = 1
position_space = makeDomain((UnstructuredDomain(self._total_N), position_space))
else:
space = 0
position_space = makeDomain(position_space)
N = 1
power_space = PowerSpace(position_space[space].get_default_codomain())
prefix = str(prefix) prefix = str(prefix)
#assert isinstance(position_space[space], (RGSpace, HPSpace, GLSpace) #assert isinstance(position_space[space], (RGSpace, HPSpace, GLSpace)
#NOTE alternative to get auxilliary domain
#auxdom = ContractionOperator(position_space, space).domain
auxdom = makeDomain(tuple(dom for i, dom in enumerate(position_space)
if i != space))
fluct = _LognormalMomentMatching(fluctuations_mean, fluct = _LognormalMomentMatching(fluctuations_mean,
fluctuations_stddev, fluctuations_stddev,
prefix + 'fluctuations', prefix + 'fluctuations',
auxdom) N)
#FIXME How should this work on domain tuples?
#fluct = fluct*self._azm.one_over() #if copies:
# fluct = fluct*self._azm.one_over()
#else:
# #print(fluct.
# co = ContractionOperator(self._azm.target, None).adjoint
# fluct = (co @ fluct)*self._azm.one_over()
flex = _LognormalMomentMatching(flexibility_mean, flexibility_stddev, flex = _LognormalMomentMatching(flexibility_mean, flexibility_stddev,
prefix + 'flexibility', prefix + 'flexibility',
auxdom) N)
asp = _LognormalMomentMatching(asperity_mean, asperity_stddev, asp = _LognormalMomentMatching(asperity_mean, asperity_stddev,
prefix + 'asperity', prefix + 'asperity',
auxdom) N)
avgsl = _normal(loglogavgslope_mean, loglogavgslope_stddev, avgsl = _normal(loglogavgslope_mean, loglogavgslope_stddev,
prefix + 'loglogavgslope', auxdom) prefix + 'loglogavgslope', N)
amp = _Amplitude(power_space, amp = _Amplitude(power_space,
fluct, flex, asp, avgsl, prefix + 'spectrum', space) fluct, flex, asp, avgsl, self._azm, prefix + 'spectrum', dofdex)
if index is not None: if index is not None:
self._a.insert(index, amp) self._a.insert(index, amp)
self._position_spaces.insert(index, position_space) self._position_spaces.insert(index, position_space)
...@@ -385,18 +438,20 @@ class CorrelatedFieldMaker: ...@@ -385,18 +438,20 @@ class CorrelatedFieldMaker:
self._position_spaces.append(position_space) self._position_spaces.append(position_space)
self._spaces.append(space) self._spaces.append(space)
def finalize_from_op(self, zeromode, prefix='', space = 0): def finalize_from_op(self, zeromode, prefix=''):
assert isinstance(zeromode, Operator) assert isinstance(zeromode, Operator)
self._azm = zeromode self._azm = zeromode
hspace = [] n_amplitudes = len(self._a)
tuple(hspace.extend(tuple(get_default_codomain(dd, space))) if self._total_N > 1:
for dd, space in zip(self._position_spaces, self._spaces)) hspace = makeDomain([UnstructuredDomain(self._total_N)] +
hspace = makeDomain(hspace) [dd[-1].get_default_codomain() for dd in self._position_spaces])
zeroind = () spaces = tuple(len(dd) for dd in self._position_spaces)
for i, dd in enumerate(self._position_spaces): spaces = 1 + np.cumsum(spaces)
zeroind += (slice(None),)*(self._spaces[i]) else:
zeroind += (0,)*len(dd[self._spaces[i]].shape) hspace = makeDomain(
zeroind += (slice(None),)*(len(dd)-self._spaces[i]-1) [dd[-1].get_default_codomain() for dd in self._position_spaces])
spaces = tuple(range(n_amplitudes))
zeroind = (slice(None),)*(1 - 1//self._total_N) + (0,)*(len(hspace.shape)-1+1//self._total_N)
foo = np.ones(hspace.shape) foo = np.ones(hspace.shape)
foo[zeroind] = 0 foo[zeroind] = 0
...@@ -405,13 +460,7 @@ class CorrelatedFieldMaker: ...@@ -405,13 +460,7 @@ class CorrelatedFieldMaker:
self._azm.target, zeroind).adjoint self._azm.target, zeroind).adjoint
azm = Adder(from_global_data(hspace, foo)) @ ZeroModeInserter @ zeromode azm = Adder(from_global_data(hspace, foo)) @ ZeroModeInserter @ zeromode
n_amplitudes = len(self._a) spaces = np.array(range(n_amplitudes)) + 1 - 1//self._total_N
spaces = [self._spaces[0],]
for i in range(1,n_amplitudes):
spaces.extend(
[len(self._position_spaces[i-1])
- self._spaces[i-1] + self._spaces[i]])
spaces = list(np.cumsum(spaces))
ht = HarmonicTransformOperator(hspace, ht = HarmonicTransformOperator(hspace,
self._position_spaces[0][self._spaces[0]], self._position_spaces[0][self._spaces[0]],
space=spaces[0]) space=spaces[0])
...@@ -426,6 +475,7 @@ class CorrelatedFieldMaker: ...@@ -426,6 +475,7 @@ class CorrelatedFieldMaker:
self._a[i].target[self._spaces[i]], self._a[i].target[self._spaces[i]],
space=spaces[i])) space=spaces[i]))
#breakpoint()
all_spaces = list(range(len(hspace))) all_spaces = list(range(len(hspace)))
a = ContractionOperator(pd.domain, spaces[1:]).adjoint @ self._a[0] a = ContractionOperator(pd.domain, spaces[1:]).adjoint @ self._a[0]
for i in range(1, n_amplitudes): for i in range(1, n_amplitudes):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment