From d5b6620e6259c41654375b280a4a54d05413853c Mon Sep 17 00:00:00 2001 From: Pavel Cisar Date: Tue, 14 Dec 2021 20:51:09 +0100 Subject: [PATCH] Plugin enhancements --- firebird/qa/__init__.py | 4 +- firebird/qa/plugin.py | 329 +++++++++++++++++++++++++++++----------- 2 files changed, 246 insertions(+), 87 deletions(-) diff --git a/firebird/qa/__init__.py b/firebird/qa/__init__.py index 38ce5e1b..fcb340ba 100644 --- a/firebird/qa/__init__.py +++ b/firebird/qa/__init__.py @@ -36,5 +36,5 @@ """ -from .plugin import db_factory, user_factory, isql_act, python_act, Database, Action, User, \ - temp_file, temp_files +from .plugin import db_factory, Database, user_factory, User, isql_act, python_act, Action, \ + temp_file, temp_files, role_factory, Role, envar_factory, Envar, ServerKeeper diff --git a/firebird/qa/plugin.py b/firebird/qa/plugin.py index 6437aa43..32835cd7 100644 --- a/firebird/qa/plugin.py +++ b/firebird/qa/plugin.py @@ -37,7 +37,7 @@ """ from __future__ import annotations -from typing import List, Dict +from typing import List, Dict, Union import os import re import shutil @@ -56,6 +56,7 @@ from firebird.driver import connect, connect_server, create_database, driver_con NetProtocol, Server, CHARSET_MAP, Connection, Cursor, \ DESCRIPTION_NAME, DESCRIPTION_DISPLAY_SIZE, DatabaseConfig, DBKeyScope, DbInfoCode, \ DbWriteMode +from firebird.driver.core import _connect_helper _vars_ = {'server': None, 'bin-dir': None, @@ -134,6 +135,7 @@ def pytest_configure(config): _vars_['save-output'] = config.getoption('save_output') srv_conf = driver_config.get_server(_vars_['server']) _vars_['host'] = srv_conf.host.value if srv_conf is not None else '' + _vars_['port'] = srv_conf.port.value if srv_conf is not None else '' _vars_['password'] = srv_conf.password.value # driver_config.register_database('employee') @@ -341,14 +343,39 @@ class Database: db.drop_database() def connect(self, *, user: str=None, password: str=None, role: str=None, no_gc: bool=None, no_db_triggers: bool=None, dbkey_scope: DBKeyScope=None, - session_time_zone: str=None, charset: str=None, sql_dialect: int=None) -> Connection: + session_time_zone: str=None, charset: str=None, sql_dialect: int=None, + auth_plugin_list: str=None) -> Connection: self._make_config(user=user, password=password, charset=charset, sql_dialect=sql_dialect) return connect('pytest', role=role, no_gc=no_gc, no_db_triggers=no_db_triggers, - dbkey_scope=dbkey_scope, session_time_zone=session_time_zone) + dbkey_scope=dbkey_scope, session_time_zone=session_time_zone, + auth_plugin_list=auth_plugin_list) def set_async_write(self) -> None: with connect_server(_vars_['server']) as srv: srv.database.set_write_mode(database=self.db_path, mode=DbWriteMode.ASYNC) +def db_factory(*, filename: str='test.fdb', init: str=None, from_backup: str=None, + copy_of: str=None, page_size: int=None, sql_dialect: int=None, + charset: str=None, user: str=None, password: str=None, + do_not_create: bool=False, do_not_drop: bool=False): + + @pytest.fixture + def database_fixture(request: FixtureRequest, db_path) -> Database: + db = Database(db_path, filename, user, password) + if not do_not_create: + if from_backup is None and copy_of is None: + db.create(page_size, sql_dialect, charset) + elif from_backup is not None: + db.restore(from_backup) + elif copy_of is not None: + db.copy(copy_of) + if init is not None: + db.init(init) + yield db + if not do_not_drop: + db.drop() + + return database_fixture + @pytest.fixture def db_path(tmp_path) -> Path: if platform.system != 'Windows': @@ -362,50 +389,133 @@ def db_path(tmp_path) -> Path: return tmp_path class User: - def __init__(self, name: str, password: str, server: Server, encoding: str): - self.name: str = name - self.password: str = password - self.server: Server = server - self.encoding = encoding - self.__srv_encoding = None - def _enter_encode(self) -> None: - if self.encoding is not None: - self.__srv_encoding = self.server.encoding - self.server.encoding = self.encoding - def _exit_encode(self) -> None: - if self.encoding is not None: - self.server.encoding = self.__srv_encoding - def create(self) -> None: - self._enter_encode() - if self.server.user.exists(self.name): + def __init__(self, db: Database, name: str, password: str, plugin: str, charset: str, + active: bool=True, tags: Dict[str]=None, first_name: str=None, + middle_name: str=None, last_name: str=None, admin: bool=False): + self.db: Database = db + self.__name: str = name if name.startswith('"') else name.upper() + self.__password: str = password + self.__plugin: str = plugin + self.__p = '' if self.__plugin is None else f" USING PLUGIN {self.__plugin}" + self.charset: str = charset + self.__active: bool = active + self.__first_name: str = first_name + self.__middle_name: str = middle_name + self.__last_name: str = last_name + self.__tags: Dict[str] = tags + self.__admin: bool = admin + def __enter__(self) -> Role: + if not self.exists(): + self.create() + return self + def __exit__(self, exc_type, exc_value, traceback) -> None: + if self.exists(): self.drop() - self.server.user.add(user_name=self.name, password=self.password) - #print(f"User {self.name} created") - self._exit_encode() - def drop(self) -> None: - self._enter_encode() - with connect('employee') as con: + def exists(self) -> bool: + with self.db.connect(charset=self.charset) as con: c = con.cursor() - grants = c.execute('select count(*) from rdb$user_privileges where rdb$user = ?', [self.name]).fetchone()[0] + name = self.name[1:-1] if self.name.startswith('"') else self.name + cmd = f"SELECT COUNT(*) FROM SEC$USERS WHERE SEC$USER_NAME = '{name}'" + if self.__plugin is not None: + cmd += f" AND SEC$PLUGIN = '{self.__plugin}'" + cnt = c.execute(cmd).fetchone()[0] + return cnt > 0 + def create(self) -> None: + if self.exists(): + self.drop() + with self.db.connect(charset=self.charset) as con: + con.execute_immediate(f"CREATE USER {self.name} PASSWORD '{self.__password}' {'ACTIVE' if self.__active else 'INACTIVE'}{self.__p}{' GRANT ADMIN ROLE' if self.__admin else ''}") + if self.__first_name is not None: + con.execute_immediate(f"ALTER USER {self.name} SET FIRSTNAME '{self.__first_name}'{self.__p}") + if self.__middle_name is not None: + con.execute_immediate(f"ALTER USER {self.name} SET MIDDLENAME '{self.__middle_name}'{self.__p}") + if self.__last_name is not None: + con.execute_immediate(f"ALTER USER {self.name} SET LASTNAME '{self.__last_name}'{self.__p}") + if self.__tags is not None: + tags = ', '.join([f"{name} = '{value}'" for name, value in self.__tags.items()]) + con.execute_immediate(f"ALTER USER {self.name} SET TAGS ({tags}){self.__p}") + con.commit() + print(f"CREATE user: {self.name} PLUGIN: {self.plugin}") + def drop(self) -> None: + with self.db.connect(charset=self.charset) as con: + c = con.cursor() + grants = c.execute('select count(*) from rdb$user_privileges where rdb$user = ?', + [self.name]).fetchone()[0] if grants > 0: c.execute(f'revoke all on all from {self.name}') con.commit() - self.server.user.delete(self.name) - #print(f"User {self.name} dropped") - self._exit_encode() - def update(self, **kwargs) -> None: - self._enter_encode() - self.server.user.update(user_name=self.name, **kwargs) - self._exit_encode() + # + print(f"DROP user: {self.name} PLUGIN: {self.plugin}") + con.execute_immediate(f"DROP USER {self.name}{self.__p}") + con.commit() + def set_tag(self, name: str, *, value: str) -> None: + with self.db.connect(charset=self.charset) as con: + con.execute_immediate(f"ALTER USER {self.name} TAGS ({name} = '{value}'){self.__p}") + con.commit() + def drop_tag(self, name: str) -> None: + with self.db.connect(charset=self.charset) as con: + con.execute_immediate(f"ALTER USER {self.name} TAGS (DROP {name}){self.__p}") + con.commit() + @property + def name(self) -> str: + return self.__name + @property + def plugin(self) -> str: + if self.__plugin is None: + with self.db.connect() as con: + c = con.cursor() + self.__plugin = c.execute(f'SELECT SEC$PLUGIN FROM SEC$USERS').fetchone()[0].strip() + return self.__plugin + @property + def password(self) -> str: + return self.__password + @password.setter + def password(self, value: str) -> None: + with self.db.connect(charset=self.charset) as con: + con.execute_immediate(f"ALTER USER {self.__name} SET PASSWORD '{value}'") + con.commit() + self.__password = value + @property + def first_name(self) -> str: + return self.__first_name + @first_name.setter + def first_name(self, value: str) -> None: + with self.db.connect(charset=self.charset) as con: + con.execute_immediate(f"ALTER USER {self.__name} SET FIRSTNAME '{value}'") + con.commit() + self.__first_name = value + @property + def middle_name(self) -> str: + return self.__middle_name + @middle_name.setter + def middle_name(self, value: str) -> None: + with self.db.connect(charset=self.charset) as con: + con.execute_immediate(f"ALTER USER {self.__name} SET MIDDLENAME '{value}'") + con.commit() + self.__middle_name = value + @property + def last_name(self) -> str: + return self.__last_name + @last_name.setter + def last_name(self, value: str) -> None: + with self.db.connect(charset=self.charset) as con: + con.execute_immediate(f"ALTER USER {self.__name} SET LASTNAME '{value}'") + con.commit() + self.__last_name = value + @property + def tags(self) -> Dict[str]: + return dict(self.__tags) -def user_factory(*, name: str, password: str, encoding: str=None): +def user_factory(db_fixture_name: str, *, name: str, password: str, plugin: str=None, + charset: str='utf8', active: bool=True, tags: Dict[str]=None, + first_name: str=None, middle_name: str=None, last_name: str=None, + admin: bool=False): @pytest.fixture - def user_fixture(request: FixtureRequest, firebird_server) -> User: - user = User(name, password, firebird_server, encoding) - user.create() - yield user - user.drop() + def user_fixture(request: FixtureRequest) -> User: + with User(request.getfixturevalue(db_fixture_name), name, password, plugin, + charset, active, tags, first_name, middle_name, last_name, admin) as user: + yield user return user_fixture @@ -444,6 +554,19 @@ class TraceSession: pytest.fail('Trace thread still alive') self.act.trace_log = self.output +class ServerKeeper: + def __init__(self, act: Action, server: str): + self.act: Action = act + self.server:str = server + self.old_value = None + def __enter__(self): + self.old_value = self.act.vars['server'] + self.act.vars['server'] = self.server + return self + def __exit__(self, exc_type, exc_value, traceback) -> None: + if self.old_value is not None: + self.act.vars['server'] = self.old_value + class Envar: def __init__(self, name: str, value: str): self.name = name @@ -458,43 +581,54 @@ class Envar: if self.old_value is not None: os.environ[self.name] = self.old_value -class Role: - def __init__(self, name: str, database: Database, charset: str=None): - self.name: str = name - self.db: Connection = database.connect(charset=charset) - def __enter__(self) -> Role: - try: - self.db.execute_immediate(f'create role {self.name}') - self.db.commit() - except: - pass - return self - def __exit__(self, exc_type, exc_value, traceback) -> None: - self.db.execute_immediate(f'drop role {self.name}') - self.db.commit() - self.db.close() - -def db_factory(*, filename: str='test.fdb', init: str=None, from_backup: str=None, - copy_of: str=None, page_size: int=None, sql_dialect: int=None, - charset: str=None, user: str=None, password: str=None, - do_not_create: bool=False): +def envar_factory(*, name: str, value: str): @pytest.fixture - def database_fixture(request: FixtureRequest, db_path) -> Database: - db = Database(db_path, filename, user, password) - if not do_not_create: - if from_backup is None and copy_of is None: - db.create(page_size, sql_dialect, charset) - elif from_backup is not None: - db.restore(from_backup) - elif copy_of is not None: - db.copy(copy_of) - if init is not None: - db.init(init) - yield db - db.drop() + def envar_fixture() -> User: + with Envar(name, value) as role: + yield role - return database_fixture + return envar_fixture + +class Role: + def __init__(self, database: Database, name: str, charset: str, do_not_create: bool): + self.db: Database = database + self.name: str = name if name.startswith('"') else name.upper() + self.charset = charset + self.do_not_create: bool = do_not_create + def __enter__(self) -> Role: + if not self.exists() and not self.do_not_create: + self.create() + return self + def __exit__(self, exc_type, exc_value, traceback) -> None: + if self.exists(): + self.drop() + def create(self) -> None: + with self.db.connect(charset=self.charset) as con: + con.execute_immediate(f'CREATE ROLE {self.name}') + con.commit() + print(f"CREATE role: {self.name}") + def drop(self) -> None: + with self.db.connect(charset=self.charset) as con: + con.execute_immediate(f'DROP ROLE {self.name}') + con.commit() + print(f"DROP role: {self.name}") + def exists(self) -> bool: + with self.db.connect(charset=self.charset) as con: + c = con.cursor() + name = self.name[1:-1] if self.name.startswith('"') else self.name + cmd = f"SELECT COUNT(*) FROM RDB$ROLES WHERE RDB$ROLE_NAME = '{name}'" + cnt = c.execute(cmd).fetchone()[0] + return cnt > 0 + +def role_factory(db_fixture_name: str, *, name: str, charset: str='utf8', do_not_create: bool=False): + + @pytest.fixture + def role_fixture(request: FixtureRequest) -> User: + with Role(request.getfixturevalue(db_fixture_name), name, charset, do_not_create) as role: + yield role + + return role_fixture class Action: def __init__(self, db: Database, script: str, substitutions: List[str], outfile: Path, @@ -650,7 +784,7 @@ class Action: if self.stderr: err_file.write_text(self.stderr, encoding=io_enc) def gbak(self, *, switches: List[str]=None, charset: str='utf8', io_enc: str=None, - input: str=None, credentials: bool=True) -> None: + input: str=None, credentials: bool=True, combine_output: bool=False) -> None: __tracebackhide__ = True out_file: Path = self.outfile.with_suffix('.out') err_file: Path = self.outfile.with_suffix('.err') @@ -669,9 +803,13 @@ class Action: params.extend(['-user', self.db.user, '-password', self.db.password]) if switches: params.extend(switches) - result: CompletedProcess = run(params, input=input, - encoding=io_enc, capture_output=True) - if result.returncode and not bool(self.expected_stderr): + if combine_output: + result: CompletedProcess = run(params, input=input, + encoding=io_enc, stdout=PIPE, stderr=STDOUT) + else: + result: CompletedProcess = run(params, input=input, + encoding=io_enc, capture_output=True) + if result.returncode and not (bool(self.expected_stderr) or combine_output): print(f"-- gbak stdout {'-' * 20}") print(result.stdout) print(f"-- gbak stderr {'-' * 20}") @@ -686,7 +824,8 @@ class Action: out_file.write_text(self.stdout, encoding=io_enc) if self.stderr: err_file.write_text(self.stderr, encoding=io_enc) - def nbackup(self, *, switches: List[str], charset: str='utf8', credentials: bool=True) -> None: + def nbackup(self, *, switches: List[str], charset: str='utf8', credentials: bool=True, + combine_output: bool=False) -> None: __tracebackhide__ = True out_file: Path = self.outfile.with_suffix('.out') err_file: Path = self.outfile.with_suffix('.err') @@ -703,9 +842,11 @@ class Action: params.extend(switches) if credentials: params.extend(['-user', self.db.user, '-password', self.db.password]) - result: CompletedProcess = run(params, - encoding=self.db.io_enc, capture_output=True) - if result.returncode and not bool(self.expected_stderr): + if combine_output: + result: CompletedProcess = run(params, encoding=self.db.io_enc, stdout=PIPE, stderr=STDOUT) + else: + result: CompletedProcess = run(params, encoding=self.db.io_enc, capture_output=True) + if result.returncode and not (bool(self.expected_stderr) or combine_output): print(f"-- nbackup stdout {'-' * 20}") print(result.stdout) print(f"-- nbackup stderr {'-' * 20}") @@ -875,6 +1016,9 @@ class Action: f2 = con1.info.get_info(DbInfoCode.FETCHES) result = 'SC' if f1 == f2 else 'SS' return result + def get_dsn(self, filename: Union[str, Path], protocol: str=None) -> str: + return _connect_helper('', self.host, self.port, str(filename), + protocol if protocol else self.protocol) def reset(self) -> None: self.return_code: int = 0 self._clean_stdout = None @@ -911,8 +1055,7 @@ class Action: for fieldDesc in cursor.description: print(f'{prefix}{fieldDesc[DESCRIPTION_NAME].ljust(32)}{row[i]}') i += 1 - def test_role(self, name: str, charset: str=None) -> Role: - return Role(name, self.db, charset) + print() def trace(self, *, db_events: List[str]=None, svc_events: List[str]=None, config: List[str]=None, keep_log: bool=True, encoding: str='ascii', database: str=None) -> TraceSession: @@ -933,8 +1076,6 @@ class Action: def trace_to_stdout(self, *, upper: bool=False) -> None: log = ''.join(self.trace_log) self.stdout = log.upper() if upper else log - def envar(self, name: str, value: str) -> Envar: - return Envar(name, value) @property def clean_stdout(self) -> str: if self._clean_stdout is None: @@ -958,6 +1099,24 @@ class Action: @property def vars(self) -> Dict[str]: return _vars_ + @property + def host(self) -> str: + return _vars_['host'] + @property + def port(self) -> str: + return _vars_['port'] + @property + def protocol(self) -> str: + return _vars_['protocol'] + @property + def security_db(self) -> str: + return _vars_['security-db'] + @property + def home_dir(self) -> Path: + return _vars_['home-dir'] + @property + def bin_dir(self) -> Path: + return _vars_['bin-dir'] def isql_act(db_fixture_name: str, script: str, *, substitutions: List[str]=None, charset: str='utf8'):