Commit 9b6d9324 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweaks

parent 5edadf46
Pipeline #22717 passed with stage
in 4 minutes and 47 seconds
...@@ -59,7 +59,7 @@ class Field(object): ...@@ -59,7 +59,7 @@ class Field(object):
""" """
def __init__(self, domain=None, val=None, dtype=None, copy=False): def __init__(self, domain=None, val=None, dtype=None, copy=False):
self.domain = self._parse_domain(domain=domain, val=val) self.domain = self._infer_domain(domain=domain, val=val)
dtype = self._infer_dtype(dtype=dtype, val=val) dtype = self._infer_dtype(dtype=dtype, val=val)
if isinstance(val, Field): if isinstance(val, Field):
...@@ -128,7 +128,7 @@ class Field(object): ...@@ -128,7 +128,7 @@ class Field(object):
return Field.empty(field.domain, dtype) return Field.empty(field.domain, dtype)
@staticmethod @staticmethod
def _parse_domain(domain, val=None): def _infer_domain(domain, val=None):
if domain is None: if domain is None:
if isinstance(val, Field): if isinstance(val, Field):
return val.domain return val.domain
......
...@@ -71,8 +71,7 @@ class WienerFilterCurvature(InversionEnabler, EndomorphicOperator): ...@@ -71,8 +71,7 @@ class WienerFilterCurvature(InversionEnabler, EndomorphicOperator):
noise = self.N.diagonal().weight(-1) noise = self.N.diagonal().weight(-1)
mock_noise = Field.from_random(random_type="normal", mock_noise = Field.from_random(random_type="normal",
domain=self.N.domain, domain=self.N.domain, dtype=noise.dtype)
dtype=noise.dtype.type)
mock_noise *= sqrt(noise) mock_noise *= sqrt(noise)
mock_data = self.R(mock_signal) + mock_noise mock_data = self.R(mock_signal) + mock_noise
......
...@@ -21,7 +21,6 @@ from builtins import range ...@@ -21,7 +21,6 @@ from builtins import range
from builtins import object from builtins import object
import numpy as np import numpy as np
from ..field import Field, DomainTuple from ..field import Field, DomainTuple
from .. import utilities
class Prober(object): class Prober(object):
...@@ -38,15 +37,16 @@ class Prober(object): ...@@ -38,15 +37,16 @@ class Prober(object):
compute_variance=False, ncpu=1): compute_variance=False, ncpu=1):
self._domain = DomainTuple.make(domain) self._domain = DomainTuple.make(domain)
self._probe_count = self._parse_probe_count(probe_count) self._probe_count = int(probe_count)
self._ncpu = self._parse_probe_count(ncpu) self._ncpu = int(ncpu)
self._random_type = self._parse_random_type(random_type) if random_type not in ["pm1", "normal"]:
raise ValueError(
"unsupported random type: '" + str(random_type) + "'.")
self._random_type = random_type
self.compute_variance = bool(compute_variance) self.compute_variance = bool(compute_variance)
self.probe_dtype = np.dtype(probe_dtype) self.probe_dtype = np.dtype(probe_dtype)
self._uid_counter = 0 self._uid_counter = 0
# ---Properties---
@property @property
def domain(self): def domain(self):
return self._domain return self._domain
...@@ -55,22 +55,11 @@ class Prober(object): ...@@ -55,22 +55,11 @@ class Prober(object):
def probe_count(self): def probe_count(self):
return self._probe_count return self._probe_count
def _parse_probe_count(self, probe_count):
return int(probe_count)
@property @property
def random_type(self): def random_type(self):
return self._random_type return self._random_type
def _parse_random_type(self, random_type): def gen_parallel_probe(self, callee):
if random_type not in ["pm1", "normal"]:
raise ValueError(
"unsupported random type: '" + str(random_type) + "'.")
return random_type
# ---Probing methods---
def gen_parallel_probe(self,callee):
for i in range(self.probe_count): for i in range(self.probe_count):
yield (callee, self.get_probe(i)) yield (callee, self.get_probe(i))
...@@ -87,7 +76,7 @@ class Prober(object): ...@@ -87,7 +76,7 @@ class Prober(object):
pool = Pool(self._ncpu) pool = Pool(self._ncpu)
for i in pool.imap_unordered(self.evaluate_probe_parallel, for i in pool.imap_unordered(self.evaluate_probe_parallel,
self.gen_parallel_probe(callee)): self.gen_parallel_probe(callee)):
self.finish_probe(i[0],i[1]) self.finish_probe(i[0], i[1])
def evaluate_probe_parallel(self, argtuple): def evaluate_probe_parallel(self, argtuple):
callee = argtuple[0] callee = argtuple[0]
...@@ -104,8 +93,7 @@ class Prober(object): ...@@ -104,8 +93,7 @@ class Prober(object):
def generate_probe(self): def generate_probe(self):
""" a random-probe generator """ """ a random-probe generator """
f = Field.from_random(random_type=self.random_type, f = Field.from_random(random_type=self.random_type,
domain=self.domain, domain=self.domain, dtype=self.probe_dtype)
dtype=self.probe_dtype.type)
uid = self._uid_counter uid = self._uid_counter
self._uid_counter += 1 self._uid_counter += 1
return (uid, f) return (uid, f)
......
...@@ -53,7 +53,19 @@ class RGSpace(Space): ...@@ -53,7 +53,19 @@ class RGSpace(Space):
if np.isscalar(shape): if np.isscalar(shape):
shape = (shape,) shape = (shape,)
self._shape = tuple(int(i) for i in shape) self._shape = tuple(int(i) for i in shape)
self._distances = self._parse_distances(distances)
if distances is None:
if self.harmonic:
self._distances = (1.,) * len(self._shape)
else:
self._distances = tuple(1./s for s in self._shape)
elif np.isscalar(distances):
self._distances = (float(distances),) * len(self._shape)
else:
temp = np.empty(len(self.shape), dtype=np.float64)
temp[:] = distances
self._distances = tuple(temp)
self._dvol = float(reduce(lambda x, y: x*y, self._distances)) self._dvol = float(reduce(lambda x, y: x*y, self._distances))
self._dim = int(reduce(lambda x, y: x*y, self._shape)) self._dim = int(reduce(lambda x, y: x*y, self._shape))
...@@ -169,14 +181,3 @@ class RGSpace(Space): ...@@ -169,14 +181,3 @@ class RGSpace(Space):
distance between neighboring grid points along the n-th dimension. distance between neighboring grid points along the n-th dimension.
""" """
return self._distances return self._distances
def _parse_distances(self, distances):
if distances is None:
if self.harmonic:
temp = np.ones_like(self.shape, dtype=np.float64)
else:
temp = 1./np.array(self.shape, dtype=np.float64)
else:
temp = np.empty(len(self.shape), dtype=np.float64)
temp[:] = distances
return tuple(temp)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment