Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
------
"""

from typing import Optional, List
from typing import Optional, List, Union
from pgpy import PGPKey, PGPMessage
from pgpy.constants import SymmetricKeyAlgorithm
from pgpy.errors import PGPError
Expand Down Expand Up @@ -74,7 +74,9 @@ def __init__(self, private_key_armored: str, passphrase: Optional[str] = None):
else:
raise ValueError("Private key locked. Passphrase needed")

def sign_and_encrypt(self, message: str, public_keys: List[str]) -> str:
def sign_and_encrypt(
self, message: Union[str, bytes], public_keys: List[str]
) -> str:
"""
Signs and encrypts a message using the private key and recipient's public keys.

Expand Down Expand Up @@ -139,7 +141,9 @@ def sign_and_encrypt(self, message: str, public_keys: List[str]) -> str:
"your message", [public_key2, public_key3]
)
"""

pgp_message = PGPMessage.new(message)

if not self.private_key.is_unlocked:
try:
with self.private_key.unlock(self.passphrase):
Expand All @@ -159,7 +163,7 @@ def sign_and_encrypt(self, message: str, public_keys: List[str]) -> str:
del sessionkey
return pgp_message.__str__()

def decrypt(self, message: str, public_key: Optional[str] = None) -> str:
def decrypt(self, message: str, public_key: Optional[str] = None) -> bytes:
"""
Decrypts a message using the private key.

Expand Down Expand Up @@ -209,7 +213,10 @@ def decrypt(self, message: str, public_key: Optional[str] = None) -> str:
public_key, _ = PGPKey.from_blob(public_key)
public_key.verify(decrypted_message)

return decrypted_message.message.__str__()
if isinstance(decrypted_message.message, str):
return bytes(decrypted_message.message, encoding="utf-8")
else:
return bytes(decrypted_message.message)
except PGPError as e:
if (
decrypted_message
Expand All @@ -221,7 +228,7 @@ def decrypt(self, message: str, public_key: Optional[str] = None) -> str:
)
raise ValueError("Failed to decrypt message: {}".format(str(e)))

def sign(self, message: str) -> str:
def sign(self, message: Union[str, bytes]) -> str:
"""
Signs a message using the private key.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,20 @@ def test_encrypt_with_locked_private_key(self):
def test_decrypt(self):
encryption = Encryption(private_key2)
decrypted_message = encryption.decrypt(encrypted_message)
self.assertIsInstance(decrypted_message, str)
self.assertEqual(decrypted_message, message)
self.assertIsInstance(decrypted_message, bytes)
self.assertEqual(decrypted_message.decode("utf-8"), message)

def test_decrypt_checking_signature(self):
encryption = Encryption(private_key2)
decrypted_message = encryption.decrypt(encrypted_message, public_key)
self.assertIsInstance(decrypted_message, str)
self.assertEqual(decrypted_message, message)
self.assertIsInstance(decrypted_message, bytes)
self.assertEqual(decrypted_message.decode("utf-8"), message)

def test_decrypt_with_locked_private_key(self):
encryption = Encryption(private_key3, passphrase)
decrypted_message = encryption.decrypt(encrypted_message)
self.assertIsInstance(decrypted_message, str)
self.assertEqual(decrypted_message, message)
self.assertIsInstance(decrypted_message, bytes)
self.assertEqual(decrypted_message.decode("utf-8"), message)

def test_decrypt_wrong_public_key(self):
encryption = Encryption(private_key2)
Expand All @@ -85,8 +85,8 @@ def test_decrypt_wrong_public_key(self):
def test_decrypt_unsigned_message(self):
encryption = Encryption(private_key3, passphrase)
decrypted_message = encryption.decrypt(encrypted_unsigned_message)
self.assertIsInstance(decrypted_message, str)
self.assertEqual(decrypted_message, message)
self.assertIsInstance(decrypted_message, bytes)
self.assertEqual(decrypted_message.decode("utf-8"), message)

def test_sign(self):
encryption = Encryption(private_key)
Expand Down