%PDF- %PDF-
Mini Shell

Mini Shell

Direktori : /lib/python3/dist-packages/twisted/conch/test/
Upload File :
Create Path :
Current File : //lib/python3/dist-packages/twisted/conch/test/test_channel.py

# Copyright Twisted Matrix Laboratories.
# See LICENSE for details.

"""
Test ssh/channel.py.
"""
from __future__ import annotations

from unittest import skipIf

from zope.interface.verify import verifyObject

try:
    from twisted.conch.ssh import channel
    from twisted.conch.ssh.address import SSHTransportAddress
    from twisted.conch.ssh.service import SSHService
    from twisted.conch.ssh.transport import SSHServerTransport
    from twisted.internet import interfaces
    from twisted.internet.address import IPv4Address
    from twisted.internet.testing import StringTransport

    skipTest = ""
except ImportError:
    skipTest = "Conch SSH not supported."
    SSHService = object  # type: ignore[assignment,misc]
from twisted.trial.unittest import TestCase


class MockConnection(SSHService):
    """
    A mock for twisted.conch.ssh.connection.SSHConnection.  Record the data
    that channels send, and when they try to close the connection.

    @ivar data: a L{dict} mapping channel id #s to lists of data sent by that
        channel.
    @ivar extData: a L{dict} mapping channel id #s to lists of 2-tuples
        (extended data type, data) sent by that channel.
    @ivar closes: a L{dict} mapping channel id #s to True if that channel sent
        a close message.
    """

    def __init__(self) -> None:
        self.data: dict[channel.SSHChannel, list[bytes]] = {}
        self.extData: dict[channel.SSHChannel, list[tuple[int, bytes]]] = {}
        self.closes: dict[channel.SSHChannel, bool] = {}

    def logPrefix(self) -> str:
        """
        Return our logging prefix.
        """
        return "MockConnection"

    def sendData(self, channel: channel.SSHChannel, data: bytes) -> None:
        """
        Record the sent data.
        """
        self.data.setdefault(channel, []).append(data)

    def sendExtendedData(
        self, channel: channel.SSHChannel, type: int, data: bytes
    ) -> None:
        """
        Record the sent extended data.
        """
        self.extData.setdefault(channel, []).append((type, data))

    def sendClose(self, channel: channel.SSHChannel) -> None:
        """
        Record that the channel sent a close message.
        """
        self.closes[channel] = True


def connectSSHTransport(
    service: SSHService,
    hostAddress: interfaces.IAddress | None = None,
    peerAddress: interfaces.IAddress | None = None,
) -> None:
    """
    Connect a SSHTransport which is already connected to a remote peer to
    the channel under test.

    @param service: Service used over the connected transport.
    @type service: L{SSHService}

    @param hostAddress: Local address of the connected transport.
    @type hostAddress: L{interfaces.IAddress}

    @param peerAddress: Remote address of the connected transport.
    @type peerAddress: L{interfaces.IAddress}
    """
    transport = SSHServerTransport()
    transport.makeConnection(
        StringTransport(hostAddress=hostAddress, peerAddress=peerAddress)
    )
    transport.setService(service)


