6
0
mirror of https://github.com/FirebirdSQL/firebird-qa.git synced 2025-01-22 21:43:06 +01:00

Explicit Optional typing + encodinf/ecoding_errors for server connection and trace

This commit is contained in:
Pavel Císař 2022-04-19 11:59:48 +02:00
parent b64188c9b7
commit 8b1f46f34e

View File

@ -487,9 +487,9 @@ class Database:
Do not create instances of this class directly! Use **only** fixtures created by `db_factory`.
"""
def __init__(self, path: Path, filename: str, user: str=None, password: str=None,
charset: str=None, debug: str='', config_name: str='pytest',
utf8filename: bool=False):
def __init__(self, path: Path, filename: str, user: Optional[str]=None,
password: Optional[str]=None, charset: Optional[str]=None, debug: str='',
config_name: str='pytest', utf8filename: bool=False):
#: firebird-driver database configuration name
self.config_name: str = config_name
if driver_config.get_database(config_name) is None:
@ -520,11 +520,12 @@ class Database:
}
srv_conf = driver_config.get_server(_vars_['server'])
#: User name
self.user: str = srv_conf.user.value if user is None else user
self.user: Optional[str] = srv_conf.user.value if user is None else user
#: User password
self.password: str = srv_conf.password.value if password is None else password
def _make_config(self, *, page_size: int=None, sql_dialect: int=None, charset: str=None,
user: str=None, password: str=None) -> None:
self.password: Optional[str] = srv_conf.password.value if password is None else password
def _make_config(self, *, page_size: Optional[int]=None, sql_dialect: Optional[int]=None,
charset: Optional[str]=None, user: Optional[str]=None,
password: Optional[str]=None) -> None:
"""Helper method that sets `firebird-driver`_ configuration for `pytest` database
to work with this particular test database instance. Used in methods that need to
call `~firebird.driver.core.connect` or `~firebird.driver.core.create_database`.
@ -550,7 +551,7 @@ class Database:
"""Returns `firebird-driver`_ configuration for test database.
"""
return driver_config.get_database(self.config_name)
def create(self, page_size: int=None, sql_dialect: int=None) -> None:
def create(self, page_size: Optional[int]=None, sql_dialect: Optional[int]=None) -> None:
"""Create the test database.
Arguments:
@ -667,10 +668,11 @@ class Database:
pass
if self.db_path.is_file():
self.db_path.unlink(missing_ok=True)
def connect(self, *, user: str=None, password: str=None, role: str=None, no_gc: bool=None,
no_db_triggers: bool=None, dbkey_scope: DBKeyScope=None,
session_time_zone: str=None, charset: str=None, sql_dialect: int=None,
auth_plugin_list: str=None) -> Connection:
def connect(self, *, user: Optional[str]=None, password: Optional[str]=None,
role: Optional[str]=None, no_gc: Optional[bool]=None,
no_db_triggers: Optional[bool]=None, dbkey_scope: Optional[DBKeyScope]=None,
session_time_zone: Optional[str]=None, charset: Optional[str]=None,
sql_dialect: Optional[int]=None, auth_plugin_list: Optional[str]=None) -> Connection:
"""Create new connection to test database.
Arguments:
@ -710,10 +712,12 @@ class Database:
with connect_server(_vars_['server']) as srv:
srv.database.set_write_mode(database=self.db_path, mode=DbWriteMode.SYNC)
def db_factory(*, filename: str='test.fdb', init: str=None, from_backup: str=None,
copy_of: str=None, page_size: int=None, sql_dialect: int=None,
charset: str=None, user: str=None, password: str=None,
do_not_create: bool=False, do_not_drop: bool=False, async_write: bool=True,
def db_factory(*, filename: str='test.fdb', init: Optional[str]=None,
from_backup: Optional[str]=None, copy_of: Optional[str]=None,
page_size: Optional[int]=None, sql_dialect: Optional[int]=None,
charset: Optional[str]=None, user: Optional[str]=None,
password: Optional[str]=None, do_not_create: bool=False,
do_not_drop: bool=False, async_write: bool=True,
config_name: str='pytest', utf8filename: bool=False):
"""Factory function that returns :doc:`fixture <pytest:explanation/fixtures>` providing
the `Database` instance.
@ -825,9 +829,9 @@ class User:
Fixture created by `user_factory` does this automatically.
"""
def __init__(self, db: Database, *, name: str, password: str, plugin: str, charset: str,
active: bool=True, tags: Dict[str]=None, first_name: str=None,
middle_name: str=None, last_name: str=None, admin: bool=False,
do_not_create: bool=False):
active: bool=True, tags: Optional[Dict[str]]=None,
first_name: Optional[str]=None, middle_name: Optional[str]=None,
last_name: Optional[str]=None, admin: bool=False, do_not_create: bool=False):
#: Database used to manage test user.
self.db: Database = db
self.__name: str = name if name.startswith('"') else name.upper()
@ -837,10 +841,10 @@ class User:
#: Firebird CHARACTER SET used for connections that manage this user.
self.charset: str = charset
self.__active: bool = active
self.__first_name: str = first_name
self.__middle_name: str = middle_name
self.__last_name: str = last_name
self.__tags: Dict[str] = tags
self.__first_name: Optional[str] = first_name
self.__middle_name: Optional[str] = middle_name
self.__last_name: Optional[str] = last_name
self.__tags: Dict[str] = list() if tags is None else list(tags)
self.__admin: bool = admin
self.__create: bool = not do_not_create
def __enter__(self) -> Role:
@ -984,10 +988,10 @@ class User:
"User tags [R/O]"
return dict(self.__tags)
def user_factory(db_fixture_name: str, *, name: str, password: str='', plugin: str=None,
charset: str='utf8', active: bool=True, tags: Dict[str]=None,
first_name: str=None, middle_name: str=None, last_name: str=None,
admin: bool=False, do_not_create: bool=False):
def user_factory(db_fixture_name: str, *, name: str, password: str='', plugin: Optional[str]=None,
charset: str='utf8', active: bool=True, tags: Optional[Dict[str]]=None,
first_name: Optional[str]=None, middle_name: Optional[str]=None,
last_name: Optional[str]=None, admin: bool=False, do_not_create: bool=False):
"""Factory function that returns :doc:`fixture <pytest:explanation/fixtures>` providing
the `User` instance.
@ -1029,7 +1033,7 @@ def user_factory(db_fixture_name: str, *, name: str, password: str='', plugin: s
return user_fixture
def trace_thread(act: Action, b: Barrier, cfg: List[str], output: List[str], keep_log: bool,
encoding: str):
encoding: Optional[str], encoding_errors: Optional[str]):
"""Function used by `TraceSession` for execution in separate thread to run trace session.
Arguments:
@ -1039,6 +1043,7 @@ def trace_thread(act: Action, b: Barrier, cfg: List[str], output: List[str], kee
output: List used to store trace session output.
keep_log: When `True`, the trace session output is discarded.
encoding: Encoding for trace session output.
encoding_errors: Error handler for trace session output encoding.
"""
with act.connect_server() as srv:
srv.encoding = encoding
@ -1056,6 +1061,7 @@ class TraceSession:
config: Trace session configuration.
keep_log: When `False`, the trace session output is discarded.
encoding: Encoding for trace session output.
encoding_errors: Error handler for trace session output encoding.
.. important::
@ -1072,7 +1078,8 @@ class TraceSession:
The trace session is automatically stopped on exit from `with` context, and trace
output is copied to `Action.trace_log` attribute.
"""
def __init__(self, act: Action, config: List[str], keep_log: bool=True, encoding: str='ascii'):
def __init__(self, act: Action, config: List[str], keep_log: bool=True,
encoding: Optional[str]=None, encoding_errors: Optional[str]=None):
#: Action instance.
self.act: Action = act
#: Trace session configuration.
@ -1084,12 +1091,15 @@ class TraceSession:
#: Thread used to execute trace session
self.trace_thread: Thread = None
#: Encoding for trace session output.
self.encoding: str = encoding
self.encoding: Optional[str] = encoding
#: Encoding errors handling for trace session output.
self.encoding_errors: Optional[str] = encoding_errors
def __enter__(self) -> TraceSession:
b = Barrier(2)
self.trace_thread = Thread(target=trace_thread, args=[self.act, b, self.config,
self.output, self.keep_log,
self.encoding])
self.encoding,
self.encoding_errors])
self.trace_thread.start()
b.wait()
return self
@ -1368,8 +1378,8 @@ class Action:
if remove_white:
value = self.strip_white(value)
return value
def execute(self, *, do_not_connect: bool=False, charset: str=None, io_enc: str=None,
combine_output: bool=False) -> None:
def execute(self, *, do_not_connect: bool=False, charset: Optional[str]=None,
io_enc: Optional[str]=None, combine_output: bool=False) -> None:
"""Execute test `script` using ISQL.
Arguments:
@ -1425,7 +1435,8 @@ class Action:
out_file.write_text(self.stdout, encoding='utf8')
if self.stderr:
err_file.write_text(self.stderr, encoding='utf8')
def extract_meta(self, *, from_db: Database=None, charset: str=None, io_enc: str=None) -> str:
def extract_meta(self, *, from_db: Optional[Database]=None, charset: Optional[str]=None,
io_enc: Optional[str]=None) -> str:
"""Call ISQL to extract database metadata (using `-x` option).
Arguments:
@ -1475,8 +1486,8 @@ class Action:
if self.stderr:
err_file.write_text(self.stderr, encoding=io_enc)
return self.stdout
def gstat(self, *, switches: List[str], charset: str=None, io_enc: str=None,
connect_db: bool=True, credentials: bool=True) -> None:
def gstat(self, *, switches: List[str], charset: Optional[str]=None,
io_enc: Optional[str]=None, connect_db: bool=True, credentials: bool=True) -> None:
"""Run `gstat` utility.
Arguments:
@ -1537,8 +1548,8 @@ class Action:
out_file.write_text(self.stdout, encoding=io_enc)
if self.stderr:
err_file.write_text(self.stderr, encoding=io_enc)
def gsec(self, *, switches: List[str]=None, charset: str=None, io_enc: str=None,
input: str=None, credentials: bool=True) -> None:
def gsec(self, *, switches: Optional[List[str]]=None, charset: Optional[str]=None,
io_enc: Optional[str]=None, input: Optional[str]=None, credentials: bool=True) -> None:
"""Run `gstat` utility.
Arguments:
@ -1582,7 +1593,7 @@ class Action:
if io_enc is None:
io_enc = CHARSET_MAP[charset]
params = [_vars_['gsec']]
if switches:
if switches is not None:
params.extend(switches)
if credentials:
params.extend(['-user', self.db.user, '-password', self.db.password])
@ -1601,8 +1612,8 @@ class Action:
out_file.write_text(self.stdout, encoding=io_enc)
if self.stderr:
err_file.write_text(self.stderr, encoding=io_enc)
def gbak(self, *, switches: List[str]=None, charset: str=None, io_enc: str=None,
credentials: bool=True, combine_output: bool=False) -> None:
def gbak(self, *, switches: Optional[List[str]]=None, charset: Optional[str]=None,
io_enc: Optional[str]=None, credentials: bool=True, combine_output: bool=False) -> None:
"""Run `gbak` utility.
Arguments:
@ -1646,7 +1657,7 @@ class Action:
params = [_vars_['gbak']]
if credentials:
params.extend(['-user', self.db.user, '-password', self.db.password])
if switches:
if switches is not None:
params.extend(switches)
if combine_output:
result: CompletedProcess = run(params, encoding=io_enc, stdout=PIPE, stderr=STDOUT)
@ -1665,8 +1676,8 @@ class Action:
out_file.write_text(self.stdout, encoding=io_enc)
if self.stderr:
err_file.write_text(self.stderr, encoding=io_enc)
def nbackup(self, *, switches: List[str], charset: str=None, io_enc: str=None,
credentials: bool=True, combine_output: bool=False) -> None:
def nbackup(self, *, switches: List[str], charset: Optional[str]=None,
io_enc: Optional[str]=None, credentials: bool=True, combine_output: bool=False) -> None:
"""Run `nbackup` utility.
Arguments:
@ -1728,8 +1739,8 @@ class Action:
out_file.write_text(self.stdout, encoding=io_enc)
if self.stderr:
err_file.write_text(self.stderr, encoding=io_enc)
def gfix(self, *, switches: List[str]=None, charset: str=None, io_enc: str=None,
credentials: bool=True, combine_output: bool=False) -> None:
def gfix(self, *, switches: Optional[List[str]]=None, charset: Optional[str]=None,
io_enc: Optional[str]=None, credentials: bool=True, combine_output: bool=False) -> None:
"""Run `gfix` utility.
Arguments:
@ -1771,7 +1782,7 @@ class Action:
if io_enc is None:
io_enc = CHARSET_MAP[charset]
params = [_vars_['gfix']]
if switches:
if switches is not None:
params.extend(switches)
if credentials:
params.extend(['-user', self.db.user, '-password', self.db.password])
@ -1792,9 +1803,11 @@ class Action:
out_file.write_text(self.stdout, encoding=io_enc)
if self.stderr:
err_file.write_text(self.stderr, encoding=io_enc)
def isql(self, *, switches: List[str], charset: str=None, io_enc: str=None,
input: str=None, input_file: Path=None, connect_db: bool=True,
credentials: bool=True, combine_output: bool=False, use_db: Database=None) -> None:
def isql(self, *, switches: Optional[List[str]]=None, charset: Optional[str]=None,
io_enc: Optional[str]=None, input: Optional[str]=None,
input_file: Optional[Path]=None, connect_db: bool=True,
credentials: bool=True, combine_output: bool=False,
use_db: Optional[Database]=None) -> None:
"""Run `isql` utility.
Arguments:
@ -1843,6 +1856,7 @@ class Action:
io_enc = CHARSET_MAP[charset]
if credentials:
params.extend(['-user', db.user, '-password', db.password])
if switches is not None:
params.extend(switches)
if input_file is not None:
params.extend(['-i', str(input_file)])
@ -1867,8 +1881,8 @@ class Action:
out_file.write_text(self.stdout, encoding=io_enc)
if self.stderr:
err_file.write_text(self.stderr, encoding=io_enc)
def svcmgr(self, *, switches: List[str]=None, charset: str=None, io_enc: str=None,
connect_mngr: bool=True) -> None:
def svcmgr(self, *, switches: Optional[List[str]]=None, charset: Optional[str]=None,
io_enc: Optional[str]=None, connect_mngr: bool=True) -> None:
"""Run `fbsvcmgr` utility.
Arguments:
@ -1914,7 +1928,7 @@ class Action:
if connect_mngr:
params.extend([f"{_vars_['host']}:service_mgr" if _vars_['host'] else 'service_mgr',
'user', self.db.user, 'password', self.db.password])
if switches:
if switches is not None:
params.extend(switches)
result: CompletedProcess = run(params, encoding=io_enc, capture_output=True)
if result.returncode and not bool(self.expected_stderr):
@ -1930,7 +1944,9 @@ class Action:
out_file.write_text(self.stdout, encoding=io_enc)
if self.stderr:
err_file.write_text(self.stderr, encoding=io_enc)
def connect_server(self, *, user: str=None, password: str=None, role: str=None) -> Server:
def connect_server(self, *, user: Optional[str]=None, password: Optional[str]=None,
role: Optional[str]=None, encoding: Optional[str]=None,
encoding_errors: Optional[str]=None) -> Server:
"""Returns `~firebird.driver.Server` instance connected to service manager of tested
server.
@ -1938,6 +1954,8 @@ class Action:
user: User name [Default taken from server configuration].
password: User password [Default taken from server configuration].
role: User role.
encoding: Encoding for service output override.
encoding_errors: Error handler for service output encoding override.
.. tip::
@ -1949,7 +1967,7 @@ class Action:
return connect_server(_vars_['server'],
user=_vars_['user'] if user is None else user,
password=_vars_['password'] if password is None else password,
role=role)
role=role, encoding=encoding, encoding_errors=encoding_errors)
def get_firebird_log(self) -> List[str]:
"""Returns content of Firebird server log.
"""
@ -1969,7 +1987,7 @@ class Action:
"""Returns server architecture: `SuperServer`, `Classic`, `SuperClassic` or `Embedded`.
"""
return _vars_['server-arch']
def get_dsn(self, filename: Union[str, Path], protocol: NetProtocol=None) -> str:
def get_dsn(self, filename: Union[str, Path], protocol: Optional[NetProtocol]=None) -> str:
"""Returns DSN for connection to database.
Arguments:
@ -2027,9 +2045,10 @@ class Action:
print(f'{prefix}{fieldDesc[DESCRIPTION_NAME].ljust(32)}{row[i]}')
i += 1
print()
def trace(self, *, db_events: List[str]=None, svc_events: List[str]=None,
config: List[str]=None, keep_log: bool=True, encoding: str='ascii',
database: str=None) -> TraceSession:
def trace(self, *, db_events: Optional[List[str]]=None, svc_events: Optional[List[str]]=None,
config: Optional[List[str]]=None, keep_log: bool=True,
encoding: Optional[str]=None, encoding_errors: Optional[str]=None,
database: Optional[str]=None) -> TraceSession:
"""Run Firebird trace session.
Arguments:
@ -2039,6 +2058,7 @@ class Action:
keep_log: When `False`, the trace session output is discarded. Otherwise is stored
in `.trace_log`.
encoding: Encoding used to decode trace session output returned by server.
encoding_errors: Handler for encoding errors used to decode trace session output returned by server.
database: Traced database filename specification. If value is not specified, uses
primary test database.
@ -2062,7 +2082,8 @@ class Action:
# Actions that are not traced
"""
if config is not None:
return TraceSession(self, config, keep_log=keep_log, encoding=encoding)
return TraceSession(self, config, keep_log=keep_log, encoding=encoding,
encoding_errors=encoding_errors)
else:
config = []
if db_events:
@ -2075,7 +2096,8 @@ class Action:
config.extend(['services', '{', 'enabled = true', 'log_initfini = false'])
config.extend(svc_events)
config.append('}')
return TraceSession(self, config, keep_log=keep_log, encoding=encoding)
return TraceSession(self, config, keep_log=keep_log, encoding=encoding,
encoding_errors=encoding_errors)
def trace_to_stdout(self, *, upper: bool=False) -> None:
"""Copy trace session log to `.stdout`.
@ -2244,7 +2266,7 @@ class Action:
"Current execution platform identifier."
return _platform
def isql_act(db_fixture_name: str, script: str='', *, substitutions: Substitutions=None):
def isql_act(db_fixture_name: str, script: str='', *, substitutions: Optional[Substitutions]=None):
"""Factory function that returns :doc:`fixture <pytest:explanation/fixtures>` providing
the `Action` instance.