From b2844f08659b7f3559c375a1e549396aeadd3a1d Mon Sep 17 00:00:00 2001 From: Sergey Dzeranov Date: Thu, 31 Oct 2024 11:40:21 +0300 Subject: [PATCH] fix: encryption module now can work not only with strings --- .../human_protocol_sdk/encryption/encryption.py | 17 ++++++++++++----- .../encryption/test_encryption.py | 16 ++++++++-------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/packages/sdk/python/human-protocol-sdk/human_protocol_sdk/encryption/encryption.py b/packages/sdk/python/human-protocol-sdk/human_protocol_sdk/encryption/encryption.py index 814787b57e..a9d094358b 100644 --- a/packages/sdk/python/human-protocol-sdk/human_protocol_sdk/encryption/encryption.py +++ b/packages/sdk/python/human-protocol-sdk/human_protocol_sdk/encryption/encryption.py @@ -40,7 +40,7 @@ ------ """ -from typing import Optional, List +from typing import Optional, List, Union from pgpy import PGPKey, PGPMessage from pgpy.constants import SymmetricKeyAlgorithm from pgpy.errors import PGPError @@ -74,7 +74,9 @@ def __init__(self, private_key_armored: str, passphrase: Optional[str] = None): else: raise ValueError("Private key locked. Passphrase needed") - def sign_and_encrypt(self, message: str, public_keys: List[str]) -> str: + def sign_and_encrypt( + self, message: Union[str, bytes], public_keys: List[str] + ) -> str: """ Signs and encrypts a message using the private key and recipient's public keys. @@ -139,7 +141,9 @@ def sign_and_encrypt(self, message: str, public_keys: List[str]) -> str: "your message", [public_key2, public_key3] ) """ + pgp_message = PGPMessage.new(message) + if not self.private_key.is_unlocked: try: with self.private_key.unlock(self.passphrase): @@ -159,7 +163,7 @@ def sign_and_encrypt(self, message: str, public_keys: List[str]) -> str: del sessionkey return pgp_message.__str__() - def decrypt(self, message: str, public_key: Optional[str] = None) -> str: + def decrypt(self, message: str, public_key: Optional[str] = None) -> bytes: """ Decrypts a message using the private key. @@ -209,7 +213,10 @@ def decrypt(self, message: str, public_key: Optional[str] = None) -> str: public_key, _ = PGPKey.from_blob(public_key) public_key.verify(decrypted_message) - return decrypted_message.message.__str__() + if isinstance(decrypted_message.message, str): + return bytes(decrypted_message.message, encoding="utf-8") + else: + return bytes(decrypted_message.message) except PGPError as e: if ( decrypted_message @@ -221,7 +228,7 @@ def decrypt(self, message: str, public_key: Optional[str] = None) -> str: ) raise ValueError("Failed to decrypt message: {}".format(str(e))) - def sign(self, message: str) -> str: + def sign(self, message: Union[str, bytes]) -> str: """ Signs a message using the private key. diff --git a/packages/sdk/python/human-protocol-sdk/test/human_protocol_sdk/encryption/test_encryption.py b/packages/sdk/python/human-protocol-sdk/test/human_protocol_sdk/encryption/test_encryption.py index 2a4163eba9..203fde95b9 100644 --- a/packages/sdk/python/human-protocol-sdk/test/human_protocol_sdk/encryption/test_encryption.py +++ b/packages/sdk/python/human-protocol-sdk/test/human_protocol_sdk/encryption/test_encryption.py @@ -58,20 +58,20 @@ def test_encrypt_with_locked_private_key(self): def test_decrypt(self): encryption = Encryption(private_key2) decrypted_message = encryption.decrypt(encrypted_message) - self.assertIsInstance(decrypted_message, str) - self.assertEqual(decrypted_message, message) + self.assertIsInstance(decrypted_message, bytes) + self.assertEqual(decrypted_message.decode("utf-8"), message) def test_decrypt_checking_signature(self): encryption = Encryption(private_key2) decrypted_message = encryption.decrypt(encrypted_message, public_key) - self.assertIsInstance(decrypted_message, str) - self.assertEqual(decrypted_message, message) + self.assertIsInstance(decrypted_message, bytes) + self.assertEqual(decrypted_message.decode("utf-8"), message) def test_decrypt_with_locked_private_key(self): encryption = Encryption(private_key3, passphrase) decrypted_message = encryption.decrypt(encrypted_message) - self.assertIsInstance(decrypted_message, str) - self.assertEqual(decrypted_message, message) + self.assertIsInstance(decrypted_message, bytes) + self.assertEqual(decrypted_message.decode("utf-8"), message) def test_decrypt_wrong_public_key(self): encryption = Encryption(private_key2) @@ -85,8 +85,8 @@ def test_decrypt_wrong_public_key(self): def test_decrypt_unsigned_message(self): encryption = Encryption(private_key3, passphrase) decrypted_message = encryption.decrypt(encrypted_unsigned_message) - self.assertIsInstance(decrypted_message, str) - self.assertEqual(decrypted_message, message) + self.assertIsInstance(decrypted_message, bytes) + self.assertEqual(decrypted_message.decode("utf-8"), message) def test_sign(self): encryption = Encryption(private_key)