569 lines
19 KiB
Python
569 lines
19 KiB
Python
# WARNING: this file is auto-generated by 'async_to_sync.py'
|
|
# from the original file 'connection_async.py'
|
|
# DO NOT CHANGE! Change the original file instead.
|
|
"""
|
|
Psycopg connection object (sync version)
|
|
"""
|
|
|
|
# Copyright (C) 2020 The Psycopg Team
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import warnings
|
|
from time import monotonic
|
|
from types import TracebackType
|
|
from typing import TYPE_CHECKING, Any, cast, overload
|
|
from contextlib import contextmanager
|
|
from collections.abc import Generator, Iterator
|
|
|
|
from . import errors as e
|
|
from . import pq, waiting
|
|
from .abc import RV, AdaptContext, ConnDict, ConnParam, Params, PQGen, Query
|
|
from .abc import QueryNoTemplate
|
|
from ._tpc import Xid
|
|
from .rows import Row, RowFactory, args_row, tuple_row
|
|
from .adapt import AdaptersMap
|
|
from ._enums import IsolationLevel
|
|
from .cursor import Cursor
|
|
from ._compat import Self, Template
|
|
from ._acompat import Lock
|
|
from .conninfo import conninfo_attempts, conninfo_to_dict, make_conninfo
|
|
from .conninfo import timeout_from_conninfo
|
|
from ._pipeline import Pipeline
|
|
from .generators import notifies
|
|
from .transaction import Transaction
|
|
from ._capabilities import capabilities
|
|
from ._server_cursor import ServerCursor
|
|
from ._conninfo_utils import gssapi_requested
|
|
from ._connection_base import BaseConnection, CursorRow, Notify
|
|
|
|
if TYPE_CHECKING:
|
|
from .pq.abc import PGconn
|
|
|
|
_WAIT_INTERVAL = 0.1
|
|
|
|
TEXT = pq.Format.TEXT
|
|
BINARY = pq.Format.BINARY
|
|
|
|
IDLE = pq.TransactionStatus.IDLE
|
|
ACTIVE = pq.TransactionStatus.ACTIVE
|
|
INTRANS = pq.TransactionStatus.INTRANS
|
|
|
|
_INTERRUPTED = KeyboardInterrupt
|
|
|
|
logger = logging.getLogger("psycopg")
|
|
|
|
|
|
class Connection(BaseConnection[Row]):
|
|
"""
|
|
Wrapper for a connection to the database.
|
|
"""
|
|
|
|
__module__ = "psycopg"
|
|
|
|
cursor_factory: type[Cursor[Row]]
|
|
server_cursor_factory: type[ServerCursor[Row]]
|
|
row_factory: RowFactory[Row]
|
|
_pipeline: Pipeline | None
|
|
|
|
def __init__(
|
|
self,
|
|
pgconn: PGconn,
|
|
row_factory: RowFactory[Row] = cast(RowFactory[Row], tuple_row),
|
|
):
|
|
super().__init__(pgconn)
|
|
self.row_factory = row_factory
|
|
self.lock = Lock()
|
|
self.cursor_factory = Cursor
|
|
self.server_cursor_factory = ServerCursor
|
|
|
|
@classmethod
|
|
def connect(
|
|
cls,
|
|
conninfo: str = "",
|
|
*,
|
|
autocommit: bool = False,
|
|
prepare_threshold: int | None = 5,
|
|
context: AdaptContext | None = None,
|
|
row_factory: RowFactory[Row] | None = None,
|
|
cursor_factory: type[Cursor[Row]] | None = None,
|
|
**kwargs: ConnParam,
|
|
) -> Self:
|
|
"""
|
|
Connect to a database server and return a new `Connection` instance.
|
|
"""
|
|
|
|
params = cls._get_connection_params(conninfo, **kwargs)
|
|
timeout = timeout_from_conninfo(params)
|
|
rv = None
|
|
attempts = conninfo_attempts(params)
|
|
conn_errors: list[tuple[e.Error, str]] = []
|
|
for attempt in attempts:
|
|
tdescr = (attempt.get("host"), attempt.get("port"), attempt.get("hostaddr"))
|
|
descr = "host: %r, port: %r, hostaddr: %r" % tdescr
|
|
logger.debug("connection attempt: %s", descr)
|
|
try:
|
|
conninfo = make_conninfo("", **attempt)
|
|
gen = cls._connect_gen(conninfo, timeout=timeout)
|
|
rv = waiting.wait_conn(gen, interval=_WAIT_INTERVAL)
|
|
except e.Error as ex:
|
|
logger.debug("connection failed: %s: %s", descr, str(ex))
|
|
conn_errors.append((ex, descr))
|
|
except e._NO_TRACEBACK as ex:
|
|
raise ex.with_traceback(None)
|
|
else:
|
|
logger.debug("connection succeeded: %s", descr)
|
|
break
|
|
|
|
if not rv:
|
|
last_ex = conn_errors[-1][0]
|
|
if len(conn_errors) == 1:
|
|
raise last_ex.with_traceback(None)
|
|
|
|
# Create a new exception with the same type as the last one, containing
|
|
# all attempt errors while preserving backward compatibility.
|
|
lines = [str(last_ex)]
|
|
lines.append("Multiple connection attempts failed. All failures were:")
|
|
lines.extend((f"- {descr}: {error}" for error, descr in conn_errors))
|
|
raise type(last_ex)("\n".join(lines)).with_traceback(None)
|
|
|
|
if (
|
|
capabilities.has_used_gssapi()
|
|
and rv.pgconn.used_gssapi
|
|
and (not gssapi_requested(params))
|
|
):
|
|
warnings.warn(
|
|
"the connection was obtained using the GSSAPI relying on the 'gssencmode=prefer' libpq default. The value for this default might be 'disable' instead, in certain psycopg[binary] implementations. If you wish to interact with the GSSAPI reliably please set the 'gssencmode' parameter in the connection string or the 'PGGSSENCMODE' environment variable to 'prefer' or 'require'",
|
|
RuntimeWarning,
|
|
)
|
|
|
|
rv._autocommit = bool(autocommit)
|
|
if row_factory:
|
|
rv.row_factory = row_factory
|
|
if cursor_factory:
|
|
rv.cursor_factory = cursor_factory
|
|
if context:
|
|
rv._adapters = AdaptersMap(context.adapters)
|
|
rv.prepare_threshold = prepare_threshold
|
|
return rv
|
|
|
|
def __enter__(self) -> Self:
|
|
return self
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> None:
|
|
if self.closed:
|
|
return
|
|
|
|
if exc_type:
|
|
# try to rollback, but if there are problems (connection in a bad
|
|
# state) just warn without clobbering the exception bubbling up.
|
|
try:
|
|
self.rollback()
|
|
except Exception as exc2:
|
|
logger.warning("error ignored in rollback on %s: %s", self, exc2)
|
|
else:
|
|
self.commit()
|
|
|
|
# Close the connection only if it doesn't belong to a pool.
|
|
if not getattr(self, "_pool", None):
|
|
self.close()
|
|
|
|
@classmethod
|
|
def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
|
|
"""Manipulate connection parameters before connecting."""
|
|
return conninfo_to_dict(conninfo, **kwargs)
|
|
|
|
def close(self) -> None:
|
|
"""Close the database connection."""
|
|
if self.closed:
|
|
return
|
|
|
|
pool = getattr(self, "_pool", None)
|
|
if pool and getattr(pool, "close_returns", False):
|
|
pool.putconn(self)
|
|
return
|
|
|
|
self._closed = True
|
|
|
|
# TODO: maybe send a cancel on close, if the connection is ACTIVE?
|
|
|
|
self.pgconn.finish()
|
|
|
|
@overload
|
|
def cursor(self, *, binary: bool = False) -> Cursor[Row]: ...
|
|
|
|
@overload
|
|
def cursor(
|
|
self, *, binary: bool = False, row_factory: RowFactory[CursorRow]
|
|
) -> Cursor[CursorRow]: ...
|
|
|
|
@overload
|
|
def cursor(
|
|
self,
|
|
name: str,
|
|
*,
|
|
binary: bool = False,
|
|
scrollable: bool | None = None,
|
|
withhold: bool = False,
|
|
) -> ServerCursor[Row]: ...
|
|
|
|
@overload
|
|
def cursor(
|
|
self,
|
|
name: str,
|
|
*,
|
|
binary: bool = False,
|
|
row_factory: RowFactory[CursorRow],
|
|
scrollable: bool | None = None,
|
|
withhold: bool = False,
|
|
) -> ServerCursor[CursorRow]: ...
|
|
|
|
def cursor(
|
|
self,
|
|
name: str = "",
|
|
*,
|
|
binary: bool = False,
|
|
row_factory: RowFactory[Any] | None = None,
|
|
scrollable: bool | None = None,
|
|
withhold: bool = False,
|
|
) -> Cursor[Any] | ServerCursor[Any]:
|
|
"""
|
|
Return a new `Cursor` to send commands and queries to the connection.
|
|
"""
|
|
self._check_connection_ok()
|
|
|
|
if not row_factory:
|
|
row_factory = self.row_factory
|
|
|
|
cur: Cursor[Any] | ServerCursor[Any]
|
|
if name:
|
|
cur = self.server_cursor_factory(
|
|
self,
|
|
name=name,
|
|
row_factory=row_factory,
|
|
scrollable=scrollable,
|
|
withhold=withhold,
|
|
)
|
|
else:
|
|
cur = self.cursor_factory(self, row_factory=row_factory)
|
|
|
|
if binary:
|
|
cur.format = BINARY
|
|
|
|
return cur
|
|
|
|
@overload
|
|
def execute(
|
|
self,
|
|
query: QueryNoTemplate,
|
|
params: Params | None = None,
|
|
*,
|
|
prepare: bool | None = None,
|
|
binary: bool = False,
|
|
) -> Cursor[Row]: ...
|
|
|
|
@overload
|
|
def execute(
|
|
self, query: Template, *, prepare: bool | None = None, binary: bool = False
|
|
) -> Cursor[Row]: ...
|
|
|
|
def execute(
|
|
self,
|
|
query: Query,
|
|
params: Params | None = None,
|
|
*,
|
|
prepare: bool | None = None,
|
|
binary: bool = False,
|
|
) -> Cursor[Row]:
|
|
"""Execute a query and return a cursor to read its results."""
|
|
try:
|
|
cur = self.cursor()
|
|
if binary:
|
|
cur.format = BINARY
|
|
|
|
if isinstance(query, Template):
|
|
if params is not None:
|
|
raise TypeError(
|
|
"'execute()' with string template query doesn't support parameters"
|
|
)
|
|
return cur.execute(query, prepare=prepare)
|
|
else:
|
|
return cur.execute(query, params, prepare=prepare)
|
|
except e._NO_TRACEBACK as ex:
|
|
raise ex.with_traceback(None)
|
|
|
|
def commit(self) -> None:
|
|
"""Commit any pending transaction to the database."""
|
|
with self.lock:
|
|
self.wait(self._commit_gen())
|
|
|
|
def rollback(self) -> None:
|
|
"""Roll back to the start of any pending transaction."""
|
|
with self.lock:
|
|
self.wait(self._rollback_gen())
|
|
|
|
def cancel_safe(self, *, timeout: float = 30.0) -> None:
|
|
"""Cancel the current operation on the connection.
|
|
|
|
:param timeout: raise a `~errors.CancellationTimeout` if the
|
|
cancellation request does not succeed within `timeout` seconds.
|
|
|
|
Note that a successful cancel attempt on the client is not a guarantee
|
|
that the server will successfully manage to cancel the operation.
|
|
|
|
This is a non-blocking version of `~Connection.cancel()` which
|
|
leverages a more secure and improved cancellation feature of the libpq,
|
|
which is only available from version 17.
|
|
|
|
If the underlying libpq is older than version 17, the method will fall
|
|
back to using the same implementation of `!cancel()`.
|
|
"""
|
|
if not self._should_cancel():
|
|
return
|
|
|
|
if capabilities.has_cancel_safe():
|
|
waiting.wait_conn(
|
|
self._cancel_gen(timeout=timeout), interval=_WAIT_INTERVAL
|
|
)
|
|
else:
|
|
self.cancel()
|
|
|
|
def _try_cancel(self, *, timeout: float = 5.0) -> None:
|
|
try:
|
|
self.cancel_safe(timeout=timeout)
|
|
except Exception as ex:
|
|
logger.warning("query cancellation failed: %s", ex)
|
|
|
|
@contextmanager
|
|
def transaction(
|
|
self, savepoint_name: str | None = None, force_rollback: bool = False
|
|
) -> Iterator[Transaction]:
|
|
"""
|
|
Start a context block with a new transaction or nested transaction.
|
|
|
|
:param savepoint_name: Name of the savepoint used to manage a nested
|
|
transaction. If `!None`, one will be chosen automatically.
|
|
:param force_rollback: Roll back the transaction at the end of the
|
|
block even if there were no error (e.g. to try a no-op process).
|
|
:rtype: Transaction
|
|
"""
|
|
tx = Transaction(self, savepoint_name, force_rollback)
|
|
if self._pipeline:
|
|
with self.pipeline(), tx, self.pipeline():
|
|
yield tx
|
|
else:
|
|
with tx:
|
|
yield tx
|
|
|
|
def notifies(
|
|
self, *, timeout: float | None = None, stop_after: int | None = None
|
|
) -> Generator[Notify]:
|
|
"""
|
|
Yield `Notify` objects as soon as they are received from the database.
|
|
|
|
:param timeout: maximum amount of time to wait for notifications.
|
|
`!None` means no timeout.
|
|
:param stop_after: stop after receiving this number of notifications.
|
|
You might actually receive more than this number if more than one
|
|
notifications arrives in the same packet.
|
|
"""
|
|
# Allow interrupting the wait with a signal by reducing a long timeout
|
|
# into shorter intervals.
|
|
if timeout is not None:
|
|
deadline = monotonic() + timeout
|
|
interval = min(timeout, _WAIT_INTERVAL)
|
|
else:
|
|
deadline = None
|
|
interval = _WAIT_INTERVAL
|
|
|
|
nreceived = 0
|
|
|
|
if self._notify_handlers:
|
|
warnings.warn(
|
|
"using 'notifies()' together with notifies handlers on the same connection is not reliable. Please use only one of these methods",
|
|
RuntimeWarning,
|
|
)
|
|
|
|
with self.lock:
|
|
enc = self.pgconn._encoding
|
|
|
|
# Remove the backlog deque for the duration of this critical
|
|
# section to avoid reporting notifies twice.
|
|
self._notifies_backlog, d = (None, self._notifies_backlog)
|
|
|
|
try:
|
|
while True:
|
|
# if notifies were received when the generator was off,
|
|
# return them in a first batch.
|
|
if d:
|
|
while d:
|
|
yield d.popleft()
|
|
nreceived += 1
|
|
else:
|
|
try:
|
|
pgns = self.wait(notifies(self.pgconn), interval=interval)
|
|
except e._NO_TRACEBACK as ex:
|
|
raise ex.with_traceback(None)
|
|
# Emit the notifications received.
|
|
for pgn in pgns:
|
|
yield Notify(
|
|
pgn.relname.decode(enc),
|
|
pgn.extra.decode(enc),
|
|
pgn.be_pid,
|
|
)
|
|
nreceived += 1
|
|
|
|
# Stop if we have received enough notifications.
|
|
if stop_after is not None and nreceived >= stop_after:
|
|
break
|
|
|
|
# Check the deadline after the loop to ensure that timeout=0
|
|
# polls at least once.
|
|
if deadline:
|
|
interval = min(_WAIT_INTERVAL, deadline - monotonic())
|
|
if interval < 0.0:
|
|
break
|
|
finally:
|
|
self._notifies_backlog = d
|
|
|
|
@contextmanager
|
|
def pipeline(self) -> Iterator[Pipeline]:
|
|
"""Context manager to switch the connection into pipeline mode."""
|
|
with self.lock:
|
|
self._check_connection_ok()
|
|
|
|
if (pipeline := self._pipeline) is None:
|
|
# WARNING: reference loop, broken ahead.
|
|
pipeline = self._pipeline = Pipeline(self)
|
|
|
|
try:
|
|
with pipeline:
|
|
yield pipeline
|
|
finally:
|
|
if pipeline.level == 0:
|
|
with self.lock:
|
|
assert pipeline is self._pipeline
|
|
self._pipeline = None
|
|
|
|
@contextmanager
|
|
def _pipeline_nolock(self) -> Iterator[Pipeline]:
|
|
"""like pipeline() but don't acquire a lock.
|
|
|
|
Assume that the caller is holding the lock.
|
|
"""
|
|
|
|
# Currently only used internally by Cursor.executemany() in a branch
|
|
# in which we already established that the connection has no pipeline.
|
|
# If this changes we may relax the asserts.
|
|
assert not self._pipeline
|
|
# WARNING: reference loop, broken ahead.
|
|
pipeline = self._pipeline = Pipeline(self, _no_lock=True)
|
|
try:
|
|
with pipeline:
|
|
yield pipeline
|
|
finally:
|
|
assert pipeline.level == 0
|
|
assert pipeline is self._pipeline
|
|
self._pipeline = None
|
|
|
|
def wait(self, gen: PQGen[RV], interval: float = _WAIT_INTERVAL) -> RV:
|
|
"""
|
|
Consume a generator operating on the connection.
|
|
|
|
The function must be used on generators that don't change connection
|
|
fd (i.e. not on connect and reset).
|
|
"""
|
|
try:
|
|
return waiting.wait(gen, self.pgconn.socket, interval=interval)
|
|
except _INTERRUPTED:
|
|
if self.pgconn.transaction_status == ACTIVE:
|
|
# On Ctrl-C, try to cancel the query in the server, otherwise
|
|
# the connection will remain stuck in ACTIVE state.
|
|
self._try_cancel(timeout=5.0)
|
|
try:
|
|
waiting.wait(gen, self.pgconn.socket, interval=interval)
|
|
except e.QueryCanceled:
|
|
pass # as expected
|
|
raise
|
|
|
|
def _set_autocommit(self, value: bool) -> None:
|
|
self.set_autocommit(value)
|
|
|
|
def set_autocommit(self, value: bool) -> None:
|
|
"""Method version of the `~Connection.autocommit` setter."""
|
|
with self.lock:
|
|
self.wait(self._set_autocommit_gen(value))
|
|
|
|
def _set_isolation_level(self, value: IsolationLevel | None) -> None:
|
|
self.set_isolation_level(value)
|
|
|
|
def set_isolation_level(self, value: IsolationLevel | None) -> None:
|
|
"""Method version of the `~Connection.isolation_level` setter."""
|
|
with self.lock:
|
|
self.wait(self._set_isolation_level_gen(value))
|
|
|
|
def _set_read_only(self, value: bool | None) -> None:
|
|
self.set_read_only(value)
|
|
|
|
def set_read_only(self, value: bool | None) -> None:
|
|
"""Method version of the `~Connection.read_only` setter."""
|
|
with self.lock:
|
|
self.wait(self._set_read_only_gen(value))
|
|
|
|
def _set_deferrable(self, value: bool | None) -> None:
|
|
self.set_deferrable(value)
|
|
|
|
def set_deferrable(self, value: bool | None) -> None:
|
|
"""Method version of the `~Connection.deferrable` setter."""
|
|
with self.lock:
|
|
self.wait(self._set_deferrable_gen(value))
|
|
|
|
def tpc_begin(self, xid: Xid | str) -> None:
|
|
"""
|
|
Begin a TPC transaction with the given transaction ID `!xid`.
|
|
"""
|
|
with self.lock:
|
|
self.wait(self._tpc_begin_gen(xid))
|
|
|
|
def tpc_prepare(self) -> None:
|
|
"""
|
|
Perform the first phase of a transaction started with `tpc_begin()`.
|
|
"""
|
|
try:
|
|
with self.lock:
|
|
self.wait(self._tpc_prepare_gen())
|
|
except e.ObjectNotInPrerequisiteState as ex:
|
|
raise e.NotSupportedError(str(ex)) from None
|
|
|
|
def tpc_commit(self, xid: Xid | str | None = None) -> None:
|
|
"""
|
|
Commit a prepared two-phase transaction.
|
|
"""
|
|
with self.lock:
|
|
self.wait(self._tpc_finish_gen("COMMIT", xid))
|
|
|
|
def tpc_rollback(self, xid: Xid | str | None = None) -> None:
|
|
"""
|
|
Roll back a prepared two-phase transaction.
|
|
"""
|
|
with self.lock:
|
|
self.wait(self._tpc_finish_gen("ROLLBACK", xid))
|
|
|
|
def tpc_recover(self) -> list[Xid]:
|
|
self._check_tpc()
|
|
status = self.info.transaction_status
|
|
with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
|
|
cur.execute(Xid._get_recover_query())
|
|
res = cur.fetchall()
|
|
|
|
if status == IDLE and self.info.transaction_status == INTRANS:
|
|
self.rollback()
|
|
|
|
return res
|