195 lines
6.1 KiB
Python
195 lines
6.1 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import socket
|
||
|
from abc import abstractmethod
|
||
|
from collections.abc import Callable, Collection, Mapping
|
||
|
from contextlib import AsyncExitStack
|
||
|
from io import IOBase
|
||
|
from ipaddress import IPv4Address, IPv6Address
|
||
|
from socket import AddressFamily
|
||
|
from types import TracebackType
|
||
|
from typing import Any, TypeVar, Union
|
||
|
|
||
|
from .._core._typedattr import (
|
||
|
TypedAttributeProvider,
|
||
|
TypedAttributeSet,
|
||
|
typed_attribute,
|
||
|
)
|
||
|
from ._streams import ByteStream, Listener, UnreliableObjectStream
|
||
|
from ._tasks import TaskGroup
|
||
|
|
||
|
IPAddressType = Union[str, IPv4Address, IPv6Address]
|
||
|
IPSockAddrType = tuple[str, int]
|
||
|
SockAddrType = Union[IPSockAddrType, str]
|
||
|
UDPPacketType = tuple[bytes, IPSockAddrType]
|
||
|
UNIXDatagramPacketType = tuple[bytes, str]
|
||
|
T_Retval = TypeVar("T_Retval")
|
||
|
|
||
|
|
||
|
class _NullAsyncContextManager:
|
||
|
async def __aenter__(self) -> None:
|
||
|
pass
|
||
|
|
||
|
async def __aexit__(
|
||
|
self,
|
||
|
exc_type: type[BaseException] | None,
|
||
|
exc_val: BaseException | None,
|
||
|
exc_tb: TracebackType | None,
|
||
|
) -> bool | None:
|
||
|
return None
|
||
|
|
||
|
|
||
|
class SocketAttribute(TypedAttributeSet):
|
||
|
#: the address family of the underlying socket
|
||
|
family: AddressFamily = typed_attribute()
|
||
|
#: the local socket address of the underlying socket
|
||
|
local_address: SockAddrType = typed_attribute()
|
||
|
#: for IP addresses, the local port the underlying socket is bound to
|
||
|
local_port: int = typed_attribute()
|
||
|
#: the underlying stdlib socket object
|
||
|
raw_socket: socket.socket = typed_attribute()
|
||
|
#: the remote address the underlying socket is connected to
|
||
|
remote_address: SockAddrType = typed_attribute()
|
||
|
#: for IP addresses, the remote port the underlying socket is connected to
|
||
|
remote_port: int = typed_attribute()
|
||
|
|
||
|
|
||
|
class _SocketProvider(TypedAttributeProvider):
|
||
|
@property
|
||
|
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
|
||
|
from .._core._sockets import convert_ipv6_sockaddr as convert
|
||
|
|
||
|
attributes: dict[Any, Callable[[], Any]] = {
|
||
|
SocketAttribute.family: lambda: self._raw_socket.family,
|
||
|
SocketAttribute.local_address: lambda: convert(
|
||
|
self._raw_socket.getsockname()
|
||
|
),
|
||
|
SocketAttribute.raw_socket: lambda: self._raw_socket,
|
||
|
}
|
||
|
try:
|
||
|
peername: tuple[str, int] | None = convert(self._raw_socket.getpeername())
|
||
|
except OSError:
|
||
|
peername = None
|
||
|
|
||
|
# Provide the remote address for connected sockets
|
||
|
if peername is not None:
|
||
|
attributes[SocketAttribute.remote_address] = lambda: peername
|
||
|
|
||
|
# Provide local and remote ports for IP based sockets
|
||
|
if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6):
|
||
|
attributes[SocketAttribute.local_port] = (
|
||
|
lambda: self._raw_socket.getsockname()[1]
|
||
|
)
|
||
|
if peername is not None:
|
||
|
remote_port = peername[1]
|
||
|
attributes[SocketAttribute.remote_port] = lambda: remote_port
|
||
|
|
||
|
return attributes
|
||
|
|
||
|
@property
|
||
|
@abstractmethod
|
||
|
def _raw_socket(self) -> socket.socket:
|
||
|
pass
|
||
|
|
||
|
|
||
|
class SocketStream(ByteStream, _SocketProvider):
|
||
|
"""
|
||
|
Transports bytes over a socket.
|
||
|
|
||
|
Supports all relevant extra attributes from :class:`~SocketAttribute`.
|
||
|
"""
|
||
|
|
||
|
|
||
|
class UNIXSocketStream(SocketStream):
|
||
|
@abstractmethod
|
||
|
async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
|
||
|
"""
|
||
|
Send file descriptors along with a message to the peer.
|
||
|
|
||
|
:param message: a non-empty bytestring
|
||
|
:param fds: a collection of files (either numeric file descriptors or open file
|
||
|
or socket objects)
|
||
|
"""
|
||
|
|
||
|
@abstractmethod
|
||
|
async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
|
||
|
"""
|
||
|
Receive file descriptors along with a message from the peer.
|
||
|
|
||
|
:param msglen: length of the message to expect from the peer
|
||
|
:param maxfds: maximum number of file descriptors to expect from the peer
|
||
|
:return: a tuple of (message, file descriptors)
|
||
|
"""
|
||
|
|
||
|
|
||
|
class SocketListener(Listener[SocketStream], _SocketProvider):
|
||
|
"""
|
||
|
Listens to incoming socket connections.
|
||
|
|
||
|
Supports all relevant extra attributes from :class:`~SocketAttribute`.
|
||
|
"""
|
||
|
|
||
|
@abstractmethod
|
||
|
async def accept(self) -> SocketStream:
|
||
|
"""Accept an incoming connection."""
|
||
|
|
||
|
async def serve(
|
||
|
self,
|
||
|
handler: Callable[[SocketStream], Any],
|
||
|
task_group: TaskGroup | None = None,
|
||
|
) -> None:
|
||
|
from .. import create_task_group
|
||
|
|
||
|
async with AsyncExitStack() as stack:
|
||
|
if task_group is None:
|
||
|
task_group = await stack.enter_async_context(create_task_group())
|
||
|
|
||
|
while True:
|
||
|
stream = await self.accept()
|
||
|
task_group.start_soon(handler, stream)
|
||
|
|
||
|
|
||
|
class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider):
|
||
|
"""
|
||
|
Represents an unconnected UDP socket.
|
||
|
|
||
|
Supports all relevant extra attributes from :class:`~SocketAttribute`.
|
||
|
"""
|
||
|
|
||
|
async def sendto(self, data: bytes, host: str, port: int) -> None:
|
||
|
"""
|
||
|
Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))).
|
||
|
|
||
|
"""
|
||
|
return await self.send((data, (host, port)))
|
||
|
|
||
|
|
||
|
class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider):
|
||
|
"""
|
||
|
Represents an connected UDP socket.
|
||
|
|
||
|
Supports all relevant extra attributes from :class:`~SocketAttribute`.
|
||
|
"""
|
||
|
|
||
|
|
||
|
class UNIXDatagramSocket(
|
||
|
UnreliableObjectStream[UNIXDatagramPacketType], _SocketProvider
|
||
|
):
|
||
|
"""
|
||
|
Represents an unconnected Unix datagram socket.
|
||
|
|
||
|
Supports all relevant extra attributes from :class:`~SocketAttribute`.
|
||
|
"""
|
||
|
|
||
|
async def sendto(self, data: bytes, path: str) -> None:
|
||
|
"""Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, path))."""
|
||
|
return await self.send((data, path))
|
||
|
|
||
|
|
||
|
class ConnectedUNIXDatagramSocket(UnreliableObjectStream[bytes], _SocketProvider):
|
||
|
"""
|
||
|
Represents a connected Unix datagram socket.
|
||
|
|
||
|
Supports all relevant extra attributes from :class:`~SocketAttribute`.
|
||
|
"""
|