[Python-modules-commits] [python-asyncssh] 01/09: Import python-asyncssh_1.3.0.orig.tar.gz

Vincent Bernat bernat at moszumanska.debian.org
Sat Oct 31 22:54:26 UTC 2015


This is an automated email from the git hooks/post-receive script.

bernat pushed a commit to branch master
in repository python-asyncssh.

commit 142bd9768421394f9f6880fb63c8de0932182c5b
Author: Vincent Bernat <bernat at debian.org>
Date:   Sat Oct 31 23:13:43 2015 +0100

    Import python-asyncssh_1.3.0.orig.tar.gz
---
 .coveragerc                          |    7 +
 README.rst                           |    6 +-
 asyncssh/__init__.py                 |    2 +-
 asyncssh/asn1.py                     |   78 ++-
 asyncssh/auth.py                     |  210 +++---
 asyncssh/channel.py                  |   25 +-
 asyncssh/cipher.py                   |    2 +-
 asyncssh/connection.py               |  229 +++++--
 asyncssh/crypto/__init__.py          |   38 +-
 asyncssh/crypto/chacha.py            |    8 +-
 asyncssh/crypto/cipher.py            |    4 +-
 asyncssh/crypto/curve25519.py        |   58 +-
 asyncssh/crypto/ec.py                |  257 ++++++++
 asyncssh/crypto/ecdh.py              |   48 ++
 asyncssh/crypto/pyca/cipher.py       |    3 -
 asyncssh/crypto/pyca/dsa.py          |   80 ++-
 asyncssh/crypto/pyca/ec.py           |  118 ++++
 asyncssh/crypto/pyca/rsa.py          |   99 ++-
 asyncssh/crypto/pycrypto/__init__.py |   15 -
 asyncssh/crypto/pycrypto/cipher.py   |   67 --
 asyncssh/crypto/pycrypto/dsa.py      |   50 --
 asyncssh/crypto/pycrypto/rsa.py      |   51 --
 asyncssh/dh.py                       |   21 +-
 asyncssh/dsa.py                      |   77 ++-
 asyncssh/ec.py                       |  667 -------------------
 asyncssh/{curve25519.py => ecdh.py}  |   47 +-
 asyncssh/ecdsa.py                    |  285 ++++++++
 asyncssh/ed25519.py                  |   19 +-
 asyncssh/known_hosts.py              |    7 +-
 asyncssh/packet.py                   |    9 +-
 asyncssh/pbe.py                      |  150 ++---
 asyncssh/public_key.py               |  186 +++---
 asyncssh/rsa.py                      |   86 +--
 asyncssh/saslprep.py                 |   10 +-
 asyncssh/sftp.py                     |  615 ++++++++++++------
 asyncssh/version.py                  |    2 +-
 docs/api.rst                         |   14 +-
 docs/changes.rst                     |   80 +++
 pylintrc                             |    2 +-
 setup.py                             |    7 +-
 tests/__init__.py                    |    3 +-
 tests/test_asn1.py                   |  183 ++++++
 tests/test_cipher.py                 |  105 +++
 tests/test_compression.py            |   39 ++
 tests/test_kex.py                    |  318 +++++++++
 tests/test_keys.py                   |  361 -----------
 tests/test_known_hosts.py            |  190 ++++++
 tests/test_mac.py                    |   47 ++
 tests/test_native_ec.py              |  138 ++++
 tests/test_public_key.py             | 1183 ++++++++++++++++++++++++++++++++++
 tests/test_saslprep.py               |   75 +++
 tests/util.py                        |   44 ++
 52 files changed, 4369 insertions(+), 2056 deletions(-)

diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 0000000..f2c3d0b
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,7 @@
+[run]
+branch = True
+
+[report]
+exclude_lines =
+    pragma: no cover
+    raise NotImplementedError
diff --git a/README.rst b/README.rst
index 021872b..25fbf3c 100644
--- a/README.rst
+++ b/README.rst
@@ -78,7 +78,7 @@ Prerequisites
 To use ``asyncssh``, you need the following:
 
 * Python 3.4 or later
-* PyCrypto 2.6 or later and/or PyCA 0.6.1 or later
+* cryptography (PyCA) 1.0.0 or later
 
 Installation
 ------------
@@ -106,8 +106,6 @@ functionality:
 AsyncSSH defines the following optional PyPI extra packages to make it
 easy to install any or all of these dependencies:
 
-  | pycrypto
-  | pyca
   | bcrypt
   | libnacl
 
