diff --git a/firebird/qa/plugin.py b/firebird/qa/plugin.py index d7f4ba4f..6437aa43 100644 --- a/firebird/qa/plugin.py +++ b/firebird/qa/plugin.py @@ -45,7 +45,7 @@ import platform import difflib import pytest from _pytest.fixtures import FixtureRequest -from subprocess import run, CompletedProcess +from subprocess import run, CompletedProcess, PIPE, STDOUT from pathlib import Path #from configparser import ConfigParser, ExtendedInterpolation from packaging.specifiers import SpecifierSet @@ -89,7 +89,7 @@ def pytest_report_header(config): f" home: {_vars_['home-dir']}", f" bin: {_vars_['bin-dir']}", f" security db: {_vars_['security-db']}", - f" run slow test: {_vars_['runslow']}", + f" run slow tests: {_vars_['runslow']}", f" save test output: {_vars_['save-output']}", ] @@ -444,6 +444,20 @@ class TraceSession: pytest.fail('Trace thread still alive') self.act.trace_log = self.output +class Envar: + def __init__(self, name: str, value: str): + self.name = name + self.value = value + self.old_value = None + def __enter__(self) -> Envar: + if self.name in os.environ: + self.old_value = os.environ[self.name] + os.environ[self.name] = self.value + return self + def __exit__(self, exc_type, exc_value, traceback) -> None: + if self.old_value is not None: + os.environ[self.name] = self.old_value + class Role: def __init__(self, name: str, database: Database, charset: str=None): self.name: str = name @@ -745,7 +759,7 @@ class Action: err_file.write_text(self.stderr, encoding=io_enc) def isql(self, *, switches: List[str], charset: str='utf8', io_enc: str=None, input: str=None, input_file: Path=None, connect_db: bool=True, - credentials: bool=True) -> None: + credentials: bool=True, combine_output: bool=False) -> None: __tracebackhide__ = True out_file: Path = self.outfile.with_suffix('.out') err_file: Path = self.outfile.with_suffix('.err') @@ -768,9 +782,13 @@ class Action: params.extend(['-i', str(input_file)]) if connect_db: params.append(str(self.db.dsn)) - result: CompletedProcess = run(params, input=input, - encoding=io_enc, capture_output=True) - if result.returncode and not bool(self.expected_stderr): + if combine_output: + result: CompletedProcess = run(params, input=input, + encoding=io_enc, stdout=PIPE, stderr=STDOUT) + else: + result: CompletedProcess = run(params, input=input, + encoding=io_enc, capture_output=True) + if result.returncode and not (bool(self.expected_stderr) or combine_output): print(f"-- isql stdout {'-' * 20}") print(result.stdout) print(f"-- isql stderr {'-' * 20}") @@ -915,6 +933,8 @@ class Action: def trace_to_stdout(self, *, upper: bool=False) -> None: log = ''.join(self.trace_log) self.stdout = log.upper() if upper else log + def envar(self, name: str, value: str) -> Envar: + return Envar(name, value) @property def clean_stdout(self) -> str: if self._clean_stdout is None: