diff --git a/nomad/api/admin.py b/nomad/api/admin.py index 0627414c23b2c2e46d844598d3d26bdb5b0d8523..062a7d8f5369efe4783e51a485d399bb95b99502 100644 --- a/nomad/api/admin.py +++ b/nomad/api/admin.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from flask import g -from flask_restplus import abort, Resource +from flask import g, request +from flask_restplus import abort, Resource, fields from nomad import infrastructure, config @@ -70,3 +70,29 @@ class AdminResetResource(Resource): infrastructure.remove() return dict(messager='Remove performed.'), 200 + + +pidprefix_model = api.model('PidPrefix', { + 'prefix': fields.Integer(description='The prefix. All new calculations will get an id that is greater.', required=True) +}) + + +@ns.route('/pidprefix') +class AdminPidPrefixResource(Resource): + @api.doc('exec_pidprefix_command') + @api.response(200, 'Pid prefix set') + @api.response(400, 'Bad pid prefix data') + @api.expect(pidprefix_model) + @login_really_required + def post(self): + """ + The ``pidprefix``command will set the pid counter to the given value. + + This might be useful while migrating data with old pids. + """ + if not g.user.is_admin: + abort(401, message='Only the admin user can perform remove.') + + infrastructure.set_pid_prefix(**request.get_json()) + + return dict(messager='PID prefix set.'), 200 diff --git a/nomad/api/auth.py b/nomad/api/auth.py index f6db27ae3e644acb00523337b33798b05141817c..013bbddb0bee04b4491a79372bc5e3ca8575a573 100644 --- a/nomad/api/auth.py +++ b/nomad/api/auth.py @@ -131,10 +131,13 @@ ns = api.namespace( user_model = api.model('User', { + 'user_id': fields.Integer(description='The id to use in the repo db, make sure it does not already exist.'), 'first_name': fields.String(description='The user\'s first name'), 'last_name': fields.String(description='The user\'s last name'), 'email': fields.String(description='Guess what, the user\'s email'), - 'affiliation': fields.String(description='The user\'s affiliation'), + 'affiliation': fields.Nested(model=api.model('Affiliation', { + 'name': fields.String(description='The name of the affiliation', default='not given'), + 'address': fields.String(description='The address of the affiliation', default='not given')})), 'password': fields.String(description='The bcrypt 2y-indented password for initial and changed password'), 'token': fields.String( description='The access token that authenticates the user with the API. ' @@ -164,6 +167,7 @@ class UserResource(Resource): @api.doc('create_user') @api.expect(user_model) + @api.response(400, 'Invalid user data') @api.marshal_with(user_model, skip_none=True, code=200, description='User created') @login_really_required def put(self): @@ -183,15 +187,21 @@ class UserResource(Resource): if required_key not in data: abort(400, message='The %s is missing' % required_key) + if 'user_id' in data: + if coe_repo.User.from_user_id(data['user_id']) is not None: + abort(400, 'User with given user_id %d already exists.' % data['user_id']) + user = coe_repo.User.create_user( email=data['email'], password=data.get('password', None), crypted=True, first_name=data['first_name'], last_name=data['last_name'], - affiliation=data.get('affiliation', None)) + affiliation=data.get('affiliation', None), token=data.get('token', None), + user_id=data.get('user_id', None)) return user, 200 @api.doc('update_user') @api.expect(user_model) + @api.response(400, 'Invalid user data') @api.marshal_with(user_model, skip_none=True, code=200, description='User updated') @login_really_required def post(self): diff --git a/nomad/client/migration.py b/nomad/client/migration.py index 53b71ea9a91a19e7567c0129618cb0707520d42c..55b9b4187f29cec8d84a1c1430d2253870ed2346 100644 --- a/nomad/client/migration.py +++ b/nomad/client/migration.py @@ -68,26 +68,14 @@ def index(drop, with_metadata, per_query): @migration.command(help='Copy users from source into empty target db') -@click.option('-h', '--host', default=config.repository_db.host, help='The migration repository target db host, default is "%s".' % config.repository_db.host) -@click.option('-p', '--port', default=config.repository_db.port, help='The migration repository target db port, default is %d.' % config.repository_db.port) -@click.option('-u', '--user', default=config.repository_db.user, help='The migration repository target db user, default is %s.' % config.repository_db.user) -@click.option('-w', '--password', default=config.repository_db.password, help='The migration repository target db password.') -@click.option('-db', '--dbname', default=config.repository_db.dbname, help='The migration repository target db name, default is %s.' % config.repository_db.dbname) def copy_users(**kwargs): _setup() - _, db = infrastructure.sqlalchemy_repository_db(readonly=False, **kwargs) - _migration.copy_users(db) - - -@migration.command(help='Set the pid auto increment to the given prefix') -@click.option('--prefix', default=7000000, help='The int to set the pid auto increment counter to') -def prefix(prefix: int): - _setup() - _migration.set_new_pid_prefix(prefix) + _migration.copy_users() @migration.command(help='Upload the given upload locations. Uses the existing index to provide user metadata') @click.argument('paths', nargs=-1) -def upload(paths: list): +@click.option('--prefix', default=None, type=int, help='Set the pid counter to this value. The counter will not be changed if not given.') +def upload(paths: list, prefix: int): _setup() - _migration.migrate(*paths) + _migration.migrate(*paths, prefix=prefix) diff --git a/nomad/coe_repo/user.py b/nomad/coe_repo/user.py index 18309eff7590d51eb92ad7a264e2ded116036b2a..515a45ead97b89467d7e1e709b7200d88687f758 100644 --- a/nomad/coe_repo/user.py +++ b/nomad/coe_repo/user.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict from passlib.hash import bcrypt from sqlalchemy import Column, Integer, String, ForeignKey from sqlalchemy.orm import relationship @@ -39,6 +40,14 @@ class LoginException(Exception): pass +class Affiliation(Base): + __tablename__ = 'affiliations' + a_id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String) + address = Column(String) + email_domain = Column(String) + + class User(Base): # type: ignore """ SQLAlchemy model class that represents NOMAD-coe repository postgresdb *users*. @@ -50,10 +59,11 @@ class User(Base): # type: ignore __tablename__ = 'users' user_id = Column(Integer, primary_key=True) + affiliation_id = Column(Integer, ForeignKey('affiliations.a_id'), name='affiliation') email = Column(String) first_name = Column(String, name='firstname') last_name = Column(String, name='lastname') - affiliation = Column(String) + affiliation = relationship('Affiliation', lazy='joined') password = Column(String) _token_chars = string.ascii_uppercase + string.ascii_lowercase + string.digits @@ -62,16 +72,23 @@ class User(Base): # type: ignore return '<User(email="%s")>' % self.email @staticmethod - def create_user(email: str, password: str, crypted: bool, **kwargs): + def create_user( + email: str, password: str, crypted: bool, user_id: int = None, + affiliation: Dict[str, str] = None, token: str = None, **kwargs): repo_db = infrastructure.repository_db repo_db.begin() try: - user = User(email=email, **kwargs) + if affiliation is not None: + affiliation = Affiliation(**affiliation) + repo_db.add(affiliation) + + user = User(email=email, user_id=user_id, affiliation=affiliation, **kwargs) repo_db.add(user) user.set_password(password, crypted) # TODO this has to change, e.g. trade for JWTs - token = ''.join(random.choices(User._token_chars, k=64)) + if token is None: + token = ''.join(random.choices(User._token_chars, k=64)) repo_db.add(Session(token=token, user=user)) repo_db.commit() diff --git a/nomad/infrastructure.py b/nomad/infrastructure.py index 709b8b8552abeceacdf8c902436e8ec63092c25b..20957e5a5c71a1897412ec86e81c942c2cff5dce 100644 --- a/nomad/infrastructure.py +++ b/nomad/infrastructure.py @@ -188,6 +188,16 @@ def sqlalchemy_repository_db(exists: bool = False, readonly: bool = True, **kwar return repository_db_conn, repository_db +def set_pid_prefix(prefix=7000000, target_db=None): + if target_db is None: + target_db = repository_db + + target_db.begin() + target_db.execute('ALTER SEQUENCE calculations_calc_id_seq RESTART WITH %d' % prefix) + target_db.commit() + logger.info('set pid prefix', pid_prefix=prefix) + + def reset(): """ Resets the databases mongo, elastic/calcs, repository db and all files. Be careful. diff --git a/nomad/migration.py b/nomad/migration.py index 427e418b9292598efc254925bd62d3054bbbb4b1..1b6fc1329e9c53ade8b4766bc52c0940f1fe2372 100644 --- a/nomad/migration.py +++ b/nomad/migration.py @@ -26,17 +26,20 @@ import zipstream import zipfile import math from mongoengine import Document, IntField, StringField, DictField -from passlib.hash import bcrypt from werkzeug.contrib.iterio import IterIO import time -from bravado.exception import HTTPNotFound +from bravado.exception import HTTPNotFound, HTTPBadRequest -from nomad import utils, config, infrastructure -from nomad.coe_repo import User, Calc +from nomad import utils, infrastructure +from nomad.coe_repo import User, Calc, LoginException from nomad.datamodel import CalcWithMetadata from nomad.processing import FAILURE, SUCCESS +default_pid_prefix = 7000000 +""" The default pid prefix for new non migrated calcualtions """ + + class SourceCalc(Document): """ Mongo document used as a calculation, upload, and metadata db and index @@ -175,25 +178,36 @@ class NomadCOEMigration: return self._client - def copy_users(self, target_db): - """ Copy all users, keeping their ids, within a single transaction. """ - target_db.begin() + def copy_users(self): + """ Copy all users. """ for source_user in self.source.query(User).all(): - self.source.expunge(source_user) # removes user from the source session - target_db.merge(source_user) - - admin = target_db.query(User).filter_by(email='admin').first() - if admin is None: - admin = User( - user_id=0, email='admin', first_name='admin', last_name='admin', - password=bcrypt.encrypt(config.services.admin_password, ident='2y')) - target_db.add(admin) - target_db.commit() - - def set_new_pid_prefix(self, target_db, prefix=7000000): - target_db.begin() - target_db.execute('ALTER SEQUENCE calculations_calc_id_seq RESTART WITH %d' % prefix) - target_db.commit() + if source_user.user_id <= 2: + # skip first two users to keep example users + # they probably are either already the example users, or [root, Evgeny] + continue + + create_user_payload = dict( + user_id=source_user.user_id, + email=source_user.email, + first_name=source_user.first_name, + last_name=source_user.last_name, + password=source_user.password + ) + + try: + create_user_payload.update(token=source_user.token) + except LoginException: + pass + + if source_user.affiliation is not None: + create_user_payload.update(affiliation=dict( + name=source_user.affiliation.name, + address=source_user.affiliation.address)) + + try: + self.client.auth.create_user(payload=create_user_payload).response() + except HTTPBadRequest as e: + self.logger.error('could not create user due to bad data', exc_info=e, user_id=source_user.user_id) def _to_comparable_list(self, list): for item in list: @@ -250,7 +264,7 @@ class NomadCOEMigration: return is_valid - def migrate(self, *args): + def migrate(self, *args, prefix: int = default_pid_prefix): """ Migrate the given uploads. @@ -265,8 +279,14 @@ class NomadCOEMigration: Uses PIDs of identified old calculations. Will create new PIDs for previously unknown uploads. New PIDs will be choosed from a `prefix++` range of ints. + Arguments: + prefix: The PID prefix that should be used for new non migrated calcualtions. + Returns: Yields a dictionary with status and statistics for each given upload. """ + if prefix is not None: + self.logger.info('set pid prefix', pid_prefix=prefix) + self.client.admin.exec_pidprefix_command(payload=dict(prefix=prefix)).response() upload_specs = args for upload_spec in upload_specs: diff --git a/tests/data/migration/example_source_db.sql b/tests/data/migration/example_source_db.sql index c18d4a799ef4cc5b51df41059fd0a338403eb2ee..de2cdca997f1fa3846ced578ed7dd15ae7d12001 100644 --- a/tests/data/migration/example_source_db.sql +++ b/tests/data/migration/example_source_db.sql @@ -7,8 +7,8 @@ SET check_function_bodies = false; SET client_min_messages = warning; TRUNCATE TABLE public.users CASCADE; -INSERT INTO public.users VALUES (1, 'one', 'one', 'one', 'one', NULL, NULL, NULL); -INSERT INTO public.users VALUES (2, 'two', 'two', 'two', 'two', NULL, NULL, NULL); +INSERT INTO public.users VALUES (3, 'one', 'one', 'one', 'one', NULL, '$2y$12$jths1LQPsLofuBQ3evVIluhQeQ/BZfbdTSZHFcPGdcNmHz2WvDj.y', NULL); +INSERT INTO public.users VALUES (4, 'two', 'two', 'two', 'two', NULL, '$2y$12$jths1LQPsLofuBQ3evVIluhQeQ/BZfbdTSZHFcPGdcNmHz2WvDj.y', NULL); INSERT INTO public.calculations VALUES (NULL, NULL, NULL, NULL, 0, false, 1, NULL); INSERT INTO public.calculations VALUES (NULL, NULL, NULL, NULL, 0, false, 2, NULL); INSERT INTO public.codefamilies VALUES (1, 'VASP'); @@ -46,10 +46,10 @@ INSERT INTO public.spacegroups VALUES (1, 123); INSERT INTO public.spacegroups VALUES (2, 123); INSERT INTO public.user_metadata VALUES (1, 0, 'label1'); INSERT INTO public.user_metadata VALUES (2, 1, 'label2'); -INSERT INTO public.ownerships VALUES (1, 1); -INSERT INTO public.ownerships VALUES (2, 2); -INSERT INTO public.coauthorships VALUES (1, 2); -INSERT INTO public.shareships VALUES (2, 1); +INSERT INTO public.ownerships VALUES (1, 3); +INSERT INTO public.ownerships VALUES (2, 4); +INSERT INTO public.coauthorships VALUES (1, 4); +INSERT INTO public.shareships VALUES (2, 3); -- example dataset INSERT INTO public.calculations VALUES (NULL, NULL, NULL, NULL, 1, false, 3, NULL); diff --git a/tests/test_api.py b/tests/test_api.py index 0e24828347e6538c47396a189b5ab7195bc745c7..0d781cfde6d973cc65fd7840e21feea39c85b94b 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -129,12 +129,20 @@ class TestAuth: def test_signature_token(self, test_user_signature_token, no_warn): assert test_user_signature_token is not None - def test_put_user(self, client, postgres, admin_user_auth): + @pytest.mark.parametrize('token, affiliation', [ + ('test_token', dict(name='HU Berlin', address='Unter den Linden 6')), + (None, None)]) + def test_put_user(self, client, postgres, admin_user_auth, token, affiliation): + data = dict( + email='test@email.com', last_name='Tester', first_name='Testi', + token=token, affiliation=affiliation, + password=bcrypt.encrypt('test_password', ident='2y')) + + data = {key: value for key, value in data.items() if value is not None} + rv = client.put( '/auth/user', headers=admin_user_auth, - content_type='application/json', data=json.dumps(dict( - email='test@email.com', last_name='Tester', first_name='Testi', - password=bcrypt.encrypt('test_password', ident='2y')))) + content_type='application/json', data=json.dumps(data)) assert rv.status_code == 200 self.assert_user(client, json.loads(rv.data)) diff --git a/tests/test_migration.py b/tests/test_migration.py index 16dffb300472c6505918f3032749fff698877f1a..e3007059619ba8535f510873c82eb50bb58d4eba 100644 --- a/tests/test_migration.py +++ b/tests/test_migration.py @@ -69,7 +69,10 @@ def source_repo(monkeysession, postgres_infra): @pytest.fixture(scope='function') def target_repo(postgres): with create_postgres_infra(readonly=False, exists=False, dbname=test_target_db_name) as db: - db.execute('TRUNCATE users CASCADE;') + db.execute('DELETE FROM affiliations;') + db.execute('DELETE FROM sessions WHERE user_id >= 3;') + db.execute('DELETE FROM users WHERE user_id >= 3;') + assert db.query(coe_repo.User).filter_by(email='admin').first() is not None yield db db.execute('TRUNCATE uploads CASCADE;') @@ -80,13 +83,6 @@ def migration(source_repo, target_repo): yield migration -def test_copy_users(migration, target_repo): - migration.copy_users(target_repo) - assert target_repo.query(coe_repo.User).count() == 3 - assert target_repo.query(coe_repo.User).filter_by(user_id=1).first().email == 'one' - assert target_repo.query(coe_repo.User).filter_by(user_id=2).first().email == 'two' - - def perform_index(migration, has_indexed, with_metadata, **kwargs): has_source_calc = False for source_calc, total in SourceCalc.index(migration.source, with_metadata=with_metadata, **kwargs): @@ -102,7 +98,7 @@ def perform_index(migration, has_indexed, with_metadata, **kwargs): assert test_calc is not None if with_metadata: - assert test_calc.metadata['uploader']['id'] == 1 + assert test_calc.metadata['uploader']['id'] == 3 assert test_calc.metadata['comment'] == 'label1' @@ -135,9 +131,6 @@ def migrate_infra(migration, target_repo, proc_infra, client, monkeysession): # source repo is the infrastructure repo indexed = list(migration.index(drop=True, with_metadata=True)) assert len(indexed) == 2 - # source repo is the infrastructure repo - migration.copy_users(target_repo) - migration.set_new_pid_prefix(target_repo) # target repo is the infrastructure repo def create_client(): @@ -149,11 +142,19 @@ def migrate_infra(migration, target_repo, proc_infra, client, monkeysession): monkeysession.setattr('nomad.infrastructure.repository_db', target_repo) monkeysession.setattr('nomad.client.create_client', create_client) + # source repo is the still the original infrastructure repo + migration.copy_users() + yield migration monkeysession.setattr('nomad.infrastructure.repository_db', old_repo) +def test_copy_users(migrate_infra, target_repo): + assert target_repo.query(coe_repo.User).filter_by(user_id=3).first().email == 'one' + assert target_repo.query(coe_repo.User).filter_by(user_id=4).first().email == 'two' + + mirgation_test_specs = [ ('baseline', dict(migrated=2, source=2)), ('archive', dict(migrated=2, source=2)), @@ -172,7 +173,7 @@ mirgation_test_specs = [ def test_migrate(migrate_infra, test, assertions, caplog): uploads_path = os.path.join('tests', 'data', 'migration', test) reports = list(migrate_infra.migrate( - *[os.path.join(uploads_path, dir) for dir in os.listdir(uploads_path)])) + *[os.path.join(uploads_path, dir) for dir in os.listdir(uploads_path)], prefix=7000000)) assert len(reports) == 1 report = reports[0] @@ -192,7 +193,7 @@ def test_migrate(migrate_infra, test, assertions, caplog): assert calc_1 is not None metadata = calc_1.to_calc_with_metadata() assert metadata.pid <= 2 - assert metadata.uploader['id'] == 1 + assert metadata.uploader['id'] == 3 assert metadata.upload_time.isoformat() == '2019-01-01T12:00:00+00:00' assert len(metadata.datasets) == 1 assert metadata.datasets[0]['id'] == 3 @@ -200,7 +201,7 @@ def test_migrate(migrate_infra, test, assertions, caplog): assert metadata.datasets[0]['doi']['value'] == 'internal_ref' assert metadata.comment == 'label1' assert len(metadata.coauthors) == 1 - assert metadata.coauthors[0]['id'] == 2 + assert metadata.coauthors[0]['id'] == 4 assert len(metadata.references) == 1 assert metadata.references[0]['value'] == 'external_ref' @@ -209,7 +210,7 @@ def test_migrate(migrate_infra, test, assertions, caplog): assert calc_1 is not None metadata = calc_2.to_calc_with_metadata() assert len(metadata.shared_with) == 1 - assert metadata.shared_with[0]['id'] == 1 + assert metadata.shared_with[0]['id'] == 3 # assert pid prefix of new calcs if assertions.get('new', 0) > 0: