Skip to content
Snippets Groups Projects
Commit b81cb205 authored by Theo Steininger's avatar Theo Steininger
Browse files

Added dtype option to prober.

parent e7db670d
No related branches found
No related tags found
2 merge requests!171Master,!169Added dtype option to prober.
Pipeline #
...@@ -44,9 +44,6 @@ class LineSearch(Loggable, object): ...@@ -44,9 +44,6 @@ class LineSearch(Loggable, object):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self): def __init__(self):
self.line_energy = None self.line_energy = None
self.f_k_minus_1 = None self.f_k_minus_1 = None
self.preferred_initial_step_size = None self.preferred_initial_step_size = None
......
...@@ -37,7 +37,8 @@ class Prober(object): ...@@ -37,7 +37,8 @@ class Prober(object):
""" """
def __init__(self, domain=None, distribution_strategy=None, probe_count=8, def __init__(self, domain=None, distribution_strategy=None, probe_count=8,
random_type='pm1', compute_variance=False): random_type='pm1', probe_dtype=np.float,
compute_variance=False):
self._domain = utilities.parse_domain(domain) self._domain = utilities.parse_domain(domain)
self._distribution_strategy = \ self._distribution_strategy = \
...@@ -45,6 +46,7 @@ class Prober(object): ...@@ -45,6 +46,7 @@ class Prober(object):
self._probe_count = self._parse_probe_count(probe_count) self._probe_count = self._parse_probe_count(probe_count)
self._random_type = self._parse_random_type(random_type) self._random_type = self._parse_random_type(random_type)
self.compute_variance = bool(compute_variance) self.compute_variance = bool(compute_variance)
self.probe_dtype = np.dtype(probe_dtype)
# ---Properties--- # ---Properties---
...@@ -104,6 +106,7 @@ class Prober(object): ...@@ -104,6 +106,7 @@ class Prober(object):
""" a random-probe generator """ """ a random-probe generator """
f = Field.from_random(random_type=self.random_type, f = Field.from_random(random_type=self.random_type,
domain=self.domain, domain=self.domain,
dtype=self.probe_dtype,
distribution_strategy=self.distribution_strategy) distribution_strategy=self.distribution_strategy)
uid = np.random.randint(1e18) uid = np.random.randint(1e18)
return (uid, f) return (uid, f)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment