from typing import Union
from noise.exceptions import NoiseMaxNonceError
from .constants import Empty, TOKEN_E, TOKEN_S, TOKEN_EE, TOKEN_ES, TOKEN_SE, TOKEN_SS, TOKEN_PSK, MAX_NONCE
[docs]class CipherState(object):
"""
Implemented as per Noise Protocol specification - paragraph 5.1.
The initialize_key() function takes additional required argument - noise_protocol.
This class holds an instance of Cipher wrapper. It manages initialisation of underlying cipher function
with appropriate key in initialize_key() and rekey() methods.
"""
def __init__(self, noise_protocol):
self.k = Empty()
self.n = None
self.cipher = noise_protocol.cipher_class()
[docs] def initialize_key(self, key):
"""
:param key: Key to set within CipherState
"""
self.k = key
self.n = 0
if self.has_key():
self.cipher.initialize(key)
[docs] def has_key(self):
"""
:return: True if self.k is not an instance of Empty
"""
return not isinstance(self.k, Empty)
def set_nonce(self, nonce):
self.n = nonce
[docs] def encrypt_with_ad(self, ad: bytes, plaintext: bytes) -> bytes:
"""
If k is non-empty returns ENCRYPT(k, n++, ad, plaintext). Otherwise returns plaintext.
:param ad: bytes sequence
:param plaintext: bytes sequence
:return: ciphertext bytes sequence
"""
if self.n == MAX_NONCE:
raise NoiseMaxNonceError('Nonce has depleted!')
if not self.has_key():
return plaintext
ciphertext = self.cipher.encrypt(self.k, self.n, ad, plaintext)
self.n = self.n + 1
return ciphertext
[docs] def decrypt_with_ad(self, ad: bytes, ciphertext: bytes) -> bytes:
"""
If k is non-empty returns DECRYPT(k, n++, ad, ciphertext). Otherwise returns ciphertext. If an authentication
failure occurs in DECRYPT() then n is not incremented and an error is signaled to the caller.
:param ad: bytes sequence
:param ciphertext: bytes sequence
:return: plaintext bytes sequence
"""
if self.n == MAX_NONCE:
raise NoiseMaxNonceError('Nonce has depleted!')
if not self.has_key():
return ciphertext
plaintext = self.cipher.decrypt(self.k, self.n, ad, ciphertext)
self.n = self.n + 1
return plaintext
def rekey(self):
self.k = self.cipher.rekey(self.k)
self.cipher.initialize(self.k)
[docs]class SymmetricState(object):
"""
Implemented as per Noise Protocol specification - paragraph 5.2.
The initialize_symmetric function takes different required argument - noise_protocol, which contains protocol_name.
"""
def __init__(self):
self.h = None
self.ck = None
self.noise_protocol = None
self.cipher_state = None
[docs] @classmethod
def initialize_symmetric(cls, noise_protocol: 'NoiseProtocol') -> 'SymmetricState':
"""
Instead of taking protocol_name as an argument, we take full NoiseProtocol object, that way we have access to
protocol name and crypto functions
Comments below are mostly copied from specification.
:param noise_protocol: a valid NoiseProtocol instance
:return: initialised SymmetricState instance
"""
# Create SymmetricState
instance = cls()
instance.noise_protocol = noise_protocol
# If protocol_name is less than or equal to HASHLEN bytes in length, sets h equal to protocol_name with zero
# bytes appended to make HASHLEN bytes. Otherwise sets h = HASH(protocol_name).
if len(noise_protocol.name) <= noise_protocol.hash_fn.hashlen:
instance.h = noise_protocol.name.ljust(noise_protocol.hash_fn.hashlen, b'\0')
else:
instance.h = noise_protocol.hash_fn.hash(noise_protocol.name)
# Sets ck = h.
instance.ck = instance.h
# Calls InitializeKey(empty).
instance.cipher_state = CipherState(noise_protocol)
instance.cipher_state.initialize_key(Empty())
noise_protocol.cipher_state_handshake = instance.cipher_state
return instance
[docs] def mix_key(self, input_key_material: bytes):
"""
:param input_key_material:
:return:
"""
# Sets ck, temp_k = HKDF(ck, input_key_material, 2).
self.ck, temp_k = self.noise_protocol.hkdf(self.ck, input_key_material, 2)
# If HASHLEN is 64, then truncates temp_k to 32 bytes.
if self.noise_protocol.hash_fn.hashlen == 64:
temp_k = temp_k[:32]
# Calls InitializeKey(temp_k).
self.cipher_state.initialize_key(temp_k)
[docs] def mix_hash(self, data: bytes):
"""
Sets h = HASH(h + data).
:param data: bytes sequence
"""
self.h = self.noise_protocol.hash_fn.hash(self.h + data)
def mix_key_and_hash(self, input_key_material: bytes):
# Sets ck, temp_h, temp_k = HKDF(ck, input_key_material, 3).
self.ck, temp_h, temp_k = self.noise_protocol.hkdf(self.ck, input_key_material, 3)
# Calls MixHash(temp_h).
self.mix_hash(temp_h)
# If HASHLEN is 64, then truncates temp_k to 32 bytes.
if self.noise_protocol.hash_fn.hashlen == 64:
temp_k = temp_k[:32]
# Calls InitializeKey(temp_k).
self.cipher_state.initialize_key(temp_k)
def get_handshake_hash(self):
return self.h
[docs] def encrypt_and_hash(self, plaintext: bytes) -> bytes:
"""
Sets ciphertext = EncryptWithAd(h, plaintext), calls MixHash(ciphertext), and returns ciphertext. Note that if
k is empty, the EncryptWithAd() call will set ciphertext equal to plaintext.
:param plaintext: bytes sequence
:return: ciphertext bytes sequence
"""
ciphertext = self.cipher_state.encrypt_with_ad(self.h, plaintext)
self.mix_hash(ciphertext)
return ciphertext
[docs] def decrypt_and_hash(self, ciphertext: bytes) -> bytes:
"""
Sets plaintext = DecryptWithAd(h, ciphertext), calls MixHash(ciphertext), and returns plaintext. Note that if
k is empty, the DecryptWithAd() call will set plaintext equal to ciphertext.
:param ciphertext: bytes sequence
:return: plaintext bytes sequence
"""
plaintext = self.cipher_state.decrypt_with_ad(self.h, ciphertext)
self.mix_hash(ciphertext)
return plaintext
[docs] def split(self):
"""
Returns a pair of CipherState objects for encrypting/decrypting transport messages.
:return: tuple (CipherState, CipherState)
"""
# Sets temp_k1, temp_k2 = HKDF(ck, b'', 2).
temp_k1, temp_k2 = self.noise_protocol.hkdf(self.ck, b'', 2)
# If HASHLEN is 64, then truncates temp_k1 and temp_k2 to 32 bytes.
if self.noise_protocol.hash_fn.hashlen == 64:
temp_k1 = temp_k1[:32]
temp_k2 = temp_k2[:32]
# Creates two new CipherState objects c1 and c2.
# Calls c1.InitializeKey(temp_k1) and c2.InitializeKey(temp_k2).
c1, c2 = CipherState(self.noise_protocol), CipherState(self.noise_protocol)
c1.initialize_key(temp_k1)
c2.initialize_key(temp_k2)
if self.noise_protocol.handshake_state.initiator:
self.noise_protocol.cipher_state_encrypt = c1
self.noise_protocol.cipher_state_decrypt = c2
else:
self.noise_protocol.cipher_state_encrypt = c2
self.noise_protocol.cipher_state_decrypt = c1
self.noise_protocol.handshake_done()
# Returns the pair (c1, c2).
return c1, c2
[docs]class HandshakeState(object):
"""
Implemented as per Noise Protocol specification - paragraph 5.3.
The initialize() function takes different required argument - noise_protocol, which contains handshake_pattern.
"""
def __init__(self):
self.noise_protocol = None
self.symmetric_state = None
self.initiator = None
self.s = None
self.e = None
self.rs = None
self.re = None
self.message_patterns = None
[docs] @classmethod
def initialize(cls, noise_protocol: 'NoiseProtocol', initiator: bool, prologue: bytes=b'', s: '_KeyPair'=None,
e: '_KeyPair'=None, rs: '_KeyPair'=None, re: '_KeyPair'=None) -> 'HandshakeState':
"""
Constructor method.
Comments below are mostly copied from specification.
Instead of taking handshake_pattern as an argument, we take full NoiseProtocol object, that way we have access
to protocol name and crypto functions
:param noise_protocol: a valid NoiseProtocol instance
:param initiator: boolean indicating the initiator or responder role
:param prologue: byte sequence which may be zero-length, or which may contain context information that both
parties want to confirm is identical
:param s: local static key pair
:param e: local ephemeral key pair
:param rs: remote party’s static public key
:param re: remote party’s ephemeral public key
:return: initialized HandshakeState instance
"""
# Create HandshakeState
instance = cls()
instance.noise_protocol = noise_protocol
# Originally in specification:
# "Derives a protocol_name byte sequence by combining the names for
# the handshake pattern and crypto functions, as specified in Section 8."
# Instead, we supply the NoiseProtocol to the function. The protocol name should already be validated.
# Calls InitializeSymmetric(noise_protocol)
instance.symmetric_state = SymmetricState.initialize_symmetric(noise_protocol)
# Calls MixHash(prologue)
instance.symmetric_state.mix_hash(prologue)
# Sets the initiator, s, e, rs, and re variables to the corresponding arguments
instance.initiator = initiator
instance.s = s if s is not None else Empty()
instance.e = e if e is not None else Empty()
instance.rs = rs if rs is not None else Empty()
instance.re = re if re is not None else Empty()
# Calls MixHash() once for each public key listed in the pre-messages from handshake_pattern, with the specified
# public key as input (...). If both initiator and responder have pre-messages, the initiator’s public keys are
# hashed first
initiator_keypair_getter = instance._get_local_keypair if initiator else instance._get_remote_keypair
responder_keypair_getter = instance._get_remote_keypair if initiator else instance._get_local_keypair
for keypair in map(initiator_keypair_getter, noise_protocol.pattern.get_initiator_pre_messages()):
instance.symmetric_state.mix_hash(keypair.public_bytes)
for keypair in map(responder_keypair_getter, noise_protocol.pattern.get_responder_pre_messages()):
instance.symmetric_state.mix_hash(keypair.public_bytes)
# Sets message_patterns to the message patterns from handshake_pattern
instance.message_patterns = noise_protocol.pattern.tokens.copy()
return instance
[docs] def write_message(self, payload: Union[bytes, bytearray], message_buffer: bytearray):
"""
Comments below are mostly copied from specification.
:param payload: byte sequence which may be zero-length
:param message_buffer: buffer-like object
:return: None or result of SymmetricState.split() - tuple (CipherState, CipherState)
"""
# Fetches and deletes the next message pattern from message_patterns, then sequentially processes each token
# from the message pattern
message_pattern = self.message_patterns.pop(0)
for token in message_pattern:
if token == TOKEN_E:
# Sets e = GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key)
self.e = self.noise_protocol.dh_fn.generate_keypair() if isinstance(self.e, Empty) else self.e
message_buffer += self.e.public_bytes
self.symmetric_state.mix_hash(self.e.public_bytes)
if self.noise_protocol.is_psk_handshake:
self.symmetric_state.mix_key(self.e.public_bytes)
elif token == TOKEN_S:
# Appends EncryptAndHash(s.public_key) to the buffer
message_buffer += self.symmetric_state.encrypt_and_hash(self.s.public_bytes)
elif token == TOKEN_EE:
# Calls MixKey(DH(e, re))
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e.private, self.re.public))
elif token == TOKEN_ES:
# Calls MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) if responder
if self.initiator:
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e.private, self.rs.public))
else:
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.re.public))
elif token == TOKEN_SE:
# Calls MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) if responder
if self.initiator:
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.re.public))
else:
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e.private, self.rs.public))
elif token == TOKEN_SS:
# Calls MixKey(DH(s, rs))
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.rs.public))
elif token == TOKEN_PSK:
self.symmetric_state.mix_key_and_hash(self.noise_protocol.psks.pop(0))
else:
raise NotImplementedError('Pattern token: {}'.format(token))
# Appends EncryptAndHash(payload) to the buffer
message_buffer += self.symmetric_state.encrypt_and_hash(payload)
# If there are no more message patterns returns two new CipherState objects by calling Split()
if len(self.message_patterns) == 0:
return self.symmetric_state.split()
[docs] def read_message(self, message: Union[bytes, bytearray], payload_buffer: bytearray):
"""
Comments below are mostly copied from specification.
:param message: byte sequence containing a Noise handshake message
:param payload_buffer: buffer-like object
:return: None or result of SymmetricState.split() - tuple (CipherState, CipherState)
"""
# Fetches and deletes the next message pattern from message_patterns, then sequentially processes each token
# from the message pattern
dhlen = self.noise_protocol.dh_fn.dhlen
message_pattern = self.message_patterns.pop(0)
for token in message_pattern:
if token == TOKEN_E:
# Sets re to the next DHLEN bytes from the message. Calls MixHash(re.public_key).
self.re = self.noise_protocol.keypair_class.from_public_bytes(bytes(message[:dhlen]))
message = message[dhlen:]
self.symmetric_state.mix_hash(self.re.public_bytes)
if self.noise_protocol.is_psk_handshake:
self.symmetric_state.mix_key(self.re.public_bytes)
elif token == TOKEN_S:
# Sets temp to the next DHLEN + 16 bytes of the message if HasKey() == True, or to the next DHLEN bytes
# otherwise. Sets rs to DecryptAndHash(temp).
if self.noise_protocol.cipher_state_handshake.has_key():
temp = bytes(message[:dhlen + 16])
message = message[dhlen + 16:]
else:
temp = bytes(message[:dhlen])
message = message[dhlen:]
self.rs = self.noise_protocol.keypair_class.from_public_bytes(
self.symmetric_state.decrypt_and_hash(temp)
)
elif token == TOKEN_EE:
# Calls MixKey(DH(e, re)).
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e.private, self.re.public))
elif token == TOKEN_ES:
# Calls MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) if responder
if self.initiator:
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e.private, self.rs.public))
else:
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.re.public))
elif token == TOKEN_SE:
# Calls MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) if responder
if self.initiator:
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.re.public))
else:
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.e.private, self.rs.public))
elif token == TOKEN_SS:
# Calls MixKey(DH(s, rs))
self.symmetric_state.mix_key(self.noise_protocol.dh_fn.dh(self.s.private, self.rs.public))
elif token == TOKEN_PSK:
self.symmetric_state.mix_key_and_hash(self.noise_protocol.psks.pop(0))
else:
raise NotImplementedError('Pattern token: {}'.format(token))
# Calls DecryptAndHash() on the remaining bytes of the message and stores the output into payload_buffer.
payload_buffer += self.symmetric_state.decrypt_and_hash(bytes(message))
# If there are no more message patterns returns two new CipherState objects by calling Split()
if len(self.message_patterns) == 0:
return self.symmetric_state.split()
def _get_local_keypair(self, token: str) -> 'KeyPair':
keypair = getattr(self, token) # Maybe explicitly handle exception when getting improper keypair
if isinstance(keypair, Empty):
raise Exception('Required keypair {} is empty!'.format(token)) # Maybe subclassed exception
return keypair
def _get_remote_keypair(self, token: str) -> 'KeyPair':
keypair = getattr(self, 'r' + token) # Maybe explicitly handle exception when getting improper keypair
if isinstance(keypair, Empty):
raise Exception('Required keypair {} is empty!'.format('r' + token)) # Maybe subclassed exception
return keypair