%PDF- %PDF-
Mini Shell

Mini Shell

Direktori : /lib/python3/dist-packages/twisted/internet/test/
Upload File :
Create Path :
Current File : //lib/python3/dist-packages/twisted/internet/test/test_protocol.py

# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.

"""
Tests for L{twisted.internet.protocol}.
"""


from io import BytesIO

from zope.interface import implementer
from zope.interface.verify import verifyObject

from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import (
    IConsumer,
    ILoggingContext,
    IProtocol,
    IProtocolFactory,
)
from twisted.internet.protocol import (
    ClientCreator,
    ConsumerToProtocolAdapter,
    Factory,
    FileWrapper,
    Protocol,
    ProtocolToConsumerAdapter,
)
from twisted.internet.testing import MemoryReactorClock, StringTransport
from twisted.logger import LogLevel, globalLogPublisher
from twisted.python.failure import Failure
from twisted.trial.unittest import TestCase


class ClientCreatorTests(TestCase):
    """
    Tests for L{twisted.internet.protocol.ClientCreator}.
    """

    def _basicConnectTest(self, check):
        """
        Helper for implementing a test to verify that one of the I{connect}
        methods of L{ClientCreator} passes the right arguments to the right
        reactor method.

        @param check: A function which will be invoked with a reactor and a
            L{ClientCreator} instance and which should call one of the
            L{ClientCreator}'s I{connect} methods and assert that all of its
            arguments except for the factory are passed on as expected to the
            reactor.  The factory should be returned.
        """

        class SomeProtocol(Protocol):
            pass

        reactor = MemoryReactorClock()
        cc = ClientCreator(reactor, SomeProtocol)
        factory = check(reactor, cc)
        protocol = factory.buildProtocol(None)
        self.assertIsInstance(protocol, SomeProtocol)

    def test_connectTCP(self):
        """
        L{ClientCreator.connectTCP} calls C{reactor.connectTCP} with the host
        and port information passed to it, and with a factory which will
        construct the protocol passed to L{ClientCreator.__init__}.
        """

        def check(reactor, cc):
            cc.connectTCP("example.com", 1234, 4321, ("1.2.3.4", 9876))
            host, port, factory, timeout, bindAddress = reactor.tcpClients.pop()
            self.assertEqual(host, "example.com")
            self.assertEqual(port, 1234)
            self.assertEqual(timeout, 4321)
            self.assertEqual(bindAddress, ("1.2.3.4", 9876))
            return factory

        self._basicConnectTest(check)

    def test_connectUNIX(self):
        """
        L{ClientCreator.connectUNIX} calls C{reactor.connectUNIX} with the
        filename passed to it, and with a factory which will construct the
        protocol passed to L{ClientCreator.__init__}.
        """

        def check(reactor, cc):
            cc.connectUNIX("/foo/bar", 123, True)
            address, factory, timeout, checkPID = reactor.unixClients.pop()
            self.assertEqual(address, "/foo/bar")
            self.assertEqual(timeout, 123)
            self.assertTrue(checkPID)
            return factory

        self._basicConnectTest(check)

    def test_connectSSL(self):
        """
        L{ClientCreator.connectSSL} calls C{reactor.connectSSL} with the host,
        port, and context factory passed to it, and with a factory which will
        construct the protocol passed to L{ClientCreator.__init__}.
        """

        def check(reactor, cc):
            expectedContextFactory = object()
            cc.connectSSL(
                "example.com", 1234, expectedContextFactory, 4321, ("4.3.2.1", 5678)
            )
            (
                host,
                port,
                factory,
                contextFactory,
                timeout,
                bindAddress,
            ) = reactor.sslClients.pop()
            self.assertEqual(host, "example.com")
            self.assertEqual(port, 1234)
            self.assertIs(contextFactory, expectedContextFactory)
            self.assertEqual(timeout, 4321)
            self.assertEqual(bindAddress, ("4.3.2.1", 5678))
            return factory

        self._basicConnectTest(check)

    def _cancelConnectTest(self, connect):
        """
        Helper for implementing a test to verify that cancellation of the
        L{Deferred} returned by one of L{ClientCreator}'s I{connect} methods is
        implemented to cancel the underlying connector.

        @param connect: A function which will be invoked with a L{ClientCreator}
            instance as an argument and which should call one its I{connect}
            methods and return the result.

        @return: A L{Deferred} which fires when the test is complete or fails if
            there is a problem.
        """
        reactor = MemoryReactorClock()
        cc = ClientCreator(reactor, Protocol)
        d = connect(cc)
        connector = reactor.connectors.pop()
        self.assertFalse(connector._disconnected)
        d.cancel()
        self.assertTrue(connector._disconnected)
        return self.assertFailure(d, CancelledError)

    def test_cancelConnectTCP(self):
        """
        The L{Deferred} returned by L{ClientCreator.connectTCP} can be cancelled
        to abort the connection attempt before it completes.
        """

        def connect(cc):
            return cc.connectTCP("example.com", 1234)

        return self._cancelConnectTest(connect)

    def test_cancelConnectUNIX(self):
        """
        The L{Deferred} returned by L{ClientCreator.connectTCP} can be cancelled
        to abort the connection attempt before it completes.
        """

        def connect(cc):
            return cc.connectUNIX("/foo/bar")

        return self._cancelConnectTest(connect)

    def test_cancelConnectSSL(self):
        """
        The L{Deferred} returned by L{ClientCreator.connectTCP} can be cancelled
        to abort the connection attempt before it completes.
        """

        def connect(cc):
            return cc.connectSSL("example.com", 1234, object())

        return self._cancelConnectTest(connect)

    def _cancelConnectTimeoutTest(self, connect):
        """
        Like L{_cancelConnectTest}, but for the case where the L{Deferred} is
        cancelled after the connection is set up but before it is fired with the
        resulting protocol instance.
        """
        reactor = MemoryReactorClock()
        cc = ClientCreator(reactor, Protocol)
        d = connect(reactor, cc)
        connector = reactor.connectors.pop()
        # Sanity check - there is an outstanding delayed call to fire the
        # Deferred.
        self.assertEqual(len(reactor.getDelayedCalls()), 1)

        # Cancel the Deferred, disconnecting the transport just set up and
        # cancelling the delayed call.
        d.cancel()

        self.assertEqual(reactor.getDelayedCalls(), [])

        # A real connector implementation is responsible for disconnecting the
        # transport as well.  For our purposes, just check that someone told the
        # connector to disconnect.
        self.assertTrue(connector._disconnected)

        return self.assertFailure(d, CancelledError)

    def test_cancelConnectTCPTimeout(self):
        """
        L{ClientCreator.connectTCP} inserts a very short delayed call between
        the time the connection is established and the time the L{Deferred}
        returned from one of its connect methods actually fires.  If the
        L{Deferred} is cancelled in this interval, the established connection is
        closed, the timeout is cancelled, and the L{Deferred} fails with
        L{CancelledError}.
        """

        def connect(reactor, cc):
            d = cc.connectTCP("example.com", 1234)
            host, port, factory, timeout, bindAddress = reactor.tcpClients.pop()
            protocol = factory.buildProtocol(None)
            transport = StringTransport()
            protocol.makeConnection(transport)
            return d

        return self._cancelConnectTimeoutTest(connect)

    def test_cancelConnectUNIXTimeout(self):
        """
        L{ClientCreator.connectUNIX} inserts a very short delayed call between
        the time the connection is established and the time the L{Deferred}
        returned from one of its connect methods actually fires.  If the
        L{Deferred} is cancelled in this interval, the established connection is
        closed, the timeout is cancelled, and the L{Deferred} fails with
        L{CancelledError}.
        """

        def connect(reactor, cc):
            d = cc.connectUNIX("/foo/bar")
            address, factory, timeout, bindAddress = reactor.unixClients.pop()
            protocol = factory.buildProtocol(None)
            transport = StringTransport()
            protocol.makeConnection(transport)
            return d

        return self._cancelConnectTimeoutTest(connect)

    def test_cancelConnectSSLTimeout(self):
        """
        L{ClientCreator.connectSSL} inserts a very short delayed call between
        the time the connection is established and the time the L{Deferred}
        returned from one of its connect methods actually fires.  If the
        L{Deferred} is cancelled in this interval, the established connection is
        closed, the timeout is cancelled, and the L{Deferred} fails with
        L{CancelledError}.
        """

        def connect(reactor, cc):
            d = cc.connectSSL("example.com", 1234, object())
            (
                host,
                port,
                factory,
                contextFactory,
                timeout,
                bindADdress,
            ) = reactor.sslClients.pop()
            protocol = factory.buildProtocol(None)
            transport = StringTransport()
            protocol.makeConnection(transport)
            return d

        return self._cancelConnectTimeoutTest(connect)

    def _cancelConnectFailedTimeoutTest(self, connect):
        """
        Like L{_cancelConnectTest}, but for the case where the L{Deferred} is
        cancelled after the connection attempt has failed but before it is fired
        with the resulting failure.
        """
        reactor = MemoryReactorClock()
        cc = ClientCreator(reactor, Protocol)
        d, factory = connect(reactor, cc)
        connector = reactor.connectors.pop()
        factory.clientConnectionFailed(
            connector, Failure(Exception("Simulated failure"))
        )

        # Sanity check - there is an outstanding delayed call to fire the
        # Deferred.
        self.assertEqual(len(reactor.getDelayedCalls()), 1)

        # Cancel the Deferred, cancelling the delayed call.
        d.cancel()

        self.assertEqual(reactor.getDelayedCalls(), [])

        return self.assertFailure(d, CancelledError)

    def test_cancelConnectTCPFailedTimeout(self):
        """
        Similar to L{test_cancelConnectTCPTimeout}, but for the case where the
        connection attempt fails.
        """

        def connect(reactor, cc):
            d = cc.connectTCP("example.com", 1234)
            host, port, factory, timeout, bindAddress = reactor.tcpClients.pop()
            return d, factory

        return self._cancelConnectFailedTimeoutTest(connect)

    def test_cancelConnectUNIXFailedTimeout(self):
        """
        Similar to L{test_cancelConnectUNIXTimeout}, but for the case where the
        connection attempt fails.
        """

        def connect(reactor, cc):
            d = cc.connectUNIX("/foo/bar")
            address, factory, timeout, bindAddress = reactor.unixClients.pop()
            return d, factory

        return self._cancelConnectFailedTimeoutTest(connect)

    def test_cancelConnectSSLFailedTimeout(self):
        """
        Similar to L{test_cancelConnectSSLTimeout}, but for the case where the
        connection attempt fails.
        """

        def connect(reactor, cc):
            d = cc.connectSSL("example.com", 1234, object())
            (
                host,
                port,
                factory,
                contextFactory,
                timeout,
                bindADdress,
            ) = reactor.sslClients.pop()
            return d, factory

        return self._cancelConnectFailedTimeoutTest(connect)


