[Python-modules-commits] [python-asyncssh] 01/09: Import python-asyncssh_1.3.0.orig.tar.gz
Vincent Bernat
bernat at moszumanska.debian.org
Sat Oct 31 22:54:26 UTC 2015
This is an automated email from the git hooks/post-receive script.
bernat pushed a commit to branch master
in repository python-asyncssh.
commit 142bd9768421394f9f6880fb63c8de0932182c5b
Author: Vincent Bernat <bernat at debian.org>
Date: Sat Oct 31 23:13:43 2015 +0100
Import python-asyncssh_1.3.0.orig.tar.gz
---
.coveragerc | 7 +
README.rst | 6 +-
asyncssh/__init__.py | 2 +-
asyncssh/asn1.py | 78 ++-
asyncssh/auth.py | 210 +++---
asyncssh/channel.py | 25 +-
asyncssh/cipher.py | 2 +-
asyncssh/connection.py | 229 +++++--
asyncssh/crypto/__init__.py | 38 +-
asyncssh/crypto/chacha.py | 8 +-
asyncssh/crypto/cipher.py | 4 +-
asyncssh/crypto/curve25519.py | 58 +-
asyncssh/crypto/ec.py | 257 ++++++++
asyncssh/crypto/ecdh.py | 48 ++
asyncssh/crypto/pyca/cipher.py | 3 -
asyncssh/crypto/pyca/dsa.py | 80 ++-
asyncssh/crypto/pyca/ec.py | 118 ++++
asyncssh/crypto/pyca/rsa.py | 99 ++-
asyncssh/crypto/pycrypto/__init__.py | 15 -
asyncssh/crypto/pycrypto/cipher.py | 67 --
asyncssh/crypto/pycrypto/dsa.py | 50 --
asyncssh/crypto/pycrypto/rsa.py | 51 --
asyncssh/dh.py | 21 +-
asyncssh/dsa.py | 77 ++-
asyncssh/ec.py | 667 -------------------
asyncssh/{curve25519.py => ecdh.py} | 47 +-
asyncssh/ecdsa.py | 285 ++++++++
asyncssh/ed25519.py | 19 +-
asyncssh/known_hosts.py | 7 +-
asyncssh/packet.py | 9 +-
asyncssh/pbe.py | 150 ++---
asyncssh/public_key.py | 186 +++---
asyncssh/rsa.py | 86 +--
asyncssh/saslprep.py | 10 +-
asyncssh/sftp.py | 615 ++++++++++++------
asyncssh/version.py | 2 +-
docs/api.rst | 14 +-
docs/changes.rst | 80 +++
pylintrc | 2 +-
setup.py | 7 +-
tests/__init__.py | 3 +-
tests/test_asn1.py | 183 ++++++
tests/test_cipher.py | 105 +++
tests/test_compression.py | 39 ++
tests/test_kex.py | 318 +++++++++
tests/test_keys.py | 361 -----------
tests/test_known_hosts.py | 190 ++++++
tests/test_mac.py | 47 ++
tests/test_native_ec.py | 138 ++++
tests/test_public_key.py | 1183 ++++++++++++++++++++++++++++++++++
tests/test_saslprep.py | 75 +++
tests/util.py | 44 ++
52 files changed, 4369 insertions(+), 2056 deletions(-)
diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 0000000..f2c3d0b
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,7 @@
+[run]
+branch = True
+
+[report]
+exclude_lines =
+ pragma: no cover
+ raise NotImplementedError
diff --git a/README.rst b/README.rst
index 021872b..25fbf3c 100644
--- a/README.rst
+++ b/README.rst
@@ -78,7 +78,7 @@ Prerequisites
To use ``asyncssh``, you need the following:
* Python 3.4 or later
-* PyCrypto 2.6 or later and/or PyCA 0.6.1 or later
+* cryptography (PyCA) 1.0.0 or later
Installation
------------
@@ -106,8 +106,6 @@ functionality:
AsyncSSH defines the following optional PyPI extra packages to make it
easy to install any or all of these dependencies:
- | pycrypto
- | pyca
| bcrypt
| libnacl
@@ -115,7 +113,7 @@ For example, to install all of these, you can run:
::
- pip install 'asyncssh[pycrypto,pyca,bcrypt,libnacl]'
+ pip install 'asyncssh[bcrypt,libnacl]'
Note that you will still need to manually install the libsodium library
listed above for libnacl to work correctly. Unfortunately, since
diff --git a/asyncssh/__init__.py b/asyncssh/__init__.py
index 86d1991..842ac10 100644
--- a/asyncssh/__init__.py
+++ b/asyncssh/__init__.py
@@ -54,4 +54,4 @@ from .sftp import SEEK_SET, SEEK_CUR, SEEK_END
from .stream import SSHReader, SSHWriter
# Import these explicitly to trigger register calls in them
-from . import curve25519, ed25519, ec, rsa, dsa, dh
+from . import ed25519, ecdsa, rsa, dsa, ecdh, dh
diff --git a/asyncssh/asn1.py b/asyncssh/asn1.py
index 41ac3cc..38aedbf 100644
--- a/asyncssh/asn1.py
+++ b/asyncssh/asn1.py
@@ -55,6 +55,9 @@ _der_class_by_type = {}
def _encode_identifier(asn1_class, constructed, tag):
"""Encode a DER object's identifier"""
+ if asn1_class not in (UNIVERSAL, APPLICATION, CONTEXT_SPECIFIC, PRIVATE):
+ raise ASN1EncodeError('Invalid ASN.1 class')
+
flags = (asn1_class << 6) | (0x20 if constructed else 0x00)
if tag < 0x20:
@@ -128,7 +131,7 @@ class RawDERObject:
"""
- def __init__(self, asn1_class, tag, content):
+ def __init__(self, tag, content, asn1_class):
self.asn1_class = asn1_class
self.tag = tag
self.content = content
@@ -137,6 +140,14 @@ class RawDERObject:
return ('RawDERObject(%s, %s, %r)' %
(_asn1_class[self.asn1_class], self.tag, self.content))
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__) and
+ self.asn1_class == other.asn1_class and
+ self.tag == other.tag and self.content == other.content)
+
+ def __hash__(self):
+ return hash((self.asn1_class, self.tag, self.content))
+
def encode_identifier(self):
"""Encode the DER identifier for this object as a byte string"""
@@ -171,6 +182,14 @@ class TaggedDERObject:
return ('TaggedDERObject(%s, %s, %r)' %
(_asn1_class[self.asn1_class], self.tag, self.value))
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__) and
+ self.asn1_class == other.asn1_class and
+ self.tag == other.tag and self.value == other.value)
+
+ def __hash__(self):
+ return hash((self.asn1_class, self.tag, self.value))
+
def encode_identifier(self):
"""Encode the DER identifier for this object as a byte string"""
@@ -239,7 +258,8 @@ class _Integer:
l = value.bit_length()
l = l // 8 + 1 if l % 8 == 0 else (l + 7) // 8
- return value.to_bytes(l, 'big', signed=True)
+ result = value.to_bytes(l, 'big', signed=True)
+ return result[1:] if result.startswith(b'\xff\x80') else result
@classmethod
def decode(cls, constructed, content):
@@ -402,7 +422,6 @@ class BitString:
self.value = value
self.unused = unused
- self.named = named
def __str__(self):
result = ''.join(bin(b)[2:].zfill(8) for b in self.value)
@@ -411,7 +430,14 @@ class BitString:
return result
def __repr__(self):
- return 'BitString(%s)' % self
+ return "BitString('%s')" % self
+
+ def __eq__(self, other):
+ return (isinstance(other, self.__class__) and
+ self.value == other.value and self.unused == other.unused)
+
+ def __hash__(self):
+ return hash((self.value, self.unused))
def encode(self):
"""Encode a DER bit string"""
@@ -448,8 +474,11 @@ class ObjectIdentifier:
def __init__(self, value):
self.value = value
+ def __str__(self):
+ return self.value
+
def __repr__(self):
- return 'ObjectIdentifier(%s)' % self.value
+ return "ObjectIdentifier('%s')" % self.value
def __eq__(self, other):
return isinstance(other, self.__class__) and self.value == other.value
@@ -463,6 +492,10 @@ class ObjectIdentifier:
def _bytes(component):
"""Convert a single element of an OID to a DER byte string"""
+ if component < 0:
+ raise ASN1EncodeError('Components of object identifier must '
+ 'be greater than or equal to 0')
+
result = [component & 0x7f]
while component >= 0x80:
component >>= 7
@@ -470,14 +503,18 @@ class ObjectIdentifier:
return bytes(result[::-1])
- components = [int(c) for c in self.value.split('.')]
+ try:
+ components = [int(c) for c in self.value.split('.')]
+ except ValueError:
+ raise ASN1EncodeError('Component values must be integers')
+
if len(components) < 2:
raise ASN1EncodeError('Object identifiers must have at least two '
'components')
elif components[0] < 0 or components[0] > 2:
raise ASN1EncodeError('First component of object identifier must '
'be between 0 and 2')
- elif components[0] < 2 and (components[1] < 0 or components[1] > 40):
+ elif components[0] < 2 and (components[1] < 0 or components[1] > 39):
raise ASN1EncodeError('Second component of object identifier must '
'be between 0 and 39')
@@ -489,26 +526,30 @@ class ObjectIdentifier:
"""Decode a DER object identifier"""
if constructed:
- raise ASN1DecodeError('OBJECT IDENTIFIER should not be '
+ raise ASN1DecodeError('OBJECT IDENTIFIER should not be '
'constructed')
if not content:
raise ASN1DecodeError('Empty object identifier')
- components = [str(component) for component in divmod(content[0], 40)]
+ b = content[0]
+ components = list(divmod(b, 40)) if b < 80 else [2, b-80]
+
component = 0
for b in content[1:]:
- if b < 0x80:
- components.append(str(component | b))
+ if b == 0x80 and component == 0:
+ raise ASN1DecodeError('Invalid component')
+ elif b < 0x80:
+ components.append(component | b)
component = 0
else:
component |= b & 0x7f
component <<= 7
if component:
- raise ASN1DecodeError('Incomplete object identifier')
+ raise ASN1DecodeError('Incomplete component')
- return cls('.'.join(components))
+ return cls('.'.join(str(c) for c in components))
def der_encode(value):
@@ -545,7 +586,7 @@ def der_encode(value):
identifier = cls.identifier
content = cls.encode(value)
else:
- raise TypeError('Cannot DER encode type %s' % t.__name__)
+ raise ASN1EncodeError('Cannot DER encode type %s' % t.__name__)
length = len(content)
if length < 0x80:
@@ -604,7 +645,10 @@ def der_decode(data, partial_ok=False):
tag |= b & 0x7f
tag <<= 7
else:
- raise ASN1DecodeError('Incomplete data')
+ raise ASN1DecodeError('Incomplete tag')
+
+ if offset >= len(data):
+ raise ASN1DecodeError('Incomplete data')
length = data[offset]
offset += 1
@@ -626,9 +670,9 @@ def der_decode(data, partial_ok=False):
value = cls.decode(constructed, data[offset:offset+length])
elif constructed:
value = TaggedDERObject(tag, der_decode(data[offset:offset+length]),
- asn1_class=asn1_class)
+ asn1_class)
else:
- value = RawDERObject(asn1_class, tag, data[offset:offset+length])
+ value = RawDERObject(tag, data[offset:offset+length], asn1_class)
if partial_ok:
return value, offset+length
diff --git a/asyncssh/auth.py b/asyncssh/auth.py
index 4eeb711..1fd0a71 100644
--- a/asyncssh/auth.py
+++ b/asyncssh/auth.py
@@ -12,6 +12,8 @@
"""SSH authentication handlers"""
+import asyncio
+
from .constants import DISC_PROTOCOL_ERROR
from .misc import DisconnectError
from .packet import Boolean, Byte, String, UInt32, SSHPacketHandler
@@ -37,16 +39,35 @@ _client_auth_handlers = {}
_server_auth_handlers = {}
-class _SSHAuthError(Exception):
- """This is raised when we can't proceed with the current form of auth."""
+class _Auth(SSHPacketHandler):
+ """Parent class for authentication"""
+
+ def __init__(self):
+ self._coro = None
+
+ def cancel(self):
+ """Cancel any authentication in progress"""
+ if self._coro:
+ self._coro.cancel()
+ self._coro = None
-class _ClientAuth(SSHPacketHandler):
- """Parent class for client auth"""
+class _ClientAuth(_Auth):
+ """Parent class for client authentication"""
def __init__(self, conn, method):
+ super().__init__()
+
self._conn = conn
self._method = method
+ self._coro = asyncio.async(self._start())
+
+ @asyncio.coroutine
+ def _start(self):
+ """Abstract method for starting client authentication"""
+
+ # Provided by subclass
+ raise NotImplementedError
def auth_succeeded(self):
"""Callback when auth succeeds"""
@@ -54,16 +75,6 @@ class _ClientAuth(SSHPacketHandler):
def auth_failed(self):
"""Callback when auth fails"""
- def process_packet(self, pkttype, packet):
- try:
- processed = super().process_packet(pkttype, packet)
- except _SSHAuthError:
- # We can't complete the current auth - move to the next one
- processed = True
- self._conn.try_next_auth()
-
- return processed
-
def send_request(self, *args, key=None):
"""Send a user authentication request"""
@@ -73,8 +84,9 @@ class _ClientAuth(SSHPacketHandler):
class _ClientNullAuth(_ClientAuth):
"""Client side implementation of null auth"""
- def __init__(self, conn, method):
- super().__init__(conn, method)
+ @asyncio.coroutine
+ def _start(self):
+ """Start client null authentication"""
self.send_request()
@@ -84,12 +96,16 @@ class _ClientNullAuth(_ClientAuth):
class _ClientPublicKeyAuth(_ClientAuth):
"""Client side implementation of public key auth"""
- def __init__(self, conn, method):
- super().__init__(conn, method)
+ @asyncio.coroutine
+ def _start(self):
+ """Start client public key authentication"""
+
+ self._alg, self._key, self._key_data = \
+ yield from self._conn.public_key_auth_requested()
- self._alg, self._key, self._key_data = conn.public_key_auth_requested()
if self._alg is None:
- raise _SSHAuthError()
+ self._conn.try_next_auth()
+ return
self.send_request(Boolean(False), String(self._alg),
String(self._key_data))
@@ -118,15 +134,34 @@ class _ClientPublicKeyAuth(_ClientAuth):
class _ClientKbdIntAuth(_ClientAuth):
"""Client side implementation of keyboard-interactive auth"""
- def __init__(self, conn, method):
- super().__init__(conn, method)
+ @asyncio.coroutine
+ def _start(self):
+ """Start client keyboard interactive authentication"""
+
+ submethods = yield from self._conn.kbdint_auth_requested()
- submethods = conn.kbdint_auth_requested()
if submethods is None:
- raise _SSHAuthError()
+ self._conn.try_next_auth()
+ return
self.send_request(String(''), String(submethods))
+ @asyncio.coroutine
+ def _receive_challenge(self, name, instruction, lang, prompts):
+ """Receive and respond to a keyboard interactive challenge"""
+
+ responses = \
+ yield from self._conn.kbdint_challenge_received(name, instruction,
+ lang, prompts)
+
+ if responses is None:
+ self._conn.try_next_auth()
+ return
+
+ self._conn.send_packet(Byte(MSG_USERAUTH_INFO_RESPONSE),
+ UInt32(len(responses)),
+ b''.join(String(r) for r in responses))
+
def _process_info_request(self, pkttype, packet):
"""Process a keyboard interactive authentication request"""
@@ -158,15 +193,11 @@ class _ClientKbdIntAuth(_ClientAuth):
prompts.append((prompt, echo))
- responses = self._conn.kbdint_challenge_received(name, instruction,
- lang, prompts)
-
- if responses is None:
- raise _SSHAuthError()
+ self.cancel()
+ self._coro = asyncio.async(self._receive_challenge(name, instruction,
+ lang, prompts))
- self._conn.send_packet(Byte(MSG_USERAUTH_INFO_RESPONSE),
- UInt32(len(responses)),
- b''.join(String(r) for r in responses))
+ return True
packet_handlers = {
MSG_USERAUTH_INFO_REQUEST: _process_info_request
@@ -181,12 +212,37 @@ class _ClientPasswordAuth(_ClientAuth):
self._password_change = False
- password = conn.password_auth_requested()
+ @asyncio.coroutine
+ def _start(self):
+ """Start client password authentication"""
+
+ password = yield from self._conn.password_auth_requested()
+
if password is None:
- raise _SSHAuthError()
+ self._conn.try_next_auth()
+ return
self.send_request(Boolean(False), String(password))
+ @asyncio.coroutine
+ def _change_password(self):
+ """Start password change"""
+
+ result = yield from self._conn.password_change_requested()
+
+ if result == NotImplemented:
+ # Password change not supported - move on to the next auth method
+ self._conn.try_next_auth()
+ return
+
+ old_password, new_password = result
+
+ self._password_change = True
+
+ self.send_request(Boolean(True),
+ String(old_password.encode('utf-8')),
+ String(new_password.encode('utf-8')))
+
def auth_succeeded(self):
if self._password_change:
self._password_change = False
@@ -212,18 +268,8 @@ class _ClientPasswordAuth(_ClientAuth):
raise DisconnectError(DISC_PROTOCOL_ERROR,
'Invalid password change request') from None
- result = self._conn.password_change_requested()
- if result == NotImplemented:
- # Password change not supported - move on to the next auth method
- raise _SSHAuthError()
- else:
- old_password, new_password = result
-
- self._password_change = True
-
- self.send_request(Boolean(True),
- String(old_password.encode('utf-8')),
- String(new_password.encode('utf-8')))
+ self.cancel()
+ self._coro = asyncio.async(self._change_password())
return True
@@ -232,12 +278,22 @@ class _ClientPasswordAuth(_ClientAuth):
}
-class _ServerAuth(SSHPacketHandler):
- """Parent class for server side auth"""
+class _ServerAuth(_Auth):
+ """Parent class for server authentication"""
+
+ def __init__(self, conn, username, packet):
+ super().__init__()
- def __init__(self, conn, username):
self._conn = conn
self._username = username
+ self._coro = asyncio.async(self._start(packet))
+
+ @asyncio.coroutine
+ def _start(self, packet):
+ """Abstract method for starting server authentication"""
+
+ # Provided by subclass
+ raise NotImplementedError
def send_failure(self, partial_success=False):
"""Send a user authentication failure response"""
@@ -260,14 +316,13 @@ class _ServerNullAuth(_ServerAuth):
# pylint: disable=unused-argument
return False
- def __init__(self, conn, username, packet):
- super().__init__(conn, username)
+ @asyncio.coroutine
+ def _start(self, packet):
+ """Always fail null server authentication"""
packet.check_end()
-
self.send_failure()
-
class _ServerPublicKeyAuth(_ServerAuth):
"""Server side implementation of public key auth"""
@@ -277,8 +332,9 @@ class _ServerPublicKeyAuth(_ServerAuth):
return conn.public_key_auth_supported()
- def __init__(self, conn, username, packet):
- super().__init__(conn, username)
+ @asyncio.coroutine
+ def _start(self, packet):
+ """Start server public key authentication"""
sig_present = packet.get_boolean()
algorithm = packet.get_string()
@@ -293,8 +349,8 @@ class _ServerPublicKeyAuth(_ServerAuth):
packet.check_end()
- if self._conn.validate_public_key(self._username, key_data,
- msg, signature):
+ if (yield from self._conn.validate_public_key(self._username, key_data,
+ msg, signature)):
if sig_present:
self.send_success()
else:
@@ -313,8 +369,9 @@ class _ServerKbdIntAuth(_ServerAuth):
return conn.kbdint_auth_supported()
- def __init__(self, conn, username, packet):
- super().__init__(conn, username)
+ @asyncio.coroutine
+ def _start(self, packet):
+ """Start server keyboard interactive authentication"""
lang = packet.get_string()
submethods = packet.get_string()
@@ -327,8 +384,9 @@ class _ServerKbdIntAuth(_ServerAuth):
raise DisconnectError(DISC_PROTOCOL_ERROR, 'Invalid keyboard '
'interactive auth request') from None
- challenge = self._conn.get_kbdint_challenge(self._username,
- lang, submethods)
+ challenge = yield from self._conn.get_kbdint_challenge(self._username,
+ lang,
+ submethods)
self._send_challenge(challenge)
def _send_challenge(self, challenge):
@@ -350,6 +408,15 @@ class _ServerKbdIntAuth(_ServerAuth):
else:
self.send_failure()
+ @asyncio.coroutine
+ def _validate_response(self, responses):
+ """Validate a keyboard interactive authentication response"""
+
+ next_challenge = \
+ yield from self._conn.validate_kbdint_response(self._username,
+ responses)
+ self._send_challenge(next_challenge)
+
def _process_info_response(self, pkttype, packet):
"""Process a keyboard interactive authentication response"""
@@ -370,9 +437,8 @@ class _ServerKbdIntAuth(_ServerAuth):
packet.check_end()
- next_challenge = self._conn.validate_kbdint_response(self._username,
- responses)
- self._send_challenge(next_challenge)
+ self.cancel()
+ self._coro = asyncio.async(self._validate_response(responses))
packet_handlers = {
MSG_USERAUTH_INFO_RESPONSE: _process_info_response
@@ -388,8 +454,9 @@ class _ServerPasswordAuth(_ServerAuth):
return conn.password_auth_supported()
- def __init__(self, conn, username, packet):
- super().__init__(conn, username)
+ @asyncio.coroutine
+ def _start(self, packet):
+ """Start server password authentication"""
password_change = packet.get_boolean()
password = packet.get_string()
@@ -405,7 +472,7 @@ class _ServerPasswordAuth(_ServerAuth):
# TODO: Handle password change request
- if self._conn.validate_password(self._username, password):
+ if (yield from self._conn.validate_password(self._username, password)):
self.send_success()
else:
self.send_failure()
@@ -423,12 +490,9 @@ def lookup_client_auth(conn, method):
"""Look up the client authentication method to use"""
if method in _auth_methods:
- try:
- return _client_auth_handlers[method](conn, method)
- except _SSHAuthError:
- pass
-
- return None
+ return _client_auth_handlers[method](conn, method)
+ else:
+ return None
def get_server_auth_methods(conn):
diff --git a/asyncssh/channel.py b/asyncssh/channel.py
index 13db4c0..b4a1923 100644
--- a/asyncssh/channel.py
+++ b/asyncssh/channel.py
@@ -187,8 +187,8 @@ class SSHChannel(SSHPacketHandler):
raise DisconnectError(DISC_PROTOCOL_ERROR,
'Unicode decode error')
- if not self._session.eof_received():
- self.close()
+ if not self._session.eof_received() and self._send_state == 'open':
+ self.write_eof()
else:
self._recv_window -= len(data)
@@ -200,18 +200,20 @@ class SSHChannel(SSHPacketHandler):
if self._encoding:
if datatype in self._recv_partial:
- input = self._recv_partial.pop(datatype) + data
+ encdata = self._recv_partial.pop(datatype) + data
else:
- input = data
+ encdata = data
- while input:
+ while encdata:
try:
- data = input.decode(self._encoding)
- input = b''
+ data = encdata.decode(self._encoding)
+ encdata = b''
except UnicodeDecodeError as exc:
if exc.start > 0:
- data = input[:exc.start].decode()
- input = input[exc.start:]
+ # Avoid pylint false positive
+ # pylint: disable=invalid-slice-index
+ data = encdata[:exc.start].decode()
+ encdata = encdata[exc.start:]
elif exc.reason == 'unexpected end of data':
break
else:
@@ -220,8 +222,8 @@ class SSHChannel(SSHPacketHandler):
self._session.data_received(data, datatype)
- if input:
- self._recv_partial[datatype] = input
+ if encdata:
+ self._recv_partial[datatype] = encdata
else:
self._session.data_received(data, datatype)
@@ -401,6 +403,7 @@ class SSHChannel(SSHPacketHandler):
# If we haven't yet sent a close, send one now
if self._send_state not in {'close_sent', 'closed'}:
self._send_packet(MSG_CHANNEL_CLOSE)
+ self._send_state = 'close_sent'
self._loop.call_soon(self._cleanup)
diff --git a/asyncssh/cipher.py b/asyncssh/cipher.py
index a494d9f..333a156 100644
--- a/asyncssh/cipher.py
+++ b/asyncssh/cipher.py
@@ -25,7 +25,7 @@ def register_encryption_alg(alg, cipher_name, mode_name, key_size,
"""Register an encryption algorithm"""
cipher = lookup_cipher(cipher_name, mode_name)
- if cipher:
+ if cipher: # pragma: no branch
_enc_algs.append(alg)
_enc_params[alg] = (key_size, cipher.iv_size,
cipher.block_size, cipher.mode_name)
diff --git a/asyncssh/connection.py b/asyncssh/connection.py
index e498ccd..55ecacc 100644
--- a/asyncssh/connection.py
+++ b/asyncssh/connection.py
@@ -68,7 +68,7 @@ from .mac import get_mac_algs, get_mac_params, get_mac
from .misc import ChannelOpenError, DisconnectError, ip_address
from .packet import Boolean, Byte, NameList, String, UInt32, UInt64
-from .packet import SSHPacket, SSHPacketHandler
+from .packet import PacketDecodeError, SSHPacket, SSHPacketHandler
from .public_key import CERT_TYPE_HOST, CERT_TYPE_USER
from .public_key import get_public_key_algs, get_certificate_algs
@@ -159,7 +159,7 @@ def _load_private_key(key):
elif isinstance(cert, bytes):
cert = import_certificate(cert)
- if cert and key.encode_ssh_public() != cert.key.encode_ssh_public():
+ if cert and key.get_ssh_public_key() != cert.key.get_ssh_public_key():
raise ValueError('Certificate key mismatch')
return key, cert
@@ -380,6 +380,10 @@ class SSHConnection(SSHPacketHandler):
def _cleanup(self, exc):
"""Clean up this connection"""
+ if self._auth:
+ self._auth.cancel()
+ self._auth = None
+
if self._channels:
for chan in list(self._channels.values()):
chan.process_connection_close(exc)
@@ -679,19 +683,23 @@ class SSHConnection(SSHPacketHandler):
not self._decompress_after_auth):
payload = self._decompressor.decompress(payload)
- packet = SSHPacket(payload)
- pkttype = packet.get_byte()
-
- if self._kex and MSG_KEX_FIRST <= pkttype <= MSG_KEX_LAST:
- if self._ignore_first_kex:
- self._ignore_first_kex = False
- processed = True
+ try:
+ packet = SSHPacket(payload)
+ pkttype = packet.get_byte()
+
+ if self._kex and MSG_KEX_FIRST <= pkttype <= MSG_KEX_LAST:
+ if self._ignore_first_kex:
+ self._ignore_first_kex = False
+ processed = True
+ else:
+ processed = self._kex.process_packet(pkttype, packet)
+ elif (self._auth and
+ MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST):
+ processed = self._auth.process_packet(pkttype, packet)
else:
- processed = self._kex.process_packet(pkttype, packet)
- elif self._auth and MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST:
- processed = self._auth.process_packet(pkttype, packet)
- else:
- processed = self.process_packet(pkttype, packet)
+ processed = self.process_packet(pkttype, packet)
+ except PacketDecodeError as exc:
+ raise DisconnectError(DISC_PROTOCOL_ERROR, str(exc))
if not processed:
self.send_packet(Byte(MSG_UNIMPLEMENTED), UInt32(self._recv_seq))
@@ -958,6 +966,7 @@ class SSHConnection(SSHPacketHandler):
self._auth_in_progress = False
self._auth_complete = True
self._extra.update(username=self._username)
+ self._send_deferred_packets()
def send_channel_open_confirmation(self, send_chan, recv_chan,
recv_window, recv_pktsize,
@@ -1297,6 +1306,9 @@ class SSHConnection(SSHPacketHandler):
self.send_userauth_success()
return
+ if self._auth:
+ self._auth.cancel()
+
self._auth = lookup_server_auth(self, self._username,
method, packet)
@@ -1330,6 +1342,7 @@ class SSHConnection(SSHPacketHandler):
packet.check_end()
if self.is_client() and self._auth:
+ self._auth.cancel()
self._auth = None
self._auth_in_progress = False
self._auth_complete = True
@@ -1828,6 +1841,10 @@ class SSHClientConnection(SSHConnection):
def try_next_auth(self):
"""Attempt client authentication using the next compatible method"""
+ if self._auth:
+ self._auth.cancel()
+ self._auth = None
+
while self._auth_methods:
method = self._auth_methods.pop(0)
@@ -1835,30 +1852,36 @@ class SSHClientConnection(SSHConnection):
if self._auth:
return
- raise DisconnectError(DISC_NO_MORE_AUTH_METHODS_AVAILABLE,
- 'Permission denied')
+ self._force_close(DisconnectError(DISC_NO_MORE_AUTH_METHODS_AVAILABLE,
+ 'Permission denied'))
+ @asyncio.coroutine
def public_key_auth_requested(self):
"""Return a client key to authenticate with"""
if self._client_keys:
key, cert = self._client_keys.pop(0)
else:
- client_key = self._owner.public_key_auth_requested()
- key, cert = _load_private_key(client_key)
+ result = self._owner.public_key_auth_requested()
+
+ if asyncio.iscoroutine(result):
+ result = yield from result
+
+ key, cert = _load_private_key(result)
if cert:
self._client_keys.insert(0, (key, None))
return cert.algorithm, key, cert.data
elif key:
- return key.algorithm, key, key.encode_ssh_public()
+ return key.algorithm, key, key.get_ssh_public_key()
else:
return None, None, None
+ @asyncio.coroutine
def password_auth_requested(self):
"""Return a password to authenticate with"""
- # Only allow passwordauth if the connection supports encryption
+ # Only allow password auth if the connection supports encryption
# and a MAC.
if (not self._send_cipher or
(not self._send_mac and
@@ -1866,17 +1889,26 @@ class SSHClientConnection(SSHConnection):
return None
if self._password:
- password = self._password
+ result = self._password
self._password = None
else:
- password = self._owner.password_auth_requested()
+ result = self._owner.password_auth_requested()
+
+ if asyncio.iscoroutine(result):
+ result = yield from result
- return password
+ return result
+ @asyncio.coroutine
def password_change_requested(self):
"""Return a password to authenticate with and what to change it to"""
- return self._owner.password_change_requested()
+ result = self._owner.password_change_requested()
+
+ if asyncio.iscoroutine(result):
+ result = yield from result
+
+ return result
def password_changed(self):
"""Report a successful password change"""
@@ -1888,6 +1920,7 @@ class SSHClientConnection(SSHConnection):
self._owner.password_change_failed()
+ @asyncio.coroutine
def kbdint_auth_requested(self):
"""Return the list of supported keyboard-interactive auth methods
@@ -1904,29 +1937,38 @@ class SSHClientConnection(SSHConnection):
self._send_mode not in ('chacha', 'gcm'))):
return None
- submethods = self._owner.kbdint_auth_requested()
- if submethods is None and self._password is not None:
+ result = self._owner.kbdint_auth_requested()
+
+ if asyncio.iscoroutine(result):
+ result = yield from result
+
+ if result is None and self._password is not None:
self._kbdint_password_auth = True
- submethods = ''
+ result = ''
- return submethods
+ return result
+ @asyncio.coroutine
def kbdint_challenge_received(self, name, instructions, lang, prompts):
"""Return responses to a keyboard-interactive auth challenge"""
if self._kbdint_password_auth:
if len(prompts) == 0:
# Silently drop any empty challenges used to print messages
- return []
- elif (len(prompts) == 1 and
- 'password' in prompts[0][0].lower().strip()):
+ result = []
+ elif len(prompts) == 1 and 'password' in prompts[0][0].lower():
password = self.password_auth_requested()
- return [password] if password is not None else None
+ result = [password] if password is not None else None
else:
- return None
+ result = None
else:
- return self._owner.kbdint_challenge_received(name, instructions,
- lang, prompts)
+ result = self._owner.kbdint_challenge_received(name, instructions,
+ lang, prompts)
... 7751 lines suppressed ...
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/python-modules/packages/python-asyncssh.git
More information about the Python-modules-commits
mailing list