@@ -115,7 +113,7 @@ For example, to install all of these, you can run:
 
   ::
 
-    pip install 'asyncssh[pycrypto,pyca,bcrypt,libnacl]'
+    pip install 'asyncssh[bcrypt,libnacl]'
 
 Note that you will still need to manually install the libsodium library
 listed above for libnacl to work correctly. Unfortunately, since
diff --git a/asyncssh/__init__.py b/asyncssh/__init__.py
index 86d1991..842ac10 100644
--- a/asyncssh/__init__.py
+++ b/asyncssh/__init__.py
@@ -54,4 +54,4 @@ from .sftp import SEEK_SET, SEEK_CUR, SEEK_END
 from .stream import SSHReader, SSHWriter
 
 # Import these explicitly to trigger register calls in them
-from . import curve25519, ed25519, ec, rsa, dsa, dh
+from . import ed25519, ecdsa, rsa, dsa, ecdh, dh
diff --git a/asyncssh/asn1.py b/asyncssh/asn1.py
index 41ac3cc..38aedbf 100644
--- a/asyncssh/asn1.py
+++ b/asyncssh/asn1.py
@@ -55,6 +55,9 @@ _der_class_by_type = {}
 def _encode_identifier(asn1_class, constructed, tag):
     """Encode a DER object's identifier"""
 
+    if asn1_class not in (UNIVERSAL, APPLICATION, CONTEXT_SPECIFIC, PRIVATE):
+        raise ASN1EncodeError('Invalid ASN.1 class')
+
     flags = (asn1_class << 6) | (0x20 if constructed else 0x00)
 
     if tag < 0x20:
@@ -128,7 +131,7 @@ class RawDERObject:
 
     """
 
-    def __init__(self, asn1_class, tag, content):
+    def __init__(self, tag, content, asn1_class):
         self.asn1_class = asn1_class
         self.tag = tag
         self.content = content
@@ -137,6 +140,14 @@ class RawDERObject:
         return ('RawDERObject(%s, %s, %r)' %
                 (_asn1_class[self.asn1_class], self.tag, self.content))
 
+    def __eq__(self, other):
+        return (isinstance(other, self.__class__) and
+                self.asn1_class == other.asn1_class and
+                self.tag == other.tag and self.content == other.content)
+
+    def __hash__(self):
+        return hash((self.asn1_class, self.tag, self.content))
+
     def encode_identifier(self):
         """Encode the DER identifier for this object as a byte string"""
 
@@ -171,6 +182,14 @@ class TaggedDERObject:
             return ('TaggedDERObject(%s, %s, %r)' %
                     (_asn1_class[self.asn1_class], self.tag, self.value))
 
+    def __eq__(self, other):
+        return (isinstance(other, self.__class__) and
+                self.asn1_class == other.asn1_class and
+                self.tag == other.tag and self.value == other.value)
+
+    def __hash__(self):
+        return hash((self.asn1_class, self.tag, self.value))
+
     def encode_identifier(self):
         """Encode the DER identifier for this object as a byte string"""
 
@@ -239,7 +258,8 @@ class _Integer:
 
         l = value.bit_length()
         l = l // 8 + 1 if l % 8 == 0 else (l + 7) // 8
-        return value.to_bytes(l, 'big', signed=True)
+        result = value.to_bytes(l, 'big', signed=True)
+        return result[1:] if result.startswith(b'\xff\x80') else result
 
     @classmethod
     def decode(cls, constructed, content):
@@ -402,7 +422,6 @@ class BitString:
 
         self.value = value
         self.unused = unused
-        self.named = named
 
     def __str__(self):
         result = ''.join(bin(b)[2:].zfill(8) for b in self.value)
@@ -411,7 +430,14 @@ class BitString:
         return result
 
     def __repr__(self):
-        return 'BitString(%s)' % self
+        return "BitString('%s')" % self
+
+    def __eq__(self, other):
+        return (isinstance(other, self.__class__) and
+                self.value == other.value and self.unused == other.unused)
+
+    def __hash__(self):
+        return hash((self.value, self.unused))
 
     def encode(self):
         """Encode a DER bit string"""
@@ -448,8 +474,11 @@ class ObjectIdentifier:
     def __init__(self, value):
         self.value = value
 
+    def __str__(self):
+        return self.value
+
     def __repr__(self):
-        return 'ObjectIdentifier(%s)' % self.value
+        return "ObjectIdentifier('%s')" % self.value
 
     def __eq__(self, other):
         return isinstance(other, self.__class__) and self.value == other.value
@@ -463,6 +492,10 @@ class ObjectIdentifier:
         def _bytes(component):
             """Convert a single element of an OID to a DER byte string"""
 
+            if component < 0:
+                raise ASN1EncodeError('Components of object identifier must '
+                                      'be greater than or equal to 0')
+
             result = [component & 0x7f]
             while component >= 0x80:
                 component >>= 7
@@ -470,14 +503,18 @@ class ObjectIdentifier:
 
             return bytes(result[::-1])
 
-        components = [int(c) for c in self.value.split('.')]
+        try:
+            components = [int(c) for c in self.value.split('.')]
+        except ValueError:
+            raise ASN1EncodeError('Component values must be integers')
+
         if len(components) < 2:
             raise ASN1EncodeError('Object identifiers must have at least two '
                                   'components')
         elif components[0] < 0 or components[0] > 2:
             raise ASN1EncodeError('First component of object identifier must '
                                   'be between 0 and 2')
-        elif components[0] < 2 and (components[1] < 0 or components[1] > 40):
+        elif components[0] < 2 and (components[1] < 0 or components[1] > 39):
             raise ASN1EncodeError('Second component of object identifier must '
                                   'be between 0 and 39')
 
@@ -489,26 +526,30 @@ class ObjectIdentifier:
         """Decode a DER object identifier"""
 
         if constructed:
-            raise ASN1DecodeError('OBJECT IDENTIFIER  should not be '
+            raise ASN1DecodeError('OBJECT IDENTIFIER should not be '
                                   'constructed')
 
         if not content:
             raise ASN1DecodeError('Empty object identifier')
 
-        components = [str(component) for component in divmod(content[0], 40)]
+        b = content[0]
+        components = list(divmod(b, 40)) if b < 80 else [2, b-80]
+
         component = 0
         for b in content[1:]:
-            if b < 0x80:
-                components.append(str(component | b))
+            if b == 0x80 and component == 0:
+                raise ASN1DecodeError('Invalid component')
+            elif b < 0x80:
+                components.append(component | b)
                 component = 0
             else:
                 component |= b & 0x7f
                 component <<= 7
 
         if component:
-            raise ASN1DecodeError('Incomplete object identifier')
+            raise ASN1DecodeError('Incomplete component')
 
-        return cls('.'.join(components))
+        return cls('.'.join(str(c) for c in components))
 
 
 def der_encode(value):
@@ -545,7 +586,7 @@ def der_encode(value):
         identifier = cls.identifier
         content = cls.encode(value)
     else:
-        raise TypeError('Cannot DER encode type %s' % t.__name__)
+        raise ASN1EncodeError('Cannot DER encode type %s' % t.__name__)
 
     length = len(content)
     if length < 0x80:
@@ -604,7 +645,10 @@ def der_decode(data, partial_ok=False):
                 tag |= b & 0x7f
                 tag <<= 7
         else:
-            raise ASN1DecodeError('Incomplete data')
+            raise ASN1DecodeError('Incomplete tag')
+
+    if offset >= len(data):
+        raise ASN1DecodeError('Incomplete data')
 
     length = data[offset]
     offset += 1
@@ -626,9 +670,9 @@ def der_decode(data, partial_ok=False):
         value = cls.decode(constructed, data[offset:offset+length])
     elif constructed:
         value = TaggedDERObject(tag, der_decode(data[offset:offset+length]),
-                                asn1_class=asn1_class)
+                                asn1_class)
     else:
-        value = RawDERObject(asn1_class, tag, data[offset:offset+length])
+        value = RawDERObject(tag, data[offset:offset+length], asn1_class)
 
     if partial_ok:
         return value, offset+length
diff --git a/asyncssh/auth.py b/asyncssh/auth.py
index 4eeb711..1fd0a71 100644
--- a/asyncssh/auth.py
+++ b/asyncssh/auth.py
@@ -12,6 +12,8 @@
 
 """SSH authentication handlers"""
 
+import asyncio
+
 from .constants import DISC_PROTOCOL_ERROR
 from .misc import DisconnectError
 from .packet import Boolean, Byte, String, UInt32, SSHPacketHandler
@@ -37,16 +39,35 @@ _client_auth_handlers = {}
 _server_auth_handlers = {}
 
 
-class _SSHAuthError(Exception):
-    """This is raised when we can't proceed with the current form of auth."""
+class _Auth(SSHPacketHandler):
+    """Parent class for authentication"""
+
+    def __init__(self):
+        self._coro = None
+
+    def cancel(self):
+        """Cancel any authentication in progress"""
 