class ProtocolTests(TestCase):
    """
    Tests for L{twisted.internet.protocol.Protocol}.
    """

    def test_interfaces(self):
        """
        L{Protocol} instances provide L{IProtocol} and L{ILoggingContext}.
        """
        proto = Protocol()
        self.assertTrue(verifyObject(IProtocol, proto))
        self.assertTrue(verifyObject(ILoggingContext, proto))

    def test_logPrefix(self):
        """
        L{Protocol.logPrefix} returns the protocol class's name.
        """

        class SomeThing(Protocol):
            pass

        self.assertEqual("SomeThing", SomeThing().logPrefix())

    def test_makeConnection(self):
        """
        L{Protocol.makeConnection} sets the given transport on itself, and
        then calls C{connectionMade}.
        """
        result = []

        class SomeProtocol(Protocol):
            def connectionMade(self):
                result.append(self.transport)

        transport = object()
        protocol = SomeProtocol()
        protocol.makeConnection(transport)
        self.assertEqual(result, [transport])


class FactoryTests(TestCase):
    """
    Tests for L{protocol.Factory}.
    """

    def test_interfaces(self):
        """
        L{Factory} instances provide both L{IProtocolFactory} and
        L{ILoggingContext}.
        """
        factory = Factory()
        self.assertTrue(verifyObject(IProtocolFactory, factory))
        self.assertTrue(verifyObject(ILoggingContext, factory))

    def test_logPrefix(self):
        """
        L{Factory.logPrefix} returns the name of the factory class.
        """

        class SomeKindOfFactory(Factory):
            pass

        self.assertEqual("SomeKindOfFactory", SomeKindOfFactory().logPrefix())

    def test_defaultBuildProtocol(self):
        """
        L{Factory.buildProtocol} by default constructs a protocol by calling
        its C{protocol} attribute, and attaches the factory to the result.
        """

        class SomeProtocol(Protocol):
            pass

        f = Factory()
        f.protocol = SomeProtocol
        protocol = f.buildProtocol(None)
        self.assertIsInstance(protocol, SomeProtocol)
        self.assertIs(protocol.factory, f)

    def test_forProtocol(self):
        """
        L{Factory.forProtocol} constructs a Factory, passing along any
        additional arguments, and sets its C{protocol} attribute to the given
        Protocol subclass.
        """

        class ArgTakingFactory(Factory):
            def __init__(self, *args, **kwargs):
                self.args, self.kwargs = args, kwargs

        factory = ArgTakingFactory.forProtocol(Protocol, 1, 2, foo=12)
        self.assertEqual(factory.protocol, Protocol)
        self.assertEqual(factory.args, (1, 2))
        self.assertEqual(factory.kwargs, {"foo": 12})

    def test_doStartLoggingStatement(self):
        """
        L{Factory.doStart} logs that it is starting a factory, followed by
        the L{repr} of the L{Factory} instance that is being started.
        """
        events = []
        globalLogPublisher.addObserver(events.append)
        self.addCleanup(lambda: globalLogPublisher.removeObserver(events.append))

        f = Factory()
        f.doStart()

        self.assertIs(events[0]["factory"], f)
        self.assertEqual(events[0]["log_level"], LogLevel.info)
        self.assertEqual(events[0]["log_format"], "Starting factory {factory!r}")

    def test_doStopLoggingStatement(self):
        """
        L{Factory.doStop} logs that it is stopping a factory, followed by
        the L{repr} of the L{Factory} instance that is being stopped.
        """
        events = []
        globalLogPublisher.addObserver(events.append)
        self.addCleanup(lambda: globalLogPublisher.removeObserver(events.append))

        class MyFactory(Factory):
            numPorts = 1

        f = MyFactory()
        f.doStop()

        self.assertIs(events[0]["factory"], f)
        self.assertEqual(events[0]["log_level"], LogLevel.info)
        self.assertEqual(events[0]["log_format"], "Stopping factory {factory!r}")


