Commit 7ae2caa6 authored by Lauri Himanen's avatar Lauri Himanen

Merged the new nomad_utils folder to the new package structure under the utils-subpackage.

parents a86eb472 797171f7
......@@ -6,7 +6,7 @@
Benjamin Regler - Apache 2.0 License
@license http://www.apache.org/licenses/LICENSE-2.0
@author Benjamin Regler
@version 1.0.0
@version 2.0.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -31,11 +31,11 @@ import random
if sys.version_info.major > 2:
# For Python 3.0 and later
from urllib.parse import quote, unquote_plus
from urllib.parse import quote, unquote_plus, urlencode
from urllib.request import urlopen, Request
else:
# Fall back to Python 2's urllib2
from urllib import quote, unquote_plus
from urllib import quote, unquote_plus, urlencode
from urllib2 import urlopen, Request
......@@ -58,33 +58,66 @@ class NomadQueryResult(object):
(default: {1.0})
"""
self._uri = []
self._download_url = ''
self._query = query or {}
self._timestamp = int(time.time())
self._response = response.get('result', {})
# Load response information
self._load(response, version)
def _load(self, response, version):
"""Load response information
Arguments:
response {dict} -- Response of the Nomad Query API
version {float} -- Version of the Nomad Query data file
"""
# Set version of the Nomad Query data file
self._version = version
# Construct download path
path = response.get('path', '')
self._download_url = self._query.get('endpoint', '') + 'download/' + \
path.split('_')[-1] + '?file=' + quote(path.encode('utf-8')) + '.json'
# Initialize
if version == 1.0:
self._response = response.get('result', {})
# Get Nomad URIs
response = NomadQuery().request(self._download_url)
if response['status'] == 'success':
regex = re.compile(r'(?<=/[a-zA-Z0-9\-_]{3}/)[^\.]+')
paths = response['data'].get('result', [])
# Construct download path
path = response.get('path', '')
self._download_url = self._query.get('endpoint', '') + \
'download/' + path.split('_')[-1] + '?file=' + \
quote(path.encode('utf-8')) + '.json'
for path in paths:
match = regex.search(path)
if match:
# Substitute prefixes
groups = match.group(0).split('/')
groups[0] = 'N' + groups[0][1:] # Normalized
# Get Nomad URIs
response = NomadQuery.request(self._download_url)
if response['status'] == 'success':
regex = re.compile(r'(?<=/[a-zA-Z0-9\-_]{3}/)[^.]+')
paths = response['data'].get('result', [])
for path in paths:
match = regex.search(path)
if match:
# Substitute prefixes
groups = match.group(0).split('/')
groups[0] = 'N' + groups[0][1:] # Normalized
if len(groups) == 2:
groups[1] = 'C' + groups[1][1:] # Computed
self._uri.append('nmd://' + '/'.join(groups))
elif version == 2.0:
self._response = response.get('data', {})
# Construct and get Nomad URIs
for entry in self._response:
if not entry['type'].lower().endswith('calculation'):
continue
if len(groups) == 2:
groups[1] = 'C' + groups[1][1:] # Computed
# Get archive gid
context = entry['attributes']['metadata']['archive_context']
gid = context['archive_gid'][0]
self._uri.append('nmd://' + '/'.join(groups))
# Assemble Nomad Uri
uri = 'nmd://N' + gid[1:] + '/' + entry['id']
self._uri.append(uri)
def version(self):
"""Get the version of the Nomad Query data file.
......@@ -107,6 +140,10 @@ class NomadQueryResult(object):
Returns:
str -- The download URL of the query
Deprecated:
Since version 2.0.0, this method is no longer used by internal code
and not recommended.
"""
return self._download_url
......@@ -142,23 +179,25 @@ class NomadQuery(object):
"""
# Version of the Nomad Query API
__version__ = 1.0
# Nomad API endpoint
endpoint = os.environ.get('NOMAD_BASE_URI','https://analytics-toolkit.nomad-coe.eu') + '/api/'
__version__ = 2.0
# Private user path
user_path = '/data/private'
def __init__(self, username='', endpoint=''):
# Nomad API endpoints
endpoint = 'https://analytics-toolkit.nomad-coe.eu/api/'
query_endpoint = 'https://analytics-toolkit.nomad-coe.eu/archive/nql-api/'
def __init__(self, username='', endpoint='', query_endpoint=''):
"""Constructor.
Keyword Arguments:
username {str} -- Current username. Leave empty to auto-detect
username (default: {''})
endpoint {str} -- Endpoint of the Nomad API (default:
${NOMAD_BASE_URI}/api if set, otherwise
{'https://analytics-toolkit.nomad-coe.eu/api/'})
username {str} -- Current username. Leave empty to auto-detect
username (default: {''})
endpoint {str} -- Endpoint of the Nomad API (default:
{'https://analytics-toolkit.nomad-coe.eu/api/'})
query_endpoint {str} -- Endpoint of the Nomad Query API (default:
{'https://analytics-toolkit.nomad-coe.eu/nql-api/'})
"""
self._username = ''
self._base_path = ''
......@@ -170,11 +209,14 @@ class NomadQuery(object):
if len(paths) == 1 and paths[0].lower() != 'nomad':
username = paths[0]
# Set username and overwrite endpoint, if required
# Set username and overwrite endpoints, if required
self.username(username)
if endpoint:
self.endpoint = str(endpoint)
if query_endpoint:
self.query_endpoint = str(query_endpoint)
def username(self, username=''):
"""Get or set the username.
......@@ -303,7 +345,7 @@ class NomadQuery(object):
if not os.path.isdir(base_path):
return queries
# Get all stored queries
# Get all stored queries
for filename in os.listdir(base_path):
path = os.path.join(base_path, filename)
if os.path.isfile(path):
......@@ -322,17 +364,22 @@ class NomadQuery(object):
queries.sort(key=lambda x: -x['timestamp'])
return queries
def query(self, query, group_by='', context='', timeout=10):
def query(self, query, group_by='', timeout=10, **kwargs):
"""Query the Nomad Database.
Arguments:
query {str} -- The query string (see Nomad API reference)
Keyword Arguments:
group_by {str} -- Group-by field. (default: {''})
context {str} -- Query context. Leave empty to use
`single_configuration_calculation` (default: {''})
timeout {number} -- Timeout of the request in seconds (default: {10})
group_by {str} -- Group-by field. (default: {''})
num_results {int} -- Number of calculations to return
(default: {10000})
num_groups {int} -- Number of distinct calculation groups to return
(default: {10})
context {str} -- Deprecated: Query context. Leave empty to use
`single_configuration_calculation` (default: {''})
compat {bool} -- Compatibility mode (default: {True})
timeout {number} -- Timeout of the request in seconds (default: {10})
Returns:
NomadQueryResult -- The Nomad query result
......@@ -343,17 +390,27 @@ class NomadQuery(object):
RuntimeError -- Unknown error. Please inform the Nomad team to
solve this problem.
"""
# Set default context
if not context:
context = 'single_configuration_calculation'
# Construct URL
url = self.endpoint + ('queryGroup/' if group_by else 'query/') + context
url = self.query_endpoint + ('search_grouped' if group_by else 'search')
params = {
'source_fields': 'archive_gid',
'sort_field': 'calculation_gid',
'num_results': max(min(kwargs.get('num_results', 10000), 10000), 1),
'format': 'nested'
}
# Normalize query - compatibility fallback
if kwargs.get('compat', True):
query = self._normalize(query)
# Add query
url += '?filter=' + quote(query.strip())
params['query'] = query.strip()
if group_by:
url += quote(' GROUPBY ' + group_by.strip().lower())
params['group_by'] = group_by.strip().lower()
params['num_groups'] = max(kwargs.get('num_groups', 10), 1)
# Construct URL
url += '?' + urlencode(params).replace('+', '%20')
# Read URL
response = self.request(url, timeout=timeout)
......@@ -362,21 +419,18 @@ class NomadQuery(object):
# Check connection timeout
response = response['data']
if 'timed_out' in response['result'] and response['result']['timed_out']:
if response['meta'].get('is_timed_out', False) or \
response['meta'].get('is_terminated_early', False):
response['message'] = 'Connection timed out.'
# Check for additional error messages
if 'message' in response or 'msg' in response:
raise RuntimeError(response.get('message', response['msg']))
# Construct Nomad Query response
query = {
'context': context,
'endpoint': self.endpoint,
'filter': query.strip(),
'group_by': group_by.strip().lower(),
'endpoint': self.query_endpoint,
'query': params.get('query', ''),
'group_by': params.get('group_by', ''),
'url': url
}
return NomadQueryResult(query, response, self.__version__)
def fetch(self, name_or_index='', resolve=False, **params):
......@@ -531,6 +585,97 @@ class NomadQuery(object):
data['data'] = self._resolve(data['uri'], **params)
return data
@staticmethod
def request(url, timeout=10):
"""Request a URL
Arguments:
url {str} -- The URL of a web address
Keyword Arguments:
timeout {number} -- Timeout of the request in seconds (default: {10})
Returns:
dict -- A dictionary with success status, response data, or
error message
"""
# Default request response
result = {
'url': url,
'status': 'error',
'message': 'Unknown error. Please inform the Nomad team to '
'solve this problem.'
}
try:
# Get URL
response = urlopen(Request(url), timeout=timeout)
# Check response code
if response.code != 200:
raise RuntimeError(result['message'])
# Read response
data = json.loads(response.read().decode('utf-8'), 'utf-8')
# Populate result
result.pop('message')
result.update({
'status': 'success',
'data': data
})
except Exception as exc:
exc = sys.exc_info()[1]
response = result.copy()
# Get error message
message = exc
if sys.version_info <= (2, 5) and hasattr(exc, 'message'):
message = exc.message
elif hasattr(exc, 'reason'):
message = exc.reason
response['message'] = str(message)
# Fix error message
if response['message'].endswith('timed out'):
response['message'] = 'Connection timed out. The Nomad ' + \
'Analytics API Service is currently unavailable.'
# Return result
return result
def _normalize(self, query):
"""[Protected] Normalize query syntax
Arguments:
query {str} -- The query string (see Nomad API reference)
Returns:
str -- The normalized query string
"""
# Convert nomad query syntax v1 to v2
if re.search(r'(?<!\\):', query):
values = re.split('\sand\s', query, 0, re.I)
# Convert query
regex = re.compile(r'([^:]+):(.+)')
for i in range(len(values)):
match = regex.search(values[i])
if match:
# Make sure strings are properly escaped
value = map(str.strip, match.group(2).split(','))
value = ','.join((v if v.isdigit()
else '"' + v.strip('\'" ') + '"')
for v in value)
# Replace colons with equal symbols
values[i] = match.group(1) + ' = ' + value
# Rebuild query
query = ' AND '.join(values)
return query
def _resolve(self, paths, size=None, seed=None, **params):
"""[Protected] Resolve Nomad URIs.
......
......@@ -16,7 +16,7 @@
import os
import logging
from nomadcore.local_meta_info import InfoKindEl, loadJsonFile
from nomadcore.metainfo.local_meta_info import InfoKindEl, loadJsonFile
logger = logging.getLogger(__name__)
baseDir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
......
......@@ -569,12 +569,17 @@ class CacheService(object):
"""Get the value identified by name. If the cachemode does not support
getting the value, an exception is raised.
returns:
Args:
name(string): The name of the cached object to return.
raises:
Returns:
The requested object from the cache
"""
cache_object = self.get_cache_object(name)
return cache_object.value
if cache_object is None:
return None
else:
return cache_object.value
def get_cache_object(self, name):
......@@ -586,8 +591,7 @@ class CacheService(object):
def __setitem__(self, name, value):
"""Used to set the value for an item. The CacheObject corresponding to
the name has to be first created by using the function
add_cache_object().
the name has to be first dclared by using the function add().
"""
cache_object = self._cache[name]
cache_object.value = value
......
import ase.calculators.calculator
if 'potential_energy' not in calculator.all_properties:
calculator.all_properties += ['potential_energy', 'kinetic_energy']
This diff is collapsed.
"""File formats.
This module implements the read(), iread() and write() functions in ase.io.
For each file format there is a namedtuple (IOFormat) that has the following
elements:
* a read(filename, index, **kwargs) generator that will yield Atoms objects
* a write(filename, images) function
* a 'single' boolean (False if multiple configurations is supported)
* a 'acceptsfd' boolean (True if file-descriptors are accepted)
There is a dict 'ioformats' that is filled with IOFormat objects as they are
needed. The 'initialize()' function will create the IOFormat object by
looking at the all_formats dict and by importing the correct read/write
functions from the correct module. The 'single' and 'acceptsfd' bools are
parsed from two-charcter string in the all_formats dict below.
Example
=======
The xyz format is implemented in the ase/io/xyz.py file which has a
read_xyz() generator and a write_xyz() function.
"""
import collections
import functools
import inspect
import os
import sys
from ase.atoms import Atoms
from ase.utils import import_module, basestring, PurePath
from ase.parallel import parallel_function, parallel_generator
class UnknownFileTypeError(Exception):
pass
IOFormat = collections.namedtuple('IOFormat',
'read, write, single, acceptsfd, isbinary')
ioformats = {} # will be filled at run-time
# 1=single, +=multiple, F=accepts a file-descriptor, S=needs a file-name str,
# B=like F, but opens in binary mode
all_formats = {
'nomad-json': ('JSON from Nomad archive', '+F'),
'nomad-ziptxt': ('ZIPPED TXT from Nomad archive', '+F'),
}
# Special cases:
format2modulename = {
}
extension2format = {
}
netcdfconventions2format = {
'http://www.etsf.eu/fileformats': 'etsf',
'AMBER': 'netcdftrajectory'
}
def initialize(format):
"""Import read and write functions."""
if format in ioformats:
return # already done
_format = format.replace('-', '_')
module_name = format2modulename.get(format, _format)
try:
module = import_module('ase.io.' + module_name)
except ImportError as err:
raise ValueError('File format not recognized: %s. Error: %s'
% (format, err))
read = getattr(module, 'read_' + _format, None)
write = getattr(module, 'write_' + _format, None)
if read and not inspect.isgeneratorfunction(read):
read = functools.partial(wrap_read_function, read)
if not read and not write:
raise ValueError('File format not recognized: ' + format)
code = all_formats[format][1]
single = code[0] == '1'
assert code[1] in 'BFS'
acceptsfd = code[1] != 'S'
isbinary = code[1] == 'B'
ioformats[format] = IOFormat(read, write, single, acceptsfd, isbinary)
def get_ioformat(format):
"""Initialize and return IOFormat tuple."""
initialize(format)
return ioformats[format]
def get_compression(filename):
"""
Parse any expected file compression from the extension of a filename.
Return the filename without the extension, and the extension. Recognises
``.gz``, ``.bz2``, ``.xz``.
>>> get_compression('H2O.pdb.gz')
('H2O.pdb', 'gz')
>>> get_compression('crystal.cif')
('crystal.cif', None)
Parameters
==========
filename: str
Full filename including extension.
Returns
=======
(root, extension): (str, str or None)
Filename split into root without extension, and the extension
indicating compression format. Will not split if compression
is not recognised.
"""
# Update if anything is added
valid_compression = ['gz', 'bz2', 'xz']
# Use stdlib as it handles most edge cases
root, compression = os.path.splitext(filename)
# extension keeps the '.' so remember to remove it
if compression.strip('.') in valid_compression:
return root, compression.strip('.')
else:
return filename, None
def open_with_compression(filename, mode='r'):
"""
Wrapper around builtin `open` that will guess compression of a file
from the filename and open it for reading or writing as if it were
a standard file.
Implemented for ``gz``(gzip), ``bz2``(bzip2) and ``xz``(lzma). Either
Python 3 or the ``backports.lzma`` module are required for ``xz``.
Supported modes are:
* 'r', 'rt', 'w', 'wt' for text mode read and write.
* 'rb, 'wb' for binary read and write.
Depending on the Python version, you may get errors trying to write the
wrong string type to the file.
Parameters
==========
filename: str
Path to the file to open, including any extensions that indicate
the compression used.
mode: str
Mode to open the file, same as for builtin ``open``, e.g 'r', 'w'.
Returns
=======
fd: file
File-like object open with the specified mode.
"""
if sys.version_info[0] > 2:
# Compressed formats sometimes default to binary, so force
# text mode in Python 3.
if mode == 'r':
mode = 'rt'
elif mode == 'w':
mode = 'wt'
elif mode == 'a':
mode = 'at'
else:
# The version of gzip in Anaconda Python 2 on Windows forcibly
# adds a 'b', so strip any 't' and let the string conversions
# be carried out implicitly by Python.
mode = mode.strip('t')
root, compression = get_compression(filename)
if compression is None:
return open(filename, mode)
elif compression == 'gz':
import gzip
fd = gzip.open(filename, mode=mode)
elif compression == 'bz2':
import bz2
if hasattr(bz2, 'open'):
# Python 3 only
fd = bz2.open(filename, mode=mode)
else:
# Python 2
fd = bz2.BZ2File(filename, mode=mode)
elif compression == 'xz':
try:
from lzma import open as lzma_open
except ImportError:
from backports.lzma import open as lzma_open
fd = lzma_open(filename, mode)