Initial commit (Clean history)
This commit is contained in:
100
path/to/venv/lib/python3.12/site-packages/wsproto/__init__.py
Normal file
100
path/to/venv/lib/python3.12/site-packages/wsproto/__init__.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
wsproto
|
||||
~~~~~~~
|
||||
|
||||
A WebSocket implementation.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .connection import Connection, ConnectionState, ConnectionType
|
||||
from .handshake import H11Handshake
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from .events import Event
|
||||
from .typing import Headers
|
||||
|
||||
__version__ = "1.3.2"
|
||||
|
||||
|
||||
class WSConnection:
|
||||
"""
|
||||
Represents the local end of a WebSocket connection to a remote peer.
|
||||
"""
|
||||
|
||||
def __init__(self, connection_type: ConnectionType) -> None:
|
||||
"""
|
||||
Constructor
|
||||
|
||||
:param wsproto.connection.ConnectionType connection_type: Controls
|
||||
whether the library behaves as a client or as a server.
|
||||
"""
|
||||
self.client = connection_type is ConnectionType.CLIENT
|
||||
self.handshake = H11Handshake(connection_type)
|
||||
self.connection: Connection | None = None
|
||||
|
||||
@property
|
||||
def state(self) -> ConnectionState:
|
||||
"""
|
||||
:returns: Connection state
|
||||
:rtype: wsproto.connection.ConnectionState
|
||||
"""
|
||||
if self.connection is None:
|
||||
return self.handshake.state
|
||||
return self.connection.state
|
||||
|
||||
def initiate_upgrade_connection(
|
||||
self, headers: Headers, path: bytes | str,
|
||||
) -> None:
|
||||
self.handshake.initiate_upgrade_connection(headers, path)
|
||||
|
||||
def send(self, event: Event) -> bytes:
|
||||
"""
|
||||
Generate network data for the specified event.
|
||||
|
||||
When you want to communicate with a WebSocket peer, you should construct
|
||||
an event and pass it to this method. This method will return the bytes
|
||||
that you should send to the peer.
|
||||
|
||||
:param wsproto.events.Event event: The event to generate data for
|
||||
:returns bytes: The data to send to the peer
|
||||
"""
|
||||
data = b""
|
||||
if self.connection is None:
|
||||
data += self.handshake.send(event)
|
||||
self.connection = self.handshake.connection
|
||||
else:
|
||||
data += self.connection.send(event)
|
||||
return data
|
||||
|
||||
def receive_data(self, data: bytes | None) -> None:
|
||||
"""
|
||||
Feed network data into the connection instance.
|
||||
|
||||
After calling this method, you should call :meth:`events` to see if the
|
||||
received data triggered any new events.
|
||||
|
||||
:param bytes data: Data received from remote peer
|
||||
"""
|
||||
if self.connection is None:
|
||||
self.handshake.receive_data(data)
|
||||
self.connection = self.handshake.connection
|
||||
else:
|
||||
self.connection.receive_data(data)
|
||||
|
||||
def events(self) -> Generator[Event, None, None]:
|
||||
"""
|
||||
A generator that yields pending events.
|
||||
|
||||
Each event is an instance of a subclass of
|
||||
:class:`wsproto.events.Event`.
|
||||
"""
|
||||
yield from self.handshake.events()
|
||||
if self.connection is not None:
|
||||
yield from self.connection.events()
|
||||
|
||||
|
||||
__all__ = ("ConnectionType", "WSConnection")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
202
path/to/venv/lib/python3.12/site-packages/wsproto/connection.py
Normal file
202
path/to/venv/lib/python3.12/site-packages/wsproto/connection.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
wsproto/connection
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
An implementation of a WebSocket connection.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .events import (
|
||||
BytesMessage,
|
||||
CloseConnection,
|
||||
Event,
|
||||
Message,
|
||||
Ping,
|
||||
Pong,
|
||||
TextMessage,
|
||||
)
|
||||
from .frame_protocol import CloseReason, FrameProtocol, Opcode, ParseFailed
|
||||
from .utilities import LocalProtocolError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from .extensions import Extension
|
||||
|
||||
|
||||
class ConnectionState(Enum):
|
||||
"""
|
||||
RFC 6455, Section 4 - Opening Handshake
|
||||
"""
|
||||
|
||||
#: The opening handshake is in progress.
|
||||
CONNECTING = 0
|
||||
#: The opening handshake is complete.
|
||||
OPEN = 1
|
||||
#: The remote WebSocket has initiated a connection close.
|
||||
REMOTE_CLOSING = 2
|
||||
#: The local WebSocket (i.e. this instance) has initiated a connection close.
|
||||
LOCAL_CLOSING = 3
|
||||
#: The closing handshake has completed.
|
||||
CLOSED = 4
|
||||
#: The connection was rejected during the opening handshake.
|
||||
REJECTING = 5
|
||||
|
||||
|
||||
class ConnectionType(Enum):
|
||||
"""An enumeration of connection types."""
|
||||
|
||||
#: This connection will act as client and talk to a remote server
|
||||
CLIENT = 1
|
||||
|
||||
#: This connection will as as server and waits for client connections
|
||||
SERVER = 2
|
||||
|
||||
|
||||
CLIENT = ConnectionType.CLIENT
|
||||
SERVER = ConnectionType.SERVER
|
||||
|
||||
|
||||
class Connection:
|
||||
"""
|
||||
A low-level WebSocket connection object.
|
||||
|
||||
This wraps two other protocol objects, an HTTP/1.1 protocol object used
|
||||
to do the initial HTTP upgrade handshake and a WebSocket frame protocol
|
||||
object used to exchange messages and other control frames.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_type: ConnectionType,
|
||||
extensions: list[Extension] | None = None,
|
||||
trailing_data: bytes = b"",
|
||||
) -> None:
|
||||
"""
|
||||
Constructor
|
||||
|
||||
:param wsproto.connection.ConnectionType connection_type: Whether this
|
||||
object is on the client- or server-side of a connection.
|
||||
To initialise as a client pass ``CLIENT`` otherwise pass ``SERVER``.
|
||||
:param list extensions: The proposed extensions.
|
||||
:param bytes trailing_data: Data that has been received, but not yet
|
||||
processed.
|
||||
"""
|
||||
self.client = connection_type is ConnectionType.CLIENT
|
||||
self._events: deque[Event] = deque()
|
||||
self._proto = FrameProtocol(self.client, extensions or [])
|
||||
self._state = ConnectionState.OPEN
|
||||
self.receive_data(trailing_data)
|
||||
|
||||
@property
|
||||
def state(self) -> ConnectionState:
|
||||
return self._state
|
||||
|
||||
def send(self, event: Event) -> bytes:
|
||||
data = b""
|
||||
if isinstance(event, Message) and self.state == ConnectionState.OPEN:
|
||||
data += self._proto.send_data(event.data, event.message_finished)
|
||||
elif isinstance(event, Ping) and self.state == ConnectionState.OPEN:
|
||||
data += self._proto.ping(event.payload)
|
||||
elif isinstance(event, Pong) and self.state == ConnectionState.OPEN:
|
||||
data += self._proto.pong(event.payload)
|
||||
elif isinstance(event, CloseConnection) and self.state in {
|
||||
ConnectionState.OPEN,
|
||||
ConnectionState.REMOTE_CLOSING,
|
||||
}:
|
||||
data += self._proto.close(event.code, event.reason)
|
||||
if self.state == ConnectionState.REMOTE_CLOSING:
|
||||
self._state = ConnectionState.CLOSED
|
||||
else:
|
||||
self._state = ConnectionState.LOCAL_CLOSING
|
||||
else:
|
||||
msg = f"Event {event} cannot be sent in state {self.state}."
|
||||
raise LocalProtocolError(
|
||||
msg,
|
||||
)
|
||||
return data
|
||||
|
||||
def receive_data(self, data: bytes | None) -> None:
|
||||
"""
|
||||
Pass some received data to the connection for handling.
|
||||
|
||||
A list of events that the remote peer triggered by sending this data can
|
||||
be retrieved with :meth:`~wsproto.connection.Connection.events`.
|
||||
|
||||
:param data: The data received from the remote peer on the network.
|
||||
:type data: ``bytes``
|
||||
"""
|
||||
if data is None:
|
||||
# "If _The WebSocket Connection is Closed_ and no Close control
|
||||
# frame was received by the endpoint (such as could occur if the
|
||||
# underlying transport connection is lost), _The WebSocket
|
||||
# Connection Close Code_ is considered to be 1006."
|
||||
self._events.append(CloseConnection(code=CloseReason.ABNORMAL_CLOSURE))
|
||||
self._state = ConnectionState.CLOSED
|
||||
return
|
||||
|
||||
if self.state in (ConnectionState.OPEN, ConnectionState.LOCAL_CLOSING):
|
||||
self._proto.receive_bytes(data)
|
||||
elif self.state is ConnectionState.CLOSED:
|
||||
msg = "Connection already closed."
|
||||
raise LocalProtocolError(msg)
|
||||
else:
|
||||
pass # pragma: no cover
|
||||
|
||||
def events(self) -> Generator[Event, None, None]:
|
||||
"""
|
||||
Return a generator that provides any events that have been generated
|
||||
by protocol activity.
|
||||
|
||||
:returns: generator of :class:`Event <wsproto.events.Event>` subclasses
|
||||
"""
|
||||
while self._events:
|
||||
yield self._events.popleft()
|
||||
|
||||
try:
|
||||
for frame in self._proto.received_frames():
|
||||
if frame.opcode is Opcode.PING:
|
||||
assert frame.frame_finished
|
||||
assert frame.message_finished
|
||||
assert isinstance(frame.payload, (bytes, bytearray))
|
||||
yield Ping(payload=frame.payload)
|
||||
|
||||
elif frame.opcode is Opcode.PONG:
|
||||
assert frame.frame_finished
|
||||
assert frame.message_finished
|
||||
assert isinstance(frame.payload, (bytes, bytearray))
|
||||
yield Pong(payload=frame.payload)
|
||||
|
||||
elif frame.opcode is Opcode.CLOSE:
|
||||
assert isinstance(frame.payload, tuple)
|
||||
code, reason = frame.payload
|
||||
if self.state is ConnectionState.LOCAL_CLOSING:
|
||||
self._state = ConnectionState.CLOSED
|
||||
else:
|
||||
self._state = ConnectionState.REMOTE_CLOSING
|
||||
yield CloseConnection(code=code, reason=reason)
|
||||
|
||||
elif frame.opcode is Opcode.TEXT:
|
||||
assert isinstance(frame.payload, str)
|
||||
yield TextMessage(
|
||||
data=frame.payload,
|
||||
frame_finished=frame.frame_finished,
|
||||
message_finished=frame.message_finished,
|
||||
)
|
||||
|
||||
elif frame.opcode is Opcode.BINARY:
|
||||
assert isinstance(frame.payload, (bytes, bytearray))
|
||||
yield BytesMessage(
|
||||
data=frame.payload,
|
||||
frame_finished=frame.frame_finished,
|
||||
message_finished=frame.message_finished,
|
||||
)
|
||||
|
||||
else:
|
||||
pass # pragma: no cover
|
||||
except ParseFailed as exc:
|
||||
yield CloseConnection(code=exc.code, reason=str(exc))
|
||||
302
path/to/venv/lib/python3.12/site-packages/wsproto/events.py
Normal file
302
path/to/venv/lib/python3.12/site-packages/wsproto/events.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""
|
||||
wsproto/events
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
Events that result from processing data on a WebSocket connection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from .extensions import Extension
|
||||
from .typing import Headers
|
||||
|
||||
|
||||
class Event(ABC):
|
||||
"""
|
||||
Base class for wsproto events.
|
||||
"""
|
||||
|
||||
pass # noqa
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Request(Event):
|
||||
"""
|
||||
The beginning of a Websocket connection, the HTTP Upgrade request
|
||||
|
||||
This event is fired when a SERVER connection receives a WebSocket
|
||||
handshake request (HTTP with upgrade header).
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: host
|
||||
|
||||
(Required) The hostname, or host header value.
|
||||
|
||||
.. attribute:: target
|
||||
|
||||
(Required) The request target (path and query string)
|
||||
|
||||
.. attribute:: extensions
|
||||
|
||||
The proposed extensions.
|
||||
|
||||
.. attribute:: extra_headers
|
||||
|
||||
The additional request headers, excluding extensions, host, subprotocols,
|
||||
and version headers.
|
||||
|
||||
.. attribute:: subprotocols
|
||||
|
||||
A list of the subprotocols proposed in the request, as a list
|
||||
of strings.
|
||||
"""
|
||||
|
||||
host: str
|
||||
target: str
|
||||
extensions: Sequence[Extension] | Sequence[str] = field(default_factory=list)
|
||||
extra_headers: Headers = field(default_factory=list)
|
||||
subprotocols: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AcceptConnection(Event):
|
||||
"""
|
||||
The acceptance of a Websocket upgrade request.
|
||||
|
||||
This event is fired when a CLIENT receives an acceptance response
|
||||
from a server. It is also used to accept an upgrade request when
|
||||
acting as a SERVER.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: extra_headers
|
||||
|
||||
Any additional (non websocket related) headers present in the
|
||||
acceptance response.
|
||||
|
||||
.. attribute:: subprotocol
|
||||
|
||||
The accepted subprotocol to use.
|
||||
|
||||
"""
|
||||
|
||||
subprotocol: str | None = None
|
||||
extensions: list[Extension] = field(default_factory=list)
|
||||
extra_headers: Headers = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RejectConnection(Event):
|
||||
"""
|
||||
The rejection of a Websocket upgrade request, the HTTP response.
|
||||
|
||||
The ``RejectConnection`` event sends the appropriate HTTP headers to
|
||||
communicate to the peer that the handshake has been rejected. You may also
|
||||
send an HTTP body by setting the ``has_body`` attribute to ``True`` and then
|
||||
sending one or more :class:`RejectData` events after this one. When sending
|
||||
a response body, the caller should set the ``Content-Length``,
|
||||
``Content-Type``, and/or ``Transfer-Encoding`` headers as appropriate.
|
||||
|
||||
When receiving a ``RejectConnection`` event, the ``has_body`` attribute will
|
||||
in almost all cases be ``True`` (even if the server set it to ``False``) and
|
||||
will be followed by at least one ``RejectData`` events, even though the data
|
||||
itself might be just ``b""``. (The only scenario in which the caller
|
||||
receives a ``RejectConnection`` with ``has_body == False`` is if the peer
|
||||
violates sends an informational status code (1xx) other than 101.)
|
||||
|
||||
The ``has_body`` attribute should only be used when receiving the event. (It
|
||||
has ) is False the headers must include a
|
||||
content-length or transfer encoding.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: headers (Headers)
|
||||
|
||||
The headers to send with the response.
|
||||
|
||||
.. attribute:: has_body
|
||||
|
||||
This defaults to False, but set to True if there is a body. See
|
||||
also :class:`~RejectData`.
|
||||
|
||||
.. attribute:: status_code
|
||||
|
||||
The response status code.
|
||||
|
||||
"""
|
||||
|
||||
status_code: int = 400
|
||||
headers: Headers = field(default_factory=list)
|
||||
has_body: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RejectData(Event):
|
||||
"""
|
||||
The rejection HTTP response body.
|
||||
|
||||
The caller may send multiple ``RejectData`` events. The final event should
|
||||
have the ``body_finished`` attribute set to ``True``.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: body_finished
|
||||
|
||||
True if this is the final chunk of the body data.
|
||||
|
||||
.. attribute:: data (bytes)
|
||||
|
||||
(Required) The raw body data.
|
||||
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
body_finished: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CloseConnection(Event):
|
||||
"""
|
||||
The end of a Websocket connection, represents a closure frame.
|
||||
|
||||
**wsproto does not automatically send a response to a close event.** To
|
||||
comply with the RFC you MUST send a close event back to the remote WebSocket
|
||||
if you have not already sent one. The :meth:`response` method provides a
|
||||
suitable event for this purpose, and you should check if a response needs
|
||||
to be sent by checking :func:`wsproto.WSConnection.state`.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: code
|
||||
|
||||
(Required) The integer close code to indicate why the connection
|
||||
has closed.
|
||||
|
||||
.. attribute:: reason
|
||||
|
||||
Additional reasoning for why the connection has closed.
|
||||
|
||||
"""
|
||||
|
||||
code: int
|
||||
reason: str | None = None
|
||||
|
||||
def response(self) -> CloseConnection:
|
||||
"""Generate an RFC-compliant close frame to send back to the peer."""
|
||||
return CloseConnection(code=self.code, reason=self.reason)
|
||||
|
||||
|
||||
T = TypeVar("T", bytes | bytearray, str)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Message(Event, Generic[T]):
|
||||
"""
|
||||
The websocket data message.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: data
|
||||
|
||||
(Required) The message data as byte string, can be decoded as UTF-8 for
|
||||
TEXT messages. This only represents a single chunk of data and
|
||||
not a full WebSocket message. You need to buffer and
|
||||
reassemble these chunks to get the full message.
|
||||
|
||||
.. attribute:: frame_finished
|
||||
|
||||
This has no semantic content, but is provided just in case some
|
||||
weird edge case user wants to be able to reconstruct the
|
||||
fragmentation pattern of the original stream.
|
||||
|
||||
.. attribute:: message_finished
|
||||
|
||||
True if this frame is the last one of this message, False if
|
||||
more frames are expected.
|
||||
|
||||
"""
|
||||
|
||||
data: T
|
||||
frame_finished: bool = True
|
||||
message_finished: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextMessage(Message[str]): # pylint: disable=unsubscriptable-object
|
||||
"""
|
||||
Fired when a data frame with TEXT payload is received.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: data
|
||||
|
||||
The message data as string, This only represents a single chunk
|
||||
of data and not a full WebSocket message. You need to buffer
|
||||
and reassemble these chunks to get the full message.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BytesMessage(
|
||||
Message[bytearray | bytes] # pylint: disable=unsubscriptable-object
|
||||
):
|
||||
"""
|
||||
Fired when a data frame with BINARY payload is received.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: data
|
||||
|
||||
The message data as bytes or a bytearray, can be decoded as UTF-8 for
|
||||
TEXT messages. This only represents a single chunk of data and
|
||||
not a full WebSocket message. You need to buffer and
|
||||
reassemble these chunks to get the full message.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Ping(Event):
|
||||
"""
|
||||
The Ping event can be sent to trigger a ping frame and is fired
|
||||
when a Ping is received.
|
||||
|
||||
**wsproto does not automatically send a pong response to a ping event.** To
|
||||
comply with the RFC you MUST send a pong even as soon as is practical. The
|
||||
:meth:`response` method provides a suitable event for this purpose.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: payload
|
||||
|
||||
An optional payload to emit with the ping frame.
|
||||
"""
|
||||
|
||||
payload: bytes = b""
|
||||
|
||||
def response(self) -> Pong:
|
||||
"""Generate an RFC-compliant :class:`Pong` response to this ping."""
|
||||
return Pong(payload=self.payload)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Pong(Event):
|
||||
"""
|
||||
The Pong event is fired when a Pong is received.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: payload
|
||||
|
||||
An optional payload to emit with the pong frame.
|
||||
|
||||
"""
|
||||
|
||||
payload: bytes = b""
|
||||
320
path/to/venv/lib/python3.12/site-packages/wsproto/extensions.py
Normal file
320
path/to/venv/lib/python3.12/site-packages/wsproto/extensions.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
wsproto/extensions
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
WebSocket extensions.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import zlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from .frame_protocol import CloseReason, FrameDecoder, FrameProtocol, Opcode, RsvBits
|
||||
|
||||
|
||||
class Extension(ABC):
|
||||
name: str
|
||||
|
||||
def enabled(self) -> bool:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def offer(self) -> bool | str:
|
||||
pass
|
||||
|
||||
def accept(self, offer: str) -> bool | str | None:
|
||||
pass
|
||||
|
||||
def finalize(self, offer: str) -> None:
|
||||
pass
|
||||
|
||||
def frame_inbound_header(
|
||||
self,
|
||||
proto: FrameDecoder | FrameProtocol,
|
||||
opcode: Opcode,
|
||||
rsv: RsvBits,
|
||||
payload_length: int,
|
||||
) -> CloseReason | RsvBits:
|
||||
return RsvBits(False, False, False)
|
||||
|
||||
def frame_inbound_payload_data(
|
||||
self, proto: FrameDecoder | FrameProtocol, data: bytes,
|
||||
) -> bytes | CloseReason:
|
||||
return data
|
||||
|
||||
def frame_inbound_complete(
|
||||
self, proto: FrameDecoder | FrameProtocol, fin: bool,
|
||||
) -> bytes | CloseReason | None:
|
||||
pass
|
||||
|
||||
def frame_outbound(
|
||||
self,
|
||||
proto: FrameDecoder | FrameProtocol,
|
||||
opcode: Opcode,
|
||||
rsv: RsvBits,
|
||||
data: bytes,
|
||||
fin: bool,
|
||||
) -> tuple[RsvBits, bytes]:
|
||||
return (rsv, data)
|
||||
|
||||
|
||||
class PerMessageDeflate(Extension):
|
||||
name = "permessage-deflate"
|
||||
|
||||
DEFAULT_CLIENT_MAX_WINDOW_BITS = 15
|
||||
DEFAULT_SERVER_MAX_WINDOW_BITS = 15
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_no_context_takeover: bool = False,
|
||||
client_max_window_bits: int | None = None,
|
||||
server_no_context_takeover: bool = False,
|
||||
server_max_window_bits: int | None = None,
|
||||
) -> None:
|
||||
self.client_no_context_takeover = client_no_context_takeover
|
||||
self.server_no_context_takeover = server_no_context_takeover
|
||||
self._client_max_window_bits = self.DEFAULT_CLIENT_MAX_WINDOW_BITS
|
||||
self._server_max_window_bits = self.DEFAULT_SERVER_MAX_WINDOW_BITS
|
||||
if client_max_window_bits is not None:
|
||||
self.client_max_window_bits = client_max_window_bits
|
||||
if server_max_window_bits is not None:
|
||||
self.server_max_window_bits = server_max_window_bits
|
||||
|
||||
self._compressor: Optional[zlib._Compress] = None # noqa
|
||||
self._decompressor: Optional[zlib._Decompress] = None # noqa
|
||||
# This refers to the current frame
|
||||
self._inbound_is_compressible: bool | None = None
|
||||
# This refers to the ongoing message (which might span multiple
|
||||
# frames). Only the first frame in a fragmented message is flagged for
|
||||
# compression, so this carries that bit forward.
|
||||
self._inbound_compressed: bool | None = None
|
||||
|
||||
self._enabled = False
|
||||
|
||||
@property
|
||||
def client_max_window_bits(self) -> int:
|
||||
return self._client_max_window_bits
|
||||
|
||||
@client_max_window_bits.setter
|
||||
def client_max_window_bits(self, value: int) -> None:
|
||||
if value < 9 or value > 15:
|
||||
msg = "Window size must be between 9 and 15 inclusive"
|
||||
raise ValueError(msg)
|
||||
self._client_max_window_bits = value
|
||||
|
||||
@property
|
||||
def server_max_window_bits(self) -> int:
|
||||
return self._server_max_window_bits
|
||||
|
||||
@server_max_window_bits.setter
|
||||
def server_max_window_bits(self, value: int) -> None:
|
||||
if value < 9 or value > 15:
|
||||
msg = "Window size must be between 9 and 15 inclusive"
|
||||
raise ValueError(msg)
|
||||
self._server_max_window_bits = value
|
||||
|
||||
def _compressible_opcode(self, opcode: Opcode) -> bool:
|
||||
return opcode in (Opcode.TEXT, Opcode.BINARY, Opcode.CONTINUATION)
|
||||
|
||||
def enabled(self) -> bool:
|
||||
return self._enabled
|
||||
|
||||
def offer(self) -> bool | str:
|
||||
parameters = [
|
||||
f"client_max_window_bits={self.client_max_window_bits}",
|
||||
f"server_max_window_bits={self.server_max_window_bits}",
|
||||
]
|
||||
|
||||
if self.client_no_context_takeover:
|
||||
parameters.append("client_no_context_takeover")
|
||||
if self.server_no_context_takeover:
|
||||
parameters.append("server_no_context_takeover")
|
||||
|
||||
return "; ".join(parameters)
|
||||
|
||||
def finalize(self, offer: str) -> None:
|
||||
bits = [b.strip() for b in offer.split(";")]
|
||||
for bit in bits[1:]:
|
||||
if bit.startswith("client_no_context_takeover"):
|
||||
self.client_no_context_takeover = True
|
||||
elif bit.startswith("server_no_context_takeover"):
|
||||
self.server_no_context_takeover = True
|
||||
elif bit.startswith("client_max_window_bits"):
|
||||
self.client_max_window_bits = int(bit.split("=", 1)[1].strip())
|
||||
elif bit.startswith("server_max_window_bits"):
|
||||
self.server_max_window_bits = int(bit.split("=", 1)[1].strip())
|
||||
|
||||
self._enabled = True
|
||||
|
||||
def _parse_params(self, params: str) -> tuple[int | None, int | None]:
|
||||
client_max_window_bits = None
|
||||
server_max_window_bits = None
|
||||
|
||||
bits = [b.strip() for b in params.split(";")]
|
||||
for bit in bits[1:]:
|
||||
if bit.startswith("client_no_context_takeover"):
|
||||
self.client_no_context_takeover = True
|
||||
elif bit.startswith("server_no_context_takeover"):
|
||||
self.server_no_context_takeover = True
|
||||
elif bit.startswith("client_max_window_bits"):
|
||||
if "=" in bit:
|
||||
client_max_window_bits = int(bit.split("=", 1)[1].strip())
|
||||
else:
|
||||
client_max_window_bits = self.client_max_window_bits
|
||||
elif bit.startswith("server_max_window_bits"):
|
||||
if "=" in bit:
|
||||
server_max_window_bits = int(bit.split("=", 1)[1].strip())
|
||||
else:
|
||||
server_max_window_bits = self.server_max_window_bits
|
||||
|
||||
return client_max_window_bits, server_max_window_bits
|
||||
|
||||
def accept(self, offer: str) -> bool | None | str:
|
||||
client_max_window_bits, server_max_window_bits = self._parse_params(offer)
|
||||
|
||||
parameters = []
|
||||
|
||||
if self.client_no_context_takeover:
|
||||
parameters.append("client_no_context_takeover")
|
||||
if self.server_no_context_takeover:
|
||||
parameters.append("server_no_context_takeover")
|
||||
try:
|
||||
if client_max_window_bits is not None:
|
||||
parameters.append(f"client_max_window_bits={client_max_window_bits}")
|
||||
self.client_max_window_bits = client_max_window_bits
|
||||
if server_max_window_bits is not None:
|
||||
parameters.append(f"server_max_window_bits={server_max_window_bits}")
|
||||
self.server_max_window_bits = server_max_window_bits
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
self._enabled = True
|
||||
return "; ".join(parameters)
|
||||
|
||||
def frame_inbound_header(
|
||||
self,
|
||||
proto: FrameDecoder | FrameProtocol,
|
||||
opcode: Opcode,
|
||||
rsv: RsvBits,
|
||||
payload_length: int,
|
||||
) -> CloseReason | RsvBits:
|
||||
if rsv.rsv1 and opcode.iscontrol():
|
||||
return CloseReason.PROTOCOL_ERROR
|
||||
if rsv.rsv1 and opcode is Opcode.CONTINUATION:
|
||||
return CloseReason.PROTOCOL_ERROR
|
||||
|
||||
self._inbound_is_compressible = self._compressible_opcode(opcode)
|
||||
|
||||
if self._inbound_compressed is None:
|
||||
self._inbound_compressed = rsv.rsv1
|
||||
if self._inbound_compressed:
|
||||
assert self._inbound_is_compressible
|
||||
if proto.client:
|
||||
bits = self.server_max_window_bits
|
||||
else:
|
||||
bits = self.client_max_window_bits
|
||||
if self._decompressor is None:
|
||||
self._decompressor = zlib.decompressobj(-int(bits))
|
||||
|
||||
return RsvBits(True, False, False)
|
||||
|
||||
def frame_inbound_payload_data(
|
||||
self, proto: FrameDecoder | FrameProtocol, data: bytes,
|
||||
) -> bytes | CloseReason:
|
||||
if not self._inbound_compressed or not self._inbound_is_compressible:
|
||||
return data
|
||||
assert self._decompressor is not None
|
||||
|
||||
try:
|
||||
return self._decompressor.decompress(bytes(data))
|
||||
except zlib.error:
|
||||
return CloseReason.INVALID_FRAME_PAYLOAD_DATA
|
||||
|
||||
def frame_inbound_complete(
|
||||
self, proto: FrameDecoder | FrameProtocol, fin: bool,
|
||||
) -> bytes | CloseReason | None:
|
||||
if not fin:
|
||||
return None
|
||||
if not self._inbound_is_compressible:
|
||||
self._inbound_compressed = None
|
||||
return None
|
||||
if not self._inbound_compressed:
|
||||
self._inbound_compressed = None
|
||||
return None
|
||||
assert self._decompressor is not None
|
||||
|
||||
try:
|
||||
data = self._decompressor.decompress(b"\x00\x00\xff\xff")
|
||||
data += self._decompressor.flush()
|
||||
except zlib.error:
|
||||
return CloseReason.INVALID_FRAME_PAYLOAD_DATA
|
||||
|
||||
if proto.client:
|
||||
no_context_takeover = self.server_no_context_takeover
|
||||
else:
|
||||
no_context_takeover = self.client_no_context_takeover
|
||||
|
||||
if no_context_takeover:
|
||||
self._decompressor = None
|
||||
|
||||
self._inbound_compressed = None
|
||||
|
||||
return data
|
||||
|
||||
def frame_outbound(
|
||||
self,
|
||||
proto: FrameDecoder | FrameProtocol,
|
||||
opcode: Opcode,
|
||||
rsv: RsvBits,
|
||||
data: bytes,
|
||||
fin: bool,
|
||||
) -> tuple[RsvBits, bytes]:
|
||||
if not self._compressible_opcode(opcode):
|
||||
return (rsv, data)
|
||||
|
||||
if opcode is not Opcode.CONTINUATION:
|
||||
rsv = RsvBits(True, rsv[1], rsv[2])
|
||||
|
||||
if self._compressor is None:
|
||||
assert opcode is not Opcode.CONTINUATION
|
||||
if proto.client:
|
||||
bits = self.client_max_window_bits
|
||||
else:
|
||||
bits = self.server_max_window_bits
|
||||
self._compressor = zlib.compressobj(
|
||||
zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -int(bits),
|
||||
)
|
||||
|
||||
data = self._compressor.compress(bytes(data))
|
||||
|
||||
if fin:
|
||||
data += self._compressor.flush(zlib.Z_SYNC_FLUSH)
|
||||
data = data[:-4]
|
||||
|
||||
if proto.client:
|
||||
no_context_takeover = self.client_no_context_takeover
|
||||
else:
|
||||
no_context_takeover = self.server_no_context_takeover
|
||||
|
||||
if no_context_takeover:
|
||||
self._compressor = None
|
||||
|
||||
return (rsv, data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
descr = [f"client_max_window_bits={self.client_max_window_bits}"]
|
||||
if self.client_no_context_takeover:
|
||||
descr.append("client_no_context_takeover")
|
||||
descr.append(f"server_max_window_bits={self.server_max_window_bits}")
|
||||
if self.server_no_context_takeover:
|
||||
descr.append("server_no_context_takeover")
|
||||
|
||||
return "<{} {}>".format(self.__class__.__name__, "; ".join(descr))
|
||||
|
||||
|
||||
#: SUPPORTED_EXTENSIONS maps all supported extension names to their class.
|
||||
#: This can be used to iterate all supported extensions of wsproto, instantiate
|
||||
#: new extensions based on their name, or check if a given extension is
|
||||
#: supported or not.
|
||||
SUPPORTED_EXTENSIONS = {PerMessageDeflate.name: PerMessageDeflate}
|
||||
@@ -0,0 +1,698 @@
|
||||
"""
|
||||
wsproto/frame_protocol
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
WebSocket frame protocol implementation.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import struct
|
||||
from codecs import IncrementalDecoder, getincrementaldecoder
|
||||
from enum import IntEnum
|
||||
from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from .extensions import Extension # pragma: no cover
|
||||
|
||||
|
||||
_XOR_TABLE = [bytes(a ^ b for a in range(256)) for b in range(256)]
|
||||
|
||||
|
||||
class XorMaskerSimple:
|
||||
def __init__(self, masking_key: bytearray | bytes) -> None:
|
||||
self._masking_key = masking_key
|
||||
|
||||
def process(self, data: bytearray) -> bytearray:
|
||||
data = bytearray(data)
|
||||
if data:
|
||||
data_array = data
|
||||
a, b, c, d = (_XOR_TABLE[n] for n in self._masking_key)
|
||||
data_array[::4] = data_array[::4].translate(a)
|
||||
data_array[1::4] = data_array[1::4].translate(b)
|
||||
data_array[2::4] = data_array[2::4].translate(c)
|
||||
data_array[3::4] = data_array[3::4].translate(d)
|
||||
|
||||
# Rotate the masking key so that the next usage continues
|
||||
# with the next key element, rather than restarting.
|
||||
key_rotation = len(data) % 4
|
||||
self._masking_key = (
|
||||
self._masking_key[key_rotation:] + self._masking_key[:key_rotation]
|
||||
)
|
||||
|
||||
return data_array
|
||||
return data
|
||||
|
||||
|
||||
class XorMaskerNull:
|
||||
def process(self, data: bytearray) -> bytearray:
|
||||
return data
|
||||
|
||||
|
||||
# RFC6455, Section 5.2 - Base Framing Protocol
|
||||
|
||||
# Payload length constants
|
||||
PAYLOAD_LENGTH_TWO_BYTE = 126
|
||||
PAYLOAD_LENGTH_EIGHT_BYTE = 127
|
||||
MAX_PAYLOAD_NORMAL = 125
|
||||
MAX_PAYLOAD_TWO_BYTE = 2**16 - 1
|
||||
MAX_PAYLOAD_EIGHT_BYTE = 2**64 - 1
|
||||
MAX_FRAME_PAYLOAD = MAX_PAYLOAD_EIGHT_BYTE
|
||||
|
||||
# MASK and PAYLOAD LEN are packed into a byte
|
||||
MASK_MASK = 0x80
|
||||
PAYLOAD_LEN_MASK = 0x7F
|
||||
|
||||
# FIN, RSV[123] and OPCODE are packed into a single byte
|
||||
FIN_MASK = 0x80
|
||||
RSV1_MASK = 0x40
|
||||
RSV2_MASK = 0x20
|
||||
RSV3_MASK = 0x10
|
||||
OPCODE_MASK = 0x0F
|
||||
|
||||
|
||||
class Opcode(IntEnum):
|
||||
"""
|
||||
RFC 6455, Section 5.2 - Base Framing Protocol
|
||||
"""
|
||||
|
||||
#: Continuation frame
|
||||
CONTINUATION = 0x0
|
||||
|
||||
#: Text message
|
||||
TEXT = 0x1
|
||||
|
||||
#: Binary message
|
||||
BINARY = 0x2
|
||||
|
||||
#: Close frame
|
||||
CLOSE = 0x8
|
||||
|
||||
#: Ping frame
|
||||
PING = 0x9
|
||||
|
||||
#: Pong frame
|
||||
PONG = 0xA
|
||||
|
||||
def iscontrol(self) -> bool:
|
||||
return bool(self & 0x08)
|
||||
|
||||
|
||||
class CloseReason(IntEnum):
|
||||
"""
|
||||
RFC 6455, Section 7.4.1 - Defined Status Codes
|
||||
"""
|
||||
|
||||
#: indicates a normal closure, meaning that the purpose for
|
||||
#: which the connection was established has been fulfilled.
|
||||
NORMAL_CLOSURE = 1000
|
||||
|
||||
#: indicates that an endpoint is "going away", such as a server
|
||||
#: going down or a browser having navigated away from a page.
|
||||
GOING_AWAY = 1001
|
||||
|
||||
#: indicates that an endpoint is terminating the connection due
|
||||
#: to a protocol error.
|
||||
PROTOCOL_ERROR = 1002
|
||||
|
||||
#: indicates that an endpoint is terminating the connection
|
||||
#: because it has received a type of data it cannot accept (e.g., an
|
||||
#: endpoint that understands only text data MAY send this if it
|
||||
#: receives a binary message).
|
||||
UNSUPPORTED_DATA = 1003
|
||||
|
||||
#: Reserved. The specific meaning might be defined in the future.
|
||||
# DON'T DEFINE THIS: RESERVED_1004 = 1004
|
||||
|
||||
#: is a reserved value and MUST NOT be set as a status code in a
|
||||
#: Close control frame by an endpoint. It is designated for use in
|
||||
#: applications expecting a status code to indicate that no status
|
||||
#: code was actually present.
|
||||
NO_STATUS_RCVD = 1005
|
||||
|
||||
#: is a reserved value and MUST NOT be set as a status code in a
|
||||
#: Close control frame by an endpoint. It is designated for use in
|
||||
#: applications expecting a status code to indicate that the
|
||||
#: connection was closed abnormally, e.g., without sending or
|
||||
#: receiving a Close control frame.
|
||||
ABNORMAL_CLOSURE = 1006
|
||||
|
||||
#: indicates that an endpoint is terminating the connection
|
||||
#: because it has received data within a message that was not
|
||||
#: consistent with the type of the message (e.g., non-UTF-8 [RFC3629]
|
||||
#: data within a text message).
|
||||
INVALID_FRAME_PAYLOAD_DATA = 1007
|
||||
|
||||
#: indicates that an endpoint is terminating the connection
|
||||
#: because it has received a message that violates its policy. This
|
||||
#: is a generic status code that can be returned when there is no
|
||||
#: other more suitable status code (e.g., 1003 or 1009) or if there
|
||||
#: is a need to hide specific details about the policy.
|
||||
POLICY_VIOLATION = 1008
|
||||
|
||||
#: indicates that an endpoint is terminating the connection
|
||||
#: because it has received a message that is too big for it to
|
||||
#: process.
|
||||
MESSAGE_TOO_BIG = 1009
|
||||
|
||||
#: indicates that an endpoint (client) is terminating the
|
||||
#: connection because it has expected the server to negotiate one or
|
||||
#: more extension, but the server didn't return them in the response
|
||||
#: message of the WebSocket handshake. The list of extensions that
|
||||
#: are needed SHOULD appear in the /reason/ part of the Close frame.
|
||||
#: Note that this status code is not used by the server, because it
|
||||
#: can fail the WebSocket handshake instead.
|
||||
MANDATORY_EXT = 1010
|
||||
|
||||
#: indicates that a server is terminating the connection because
|
||||
#: it encountered an unexpected condition that prevented it from
|
||||
#: fulfilling the request.
|
||||
INTERNAL_ERROR = 1011
|
||||
|
||||
#: Server/service is restarting
|
||||
#: (not part of RFC6455)
|
||||
SERVICE_RESTART = 1012
|
||||
|
||||
#: Temporary server condition forced blocking client's request
|
||||
#: (not part of RFC6455)
|
||||
TRY_AGAIN_LATER = 1013
|
||||
|
||||
#: is a reserved value and MUST NOT be set as a status code in a
|
||||
#: Close control frame by an endpoint. It is designated for use in
|
||||
#: applications expecting a status code to indicate that the
|
||||
#: connection was closed due to a failure to perform a TLS handshake
|
||||
#: (e.g., the server certificate can't be verified).
|
||||
TLS_HANDSHAKE_FAILED = 1015
|
||||
|
||||
|
||||
# RFC 6455, Section 7.4.1 - Defined Status Codes
|
||||
LOCAL_ONLY_CLOSE_REASONS = (
|
||||
CloseReason.NO_STATUS_RCVD,
|
||||
CloseReason.ABNORMAL_CLOSURE,
|
||||
CloseReason.TLS_HANDSHAKE_FAILED,
|
||||
)
|
||||
|
||||
|
||||
# RFC 6455, Section 7.4.2 - Status Code Ranges
|
||||
MIN_CLOSE_REASON = 1000
|
||||
MIN_PROTOCOL_CLOSE_REASON = 1000
|
||||
MAX_PROTOCOL_CLOSE_REASON = 2999
|
||||
MIN_LIBRARY_CLOSE_REASON = 3000
|
||||
MAX_LIBRARY_CLOSE_REASON = 3999
|
||||
MIN_PRIVATE_CLOSE_REASON = 4000
|
||||
MAX_PRIVATE_CLOSE_REASON = 4999
|
||||
MAX_CLOSE_REASON = 4999
|
||||
|
||||
|
||||
NULL_MASK = struct.pack("!I", 0)
|
||||
|
||||
|
||||
class ParseFailed(Exception):
|
||||
def __init__(
|
||||
self, msg: str, code: CloseReason = CloseReason.PROTOCOL_ERROR,
|
||||
) -> None:
|
||||
super().__init__(msg)
|
||||
self.code = code
|
||||
|
||||
|
||||
class RsvBits(NamedTuple):
|
||||
rsv1: bool
|
||||
rsv2: bool
|
||||
rsv3: bool
|
||||
|
||||
|
||||
class Header(NamedTuple):
|
||||
fin: bool
|
||||
rsv: RsvBits
|
||||
opcode: Opcode
|
||||
payload_len: int
|
||||
masking_key: bytes | None
|
||||
|
||||
|
||||
class Frame(NamedTuple):
|
||||
opcode: Opcode
|
||||
payload: bytes | str | tuple[int, str]
|
||||
frame_finished: bool
|
||||
message_finished: bool
|
||||
|
||||
|
||||
def _truncate_utf8(data: bytes, nbytes: int) -> bytes:
|
||||
if len(data) <= nbytes:
|
||||
return data
|
||||
|
||||
# Truncate
|
||||
data = data[:nbytes]
|
||||
# But we might have cut a codepoint in half, in which case we want to
|
||||
# discard the partial character so the data is at least
|
||||
# well-formed. This is a little inefficient since it processes the
|
||||
# whole message twice when in theory we could just peek at the last
|
||||
# few characters, but since this is only used for close messages (max
|
||||
# length = 125 bytes) it really doesn't matter.
|
||||
return data.decode("utf-8", errors="ignore").encode("utf-8")
|
||||
|
||||
|
||||
class Buffer:
|
||||
def __init__(self, initial_bytes: bytes | None = None) -> None:
|
||||
self.buffer = bytearray()
|
||||
self.bytes_used = 0
|
||||
if initial_bytes:
|
||||
self.feed(initial_bytes)
|
||||
|
||||
def feed(self, new_bytes: bytes) -> None:
|
||||
self.buffer += new_bytes
|
||||
|
||||
def consume_at_most(self, nbytes: int) -> bytearray:
|
||||
if not nbytes:
|
||||
return bytearray()
|
||||
|
||||
data = self.buffer[self.bytes_used : self.bytes_used + nbytes]
|
||||
self.bytes_used += len(data)
|
||||
return data
|
||||
|
||||
def consume_exactly(self, nbytes: int) -> bytearray | None:
|
||||
if len(self.buffer) - self.bytes_used < nbytes:
|
||||
return None
|
||||
|
||||
return self.consume_at_most(nbytes)
|
||||
|
||||
def commit(self) -> None:
|
||||
# In CPython 3.4+, del[:n] is amortized O(n), *not* quadratic
|
||||
del self.buffer[: self.bytes_used]
|
||||
self.bytes_used = 0
|
||||
|
||||
def rollback(self) -> None:
|
||||
self.bytes_used = 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.buffer)
|
||||
|
||||
|
||||
class MessageDecoder:
|
||||
def __init__(self) -> None:
|
||||
self.opcode: Opcode | None = None
|
||||
self.decoder: IncrementalDecoder | None = None
|
||||
|
||||
def process_frame(self, frame: Frame) -> Frame:
|
||||
assert not frame.opcode.iscontrol()
|
||||
|
||||
if self.opcode is None:
|
||||
if frame.opcode is Opcode.CONTINUATION:
|
||||
msg = "unexpected CONTINUATION"
|
||||
raise ParseFailed(msg)
|
||||
self.opcode = frame.opcode
|
||||
elif frame.opcode is not Opcode.CONTINUATION:
|
||||
msg = f"expected CONTINUATION, got {frame.opcode!r}"
|
||||
raise ParseFailed(msg)
|
||||
|
||||
if frame.opcode is Opcode.TEXT:
|
||||
self.decoder = getincrementaldecoder("utf-8")()
|
||||
|
||||
finished = frame.frame_finished and frame.message_finished
|
||||
|
||||
if self.decoder is None:
|
||||
data = frame.payload
|
||||
else:
|
||||
assert isinstance(frame.payload, (bytes, bytearray))
|
||||
try:
|
||||
data = self.decoder.decode(frame.payload, finished)
|
||||
except UnicodeDecodeError as exc:
|
||||
raise ParseFailed(str(exc), CloseReason.INVALID_FRAME_PAYLOAD_DATA)
|
||||
|
||||
frame = Frame(self.opcode, data, frame.frame_finished, finished)
|
||||
|
||||
if finished:
|
||||
self.opcode = None
|
||||
self.decoder = None
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
class FrameDecoder:
|
||||
def __init__(
|
||||
self, client: bool, extensions: list[Extension] | None = None,
|
||||
) -> None:
|
||||
self.client = client
|
||||
self.extensions = extensions or []
|
||||
|
||||
self.buffer = Buffer()
|
||||
|
||||
self.header: Header | None = None
|
||||
self.effective_opcode: Opcode | None = None
|
||||
self.masker: None | XorMaskerNull | XorMaskerSimple = None
|
||||
self.payload_required = 0
|
||||
self.payload_consumed = 0
|
||||
|
||||
def receive_bytes(self, data: bytes) -> None:
|
||||
self.buffer.feed(data)
|
||||
|
||||
def process_buffer(self) -> Frame | None:
|
||||
if not self.header and not self.parse_header():
|
||||
return None
|
||||
# parse_header() sets these.
|
||||
assert self.header is not None
|
||||
assert self.masker is not None
|
||||
assert self.effective_opcode is not None
|
||||
|
||||
if len(self.buffer) < self.payload_required:
|
||||
return None
|
||||
|
||||
payload_remaining = self.header.payload_len - self.payload_consumed
|
||||
payload = self.buffer.consume_at_most(payload_remaining)
|
||||
if not payload and self.header.payload_len > 0:
|
||||
return None
|
||||
self.buffer.commit()
|
||||
|
||||
self.payload_consumed += len(payload)
|
||||
finished = self.payload_consumed == self.header.payload_len
|
||||
|
||||
payload = self.masker.process(payload)
|
||||
|
||||
for extension in self.extensions:
|
||||
payload_ = extension.frame_inbound_payload_data(self, bytes(payload))
|
||||
if isinstance(payload_, CloseReason):
|
||||
msg = "error in extension"
|
||||
raise ParseFailed(msg, payload_)
|
||||
payload = bytearray(payload_)
|
||||
|
||||
if finished:
|
||||
final = bytearray()
|
||||
for extension in self.extensions:
|
||||
result = extension.frame_inbound_complete(self, self.header.fin)
|
||||
if isinstance(result, CloseReason):
|
||||
msg = "error in extension"
|
||||
raise ParseFailed(msg, result)
|
||||
if result is not None:
|
||||
final += result
|
||||
payload += final
|
||||
|
||||
frame = Frame(self.effective_opcode, bytes(payload), finished, self.header.fin)
|
||||
|
||||
if finished:
|
||||
self.header = None
|
||||
self.effective_opcode = None
|
||||
self.masker = None
|
||||
else:
|
||||
self.effective_opcode = Opcode.CONTINUATION
|
||||
|
||||
return frame
|
||||
|
||||
def parse_header(self) -> bool:
|
||||
data = self.buffer.consume_exactly(2)
|
||||
if data is None:
|
||||
self.buffer.rollback()
|
||||
return False
|
||||
|
||||
fin = bool(data[0] & FIN_MASK)
|
||||
rsv = RsvBits(
|
||||
bool(data[0] & RSV1_MASK),
|
||||
bool(data[0] & RSV2_MASK),
|
||||
bool(data[0] & RSV3_MASK),
|
||||
)
|
||||
opcode = data[0] & OPCODE_MASK
|
||||
try:
|
||||
opcode = Opcode(opcode)
|
||||
except ValueError:
|
||||
msg = f"Invalid opcode {opcode:#x}"
|
||||
raise ParseFailed(msg)
|
||||
|
||||
if opcode.iscontrol() and not fin:
|
||||
msg = "Invalid attempt to fragment control frame"
|
||||
raise ParseFailed(msg)
|
||||
|
||||
has_mask = bool(data[1] & MASK_MASK)
|
||||
payload_len_short = data[1] & PAYLOAD_LEN_MASK
|
||||
payload_len = self.parse_extended_payload_length(opcode, payload_len_short)
|
||||
if payload_len is None:
|
||||
self.buffer.rollback()
|
||||
return False
|
||||
|
||||
self.extension_processing(opcode, rsv, payload_len)
|
||||
|
||||
if has_mask and self.client:
|
||||
msg = "client received unexpected masked frame"
|
||||
raise ParseFailed(msg)
|
||||
if not has_mask and not self.client:
|
||||
msg = "server received unexpected unmasked frame"
|
||||
raise ParseFailed(msg)
|
||||
if has_mask:
|
||||
masking_key = self.buffer.consume_exactly(4)
|
||||
if masking_key is None:
|
||||
self.buffer.rollback()
|
||||
return False
|
||||
self.masker = XorMaskerSimple(masking_key)
|
||||
else:
|
||||
self.masker = XorMaskerNull()
|
||||
|
||||
self.buffer.commit()
|
||||
self.header = Header(fin, rsv, opcode, payload_len, None)
|
||||
self.effective_opcode = self.header.opcode
|
||||
if self.header.opcode.iscontrol():
|
||||
self.payload_required = payload_len
|
||||
else:
|
||||
self.payload_required = 0
|
||||
self.payload_consumed = 0
|
||||
return True
|
||||
|
||||
def parse_extended_payload_length(
|
||||
self, opcode: Opcode, payload_len: int,
|
||||
) -> int | None:
|
||||
if opcode.iscontrol() and payload_len > MAX_PAYLOAD_NORMAL:
|
||||
msg = "Control frame with payload len > 125"
|
||||
raise ParseFailed(msg)
|
||||
if payload_len == PAYLOAD_LENGTH_TWO_BYTE:
|
||||
data = self.buffer.consume_exactly(2)
|
||||
if data is None:
|
||||
return None
|
||||
(payload_len,) = struct.unpack("!H", data)
|
||||
if payload_len <= MAX_PAYLOAD_NORMAL:
|
||||
msg = "Payload length used 2 bytes when 1 would have sufficed"
|
||||
raise ParseFailed(
|
||||
msg,
|
||||
)
|
||||
elif payload_len == PAYLOAD_LENGTH_EIGHT_BYTE:
|
||||
data = self.buffer.consume_exactly(8)
|
||||
if data is None:
|
||||
return None
|
||||
(payload_len,) = struct.unpack("!Q", data)
|
||||
if payload_len <= MAX_PAYLOAD_TWO_BYTE:
|
||||
msg = "Payload length used 8 bytes when 2 would have sufficed"
|
||||
raise ParseFailed(
|
||||
msg,
|
||||
)
|
||||
if payload_len >> 63:
|
||||
# I'm not sure why this is illegal, but that's what the RFC
|
||||
# says, so...
|
||||
msg = "8-byte payload length with non-zero MSB"
|
||||
raise ParseFailed(msg)
|
||||
|
||||
return payload_len
|
||||
|
||||
def extension_processing(
|
||||
self, opcode: Opcode, rsv: RsvBits, payload_len: int,
|
||||
) -> None:
|
||||
rsv_used = [False, False, False]
|
||||
for extension in self.extensions:
|
||||
result = extension.frame_inbound_header(self, opcode, rsv, payload_len)
|
||||
if isinstance(result, CloseReason):
|
||||
msg = "error in extension"
|
||||
raise ParseFailed(msg, result)
|
||||
for bit, used in enumerate(result):
|
||||
if used:
|
||||
rsv_used[bit] = True
|
||||
for expected, found in zip(rsv_used, rsv):
|
||||
if found and not expected:
|
||||
msg = "Reserved bit set unexpectedly"
|
||||
raise ParseFailed(msg)
|
||||
|
||||
|
||||
class FrameProtocol:
|
||||
def __init__(self, client: bool, extensions: list[Extension]) -> None:
|
||||
self.client = client
|
||||
self.extensions = [ext for ext in extensions if ext.enabled()]
|
||||
|
||||
# Global state
|
||||
self._frame_decoder = FrameDecoder(self.client, self.extensions)
|
||||
self._message_decoder = MessageDecoder()
|
||||
self._parse_more = self._parse_more_gen()
|
||||
|
||||
self._outbound_opcode: Opcode | None = None
|
||||
|
||||
def _process_close(self, frame: Frame) -> Frame:
|
||||
data = frame.payload
|
||||
assert isinstance(data, (bytes, bytearray))
|
||||
|
||||
if not data:
|
||||
# "If this Close control frame contains no status code, _The
|
||||
# WebSocket Connection Close Code_ is considered to be 1005"
|
||||
data = (CloseReason.NO_STATUS_RCVD, "")
|
||||
elif len(data) == 1:
|
||||
msg = "CLOSE with 1 byte payload"
|
||||
raise ParseFailed(msg)
|
||||
else:
|
||||
(code,) = struct.unpack("!H", data[:2])
|
||||
if code < MIN_CLOSE_REASON or code > MAX_CLOSE_REASON:
|
||||
msg = "CLOSE with invalid code"
|
||||
raise ParseFailed(msg)
|
||||
with contextlib.suppress(ValueError):
|
||||
code = CloseReason(code)
|
||||
if code in LOCAL_ONLY_CLOSE_REASONS:
|
||||
msg = "remote CLOSE with local-only reason"
|
||||
raise ParseFailed(msg)
|
||||
if not isinstance(code, CloseReason) and code <= MAX_PROTOCOL_CLOSE_REASON:
|
||||
msg = "CLOSE with unknown reserved code"
|
||||
raise ParseFailed(msg)
|
||||
try:
|
||||
reason = data[2:].decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise ParseFailed(
|
||||
"Error decoding CLOSE reason: " + str(exc),
|
||||
CloseReason.INVALID_FRAME_PAYLOAD_DATA,
|
||||
)
|
||||
data = (code, reason)
|
||||
|
||||
return Frame(frame.opcode, data, frame.frame_finished, frame.message_finished)
|
||||
|
||||
def _parse_more_gen(self) -> Generator[Frame | None, None, None]:
|
||||
# Consume as much as we can from self._buffer, yielding events, and
|
||||
# then yield None when we need more data. Or raise ParseFailed.
|
||||
|
||||
# XX FIXME this should probably be refactored so that we never see
|
||||
# disabled extensions in the first place...
|
||||
self.extensions = [ext for ext in self.extensions if ext.enabled()]
|
||||
closed = False
|
||||
|
||||
while not closed:
|
||||
frame = self._frame_decoder.process_buffer()
|
||||
|
||||
if frame is not None:
|
||||
if not frame.opcode.iscontrol():
|
||||
frame = self._message_decoder.process_frame(frame)
|
||||
elif frame.opcode == Opcode.CLOSE:
|
||||
frame = self._process_close(frame)
|
||||
closed = True
|
||||
|
||||
yield frame
|
||||
|
||||
def receive_bytes(self, data: bytes) -> None:
|
||||
self._frame_decoder.receive_bytes(data)
|
||||
|
||||
def received_frames(self) -> Generator[Frame, None, None]:
|
||||
for event in self._parse_more:
|
||||
if event is None:
|
||||
break
|
||||
else:
|
||||
yield event
|
||||
|
||||
def close(self, code: int | None = None, reason: str | None = None) -> bytearray:
|
||||
payload = bytearray()
|
||||
if code is CloseReason.NO_STATUS_RCVD:
|
||||
code = None
|
||||
if code is None and reason:
|
||||
msg = "cannot specify a reason without a code"
|
||||
raise TypeError(msg)
|
||||
if code in LOCAL_ONLY_CLOSE_REASONS:
|
||||
code = CloseReason.NORMAL_CLOSURE
|
||||
if code is not None:
|
||||
payload += bytearray(struct.pack("!H", code))
|
||||
if reason is not None:
|
||||
payload += _truncate_utf8(
|
||||
reason.encode("utf-8"), MAX_PAYLOAD_NORMAL - 2,
|
||||
)
|
||||
|
||||
return self._serialize_frame(Opcode.CLOSE, payload)
|
||||
|
||||
def ping(self, payload: bytes = b"") -> bytearray:
|
||||
return self._serialize_frame(Opcode.PING, payload)
|
||||
|
||||
def pong(self, payload: bytes = b"") -> bytearray:
|
||||
return self._serialize_frame(Opcode.PONG, payload)
|
||||
|
||||
def send_data(
|
||||
self, payload: bytes | bytearray | str = b"", fin: bool = True,
|
||||
) -> bytearray:
|
||||
if isinstance(payload, (bytes, bytearray, memoryview)):
|
||||
opcode = Opcode.BINARY
|
||||
elif isinstance(payload, str):
|
||||
opcode = Opcode.TEXT
|
||||
payload = payload.encode("utf-8")
|
||||
else:
|
||||
msg = "Must provide bytes or text"
|
||||
raise TypeError(msg)
|
||||
|
||||
if self._outbound_opcode is None:
|
||||
self._outbound_opcode = opcode
|
||||
elif self._outbound_opcode is not opcode:
|
||||
msg = "Data type mismatch inside message"
|
||||
raise TypeError(msg)
|
||||
else:
|
||||
opcode = Opcode.CONTINUATION
|
||||
|
||||
if fin:
|
||||
self._outbound_opcode = None
|
||||
|
||||
return self._serialize_frame(opcode, payload, fin)
|
||||
|
||||
def _make_fin_rsv_opcode(self, fin: bool, rsv: RsvBits, opcode: Opcode) -> int:
|
||||
fin_bits = int(fin) << 7
|
||||
rsv_bits = (int(rsv.rsv1) << 6) + (int(rsv.rsv2) << 5) + (int(rsv.rsv3) << 4)
|
||||
opcode_bits = int(opcode)
|
||||
|
||||
return fin_bits | rsv_bits | opcode_bits
|
||||
|
||||
def _serialize_frame(
|
||||
self, opcode: Opcode, payload: bytes | bytearray = b"", fin: bool = True,
|
||||
) -> bytearray:
|
||||
payload = bytearray(payload)
|
||||
|
||||
rsv = RsvBits(False, False, False)
|
||||
for extension in reversed(self.extensions):
|
||||
rsv, payload = extension.frame_outbound(self, opcode, rsv, bytes(payload), fin)
|
||||
|
||||
fin_rsv_opcode = self._make_fin_rsv_opcode(fin, rsv, opcode)
|
||||
|
||||
payload_length = len(payload)
|
||||
quad_payload = False
|
||||
if payload_length <= MAX_PAYLOAD_NORMAL:
|
||||
first_payload = payload_length
|
||||
second_payload = None
|
||||
elif payload_length <= MAX_PAYLOAD_TWO_BYTE:
|
||||
first_payload = PAYLOAD_LENGTH_TWO_BYTE
|
||||
second_payload = payload_length
|
||||
else:
|
||||
first_payload = PAYLOAD_LENGTH_EIGHT_BYTE
|
||||
second_payload = payload_length
|
||||
quad_payload = True
|
||||
|
||||
if self.client:
|
||||
first_payload |= 1 << 7
|
||||
|
||||
header = bytearray([fin_rsv_opcode, first_payload])
|
||||
if second_payload is not None:
|
||||
if opcode.iscontrol():
|
||||
msg = "payload too long for control frame"
|
||||
raise ValueError(msg)
|
||||
if quad_payload:
|
||||
header += bytearray(struct.pack("!Q", second_payload))
|
||||
else:
|
||||
header += bytearray(struct.pack("!H", second_payload))
|
||||
|
||||
if self.client:
|
||||
# "The masking key is a 32-bit value chosen at random by the
|
||||
# client. When preparing a masked frame, the client MUST pick a
|
||||
# fresh masking key from the set of allowed 32-bit values. The
|
||||
# masking key needs to be unpredictable; thus, the masking key
|
||||
# MUST be derived from a strong source of entropy, and the masking
|
||||
# key for a given frame MUST NOT make it simple for a server/proxy
|
||||
# to predict the masking key for a subsequent frame. The
|
||||
# unpredictability of the masking key is essential to prevent
|
||||
# authors of malicious applications from selecting the bytes that
|
||||
# appear on the wire."
|
||||
# -- https://tools.ietf.org/html/rfc6455#section-5.3
|
||||
masking_key = bytearray(os.urandom(4))
|
||||
masker = XorMaskerSimple(masking_key)
|
||||
return bytearray(header + masking_key + masker.process(bytearray(payload)))
|
||||
|
||||
return bytearray(header + payload)
|
||||
515
path/to/venv/lib/python3.12/site-packages/wsproto/handshake.py
Normal file
515
path/to/venv/lib/python3.12/site-packages/wsproto/handshake.py
Normal file
@@ -0,0 +1,515 @@
|
||||
"""
|
||||
wsproto/handshake
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
An implementation of WebSocket handshakes.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
cast,
|
||||
)
|
||||
|
||||
import h11
|
||||
|
||||
from .connection import Connection, ConnectionState, ConnectionType
|
||||
from .events import AcceptConnection, Event, RejectConnection, RejectData, Request
|
||||
from .extensions import Extension
|
||||
from .utilities import (
|
||||
LocalProtocolError,
|
||||
RemoteProtocolError,
|
||||
generate_accept_token,
|
||||
generate_nonce,
|
||||
normed_header_dict,
|
||||
split_comma_header,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
|
||||
from .typing import Headers
|
||||
|
||||
# RFC6455, Section 4.2.1/6 - Reading the Client's Opening Handshake
|
||||
WEBSOCKET_VERSION = b"13"
|
||||
|
||||
# RFC6455, Section 4.2.1/3 - Value of the Upgrade header
|
||||
WEBSOCKET_UPGRADE = b"websocket"
|
||||
|
||||
|
||||
class H11Handshake:
|
||||
"""A Handshake implementation for HTTP/1.1 connections."""
|
||||
|
||||
def __init__(self, connection_type: ConnectionType) -> None:
|
||||
self.client = connection_type is ConnectionType.CLIENT
|
||||
self._state = ConnectionState.CONNECTING
|
||||
|
||||
if self.client:
|
||||
self._h11_connection = h11.Connection(h11.CLIENT)
|
||||
else:
|
||||
self._h11_connection = h11.Connection(h11.SERVER)
|
||||
|
||||
self._connection: Connection | None = None
|
||||
self._events: deque[Event] = deque()
|
||||
self._initiating_request: Request | None = None
|
||||
self._nonce: bytes | None = None
|
||||
|
||||
@property
|
||||
def state(self) -> ConnectionState:
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def connection(self) -> Connection | None:
|
||||
"""
|
||||
Return the established connection.
|
||||
|
||||
This will either return the connection or raise a
|
||||
LocalProtocolError if the connection has not yet been
|
||||
established.
|
||||
|
||||
:rtype: h11.Connection
|
||||
"""
|
||||
return self._connection
|
||||
|
||||
def initiate_upgrade_connection(
|
||||
self, headers: Headers, path: bytes | str,
|
||||
) -> None:
|
||||
"""
|
||||
Initiate an upgrade connection.
|
||||
|
||||
This should be used if the request has already be received and
|
||||
parsed.
|
||||
|
||||
:param list headers: HTTP headers represented as a list of 2-tuples.
|
||||
:param str path: A URL path.
|
||||
"""
|
||||
if self.client:
|
||||
msg = "Cannot initiate an upgrade connection when acting as the client"
|
||||
raise LocalProtocolError(
|
||||
msg,
|
||||
)
|
||||
upgrade_request = h11.Request(method=b"GET", target=path, headers=headers)
|
||||
h11_client = h11.Connection(h11.CLIENT)
|
||||
self.receive_data(h11_client.send(upgrade_request))
|
||||
|
||||
def send(self, event: Event) -> bytes:
|
||||
"""
|
||||
Send an event to the remote.
|
||||
|
||||
This will return the bytes to send based on the event or raise
|
||||
a LocalProtocolError if the event is not valid given the
|
||||
state.
|
||||
|
||||
:returns: Data to send to the WebSocket peer.
|
||||
:rtype: bytes
|
||||
"""
|
||||
data = b""
|
||||
if isinstance(event, Request):
|
||||
data += self._initiate_connection(event)
|
||||
elif isinstance(event, AcceptConnection):
|
||||
data += self._accept(event)
|
||||
elif isinstance(event, RejectConnection):
|
||||
data += self._reject(event)
|
||||
elif isinstance(event, RejectData):
|
||||
data += self._send_reject_data(event)
|
||||
else:
|
||||
msg = f"Event {event} cannot be sent during the handshake"
|
||||
raise LocalProtocolError(
|
||||
msg,
|
||||
)
|
||||
return data
|
||||
|
||||
def receive_data(self, data: bytes | None) -> None:
|
||||
"""
|
||||
Receive data from the remote.
|
||||
|
||||
A list of events that the remote peer triggered by sending
|
||||
this data can be retrieved with :meth:`events`.
|
||||
|
||||
:param bytes data: Data received from the WebSocket peer.
|
||||
"""
|
||||
self._h11_connection.receive_data(data or b"")
|
||||
while True:
|
||||
try:
|
||||
event = self._h11_connection.next_event()
|
||||
except h11.RemoteProtocolError:
|
||||
msg = "Bad HTTP message"
|
||||
raise RemoteProtocolError(
|
||||
msg, event_hint=RejectConnection(),
|
||||
)
|
||||
if (
|
||||
isinstance(event, h11.ConnectionClosed)
|
||||
or event is h11.NEED_DATA
|
||||
or event is h11.PAUSED
|
||||
):
|
||||
break
|
||||
|
||||
if self.client:
|
||||
if isinstance(event, h11.InformationalResponse):
|
||||
if event.status_code == 101:
|
||||
self._events.append(self._establish_client_connection(event))
|
||||
else:
|
||||
self._events.append(
|
||||
RejectConnection(
|
||||
headers=list(event.headers),
|
||||
status_code=event.status_code,
|
||||
has_body=False,
|
||||
),
|
||||
)
|
||||
self._state = ConnectionState.CLOSED
|
||||
elif isinstance(event, h11.Response):
|
||||
self._state = ConnectionState.REJECTING
|
||||
self._events.append(
|
||||
RejectConnection(
|
||||
headers=list(event.headers),
|
||||
status_code=event.status_code,
|
||||
has_body=True,
|
||||
),
|
||||
)
|
||||
elif isinstance(event, h11.Data):
|
||||
self._events.append(
|
||||
RejectData(data=event.data, body_finished=False),
|
||||
)
|
||||
elif isinstance(event, h11.EndOfMessage):
|
||||
self._events.append(RejectData(data=b"", body_finished=True))
|
||||
self._state = ConnectionState.CLOSED
|
||||
elif isinstance(event, h11.Request):
|
||||
self._events.append(self._process_connection_request(event))
|
||||
|
||||
def events(self) -> Generator[Event, None, None]:
|
||||
"""
|
||||
Return a generator that provides any events that have been generated
|
||||
by protocol activity.
|
||||
|
||||
:returns: a generator that yields H11 events.
|
||||
"""
|
||||
while self._events:
|
||||
yield self._events.popleft()
|
||||
|
||||
# Server mode methods
|
||||
|
||||
def _process_connection_request(
|
||||
self, event: h11.Request,
|
||||
) -> Request:
|
||||
if event.method != b"GET":
|
||||
msg = "Request method must be GET"
|
||||
raise RemoteProtocolError(
|
||||
msg, event_hint=RejectConnection(),
|
||||
)
|
||||
connection_tokens = None
|
||||
extensions: list[str] = []
|
||||
host = None
|
||||
key = None
|
||||
subprotocols: list[str] = []
|
||||
upgrade = b""
|
||||
version = None
|
||||
headers: Headers = []
|
||||
for name, value in event.headers:
|
||||
name = name.lower()
|
||||
if name == b"connection":
|
||||
connection_tokens = split_comma_header(value)
|
||||
elif name == b"host":
|
||||
host = value.decode("idna")
|
||||
continue # Skip appending to headers
|
||||
elif name == b"sec-websocket-extensions":
|
||||
extensions.extend(split_comma_header(value))
|
||||
continue # Skip appending to headers
|
||||
elif name == b"sec-websocket-key":
|
||||
key = value
|
||||
elif name == b"sec-websocket-protocol":
|
||||
subprotocols.extend(split_comma_header(value))
|
||||
continue # Skip appending to headers
|
||||
elif name == b"sec-websocket-version":
|
||||
version = value
|
||||
elif name == b"upgrade":
|
||||
upgrade = value
|
||||
headers.append((name, value))
|
||||
if connection_tokens is None or not any(
|
||||
token.lower() == "upgrade" for token in connection_tokens
|
||||
):
|
||||
msg = "Missing header, 'Connection: Upgrade'"
|
||||
raise RemoteProtocolError(
|
||||
msg, event_hint=RejectConnection(),
|
||||
)
|
||||
if version != WEBSOCKET_VERSION:
|
||||
msg = "Missing header, 'Sec-WebSocket-Version'"
|
||||
raise RemoteProtocolError(
|
||||
msg,
|
||||
event_hint=RejectConnection(
|
||||
headers=[(b"Sec-WebSocket-Version", WEBSOCKET_VERSION)],
|
||||
status_code=426 if version else 400,
|
||||
),
|
||||
)
|
||||
if key is None:
|
||||
msg = "Missing header, 'Sec-WebSocket-Key'"
|
||||
raise RemoteProtocolError(
|
||||
msg, event_hint=RejectConnection(),
|
||||
)
|
||||
if upgrade.lower() != WEBSOCKET_UPGRADE:
|
||||
msg = f"Missing header, 'Upgrade: {WEBSOCKET_UPGRADE.decode()}'"
|
||||
raise RemoteProtocolError(
|
||||
msg,
|
||||
event_hint=RejectConnection(),
|
||||
)
|
||||
if host is None:
|
||||
msg = "Missing header, 'Host'"
|
||||
raise RemoteProtocolError(
|
||||
msg, event_hint=RejectConnection(),
|
||||
)
|
||||
|
||||
self._initiating_request = Request(
|
||||
extensions=extensions,
|
||||
extra_headers=headers,
|
||||
host=host,
|
||||
subprotocols=subprotocols,
|
||||
target=event.target.decode("ascii"),
|
||||
)
|
||||
return self._initiating_request
|
||||
|
||||
def _accept(self, event: AcceptConnection) -> bytes:
|
||||
# _accept is always called after _process_connection_request.
|
||||
assert self._initiating_request is not None
|
||||
request_headers = normed_header_dict(self._initiating_request.extra_headers)
|
||||
|
||||
nonce = request_headers[b"sec-websocket-key"]
|
||||
accept_token = generate_accept_token(nonce)
|
||||
|
||||
headers = [
|
||||
(b"Upgrade", WEBSOCKET_UPGRADE),
|
||||
(b"Connection", b"Upgrade"),
|
||||
(b"Sec-WebSocket-Accept", accept_token),
|
||||
]
|
||||
|
||||
if event.subprotocol is not None:
|
||||
if event.subprotocol not in self._initiating_request.subprotocols:
|
||||
msg = f"unexpected subprotocol {event.subprotocol}"
|
||||
raise LocalProtocolError(msg)
|
||||
headers.append(
|
||||
(b"Sec-WebSocket-Protocol", event.subprotocol.encode("ascii")),
|
||||
)
|
||||
|
||||
if event.extensions:
|
||||
accepts = server_extensions_handshake(
|
||||
cast("Sequence[str]", self._initiating_request.extensions),
|
||||
event.extensions,
|
||||
)
|
||||
if accepts:
|
||||
headers.append((b"Sec-WebSocket-Extensions", accepts))
|
||||
|
||||
response = h11.InformationalResponse(
|
||||
status_code=101,
|
||||
headers=headers + event.extra_headers,
|
||||
reason=b"Switching Protocols",
|
||||
)
|
||||
self._connection = Connection(
|
||||
ConnectionType.CLIENT if self.client else ConnectionType.SERVER,
|
||||
event.extensions,
|
||||
)
|
||||
self._state = ConnectionState.OPEN
|
||||
return self._h11_connection.send(response) or b""
|
||||
|
||||
def _reject(self, event: RejectConnection) -> bytes:
|
||||
if self.state != ConnectionState.CONNECTING:
|
||||
msg = f"Connection cannot be rejected in state {self.state}"
|
||||
raise LocalProtocolError(
|
||||
msg,
|
||||
)
|
||||
|
||||
headers = list(event.headers)
|
||||
if not event.has_body:
|
||||
headers.append((b"content-length", b"0"))
|
||||
response = h11.Response(status_code=event.status_code, headers=headers)
|
||||
data = self._h11_connection.send(response) or b""
|
||||
self._state = ConnectionState.REJECTING
|
||||
if not event.has_body:
|
||||
data += self._h11_connection.send(h11.EndOfMessage()) or b""
|
||||
self._state = ConnectionState.CLOSED
|
||||
return data
|
||||
|
||||
def _send_reject_data(self, event: RejectData) -> bytes:
|
||||
if self.state != ConnectionState.REJECTING:
|
||||
msg = f"Cannot send rejection data in state {self.state}"
|
||||
raise LocalProtocolError(
|
||||
msg,
|
||||
)
|
||||
|
||||
data = self._h11_connection.send(h11.Data(data=event.data)) or b""
|
||||
if event.body_finished:
|
||||
data += self._h11_connection.send(h11.EndOfMessage()) or b""
|
||||
self._state = ConnectionState.CLOSED
|
||||
return data
|
||||
|
||||
# Client mode methods
|
||||
|
||||
def _initiate_connection(self, request: Request) -> bytes:
|
||||
self._initiating_request = request
|
||||
self._nonce = generate_nonce()
|
||||
|
||||
headers = [
|
||||
(b"Host", request.host.encode("idna")),
|
||||
(b"Upgrade", WEBSOCKET_UPGRADE),
|
||||
(b"Connection", b"Upgrade"),
|
||||
(b"Sec-WebSocket-Key", self._nonce),
|
||||
(b"Sec-WebSocket-Version", WEBSOCKET_VERSION),
|
||||
]
|
||||
|
||||
if request.subprotocols:
|
||||
headers.append(
|
||||
(
|
||||
b"Sec-WebSocket-Protocol",
|
||||
(", ".join(request.subprotocols)).encode("ascii"),
|
||||
),
|
||||
)
|
||||
|
||||
if request.extensions:
|
||||
offers: dict[str, str | bool] = {}
|
||||
for e in request.extensions:
|
||||
assert isinstance(e, Extension)
|
||||
offers[e.name] = e.offer()
|
||||
extensions = []
|
||||
for name, params in offers.items():
|
||||
bname = name.encode("ascii")
|
||||
if isinstance(params, bool):
|
||||
if params:
|
||||
extensions.append(bname)
|
||||
else:
|
||||
extensions.append(b"%s; %s" % (bname, params.encode("ascii")))
|
||||
if extensions:
|
||||
headers.append((b"Sec-WebSocket-Extensions", b", ".join(extensions)))
|
||||
|
||||
upgrade = h11.Request(
|
||||
method=b"GET",
|
||||
target=request.target.encode("ascii"),
|
||||
headers=headers + request.extra_headers,
|
||||
)
|
||||
return self._h11_connection.send(upgrade) or b""
|
||||
|
||||
def _establish_client_connection(
|
||||
self, event: h11.InformationalResponse,
|
||||
) -> AcceptConnection:
|
||||
# _establish_client_connection is always called after _initiate_connection.
|
||||
assert self._initiating_request is not None
|
||||
assert self._nonce is not None
|
||||
|
||||
accept = None
|
||||
connection_tokens = None
|
||||
accepts: list[str] = []
|
||||
subprotocol = None
|
||||
upgrade = b""
|
||||
headers: Headers = []
|
||||
for name, value in event.headers:
|
||||
name = name.lower()
|
||||
if name == b"connection":
|
||||
connection_tokens = split_comma_header(value)
|
||||
continue # Skip appending to headers
|
||||
if name == b"sec-websocket-extensions":
|
||||
accepts = split_comma_header(value)
|
||||
continue # Skip appending to headers
|
||||
if name == b"sec-websocket-accept":
|
||||
accept = value
|
||||
continue # Skip appending to headers
|
||||
if name == b"sec-websocket-protocol":
|
||||
subprotocol = value.decode("ascii")
|
||||
continue # Skip appending to headers
|
||||
if name == b"upgrade":
|
||||
upgrade = value
|
||||
continue # Skip appending to headers
|
||||
headers.append((name, value))
|
||||
|
||||
if connection_tokens is None or not any(
|
||||
token.lower() == "upgrade" for token in connection_tokens
|
||||
):
|
||||
msg = "Missing header, 'Connection: Upgrade'"
|
||||
raise RemoteProtocolError(
|
||||
msg, event_hint=RejectConnection(),
|
||||
)
|
||||
if upgrade.lower() != WEBSOCKET_UPGRADE:
|
||||
msg = f"Missing header, 'Upgrade: {WEBSOCKET_UPGRADE.decode()}'"
|
||||
raise RemoteProtocolError(
|
||||
msg,
|
||||
event_hint=RejectConnection(),
|
||||
)
|
||||
accept_token = generate_accept_token(self._nonce)
|
||||
if accept != accept_token:
|
||||
msg = "Bad accept token"
|
||||
raise RemoteProtocolError(msg, event_hint=RejectConnection())
|
||||
if subprotocol is not None and subprotocol not in self._initiating_request.subprotocols:
|
||||
msg = f"unrecognized subprotocol {subprotocol}"
|
||||
raise RemoteProtocolError(
|
||||
msg,
|
||||
event_hint=RejectConnection(),
|
||||
)
|
||||
extensions = client_extensions_handshake(
|
||||
accepts, cast("Sequence[Extension]", self._initiating_request.extensions),
|
||||
)
|
||||
|
||||
self._connection = Connection(
|
||||
ConnectionType.CLIENT if self.client else ConnectionType.SERVER,
|
||||
extensions,
|
||||
self._h11_connection.trailing_data[0],
|
||||
)
|
||||
self._state = ConnectionState.OPEN
|
||||
return AcceptConnection(
|
||||
extensions=extensions, extra_headers=headers, subprotocol=subprotocol,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(client={self.client}, state={self.state})"
|
||||
|
||||
|
||||
def server_extensions_handshake(
|
||||
requested: Iterable[str], supported: list[Extension],
|
||||
) -> bytes | None:
|
||||
"""
|
||||
Agree on the extensions to use returning an appropriate header value.
|
||||
|
||||
This returns None if there are no agreed extensions
|
||||
"""
|
||||
accepts: dict[str, bool | bytes] = {}
|
||||
for offer in requested:
|
||||
name = offer.split(";", 1)[0].strip()
|
||||
for extension in supported:
|
||||
if extension.name == name:
|
||||
accept = extension.accept(offer)
|
||||
if isinstance(accept, bool):
|
||||
if accept:
|
||||
accepts[extension.name] = True
|
||||
elif accept is not None:
|
||||
accepts[extension.name] = accept.encode("ascii")
|
||||
|
||||
if accepts:
|
||||
extensions: list[bytes] = []
|
||||
for name, params in accepts.items():
|
||||
name_bytes = name.encode("ascii")
|
||||
if isinstance(params, bool):
|
||||
assert params
|
||||
extensions.append(name_bytes)
|
||||
elif params == b"":
|
||||
extensions.append(b"%s" % (name_bytes))
|
||||
else:
|
||||
extensions.append(b"%s; %s" % (name_bytes, params))
|
||||
return b", ".join(extensions)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def client_extensions_handshake(
|
||||
accepted: Iterable[str], supported: Sequence[Extension],
|
||||
) -> list[Extension]:
|
||||
# This raises RemoteProtocolError is the accepted extension is not
|
||||
# supported.
|
||||
extensions = []
|
||||
for accept in accepted:
|
||||
name = accept.split(";", 1)[0].strip()
|
||||
for extension in supported:
|
||||
if extension.name == name:
|
||||
extension.finalize(accept)
|
||||
extensions.append(extension)
|
||||
break
|
||||
else:
|
||||
msg = f"unrecognized extension {name}"
|
||||
raise RemoteProtocolError(
|
||||
msg, event_hint=RejectConnection(),
|
||||
)
|
||||
return extensions
|
||||
@@ -0,0 +1 @@
|
||||
Marker
|
||||
@@ -0,0 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
Headers = list[tuple[bytes, bytes]]
|
||||
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
wsproto/utilities
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
Utility functions that do not belong in a separate module.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from h11._headers import Headers as H11Headers
|
||||
|
||||
from .events import Event
|
||||
from .typing import Headers
|
||||
|
||||
# RFC6455, Section 1.3 - Opening Handshake
|
||||
ACCEPT_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
|
||||
class ProtocolError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class LocalProtocolError(ProtocolError):
|
||||
"""
|
||||
Indicates an error due to local/programming errors.
|
||||
|
||||
This is raised when the connection is asked to do something that
|
||||
is either incompatible with the state or the websocket standard.
|
||||
|
||||
"""
|
||||
|
||||
pass # noqa
|
||||
|
||||
|
||||
class RemoteProtocolError(ProtocolError):
|
||||
"""
|
||||
Indicates an error due to the remote's actions.
|
||||
|
||||
This is raised when processing the bytes from the remote if the
|
||||
remote has sent data that is incompatible with the websocket
|
||||
standard.
|
||||
|
||||
.. attribute:: event_hint
|
||||
|
||||
This is a suggested wsproto Event to send to the client based
|
||||
on the error. It could be None if no hint is available.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, event_hint: Event | None = None) -> None:
|
||||
self.event_hint = event_hint
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
# Some convenience utilities for working with HTTP headers
|
||||
def normed_header_dict(h11_headers: Headers | H11Headers) -> dict[bytes, bytes]:
|
||||
# This mangles Set-Cookie headers. But it happens that we don't care about
|
||||
# any of those, so it's OK. For every other HTTP header, if there are
|
||||
# multiple instances then you're allowed to join them together with
|
||||
# commas.
|
||||
name_to_values: dict[bytes, list[bytes]] = {}
|
||||
for name, value in h11_headers:
|
||||
name_to_values.setdefault(name, []).append(value)
|
||||
name_to_normed_value = {}
|
||||
for name, values in name_to_values.items():
|
||||
name_to_normed_value[name] = b", ".join(values)
|
||||
return name_to_normed_value
|
||||
|
||||
|
||||
# We use this for parsing the proposed protocol list, and for parsing the
|
||||
# proposed and accepted extension lists. For the proposed protocol list it's
|
||||
# fine, because the ABNF is just 1#token. But for the extension lists, it's
|
||||
# wrong, because those can contain quoted strings, which can in turn contain
|
||||
# commas. XX FIXME
|
||||
def split_comma_header(value: bytes) -> list[str]:
|
||||
return [piece.decode("ascii").strip() for piece in value.split(b",")]
|
||||
|
||||
|
||||
def generate_nonce() -> bytes:
|
||||
# os.urandom may be overkill for this use case, but I don't think this
|
||||
# is a bottleneck, and better safe than sorry...
|
||||
return base64.b64encode(os.urandom(16))
|
||||
|
||||
|
||||
def generate_accept_token(token: bytes) -> bytes:
|
||||
accept_token = token + ACCEPT_GUID
|
||||
accept_token = hashlib.sha1(accept_token).digest()
|
||||
return base64.b64encode(accept_token)
|
||||
Reference in New Issue
Block a user