class AdapterTests(TestCase):
    """
    Tests for L{ProtocolToConsumerAdapter} and L{ConsumerToProtocolAdapter}.
    """

    def test_protocolToConsumer(self):
        """
        L{IProtocol} providers can be adapted to L{IConsumer} providers using
        L{ProtocolToConsumerAdapter}.
        """
        result = []
        p = Protocol()
        p.dataReceived = result.append
        consumer = IConsumer(p)
        consumer.write(b"hello")
        self.assertEqual(result, [b"hello"])
        self.assertIsInstance(consumer, ProtocolToConsumerAdapter)

    def test_consumerToProtocol(self):
        """
        L{IConsumer} providers can be adapted to L{IProtocol} providers using
        L{ProtocolToConsumerAdapter}.
        """
        result = []

        @implementer(IConsumer)
        class Consumer:
            def write(self, d):
                result.append(d)

        c = Consumer()
        protocol = IProtocol(c)
        protocol.dataReceived(b"hello")
        self.assertEqual(result, [b"hello"])
        self.assertIsInstance(protocol, ConsumerToProtocolAdapter)


class FileWrapperTests(TestCase):
    """
    L{twisted.internet.protocol.FileWrapper}
    """

    def test_write(self):
        """
        L{twisted.internet.protocol.FileWrapper.write}
        """
        wrapper = FileWrapper(BytesIO())
        wrapper.write(b"test1")
        self.assertEqual(wrapper.file.getvalue(), b"test1")

        wrapper = FileWrapper(BytesIO())
        # BytesIO() cannot accept unicode, so this will
        # cause an exception to be thrown which will be
        # handled by FileWrapper.handle_exception().
        wrapper.write("stuff")
        self.assertNotEqual(wrapper.file.getvalue(), "stuff")

    def test_writeSequence(self):
        """
        L{twisted.internet.protocol.FileWrapper.writeSequence}
        """
        wrapper = FileWrapper(BytesIO())
        wrapper.writeSequence([b"test1", b"test2"])
        self.assertEqual(wrapper.file.getvalue(), b"test1test2")

        wrapper = FileWrapper(BytesIO())
        # In Python 3, b"".join([u"a", u"b"]) will raise a TypeError
        self.assertRaises(TypeError, wrapper.writeSequence, ["test3", "test4"])

Zerion Mini Shell 1.0