diff --git a/testgres/node.py b/testgres/node.py index 77538896..6d67d2cf 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -812,7 +812,8 @@ def psql(self, filename=None, dbname=None, username=None, - input=None): + input=None, + **variables): """ Execute a query using psql. @@ -822,9 +823,18 @@ def psql(self, dbname: database name to connect to. username: database user name. input: raw input to be passed. + **variables: vars to be set before execution. Returns: A tuple of (code, stdout, stderr). + + Examples: + >>> psql('select 1') + (0, b'1\n', b'') + >>> psql('postgres', 'select 2') + (0, b'2\n', b'') + >>> psql(query='select 3', ON_ERROR_STOP=1) + (0, b'3\n', b'') """ # Set default arguments @@ -843,6 +853,10 @@ def psql(self, dbname ] # yapf: disable + # set variables before execution + for key, value in iteritems(variables): + psql_params.extend(["--set", '{}={}'.format(key, value)]) + # select query source if query: psql_params.extend(("-c", query)) @@ -874,10 +888,15 @@ def safe_psql(self, query=None, **kwargs): username: database user name. input: raw input to be passed. + **kwargs are passed to psql(). + Returns: psql's output as str. """ + # force this setting + kwargs['ON_ERROR_STOP'] = 1 + ret, out, err = self.psql(query=query, **kwargs) if ret: raise QueryException((err or b'').decode('utf-8'), query)