+        if self._coro:
+            self._coro.cancel()
+            self._coro = None
 
-class _ClientAuth(SSHPacketHandler):
-    """Parent class for client auth"""
+class _ClientAuth(_Auth):
+    """Parent class for client authentication"""
 
     def __init__(self, conn, method):
+        super().__init__()
+
         self._conn = conn
         self._method = method
+        self._coro = asyncio.async(self._start())
+
+    @asyncio.coroutine
+    def _start(self):
+        """Abstract method for starting client authentication"""
+
+        # Provided by subclass
+        raise NotImplementedError
 
     def auth_succeeded(self):
         """Callback when auth succeeds"""
@@ -54,16 +75,6 @@ class _ClientAuth(SSHPacketHandler):
     def auth_failed(self):
         """Callback when auth fails"""
 
-    def process_packet(self, pkttype, packet):
-        try:
-            processed = super().process_packet(pkttype, packet)
-        except _SSHAuthError:
-            # We can't complete the current auth - move to the next one
-            processed = True
-            self._conn.try_next_auth()
-
-        return processed
-
     def send_request(self, *args, key=None):
         """Send a user authentication request"""
 
@@ -73,8 +84,9 @@ class _ClientAuth(SSHPacketHandler):
 class _ClientNullAuth(_ClientAuth):
     """Client side implementation of null auth"""
 
-    def __init__(self, conn, method):
-        super().__init__(conn, method)
+    @asyncio.coroutine
+    def _start(self):
+        """Start client null authentication"""
 
         self.send_request()
 
@@ -84,12 +96,16 @@ class _ClientNullAuth(_ClientAuth):
 class _ClientPublicKeyAuth(_ClientAuth):
     """Client side implementation of public key auth"""
 
-    def __init__(self, conn, method):
-        super().__init__(conn, method)
+    @asyncio.coroutine
+    def _start(self):
+        """Start client public key authentication"""
+
+        self._alg, self._key, self._key_data = \
+            yield from self._conn.public_key_auth_requested()
 
-        self._alg, self._key, self._key_data = conn.public_key_auth_requested()
         if self._alg is None:
-            raise _SSHAuthError()
+            self._conn.try_next_auth()
+            return
 
         self.send_request(Boolean(False), String(self._alg),
                           String(self._key_data))
@@ -118,15 +134,34 @@ class _ClientPublicKeyAuth(_ClientAuth):
 class _ClientKbdIntAuth(_ClientAuth):
     """Client side implementation of keyboard-interactive auth"""
 
-    def __init__(self, conn, method):
-        super().__init__(conn, method)
+    @asyncio.coroutine
+    def _start(self):
+        """Start client keyboard interactive authentication"""
+
+        submethods = yield from self._conn.kbdint_auth_requested()
 
-        submethods = conn.kbdint_auth_requested()
         if submethods is None:
-            raise _SSHAuthError()
+            self._conn.try_next_auth()
+            return
 
         self.send_request(String(''), String(submethods))
 
+    @asyncio.coroutine
+    def _receive_challenge(self, name, instruction, lang, prompts):
+        """Receive and respond to a keyboard interactive challenge"""
+
+        responses = \
+            yield from self._conn.kbdint_challenge_received(name, instruction,
+                                                            lang, prompts)
+
+        if responses is None:
+            self._conn.try_next_auth()
+            return
+
+        self._conn.send_packet(Byte(MSG_USERAUTH_INFO_RESPONSE),
+                               UInt32(len(responses)),
+                               b''.join(String(r) for r in responses))
+
     def _process_info_request(self, pkttype, packet):
         """Process a keyboard interactive authentication request"""
 
@@ -158,15 +193,11 @@ class _ClientKbdIntAuth(_ClientAuth):
 
             prompts.append((prompt, echo))
 
-        responses = self._conn.kbdint_challenge_received(name, instruction,
-                                                         lang, prompts)
-
-        if responses is None:
-            raise _SSHAuthError()
+        self.cancel()
+        self._coro = asyncio.async(self._receive_challenge(name, instruction,
+                                                           lang, prompts))
 
