Files
2025-12-30 11:27:14 +07:00

569 lines
19 KiB
Python

# WARNING: this file is auto-generated by 'async_to_sync.py'
# from the original file 'connection_async.py'
# DO NOT CHANGE! Change the original file instead.
"""
Psycopg connection object (sync version)
"""
# Copyright (C) 2020 The Psycopg Team
from __future__ import annotations
import logging
import warnings
from time import monotonic
from types import TracebackType
from typing import TYPE_CHECKING, Any, cast, overload
from contextlib import contextmanager
from collections.abc import Generator, Iterator
from . import errors as e
from . import pq, waiting
from .abc import RV, AdaptContext, ConnDict, ConnParam, Params, PQGen, Query
from .abc import QueryNoTemplate
from ._tpc import Xid
from .rows import Row, RowFactory, args_row, tuple_row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
from .cursor import Cursor
from ._compat import Self, Template
from ._acompat import Lock
from .conninfo import conninfo_attempts, conninfo_to_dict, make_conninfo
from .conninfo import timeout_from_conninfo
from ._pipeline import Pipeline
from .generators import notifies
from .transaction import Transaction
from ._capabilities import capabilities
from ._server_cursor import ServerCursor
from ._conninfo_utils import gssapi_requested
from ._connection_base import BaseConnection, CursorRow, Notify
if TYPE_CHECKING:
from .pq.abc import PGconn
_WAIT_INTERVAL = 0.1
TEXT = pq.Format.TEXT
BINARY = pq.Format.BINARY
IDLE = pq.TransactionStatus.IDLE
ACTIVE = pq.TransactionStatus.ACTIVE
INTRANS = pq.TransactionStatus.INTRANS
_INTERRUPTED = KeyboardInterrupt
logger = logging.getLogger("psycopg")
class Connection(BaseConnection[Row]):
"""
Wrapper for a connection to the database.
"""
__module__ = "psycopg"
cursor_factory: type[Cursor[Row]]
server_cursor_factory: type[ServerCursor[Row]]
row_factory: RowFactory[Row]
_pipeline: Pipeline | None
def __init__(
self,
pgconn: PGconn,
row_factory: RowFactory[Row] = cast(RowFactory[Row], tuple_row),
):
super().__init__(pgconn)
self.row_factory = row_factory
self.lock = Lock()
self.cursor_factory = Cursor
self.server_cursor_factory = ServerCursor
@classmethod
def connect(
cls,
conninfo: str = "",
*,
autocommit: bool = False,
prepare_threshold: int | None = 5,
context: AdaptContext | None = None,
row_factory: RowFactory[Row] | None = None,
cursor_factory: type[Cursor[Row]] | None = None,
**kwargs: ConnParam,
) -> Self:
"""
Connect to a database server and return a new `Connection` instance.
"""
params = cls._get_connection_params(conninfo, **kwargs)
timeout = timeout_from_conninfo(params)
rv = None
attempts = conninfo_attempts(params)
conn_errors: list[tuple[e.Error, str]] = []
for attempt in attempts:
tdescr = (attempt.get("host"), attempt.get("port"), attempt.get("hostaddr"))
descr = "host: %r, port: %r, hostaddr: %r" % tdescr
logger.debug("connection attempt: %s", descr)
try:
conninfo = make_conninfo("", **attempt)
gen = cls._connect_gen(conninfo, timeout=timeout)
rv = waiting.wait_conn(gen, interval=_WAIT_INTERVAL)
except e.Error as ex:
logger.debug("connection failed: %s: %s", descr, str(ex))
conn_errors.append((ex, descr))
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
else:
logger.debug("connection succeeded: %s", descr)
break
if not rv:
last_ex = conn_errors[-1][0]
if len(conn_errors) == 1:
raise last_ex.with_traceback(None)
# Create a new exception with the same type as the last one, containing
# all attempt errors while preserving backward compatibility.
lines = [str(last_ex)]
lines.append("Multiple connection attempts failed. All failures were:")
lines.extend((f"- {descr}: {error}" for error, descr in conn_errors))
raise type(last_ex)("\n".join(lines)).with_traceback(None)
if (
capabilities.has_used_gssapi()
and rv.pgconn.used_gssapi
and (not gssapi_requested(params))
):
warnings.warn(
"the connection was obtained using the GSSAPI relying on the 'gssencmode=prefer' libpq default. The value for this default might be 'disable' instead, in certain psycopg[binary] implementations. If you wish to interact with the GSSAPI reliably please set the 'gssencmode' parameter in the connection string or the 'PGGSSENCMODE' environment variable to 'prefer' or 'require'",
RuntimeWarning,
)
rv._autocommit = bool(autocommit)
if row_factory:
rv.row_factory = row_factory
if cursor_factory:
rv.cursor_factory = cursor_factory
if context:
rv._adapters = AdaptersMap(context.adapters)
rv.prepare_threshold = prepare_threshold
return rv
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.closed:
return
if exc_type:
# try to rollback, but if there are problems (connection in a bad
# state) just warn without clobbering the exception bubbling up.
try:
self.rollback()
except Exception as exc2:
logger.warning("error ignored in rollback on %s: %s", self, exc2)
else:
self.commit()
# Close the connection only if it doesn't belong to a pool.
if not getattr(self, "_pool", None):
self.close()
@classmethod
def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
"""Manipulate connection parameters before connecting."""
return conninfo_to_dict(conninfo, **kwargs)
def close(self) -> None:
"""Close the database connection."""
if self.closed:
return
pool = getattr(self, "_pool", None)
if pool and getattr(pool, "close_returns", False):
pool.putconn(self)
return
self._closed = True
# TODO: maybe send a cancel on close, if the connection is ACTIVE?
self.pgconn.finish()
@overload
def cursor(self, *, binary: bool = False) -> Cursor[Row]: ...
@overload
def cursor(
self, *, binary: bool = False, row_factory: RowFactory[CursorRow]
) -> Cursor[CursorRow]: ...
@overload
def cursor(
self,
name: str,
*,
binary: bool = False,
scrollable: bool | None = None,
withhold: bool = False,
) -> ServerCursor[Row]: ...
@overload
def cursor(
self,
name: str,
*,
binary: bool = False,
row_factory: RowFactory[CursorRow],
scrollable: bool | None = None,
withhold: bool = False,
) -> ServerCursor[CursorRow]: ...
def cursor(
self,
name: str = "",
*,
binary: bool = False,
row_factory: RowFactory[Any] | None = None,
scrollable: bool | None = None,
withhold: bool = False,
) -> Cursor[Any] | ServerCursor[Any]:
"""
Return a new `Cursor` to send commands and queries to the connection.
"""
self._check_connection_ok()
if not row_factory:
row_factory = self.row_factory
cur: Cursor[Any] | ServerCursor[Any]
if name:
cur = self.server_cursor_factory(
self,
name=name,
row_factory=row_factory,
scrollable=scrollable,
withhold=withhold,
)
else:
cur = self.cursor_factory(self, row_factory=row_factory)
if binary:
cur.format = BINARY
return cur
@overload
def execute(
self,
query: QueryNoTemplate,
params: Params | None = None,
*,
prepare: bool | None = None,
binary: bool = False,
) -> Cursor[Row]: ...
@overload
def execute(
self, query: Template, *, prepare: bool | None = None, binary: bool = False
) -> Cursor[Row]: ...
def execute(
self,
query: Query,
params: Params | None = None,
*,
prepare: bool | None = None,
binary: bool = False,
) -> Cursor[Row]:
"""Execute a query and return a cursor to read its results."""
try:
cur = self.cursor()
if binary:
cur.format = BINARY
if isinstance(query, Template):
if params is not None:
raise TypeError(
"'execute()' with string template query doesn't support parameters"
)
return cur.execute(query, prepare=prepare)
else:
return cur.execute(query, params, prepare=prepare)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
def commit(self) -> None:
"""Commit any pending transaction to the database."""
with self.lock:
self.wait(self._commit_gen())
def rollback(self) -> None:
"""Roll back to the start of any pending transaction."""
with self.lock:
self.wait(self._rollback_gen())
def cancel_safe(self, *, timeout: float = 30.0) -> None:
"""Cancel the current operation on the connection.
:param timeout: raise a `~errors.CancellationTimeout` if the
cancellation request does not succeed within `timeout` seconds.
Note that a successful cancel attempt on the client is not a guarantee
that the server will successfully manage to cancel the operation.
This is a non-blocking version of `~Connection.cancel()` which
leverages a more secure and improved cancellation feature of the libpq,
which is only available from version 17.
If the underlying libpq is older than version 17, the method will fall
back to using the same implementation of `!cancel()`.
"""
if not self._should_cancel():
return
if capabilities.has_cancel_safe():
waiting.wait_conn(
self._cancel_gen(timeout=timeout), interval=_WAIT_INTERVAL
)
else:
self.cancel()
def _try_cancel(self, *, timeout: float = 5.0) -> None:
try:
self.cancel_safe(timeout=timeout)
except Exception as ex:
logger.warning("query cancellation failed: %s", ex)
@contextmanager
def transaction(
self, savepoint_name: str | None = None, force_rollback: bool = False
) -> Iterator[Transaction]:
"""
Start a context block with a new transaction or nested transaction.
:param savepoint_name: Name of the savepoint used to manage a nested
transaction. If `!None`, one will be chosen automatically.
:param force_rollback: Roll back the transaction at the end of the
block even if there were no error (e.g. to try a no-op process).
:rtype: Transaction
"""
tx = Transaction(self, savepoint_name, force_rollback)
if self._pipeline:
with self.pipeline(), tx, self.pipeline():
yield tx
else:
with tx:
yield tx
def notifies(
self, *, timeout: float | None = None, stop_after: int | None = None
) -> Generator[Notify]:
"""
Yield `Notify` objects as soon as they are received from the database.
:param timeout: maximum amount of time to wait for notifications.
`!None` means no timeout.
:param stop_after: stop after receiving this number of notifications.
You might actually receive more than this number if more than one
notifications arrives in the same packet.
"""
# Allow interrupting the wait with a signal by reducing a long timeout
# into shorter intervals.
if timeout is not None:
deadline = monotonic() + timeout
interval = min(timeout, _WAIT_INTERVAL)
else:
deadline = None
interval = _WAIT_INTERVAL
nreceived = 0
if self._notify_handlers:
warnings.warn(
"using 'notifies()' together with notifies handlers on the same connection is not reliable. Please use only one of these methods",
RuntimeWarning,
)
with self.lock:
enc = self.pgconn._encoding
# Remove the backlog deque for the duration of this critical
# section to avoid reporting notifies twice.
self._notifies_backlog, d = (None, self._notifies_backlog)
try:
while True:
# if notifies were received when the generator was off,
# return them in a first batch.
if d:
while d:
yield d.popleft()
nreceived += 1
else:
try:
pgns = self.wait(notifies(self.pgconn), interval=interval)
except e._NO_TRACEBACK as ex:
raise ex.with_traceback(None)
# Emit the notifications received.
for pgn in pgns:
yield Notify(
pgn.relname.decode(enc),
pgn.extra.decode(enc),
pgn.be_pid,
)
nreceived += 1
# Stop if we have received enough notifications.
if stop_after is not None and nreceived >= stop_after:
break
# Check the deadline after the loop to ensure that timeout=0
# polls at least once.
if deadline:
interval = min(_WAIT_INTERVAL, deadline - monotonic())
if interval < 0.0:
break
finally:
self._notifies_backlog = d
@contextmanager
def pipeline(self) -> Iterator[Pipeline]:
"""Context manager to switch the connection into pipeline mode."""
with self.lock:
self._check_connection_ok()
if (pipeline := self._pipeline) is None:
# WARNING: reference loop, broken ahead.
pipeline = self._pipeline = Pipeline(self)
try:
with pipeline:
yield pipeline
finally:
if pipeline.level == 0:
with self.lock:
assert pipeline is self._pipeline
self._pipeline = None
@contextmanager
def _pipeline_nolock(self) -> Iterator[Pipeline]:
"""like pipeline() but don't acquire a lock.
Assume that the caller is holding the lock.
"""
# Currently only used internally by Cursor.executemany() in a branch
# in which we already established that the connection has no pipeline.
# If this changes we may relax the asserts.
assert not self._pipeline
# WARNING: reference loop, broken ahead.
pipeline = self._pipeline = Pipeline(self, _no_lock=True)
try:
with pipeline:
yield pipeline
finally:
assert pipeline.level == 0
assert pipeline is self._pipeline
self._pipeline = None
def wait(self, gen: PQGen[RV], interval: float = _WAIT_INTERVAL) -> RV:
"""
Consume a generator operating on the connection.
The function must be used on generators that don't change connection
fd (i.e. not on connect and reset).
"""
try:
return waiting.wait(gen, self.pgconn.socket, interval=interval)
except _INTERRUPTED:
if self.pgconn.transaction_status == ACTIVE:
# On Ctrl-C, try to cancel the query in the server, otherwise
# the connection will remain stuck in ACTIVE state.
self._try_cancel(timeout=5.0)
try:
waiting.wait(gen, self.pgconn.socket, interval=interval)
except e.QueryCanceled:
pass # as expected
raise
def _set_autocommit(self, value: bool) -> None:
self.set_autocommit(value)
def set_autocommit(self, value: bool) -> None:
"""Method version of the `~Connection.autocommit` setter."""
with self.lock:
self.wait(self._set_autocommit_gen(value))
def _set_isolation_level(self, value: IsolationLevel | None) -> None:
self.set_isolation_level(value)
def set_isolation_level(self, value: IsolationLevel | None) -> None:
"""Method version of the `~Connection.isolation_level` setter."""
with self.lock:
self.wait(self._set_isolation_level_gen(value))
def _set_read_only(self, value: bool | None) -> None:
self.set_read_only(value)
def set_read_only(self, value: bool | None) -> None:
"""Method version of the `~Connection.read_only` setter."""
with self.lock:
self.wait(self._set_read_only_gen(value))
def _set_deferrable(self, value: bool | None) -> None:
self.set_deferrable(value)
def set_deferrable(self, value: bool | None) -> None:
"""Method version of the `~Connection.deferrable` setter."""
with self.lock:
self.wait(self._set_deferrable_gen(value))
def tpc_begin(self, xid: Xid | str) -> None:
"""
Begin a TPC transaction with the given transaction ID `!xid`.
"""
with self.lock:
self.wait(self._tpc_begin_gen(xid))
def tpc_prepare(self) -> None:
"""
Perform the first phase of a transaction started with `tpc_begin()`.
"""
try:
with self.lock:
self.wait(self._tpc_prepare_gen())
except e.ObjectNotInPrerequisiteState as ex:
raise e.NotSupportedError(str(ex)) from None
def tpc_commit(self, xid: Xid | str | None = None) -> None:
"""
Commit a prepared two-phase transaction.
"""
with self.lock:
self.wait(self._tpc_finish_gen("COMMIT", xid))
def tpc_rollback(self, xid: Xid | str | None = None) -> None:
"""
Roll back a prepared two-phase transaction.
"""
with self.lock:
self.wait(self._tpc_finish_gen("ROLLBACK", xid))
def tpc_recover(self) -> list[Xid]:
self._check_tpc()
status = self.info.transaction_status
with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
cur.execute(Xid._get_recover_query())
res = cur.fetchall()
if status == IDLE and self.info.transaction_status == INTRANS:
self.rollback()
return res