624 lines
21 KiB
Python
624 lines
21 KiB
Python
"""
|
|
Psycopg connection object (async version)
|
|
"""
|
|
|
|
# Copyright (C) 2020 The Psycopg Team
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import warnings
|
|
from time import monotonic
|
|
from types import TracebackType
|
|
from typing import TYPE_CHECKING, Any, cast, overload
|
|
from contextlib import asynccontextmanager
|
|
from collections.abc import AsyncGenerator, AsyncIterator
|
|
|
|
from . import errors as e
|
|
from . import pq, waiting
|
|
from .abc import RV, AdaptContext, ConnDict, ConnParam, Params, PQGen, Query
|
|
from .abc import QueryNoTemplate
|
|
from ._tpc import Xid
|
|
from .rows import AsyncRowFactory, Row, args_row, tuple_row
|
|
from .adapt import AdaptersMap
|
|
from ._enums import IsolationLevel
|
|
from ._compat import Self, Template
|
|
from ._acompat import ALock
|
|
from .conninfo import conninfo_attempts_async, conninfo_to_dict, make_conninfo
|
|
from .conninfo import timeout_from_conninfo
|
|
from .generators import notifies
|
|
from .transaction import AsyncTransaction
|
|
from .cursor_async import AsyncCursor
|
|
from ._capabilities import capabilities
|
|
from ._conninfo_utils import gssapi_requested
|
|
from ._pipeline_async import AsyncPipeline
|
|
from ._connection_base import BaseConnection, CursorRow, Notify
|
|
from ._server_cursor_async import AsyncServerCursor
|
|
|
|
if True: # ASYNC
|
|
import sys
|
|
import asyncio
|
|
|
|
if TYPE_CHECKING:
|
|
from .pq.abc import PGconn
|
|
|
|
_WAIT_INTERVAL = 0.1
|
|
|
|
TEXT = pq.Format.TEXT
|
|
BINARY = pq.Format.BINARY
|
|
|
|
IDLE = pq.TransactionStatus.IDLE
|
|
ACTIVE = pq.TransactionStatus.ACTIVE
|
|
INTRANS = pq.TransactionStatus.INTRANS
|
|
|
|
if True: # ASYNC
|
|
_INTERRUPTED = (asyncio.CancelledError, KeyboardInterrupt)
|
|
else:
|
|
_INTERRUPTED = KeyboardInterrupt
|
|
|
|
logger = logging.getLogger("psycopg")
|
|
|
|
|
|
class AsyncConnection(BaseConnection[Row]):
|
|
"""
|
|
Wrapper for a connection to the database.
|
|
"""
|
|
|
|
__module__ = "psycopg"
|
|
|
|
cursor_factory: type[AsyncCursor[Row]]
|
|
server_cursor_factory: type[AsyncServerCursor[Row]]
|
|
row_factory: AsyncRowFactory[Row]
|
|
_pipeline: AsyncPipeline | None
|
|
|
|
def __init__(
|
|
self,
|
|
pgconn: PGconn,
|
|
row_factory: AsyncRowFactory[Row] = cast(AsyncRowFactory[Row], tuple_row),
|
|
):
|
|
super().__init__(pgconn)
|
|
self.row_factory = row_factory
|
|
self.lock = ALock()
|
|
self.cursor_factory = AsyncCursor
|
|
self.server_cursor_factory = AsyncServerCursor
|
|
|
|
@classmethod
|
|
async def connect(
|
|
cls,
|
|
conninfo: str = "",
|
|
*,
|
|
autocommit: bool = False,
|
|
prepare_threshold: int | None = 5,
|
|
context: AdaptContext | None = None,
|
|
row_factory: AsyncRowFactory[Row] | None = None,
|
|
cursor_factory: type[AsyncCursor[Row]] | None = None,
|
|
**kwargs: ConnParam,
|
|
) -> Self:
|
|
"""
|
|
Connect to a database server and return a new `AsyncConnection` instance.
|
|
"""
|
|
if True: # ASYNC
|
|
if sys.platform == "win32":
|
|
loop = asyncio.get_running_loop()
|
|
if isinstance(loop, asyncio.ProactorEventLoop):
|
|
|
|
from ._compat import _asyncio_run_snippet
|
|
|
|
raise e.InterfaceError(
|
|
"Psycopg cannot use the 'ProactorEventLoop' to run in async"
|
|
" mode. Please use a compatible event loop, for instance by"
|
|
+ f" {_asyncio_run_snippet}"
|
|
)
|
|
|
|
params = await cls._get_connection_params(conninfo, **kwargs)
|
|
timeout = timeout_from_conninfo(params)
|
|
rv = None
|
|
attempts = await conninfo_attempts_async(params)
|
|
conn_errors: list[tuple[e.Error, str]] = []
|
|
for attempt in attempts:
|
|
tdescr = (attempt.get("host"), attempt.get("port"), attempt.get("hostaddr"))
|
|
descr = "host: %r, port: %r, hostaddr: %r" % tdescr
|
|
logger.debug("connection attempt: %s", descr)
|
|
try:
|
|
conninfo = make_conninfo("", **attempt)
|
|
gen = cls._connect_gen(conninfo, timeout=timeout)
|
|
rv = await waiting.wait_conn_async(gen, interval=_WAIT_INTERVAL)
|
|
except e.Error as ex:
|
|
logger.debug("connection failed: %s: %s", descr, str(ex))
|
|
conn_errors.append((ex, descr))
|
|
except e._NO_TRACEBACK as ex:
|
|
raise ex.with_traceback(None)
|
|
else:
|
|
logger.debug("connection succeeded: %s", descr)
|
|
break
|
|
|
|
if not rv:
|
|
last_ex = conn_errors[-1][0]
|
|
if len(conn_errors) == 1:
|
|
raise last_ex.with_traceback(None)
|
|
|
|
# Create a new exception with the same type as the last one, containing
|
|
# all attempt errors while preserving backward compatibility.
|
|
lines = [str(last_ex)]
|
|
lines.append("Multiple connection attempts failed. All failures were:")
|
|
lines.extend(f"- {descr}: {error}" for error, descr in conn_errors)
|
|
raise type(last_ex)("\n".join(lines)).with_traceback(None)
|
|
|
|
if (
|
|
capabilities.has_used_gssapi()
|
|
and rv.pgconn.used_gssapi
|
|
and not gssapi_requested(params)
|
|
):
|
|
warnings.warn(
|
|
"the connection was obtained using the GSSAPI relying on the"
|
|
" 'gssencmode=prefer' libpq default. The value for this default might"
|
|
" be 'disable' instead, in certain psycopg[binary] implementations."
|
|
" If you wish to interact with the GSSAPI reliably please set the"
|
|
" 'gssencmode' parameter in the connection string or the"
|
|
" 'PGGSSENCMODE' environment variable to 'prefer' or 'require'",
|
|
RuntimeWarning,
|
|
)
|
|
|
|
rv._autocommit = bool(autocommit)
|
|
if row_factory:
|
|
rv.row_factory = row_factory
|
|
if cursor_factory:
|
|
rv.cursor_factory = cursor_factory
|
|
if context:
|
|
rv._adapters = AdaptersMap(context.adapters)
|
|
rv.prepare_threshold = prepare_threshold
|
|
return rv
|
|
|
|
async def __aenter__(self) -> Self:
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> None:
|
|
if self.closed:
|
|
return
|
|
|
|
if exc_type:
|
|
# try to rollback, but if there are problems (connection in a bad
|
|
# state) just warn without clobbering the exception bubbling up.
|
|
try:
|
|
await self.rollback()
|
|
except Exception as exc2:
|
|
logger.warning("error ignored in rollback on %s: %s", self, exc2)
|
|
else:
|
|
await self.commit()
|
|
|
|
# Close the connection only if it doesn't belong to a pool.
|
|
if not getattr(self, "_pool", None):
|
|
await self.close()
|
|
|
|
@classmethod
|
|
async def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> ConnDict:
|
|
"""Manipulate connection parameters before connecting."""
|
|
return conninfo_to_dict(conninfo, **kwargs)
|
|
|
|
async def close(self) -> None:
|
|
"""Close the database connection."""
|
|
if self.closed:
|
|
return
|
|
|
|
pool = getattr(self, "_pool", None)
|
|
if pool and getattr(pool, "close_returns", False):
|
|
await pool.putconn(self)
|
|
return
|
|
|
|
self._closed = True
|
|
|
|
# TODO: maybe send a cancel on close, if the connection is ACTIVE?
|
|
|
|
self.pgconn.finish()
|
|
|
|
@overload
|
|
def cursor(self, *, binary: bool = False) -> AsyncCursor[Row]: ...
|
|
|
|
@overload
|
|
def cursor(
|
|
self, *, binary: bool = False, row_factory: AsyncRowFactory[CursorRow]
|
|
) -> AsyncCursor[CursorRow]: ...
|
|
|
|
@overload
|
|
def cursor(
|
|
self,
|
|
name: str,
|
|
*,
|
|
binary: bool = False,
|
|
scrollable: bool | None = None,
|
|
withhold: bool = False,
|
|
) -> AsyncServerCursor[Row]: ...
|
|
|
|
@overload
|
|
def cursor(
|
|
self,
|
|
name: str,
|
|
*,
|
|
binary: bool = False,
|
|
row_factory: AsyncRowFactory[CursorRow],
|
|
scrollable: bool | None = None,
|
|
withhold: bool = False,
|
|
) -> AsyncServerCursor[CursorRow]: ...
|
|
|
|
def cursor(
|
|
self,
|
|
name: str = "",
|
|
*,
|
|
binary: bool = False,
|
|
row_factory: AsyncRowFactory[Any] | None = None,
|
|
scrollable: bool | None = None,
|
|
withhold: bool = False,
|
|
) -> AsyncCursor[Any] | AsyncServerCursor[Any]:
|
|
"""
|
|
Return a new `AsyncCursor` to send commands and queries to the connection.
|
|
"""
|
|
self._check_connection_ok()
|
|
|
|
if not row_factory:
|
|
row_factory = self.row_factory
|
|
|
|
cur: AsyncCursor[Any] | AsyncServerCursor[Any]
|
|
if name:
|
|
cur = self.server_cursor_factory(
|
|
self,
|
|
name=name,
|
|
row_factory=row_factory,
|
|
scrollable=scrollable,
|
|
withhold=withhold,
|
|
)
|
|
else:
|
|
cur = self.cursor_factory(self, row_factory=row_factory)
|
|
|
|
if binary:
|
|
cur.format = BINARY
|
|
|
|
return cur
|
|
|
|
@overload
|
|
async def execute(
|
|
self,
|
|
query: QueryNoTemplate,
|
|
params: Params | None = None,
|
|
*,
|
|
prepare: bool | None = None,
|
|
binary: bool = False,
|
|
) -> AsyncCursor[Row]: ...
|
|
|
|
@overload
|
|
async def execute(
|
|
self,
|
|
query: Template,
|
|
*,
|
|
prepare: bool | None = None,
|
|
binary: bool = False,
|
|
) -> AsyncCursor[Row]: ...
|
|
|
|
async def execute(
|
|
self,
|
|
query: Query,
|
|
params: Params | None = None,
|
|
*,
|
|
prepare: bool | None = None,
|
|
binary: bool = False,
|
|
) -> AsyncCursor[Row]:
|
|
"""Execute a query and return a cursor to read its results."""
|
|
try:
|
|
cur = self.cursor()
|
|
if binary:
|
|
cur.format = BINARY
|
|
|
|
if isinstance(query, Template):
|
|
if params is not None:
|
|
raise TypeError(
|
|
"'execute()' with string template query"
|
|
" doesn't support parameters"
|
|
)
|
|
return await cur.execute(query, prepare=prepare)
|
|
else:
|
|
return await cur.execute(query, params, prepare=prepare)
|
|
|
|
except e._NO_TRACEBACK as ex:
|
|
raise ex.with_traceback(None)
|
|
|
|
async def commit(self) -> None:
|
|
"""Commit any pending transaction to the database."""
|
|
async with self.lock:
|
|
await self.wait(self._commit_gen())
|
|
|
|
async def rollback(self) -> None:
|
|
"""Roll back to the start of any pending transaction."""
|
|
async with self.lock:
|
|
await self.wait(self._rollback_gen())
|
|
|
|
async def cancel_safe(self, *, timeout: float = 30.0) -> None:
|
|
"""Cancel the current operation on the connection.
|
|
|
|
:param timeout: raise a `~errors.CancellationTimeout` if the
|
|
cancellation request does not succeed within `timeout` seconds.
|
|
|
|
Note that a successful cancel attempt on the client is not a guarantee
|
|
that the server will successfully manage to cancel the operation.
|
|
|
|
This is a non-blocking version of `~Connection.cancel()` which
|
|
leverages a more secure and improved cancellation feature of the libpq,
|
|
which is only available from version 17.
|
|
|
|
If the underlying libpq is older than version 17, the method will fall
|
|
back to using the same implementation of `!cancel()`.
|
|
"""
|
|
if not self._should_cancel():
|
|
return
|
|
|
|
if capabilities.has_cancel_safe():
|
|
await waiting.wait_conn_async(
|
|
self._cancel_gen(timeout=timeout), interval=_WAIT_INTERVAL
|
|
)
|
|
else:
|
|
if True: # ASYNC
|
|
await asyncio.to_thread(self.cancel)
|
|
else:
|
|
self.cancel()
|
|
|
|
async def _try_cancel(self, *, timeout: float = 5.0) -> None:
|
|
try:
|
|
await self.cancel_safe(timeout=timeout)
|
|
except Exception as ex:
|
|
logger.warning("query cancellation failed: %s", ex)
|
|
|
|
@asynccontextmanager
|
|
async def transaction(
|
|
self, savepoint_name: str | None = None, force_rollback: bool = False
|
|
) -> AsyncIterator[AsyncTransaction]:
|
|
"""
|
|
Start a context block with a new transaction or nested transaction.
|
|
|
|
:param savepoint_name: Name of the savepoint used to manage a nested
|
|
transaction. If `!None`, one will be chosen automatically.
|
|
:param force_rollback: Roll back the transaction at the end of the
|
|
block even if there were no error (e.g. to try a no-op process).
|
|
:rtype: AsyncTransaction
|
|
"""
|
|
tx = AsyncTransaction(self, savepoint_name, force_rollback)
|
|
if self._pipeline:
|
|
async with self.pipeline(), tx, self.pipeline():
|
|
yield tx
|
|
else:
|
|
async with tx:
|
|
yield tx
|
|
|
|
async def notifies(
|
|
self, *, timeout: float | None = None, stop_after: int | None = None
|
|
) -> AsyncGenerator[Notify]:
|
|
"""
|
|
Yield `Notify` objects as soon as they are received from the database.
|
|
|
|
:param timeout: maximum amount of time to wait for notifications.
|
|
`!None` means no timeout.
|
|
:param stop_after: stop after receiving this number of notifications.
|
|
You might actually receive more than this number if more than one
|
|
notifications arrives in the same packet.
|
|
"""
|
|
# Allow interrupting the wait with a signal by reducing a long timeout
|
|
# into shorter intervals.
|
|
if timeout is not None:
|
|
deadline = monotonic() + timeout
|
|
interval = min(timeout, _WAIT_INTERVAL)
|
|
else:
|
|
deadline = None
|
|
interval = _WAIT_INTERVAL
|
|
|
|
nreceived = 0
|
|
|
|
if self._notify_handlers:
|
|
warnings.warn(
|
|
"using 'notifies()' together with notifies handlers on the"
|
|
" same connection is not reliable."
|
|
" Please use only one of these methods",
|
|
RuntimeWarning,
|
|
)
|
|
|
|
async with self.lock:
|
|
enc = self.pgconn._encoding
|
|
|
|
# Remove the backlog deque for the duration of this critical
|
|
# section to avoid reporting notifies twice.
|
|
self._notifies_backlog, d = None, self._notifies_backlog
|
|
|
|
try:
|
|
while True:
|
|
# if notifies were received when the generator was off,
|
|
# return them in a first batch.
|
|
if d:
|
|
while d:
|
|
yield d.popleft()
|
|
nreceived += 1
|
|
else:
|
|
try:
|
|
pgns = await self.wait(
|
|
notifies(self.pgconn), interval=interval
|
|
)
|
|
except e._NO_TRACEBACK as ex:
|
|
raise ex.with_traceback(None)
|
|
|
|
# Emit the notifications received.
|
|
for pgn in pgns:
|
|
yield Notify(
|
|
pgn.relname.decode(enc),
|
|
pgn.extra.decode(enc),
|
|
pgn.be_pid,
|
|
)
|
|
nreceived += 1
|
|
|
|
# Stop if we have received enough notifications.
|
|
if stop_after is not None and nreceived >= stop_after:
|
|
break
|
|
|
|
# Check the deadline after the loop to ensure that timeout=0
|
|
# polls at least once.
|
|
if deadline:
|
|
interval = min(_WAIT_INTERVAL, deadline - monotonic())
|
|
if interval < 0.0:
|
|
break
|
|
finally:
|
|
self._notifies_backlog = d
|
|
|
|
@asynccontextmanager
|
|
async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
|
|
"""Context manager to switch the connection into pipeline mode."""
|
|
async with self.lock:
|
|
self._check_connection_ok()
|
|
|
|
if (pipeline := self._pipeline) is None:
|
|
# WARNING: reference loop, broken ahead.
|
|
pipeline = self._pipeline = AsyncPipeline(self)
|
|
|
|
try:
|
|
async with pipeline:
|
|
yield pipeline
|
|
finally:
|
|
if pipeline.level == 0:
|
|
async with self.lock:
|
|
assert pipeline is self._pipeline
|
|
self._pipeline = None
|
|
|
|
@asynccontextmanager
|
|
async def _pipeline_nolock(self) -> AsyncIterator[AsyncPipeline]:
|
|
"""like pipeline() but don't acquire a lock.
|
|
|
|
Assume that the caller is holding the lock.
|
|
"""
|
|
|
|
# Currently only used internally by Cursor.executemany() in a branch
|
|
# in which we already established that the connection has no pipeline.
|
|
# If this changes we may relax the asserts.
|
|
assert not self._pipeline
|
|
# WARNING: reference loop, broken ahead.
|
|
pipeline = self._pipeline = AsyncPipeline(self, _no_lock=True)
|
|
try:
|
|
async with pipeline:
|
|
yield pipeline
|
|
finally:
|
|
assert pipeline.level == 0
|
|
assert pipeline is self._pipeline
|
|
self._pipeline = None
|
|
|
|
async def wait(self, gen: PQGen[RV], interval: float = _WAIT_INTERVAL) -> RV:
|
|
"""
|
|
Consume a generator operating on the connection.
|
|
|
|
The function must be used on generators that don't change connection
|
|
fd (i.e. not on connect and reset).
|
|
"""
|
|
try:
|
|
return await waiting.wait_async(gen, self.pgconn.socket, interval=interval)
|
|
except _INTERRUPTED:
|
|
if self.pgconn.transaction_status == ACTIVE:
|
|
# On Ctrl-C, try to cancel the query in the server, otherwise
|
|
# the connection will remain stuck in ACTIVE state.
|
|
await self._try_cancel(timeout=5.0)
|
|
try:
|
|
await waiting.wait_async(gen, self.pgconn.socket, interval=interval)
|
|
except e.QueryCanceled:
|
|
pass # as expected
|
|
raise
|
|
|
|
def _set_autocommit(self, value: bool) -> None:
|
|
if True: # ASYNC
|
|
self._no_set_async("autocommit")
|
|
else:
|
|
self.set_autocommit(value)
|
|
|
|
async def set_autocommit(self, value: bool) -> None:
|
|
"""Method version of the `~Connection.autocommit` setter."""
|
|
async with self.lock:
|
|
await self.wait(self._set_autocommit_gen(value))
|
|
|
|
def _set_isolation_level(self, value: IsolationLevel | None) -> None:
|
|
if True: # ASYNC
|
|
self._no_set_async("isolation_level")
|
|
else:
|
|
self.set_isolation_level(value)
|
|
|
|
async def set_isolation_level(self, value: IsolationLevel | None) -> None:
|
|
"""Method version of the `~Connection.isolation_level` setter."""
|
|
async with self.lock:
|
|
await self.wait(self._set_isolation_level_gen(value))
|
|
|
|
def _set_read_only(self, value: bool | None) -> None:
|
|
if True: # ASYNC
|
|
self._no_set_async("read_only")
|
|
else:
|
|
self.set_read_only(value)
|
|
|
|
async def set_read_only(self, value: bool | None) -> None:
|
|
"""Method version of the `~Connection.read_only` setter."""
|
|
async with self.lock:
|
|
await self.wait(self._set_read_only_gen(value))
|
|
|
|
def _set_deferrable(self, value: bool | None) -> None:
|
|
if True: # ASYNC
|
|
self._no_set_async("deferrable")
|
|
else:
|
|
self.set_deferrable(value)
|
|
|
|
async def set_deferrable(self, value: bool | None) -> None:
|
|
"""Method version of the `~Connection.deferrable` setter."""
|
|
async with self.lock:
|
|
await self.wait(self._set_deferrable_gen(value))
|
|
|
|
if True: # ASYNC
|
|
|
|
def _no_set_async(self, attribute: str) -> None:
|
|
raise AttributeError(
|
|
f"'the {attribute!r} property is read-only on async connections:"
|
|
f" please use 'await .set_{attribute}()' instead."
|
|
)
|
|
|
|
async def tpc_begin(self, xid: Xid | str) -> None:
|
|
"""
|
|
Begin a TPC transaction with the given transaction ID `!xid`.
|
|
"""
|
|
async with self.lock:
|
|
await self.wait(self._tpc_begin_gen(xid))
|
|
|
|
async def tpc_prepare(self) -> None:
|
|
"""
|
|
Perform the first phase of a transaction started with `tpc_begin()`.
|
|
"""
|
|
try:
|
|
async with self.lock:
|
|
await self.wait(self._tpc_prepare_gen())
|
|
except e.ObjectNotInPrerequisiteState as ex:
|
|
raise e.NotSupportedError(str(ex)) from None
|
|
|
|
async def tpc_commit(self, xid: Xid | str | None = None) -> None:
|
|
"""
|
|
Commit a prepared two-phase transaction.
|
|
"""
|
|
async with self.lock:
|
|
await self.wait(self._tpc_finish_gen("COMMIT", xid))
|
|
|
|
async def tpc_rollback(self, xid: Xid | str | None = None) -> None:
|
|
"""
|
|
Roll back a prepared two-phase transaction.
|
|
"""
|
|
async with self.lock:
|
|
await self.wait(self._tpc_finish_gen("ROLLBACK", xid))
|
|
|
|
async def tpc_recover(self) -> list[Xid]:
|
|
self._check_tpc()
|
|
status = self.info.transaction_status
|
|
async with self.cursor(row_factory=args_row(Xid._from_record)) as cur:
|
|
await cur.execute(Xid._get_recover_query())
|
|
res = await cur.fetchall()
|
|
|
|
if status == IDLE and self.info.transaction_status == INTRANS:
|
|
await self.rollback()
|
|
|
|
return res
|