-        self._conn.send_packet(Byte(MSG_USERAUTH_INFO_RESPONSE),
-                               UInt32(len(responses)),
-                               b''.join(String(r) for r in responses))
+        return True
 
     packet_handlers = {
         MSG_USERAUTH_INFO_REQUEST: _process_info_request
@@ -181,12 +212,37 @@ class _ClientPasswordAuth(_ClientAuth):
 
         self._password_change = False
 
-        password = conn.password_auth_requested()
+    @asyncio.coroutine
+    def _start(self):
+        """Start client password authentication"""
+
+        password = yield from self._conn.password_auth_requested()
+
         if password is None:
-            raise _SSHAuthError()
+            self._conn.try_next_auth()
+            return
 
         self.send_request(Boolean(False), String(password))
 
+    @asyncio.coroutine
+    def _change_password(self):
+        """Start password change"""
+
+        result = yield from self._conn.password_change_requested()
+
+        if result == NotImplemented:
+            # Password change not supported - move on to the next auth method
+            self._conn.try_next_auth()
+            return
+
+        old_password, new_password = result
+
+        self._password_change = True
+
+        self.send_request(Boolean(True),
+                          String(old_password.encode('utf-8')),
+                          String(new_password.encode('utf-8')))
+
     def auth_succeeded(self):
         if self._password_change:
             self._password_change = False
@@ -212,18 +268,8 @@ class _ClientPasswordAuth(_ClientAuth):
             raise DisconnectError(DISC_PROTOCOL_ERROR,
                                   'Invalid password change request') from None
 
-        result = self._conn.password_change_requested()
-        if result == NotImplemented:
-            # Password change not supported - move on to the next auth method
-            raise _SSHAuthError()
-        else:
-            old_password, new_password = result
-
-            self._password_change = True
-
-            self.send_request(Boolean(True),
-                              String(old_password.encode('utf-8')),
-                              String(new_password.encode('utf-8')))
+        self.cancel()
+        self._coro = asyncio.async(self._change_password())
 
         return True
 
@@ -232,12 +278,22 @@ class _ClientPasswordAuth(_ClientAuth):
     }
 
 
-class _ServerAuth(SSHPacketHandler):
-    """Parent class for server side auth"""
+class _ServerAuth(_Auth):
+    """Parent class for server authentication"""
+
+    def __init__(self, conn, username, packet):
+        super().__init__()
 
-    def __init__(self, conn, username):
         self._conn = conn
         self._username = username
+        self._coro = asyncio.async(self._start(packet))
+
+    @asyncio.coroutine
+    def _start(self, packet):
+        """Abstract method for starting server authentication"""
+
+        # Provided by subclass
+        raise NotImplementedError
 
     def send_failure(self, partial_success=False):
         """Send a user authentication failure response"""
@@ -260,14 +316,13 @@ class _ServerNullAuth(_ServerAuth):
         # pylint: disable=unused-argument
         return False
 
-    def __init__(self, conn, username, packet):
-        super().__init__(conn, username)
+    @asyncio.coroutine
+    def _start(self, packet):
+        """Always fail null server authentication"""
 
         packet.check_end()
-
         self.send_failure()
 
-
 class _ServerPublicKeyAuth(_ServerAuth):
     """Server side implementation of public key auth"""
 
@@ -277,8 +332,9 @@ class _ServerPublicKeyAuth(_ServerAuth):
 
         return conn.public_key_auth_supported()
 
-    def __init__(self, conn, username, packet):
-        super().__init__(conn, username)
+    @asyncio.coroutine
+    def _start(self, packet):
+        """Start server public key authentication"""
 
         sig_present = packet.get_boolean()
         algorithm = packet.get_string()
@@ -293,8 +349,8 @@ class _ServerPublicKeyAuth(_ServerAuth):
 
         packet.check_end()
 
-        if self._conn.validate_public_key(self._username, key_data,
-                                          msg, signature):
+        if (yield from self._conn.validate_public_key(self._username, key_data,
+                                                      msg, signature)):
             if sig_present:
                 self.send_success()
             else:
@@ -313,8 +369,9 @@ class _ServerKbdIntAuth(_ServerAuth):
 
         return conn.kbdint_auth_supported()
 
-    def __init__(self, conn, username, packet):
-        super().__init__(conn, username)
+    @asyncio.coroutine
+    def _start(self, packet):
+        """Start server keyboard interactive authentication"""
 
         lang = packet.get_string()
         submethods = packet.get_string()
@@ -327,8 +384,9 @@ class _ServerKbdIntAuth(_ServerAuth):
             raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid keyboard '
                                   'interactive auth request') from None
 