@skipIf(skipTest, skipTest)
class ChannelTests(TestCase):
    """
    Tests for L{SSHChannel}.
    """

    def setUp(self) -> None:
        """
        Initialize the channel.  remoteMaxPacket is 10 so that data is able
        to be sent (the default of 0 means no data is sent because no packets
        are made).
        """
        self.conn = MockConnection()
        self.channel = channel.SSHChannel(conn=self.conn, remoteMaxPacket=10)
        self.channel.name = b"channel"

    def test_interface(self) -> None:
        """
        L{SSHChannel} instances provide L{interfaces.ITransport}.
        """
        self.assertTrue(verifyObject(interfaces.ITransport, self.channel))

    def test_init(self) -> None:
        """
        Test that SSHChannel initializes correctly.  localWindowSize defaults
        to 131072 (2**17) and localMaxPacket to 32768 (2**15) as reasonable
        defaults (what OpenSSH uses for those variables).

        The values in the second set of assertions are meaningless; they serve
        only to verify that the instance variables are assigned in the correct
        order.
        """
        c = channel.SSHChannel(conn=self.conn)
        self.assertEqual(c.localWindowSize, 131072)
        self.assertEqual(c.localWindowLeft, 131072)
        self.assertEqual(c.localMaxPacket, 32768)
        self.assertEqual(c.remoteWindowLeft, 0)
        self.assertEqual(c.remoteMaxPacket, 0)
        self.assertEqual(c.conn, self.conn)
        self.assertIsNone(c.data)
        self.assertIsNone(c.avatar)

        c2 = channel.SSHChannel(1, 2, 3, 4, 5, 6, 7)
        self.assertEqual(c2.localWindowSize, 1)
        self.assertEqual(c2.localWindowLeft, 1)
        self.assertEqual(c2.localMaxPacket, 2)
        self.assertEqual(c2.remoteWindowLeft, 3)
        self.assertEqual(c2.remoteMaxPacket, 4)
        self.assertEqual(c2.conn, 5)
        self.assertEqual(c2.data, 6)
        self.assertEqual(c2.avatar, 7)

    def test_str(self) -> None:
        """
        Test that str(SSHChannel) works gives the channel name and local and
        remote windows at a glance..
        """
        self.assertEqual(str(self.channel), "<SSHChannel channel (lw 131072 rw 0)>")
        self.assertEqual(
            str(channel.SSHChannel(localWindow=1)), "<SSHChannel None (lw 1 rw 0)>"
        )

    def test_bytes(self) -> None:
        """
        Test that bytes(SSHChannel) works, gives the channel name and
        local and remote windows at a glance..

        """
        self.assertEqual(
            self.channel.__bytes__(), b"<SSHChannel channel (lw 131072 rw 0)>"
        )
        self.assertEqual(
            channel.SSHChannel(localWindow=1).__bytes__(),
            b"<SSHChannel None (lw 1 rw 0)>",
        )

    def test_logPrefix(self) -> None:
        """
        Test that SSHChannel.logPrefix gives the name of the channel, the
        local channel ID and the underlying connection.
        """
        self.assertEqual(
            self.channel.logPrefix(), "SSHChannel channel (unknown) on MockConnection"
        )

    def test_addWindowBytes(self) -> None:
        """
        Test that addWindowBytes adds bytes to the window and resumes writing
        if it was paused.
        """
        cb = [False]

        def stubStartWriting() -> None:
            cb[0] = True

        self.channel.startWriting = stubStartWriting  # type: ignore[method-assign]
        self.channel.write(b"test")
        self.channel.writeExtended(1, b"test")
        self.channel.addWindowBytes(50)
        self.assertEqual(self.channel.remoteWindowLeft, 50 - 4 - 4)
        self.assertTrue(self.channel.areWriting)
        self.assertTrue(cb[0])
        self.assertEqual(self.channel.buf, b"")
        self.assertEqual(self.conn.data[self.channel], [b"test"])
        self.assertEqual(self.channel.extBuf, [])
        self.assertEqual(self.conn.extData[self.channel], [(1, b"test")])

        cb[0] = False
        self.channel.addWindowBytes(20)
        self.assertFalse(cb[0])

        self.channel.write(b"a" * 80)
        self.channel.loseConnection()
        self.channel.addWindowBytes(20)
        self.assertFalse(cb[0])

    def test_requestReceived(self) -> None:
        """
        Test that requestReceived handles requests by dispatching them to
        request_* methods.
        """
        self.channel.request_test_method = lambda data: data == b""  # type: ignore[attr-defined]
        self.assertTrue(self.channel.requestReceived(b"test-method", b""))
        self.assertFalse(self.channel.requestReceived(b"test-method", b"a"))
        self.assertFalse(self.channel.requestReceived(b"bad-method", b""))

    def test_closeReceieved(self) -> None:
        """
        Test that the default closeReceieved closes the connection.
        """
        self.assertFalse(self.channel.closing)
        self.channel.closeReceived()
        self.assertTrue(self.channel.closing)

    def test_write(self) -> None:
        """
        Test that write handles data correctly.  Send data up to the size
        of the remote window, splitting the data into packets of length
        remoteMaxPacket.
        """
        cb = [False]

        def stubStopWriting() -> None:
            cb[0] = True

        # no window to start with
        self.channel.stopWriting = stubStopWriting  # type: ignore[method-assign]
        self.channel.write(b"d")
        self.channel.write(b"a")
        self.assertFalse(self.channel.areWriting)
        self.assertTrue(cb[0])
        # regular write
        self.channel.addWindowBytes(20)
        self.channel.write(b"ta")
        data = self.conn.data[self.channel]
        self.assertEqual(data, [b"da", b"ta"])
        self.assertEqual(self.channel.remoteWindowLeft, 16)
        # larger than max packet
        self.channel.write(b"12345678901")
        self.assertEqual(data, [b"da", b"ta", b"1234567890", b"1"])
        self.assertEqual(self.channel.remoteWindowLeft, 5)
        # running out of window
        cb[0] = False
        self.channel.write(b"123456")
        self.assertFalse(self.channel.areWriting)
        self.assertTrue(cb[0])
        self.assertEqual(data, [b"da", b"ta", b"1234567890", b"1", b"12345"])
        self.assertEqual(self.channel.buf, b"6")
        self.assertEqual(self.channel.remoteWindowLeft, 0)

    def test_writeExtended(self) -> None:
        """
        Test that writeExtended handles data correctly.  Send extended data
        up to the size of the window, splitting the extended data into packets
        of length remoteMaxPacket.
        """
        cb = [False]

        def stubStopWriting() -> None:
            cb[0] = True

        # no window to start with
        self.channel.stopWriting = stubStopWriting  # type: ignore[method-assign]
        self.channel.writeExtended(1, b"d")
        self.channel.writeExtended(1, b"a")
        self.channel.writeExtended(2, b"t")
        self.assertFalse(self.channel.areWriting)
        self.assertTrue(cb[0])
        # regular write
        self.channel.addWindowBytes(20)
        self.channel.writeExtended(2, b"a")
        data = self.conn.extData[self.channel]
        self.assertEqual(data, [(1, b"da"), (2, b"t"), (2, b"a")])
        self.assertEqual(self.channel.remoteWindowLeft, 16)
        # larger than max packet
        self.channel.writeExtended(3, b"12345678901")
        self.assertEqual(
            data, [(1, b"da"), (2, b"t"), (2, b"a"), (3, b"1234567890"), (3, b"1")]
        )
        self.assertEqual(self.channel.remoteWindowLeft, 5)
        # running out of window
        cb[0] = False
        self.channel.writeExtended(4, b"123456")
        self.assertFalse(self.channel.areWriting)
        self.assertTrue(cb[0])
        self.assertEqual(
            data,
            [
                (1, b"da"),
                (2, b"t"),
                (2, b"a"),
                (3, b"1234567890"),
                (3, b"1"),
                (4, b"12345"),
            ],
        )
        self.assertEqual(self.channel.extBuf, [[4, b"6"]])
        self.assertEqual(self.channel.remoteWindowLeft, 0)

    def test_writeSequence(self) -> None:
        """
        Test that writeSequence is equivalent to write(''.join(sequece)).
        """
        self.channel.addWindowBytes(20)
        self.channel.writeSequence(b"%d" % (i,) for i in range(10))
        self.assertEqual(self.conn.data[self.channel], [b"0123456789"])

    def test_loseConnection(self) -> None:
        """
        Tesyt that loseConnection() doesn't close the channel until all
        the data is sent.
        """
        self.channel.write(b"data")
        self.channel.writeExtended(1, b"datadata")
        self.channel.loseConnection()
        self.assertIsNone(self.conn.closes.get(self.channel))
        self.channel.addWindowBytes(4)  # send regular data
        self.assertIsNone(self.conn.closes.get(self.channel))
        self.channel.addWindowBytes(8)  # send extended data
        self.assertTrue(self.conn.closes.get(self.channel))

    def test_getPeer(self) -> None:
        """
        L{SSHChannel.getPeer} returns the same object as the underlying
        transport's C{getPeer} method returns.
        """
        peer = IPv4Address("TCP", "192.168.0.1", 54321)
        connectSSHTransport(service=self.channel.conn, peerAddress=peer)

        self.assertEqual(SSHTransportAddress(peer), self.channel.getPeer())

    def test_getHost(self) -> None:
        """
        L{SSHChannel.getHost} returns the same object as the underlying
        transport's C{getHost} method returns.
        """
        host = IPv4Address("TCP", "127.0.0.1", 12345)
        connectSSHTransport(service=self.channel.conn, hostAddress=host)

        self.assertEqual(SSHTransportAddress(host), self.channel.getHost())

Zerion Mini Shell 1.0