-        challenge = self._conn.get_kbdint_challenge(self._username,
-                                                    lang, submethods)
+        challenge = yield from self._conn.get_kbdint_challenge(self._username,
+                                                               lang,
+                                                               submethods)
         self._send_challenge(challenge)
 
     def _send_challenge(self, challenge):
@@ -350,6 +408,15 @@ class _ServerKbdIntAuth(_ServerAuth):
         else:
             self.send_failure()
 
+    @asyncio.coroutine
+    def _validate_response(self, responses):
+        """Validate a keyboard interactive authentication response"""
+
+        next_challenge = \
+            yield from self._conn.validate_kbdint_response(self._username,
+                                                           responses)
+        self._send_challenge(next_challenge)
+
     def _process_info_response(self, pkttype, packet):
         """Process a keyboard interactive authentication response"""
 
@@ -370,9 +437,8 @@ class _ServerKbdIntAuth(_ServerAuth):
 
         packet.check_end()
 
-        next_challenge = self._conn.validate_kbdint_response(self._username,
-                                                             responses)
-        self._send_challenge(next_challenge)
+        self.cancel()
+        self._coro = asyncio.async(self._validate_response(responses))
 
     packet_handlers = {
         MSG_USERAUTH_INFO_RESPONSE: _process_info_response
@@ -388,8 +454,9 @@ class _ServerPasswordAuth(_ServerAuth):
 
         return conn.password_auth_supported()
 
-    def __init__(self, conn, username, packet):
-        super().__init__(conn, username)
+    @asyncio.coroutine
+    def _start(self, packet):
+        """Start server password authentication"""
 
         password_change = packet.get_boolean()
         password = packet.get_string()
@@ -405,7 +472,7 @@ class _ServerPasswordAuth(_ServerAuth):
 
         # TODO: Handle password change request
 
-        if self._conn.validate_password(self._username, password):
+        if (yield from self._conn.validate_password(self._username, password)):
             self.send_success()
         else:
             self.send_failure()
@@ -423,12 +490,9 @@ def lookup_client_auth(conn, method):
     """Look up the client authentication method to use"""
 
     if method in _auth_methods:
-        try:
-            return _client_auth_handlers[method](conn, method)
-        except _SSHAuthError:
-            pass
-
-    return None
+        return _client_auth_handlers[method](conn, method)
+    else:
+        return None
 
 
 def get_server_auth_methods(conn):
diff --git a/asyncssh/channel.py b/asyncssh/channel.py
index 13db4c0..b4a1923 100644
--- a/asyncssh/channel.py
+++ b/asyncssh/channel.py
@@ -187,8 +187,8 @@ class SSHChannel(SSHPacketHandler):
                 raise DisconnectError(DISC_PROTOCOL_ERROR,
                                       'Unicode decode error')
 
-            if not self._session.eof_received():
-                self.close()
+            if not self._session.eof_received() and self._send_state == 'open':
+                self.write_eof()
         else:
             self._recv_window -= len(data)
 
@@ -200,18 +200,20 @@ class SSHChannel(SSHPacketHandler):
 
             if self._encoding:
                 if datatype in self._recv_partial:
-                    input = self._recv_partial.pop(datatype) + data
+                    encdata = self._recv_partial.pop(datatype) + data
                 else:
-                    input = data
+                    encdata = data
 
-                while input:
+                while encdata:
                     try:
-                        data = input.decode(self._encoding)
-                        input = b''
+                        data = encdata.decode(self._encoding)
+                        encdata = b''
                     except UnicodeDecodeError as exc:
                         if exc.start > 0:
-                            data = input[:exc.start].decode()
-                            input = input[exc.start:]
+                            # Avoid pylint false positive
+                            # pylint: disable=invalid-slice-index
+                            data = encdata[:exc.start].decode()
+                            encdata = encdata[exc.start:]
                         elif exc.reason == 'unexpected end of data':
                             break
                         else:
@@ -220,8 +222,8 @@ class SSHChannel(SSHPacketHandler):
 
                     self._session.data_received(data, datatype)
 
-                if input:
-                    self._recv_partial[datatype] = input
+                if encdata:
+                    self._recv_partial[datatype] = encdata
             else:
                 self._session.data_received(data, datatype)
 
@@ -401,6 +403,7 @@ class SSHChannel(SSHPacketHandler):
         # If we haven't yet sent a close, send one now
         if self._send_state not in {'close_sent', 'closed'}:
             self._send_packet(MSG_CHANNEL_CLOSE)
+            self._send_state = 'close_sent'
 
         self._loop.call_soon(self._cleanup)
 
diff --git a/asyncssh/cipher.py b/asyncssh/cipher.py
index a494d9f..333a156 100644
--- a/asyncssh/cipher.py
+++ b/asyncssh/cipher.py
@@ -25,7 +25,7 @@ def register_encryption_alg(alg, cipher_name, mode_name, key_size,
     """Register an encryption algorithm"""
 
     cipher = lookup_cipher(cipher_name, mode_name)
-    if cipher:
+    if cipher: # pragma: no branch
         _enc_algs.append(alg)
         _enc_params[alg] = (key_size, cipher.iv_size,
                             cipher.block_size, cipher.mode_name)
diff --git a/asyncssh/connection.py b/asyncssh/connection.py
index e498ccd..55ecacc 100644
--- a/asyncssh/connection.py
+++ b/asyncssh/connection.py
@@ -68,7 +68,7 @@ from .mac import get_mac_algs, get_mac_params, get_mac
 from .misc import ChannelOpenError, DisconnectError, ip_address
 
 from .packet import Boolean, Byte, NameList, String, UInt32, UInt64
-from .packet import SSHPacket, SSHPacketHandler
+from .packet import PacketDecodeError, SSHPacket, SSHPacketHandler
 
 from .public_key import CERT_TYPE_HOST, CERT_TYPE_USER
 from .public_key import get_public_key_algs, get_certificate_algs
@@ -159,7 +159,7 @@ def _load_private_key(key):
     elif isinstance(cert, bytes):
         cert = import_certificate(cert)
 
-    if cert and key.encode_ssh_public() != cert.key.encode_ssh_public():
+    if cert and key.get_ssh_public_key() != cert.key.get_ssh_public_key():
         raise ValueError('Certificate key mismatch')
 
     return key, cert
@@ -380,6 +380,10 @@ class SSHConnection(SSHPacketHandler):
     def _cleanup(self, exc):
         """Clean up this connection"""
 
+        if self._auth:
+            self._auth.cancel()
+            self._auth = None
+
         if self._channels:
             for chan in list(self._channels.values()):
                 chan.process_connection_close(exc)
@@ -679,19 +683,23 @@ class SSHConnection(SSHPacketHandler):
                                    not self._decompress_after_auth):
             payload = self._decompressor.decompress(payload)
 
-        packet = SSHPacket(payload)
-        pkttype = packet.get_byte()
-
-        if self._kex and MSG_KEX_FIRST <= pkttype <= MSG_KEX_LAST:
-            if self._ignore_first_kex:
-                self._ignore_first_kex = False
-                processed = True
+        try:
+            packet = SSHPacket(payload)
+            pkttype = packet.get_byte()
+
+            if self._kex and MSG_KEX_FIRST <= pkttype <= MSG_KEX_LAST:
+                if self._ignore_first_kex:
+                    self._ignore_first_kex = False
+                    processed = True
+                else:
+                    processed = self._kex.process_packet(pkttype, packet)
+            elif (self._auth and
+                  MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST):
+                processed = self._auth.process_packet(pkttype, packet)
             else:
-                processed = self._kex.process_packet(pkttype, packet)
-        elif self._auth and MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST:
-            processed = self._auth.process_packet(pkttype, packet)
-        else:
-            processed = self.process_packet(pkttype, packet)
+                processed = self.process_packet(pkttype, packet)
+        except PacketDecodeError as exc:
+            raise DisconnectError(DISC_PROTOCOL_ERROR, str(exc))
 
         if not processed:
             self.send_packet(Byte(MSG_UNIMPLEMENTED), UInt32(self._recv_seq))
@@ -958,6 +966,7 @@ class SSHConnection(SSHPacketHandler):
         self._auth_in_progress = False
         self._auth_complete = True
         self._extra.update(username=self._username)
+        self._send_deferred_packets()
 
     def send_channel_open_confirmation(self, send_chan, recv_chan,
                                        recv_window, recv_pktsize,
@@ -1297,6 +1306,9 @@ class SSHConnection(SSHPacketHandler):
                     self.send_userauth_success()
                     return
 
+            if self._auth:
+                self._auth.cancel()
+
             self._auth = lookup_server_auth(self, self._username,
                                             method, packet)
 
@@ -1330,6 +1342,7 @@ class SSHConnection(SSHPacketHandler):
         packet.check_end()
 
         if self.is_client() and self._auth:
+            self._auth.cancel()
             self._auth = None
             self._auth_in_progress = False
             self._auth_complete = True
@@ -1828,6 +1841,10 @@ class SSHClientConnection(SSHConnection):
     def try_next_auth(self):
         """Attempt client authentication using the next compatible method"""
 
+        if self._auth:
+            self._auth.cancel()
+            self._auth = None
+
         while self._auth_methods:
             method = self._auth_methods.pop(0)
 
@@ -1835,30 +1852,36 @@ class SSHClientConnection(SSHConnection):
             if self._auth:
                 return
 
-        raise DisconnectError(DISC_NO_MORE_AUTH_METHODS_AVAILABLE,
-                              'Permission denied')
+        self._force_close(DisconnectError(DISC_NO_MORE_AUTH_METHODS_AVAILABLE,
+                                          'Permission denied'))
 
+    @asyncio.coroutine
     def public_key_auth_requested(self):
         """Return a client key to authenticate with"""
 
         if self._client_keys:
             key, cert = self._client_keys.pop(0)
         else:
-            client_key = self._owner.public_key_auth_requested()
-            key, cert = _load_private_key(client_key)
+            result = self._owner.public_key_auth_requested()
+
+            if asyncio.iscoroutine(result):
+                result = yield from result
+
+            key, cert = _load_private_key(result)
 
         if cert:
             self._client_keys.insert(0, (key, None))
             return cert.algorithm, key, cert.data
         elif key:
-            return key.algorithm, key, key.encode_ssh_public()
+            return key.algorithm, key, key.get_ssh_public_key()
         else:
             return None, None, None
 
+    @asyncio.coroutine
     def password_auth_requested(self):
         """Return a password to authenticate with"""
 
-        # Only allow passwordauth if the connection supports encryption
+        # Only allow password auth if the connection supports encryption
         # and a MAC.
         if (not self._send_cipher or
                 (not self._send_mac and
@@ -1866,17 +1889,26 @@ class SSHClientConnection(SSHConnection):
             return None
 
         if self._password:
-            password = self._password
+            result = self._password
             self._password = None
         else:
-            password = self._owner.password_auth_requested()
+            result = self._owner.password_auth_requested()
+
+            if asyncio.iscoroutine(result):
+                result = yield from result
 
-        return password
+        return result
 
+    @asyncio.coroutine
     def password_change_requested(self):
         """Return a password to authenticate with and what to change it to"""
 
-        return self._owner.password_change_requested()
+        result = self._owner.password_change_requested()
+
+        if asyncio.iscoroutine(result):
+            result = yield from result
+
+        return result
 
     def password_changed(self):
         """Report a successful password change"""
@@ -1888,6 +1920,7 @@ class SSHClientConnection(SSHConnection):
 
         self._owner.password_change_failed()
 
+    @asyncio.coroutine
     def kbdint_auth_requested(self):
         """Return the list of supported keyboard-interactive auth methods
 
@@ -1904,29 +1937,38 @@ class SSHClientConnection(SSHConnection):
                  self._send_mode not in ('chacha', 'gcm'))):
             return None
 
-        submethods = self._owner.kbdint_auth_requested()
-        if submethods is None and self._password is not None:
+        result = self._owner.kbdint_auth_requested()
+
+        if asyncio.iscoroutine(result):
+            result = yield from result
+
+        if result is None and self._password is not None:
             self._kbdint_password_auth = True
-            submethods = ''
+            result = ''
 
-        return submethods
+        return result
 
+    @asyncio.coroutine
     def kbdint_challenge_received(self, name, instructions, lang, prompts):
         """Return responses to a keyboard-interactive auth challenge"""
 
         if self._kbdint_password_auth:
             if len(prompts) == 0:
                 # Silently drop any empty challenges used to print messages
-                return []
-            elif (len(prompts) == 1 and
-                  'password' in prompts[0][0].lower().strip()):
+                result = []
+            elif len(prompts) == 1 and 'password' in prompts[0][0].lower():
                 password = self.password_auth_requested()
-                return [password] if password is not None else None
+                result = [password] if password is not None else None
             else:
-                return None
+                result = None
         else:
-            return self._owner.kbdint_challenge_received(name, instructions,
-                                                         lang, prompts)
+            result = self._owner.kbdint_challenge_received(name, instructions,
+                                                           lang, prompts)
... 7751 lines suppressed ...

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/python-modules/packages/python-asyncssh.git



More information about the Python-modules-commits mailing list