Skip to content

Commit fc53dc7

Browse files
committed
Updated FreeCallPaymentStrategy
1 parent a921a83 commit fc53dc7

5 files changed

Lines changed: 62 additions & 68 deletions

File tree

snet/sdk/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def create_service_client(self,
103103
org_id: str,
104104
service_id: str,
105105
group_name: str=None,
106+
payment_strategy: PaymentStrategy = None,
106107
payment_strategy_type: PaymentStrategyType=PaymentStrategyType.DEFAULT,
107108
address=None,
108109
options=None,
@@ -135,7 +136,8 @@ def create_service_client(self,
135136
options['concurrency'] = self._sdk_config.get("concurrency", True)
136137
options['concurrent_calls'] = concurrent_calls
137138

138-
139+
if payment_strategy is None:
140+
payment_strategy = payment_strategy_type.value()
139141

140142
service_metadata = self._metadata_provider.enhance_service_metadata(
141143
org_id, service_id
@@ -146,7 +148,8 @@ def create_service_client(self,
146148

147149
pb2_module = self.get_module_by_keyword(keyword="pb2.py")
148150
_service_client = ServiceClient(org_id, service_id, service_metadata,
149-
group, service_stubs, payment_strategy_type.value(),
151+
group, service_stubs,
152+
payment_strategy,
150153
options, self.mpe_contract,
151154
self.account, self.web3, pb2_module,
152155
self.payment_channel_provider,

snet/sdk/payment_strategies/default_payment_strategy.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,9 @@
77

88
class DefaultPaymentStrategy(PaymentStrategy):
99

10-
def __init__(self, concurrent_calls: int = 1):
11-
self.concurrent_calls = concurrent_calls
12-
self.concurrencyManager = ConcurrencyManager(concurrent_calls)
10+
def __init__(self):
1311
self.channel = None
1412

15-
def set_concurrency_token(self, token):
16-
self.concurrencyManager.__token = token
17-
1813
def set_channel(self, channel):
1914
self.channel = channel
2015

@@ -25,7 +20,8 @@ def get_payment_metadata(self, service_client):
2520
metadata = free_call_payment_strategy.get_payment_metadata(service_client)
2621
else:
2722
if service_client.get_concurrency_flag():
28-
payment_strategy = PrePaidPaymentStrategy(self.concurrencyManager)
23+
concurrent_calls = service_client.get_concurrent_calls()
24+
payment_strategy = PrePaidPaymentStrategy(concurrent_calls)
2925
metadata = payment_strategy.get_payment_metadata(service_client, self.channel)
3026
else:
3127
payment_strategy = PaidCallPaymentStrategy()
@@ -37,5 +33,6 @@ def get_price(self, service_client):
3733
pass
3834

3935
def get_concurrency_token_and_channel(self, service_client):
40-
payment_strategy = PrePaidPaymentStrategy(self.concurrencyManager)
36+
concurrent_calls = service_client.get_concurrent_calls()
37+
payment_strategy = PrePaidPaymentStrategy(concurrent_calls)
4138
return payment_strategy.get_concurrency_token_and_channel(service_client)

snet/sdk/payment_strategies/freecall_payment_strategy.py

Lines changed: 45 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,48 @@
11
import importlib
2-
from urllib.parse import urlparse
32

43
import grpc
5-
from grpc import _channel
64
import web3
75

86
from snet.sdk.payment_strategies.payment_strategy import PaymentStrategy
9-
from snet.sdk.resources.root_certificate import certificate
107
from snet.sdk.utils.utils import RESOURCES_PATH, add_to_path
118

129
class FreeCallPaymentStrategy(PaymentStrategy):
1310

11+
def __init__(self):
12+
self._user_address = None
13+
self._free_call_token = None
14+
self._token_expiration_block = None
15+
self._free_calls_available = None
16+
1417
def is_free_call_available(self, service_client) -> bool:
15-
try:
18+
if not self._user_address:
1619
self._user_address = service_client.account.signer_address
17-
self._free_call_token, self._token_expiry_date_block = self.get_free_call_token_details(service_client)
1820

19-
if not self._free_call_token:
20-
return False
21+
current_block_number = service_client.get_current_block_number()
2122

22-
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
23-
state_service_pb2 = importlib.import_module("state_service_pb2")
23+
if (not self._free_call_token or
24+
not self._token_expiration_block or
25+
current_block_number > self._token_expiration_block):
26+
self._free_call_token, self._token_expiration_block = self.get_free_call_token_details(service_client)
2427

25-
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
26-
state_service_pb2_grpc = importlib.import_module("state_service_pb2_grpc")
28+
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
29+
state_service_pb2 = importlib.import_module("state_service_pb2")
2730

28-
signature, current_block_number = self.generate_signature(service_client)
31+
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
32+
state_service_pb2_grpc = importlib.import_module("state_service_pb2_grpc")
2933

30-
request = state_service_pb2.FreeCallStateRequest()
31-
request.user_address = self._user_address
32-
request.token_for_free_call = self._free_call_token
33-
request.token_expiry_date_block = self._token_expiry_date_block
34-
request.signature = signature
35-
request.current_block = current_block_number
34+
signature = self.generate_signature(service_client, current_block_number)
35+
request = state_service_pb2.FreeCallStateRequest(
36+
address=self._user_address,
37+
free_call_token=self._free_call_token,
38+
signature=signature,
39+
current_block=current_block_number
40+
)
3641

37-
channel = self.select_channel(service_client)
42+
channel = service_client._get_grpc_channel()
43+
stub = state_service_pb2_grpc.FreeCallStateServiceStub(channel)
3844

39-
stub = state_service_pb2_grpc.FreeCallStateServiceStub(channel)
45+
try:
4046
response = stub.GetFreeCallsAvailable(request)
4147
if response.free_calls_available > 0:
4248
return True
@@ -45,10 +51,10 @@ def is_free_call_available(self, service_client) -> bool:
4551
if self._user_address:
4652
print(f"Warning: {e.details()}")
4753
return False
48-
except Exception as e:
49-
return False
5054

5155
def get_payment_metadata(self, service_client) -> list:
56+
if self.is_free_call_available(service_client):
57+
raise Exception(f"Free calls limit for address {self._user_address} has expired. Please use another payment strategy")
5258
signature, current_block_number = self.generate_signature(service_client)
5359
metadata = [("snet-free-call-auth-token-bin", self._free_call_token),
5460
("snet-payment-type", "free-call"),
@@ -58,52 +64,39 @@ def get_payment_metadata(self, service_client) -> list:
5864

5965
return metadata
6066

61-
def select_channel(self, service_client) -> _channel.Channel:
62-
_, _, _, daemon_endpoint = service_client.get_service_details()
63-
endpoint_object = urlparse(daemon_endpoint)
64-
if endpoint_object.port is not None:
65-
channel_endpoint = endpoint_object.hostname + ":" + str(endpoint_object.port)
66-
else:
67-
channel_endpoint = endpoint_object.hostname
68-
69-
if endpoint_object.scheme == "http":
70-
channel = grpc.insecure_channel(channel_endpoint)
71-
elif endpoint_object.scheme == "https":
72-
channel = grpc.secure_channel(channel_endpoint, grpc.ssl_channel_credentials(root_certificates=certificate))
73-
else:
74-
raise ValueError('Unsupported scheme in service metadata ("{}")'.format(endpoint_object.scheme))
75-
return channel
76-
77-
def generate_signature(self, service_client) -> tuple[bytes, int]:
67+
def generate_signature(self, service_client, current_block_number=None, with_token=True) -> tuple[bytes, int]:
68+
if not current_block_number:
69+
current_block_number = service_client.get_current_block_number()
7870
org_id, service_id, group_id, _ = service_client.get_service_details()
7971

80-
if self._token_expiry_date_block == 0 or len(self._user_address) == 0 or len(self._free_call_token) == 0:
81-
raise Exception(
82-
"You are using default 'FreeCallPaymentStrategy' to use this strategy you need to pass "
83-
"'free_call_auth_token-bin','user_address','free-call-token-expiry-block' in config")
72+
message_types = ["string", "string", "string", "string", "string", "uint256", "bytes32"]
73+
message_values = ["__prefix_free_trial", self._user_address, org_id, service_id, group_id,
74+
current_block_number, self._free_call_token]
8475

85-
current_block_number = service_client.get_current_block_number()
76+
if not with_token:
77+
message_types = message_types[:-1]
78+
message_values = message_values[:-1]
8679

87-
message = web3.Web3.solidity_keccak(
88-
["string", "string", "string", "string", "string", "uint256", "bytes32"],
89-
["__prefix_free_trial", self._user_address, org_id, service_id, group_id, current_block_number,
90-
self._free_call_token]
91-
)
80+
message = web3.Web3.solidity_keccak(message_types, message_values)
9281
return service_client.generate_signature(message), current_block_number
9382

94-
def get_free_call_token_details(self, service_client) -> tuple[bytes, int]:
83+
def get_free_call_token_details(self, service_client, current_block_number=None) -> tuple[bytes, int]:
84+
85+
signature, current_block_number = self.generate_signature(service_client, current_block_number, with_token=False)
86+
9587
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
9688
state_service_pb2 = importlib.import_module("state_service_pb2")
9789

9890
request = state_service_pb2.GetFreeCallTokenRequest(
9991
address=self._user_address,
100-
92+
signature=signature,
93+
current_block=current_block_number
10194
)
10295

10396
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
10497
state_service_pb2_grpc = importlib.import_module("state_service_pb2_grpc")
10598

106-
channel = self.select_channel(service_client)
99+
channel = service_client._get_grpc_channel()
107100
stub = state_service_pb2_grpc.FreeCallStateServiceStub(channel)
108101
response = stub.GetFreeCallToken(request)
109102

snet/sdk/payment_strategies/prepaid_payment_strategy.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,16 @@
44

55
class PrePaidPaymentStrategy(PaymentStrategy):
66

7-
def __init__(self, concurrency_manager: ConcurrencyManager,
8-
block_offset: int = 240, call_allowance: int = 1):
9-
self.concurrency_manager = concurrency_manager
7+
def __init__(self, concurrent_calls: int, block_offset: int = 240, call_allowance: int = 1):
8+
self.concurrency_manager = ConcurrencyManager(concurrent_calls)
109
self.block_offset = block_offset
1110
self.call_allowance = call_allowance
1211

1312
def get_price(self, service_client):
1413
return service_client.get_price() * self.concurrency_manager.concurrent_calls
1514

16-
def get_payment_metadata(self, service_client, channel):
17-
if channel is None:
18-
channel = self.select_channel(service_client)
15+
def get_payment_metadata(self, service_client):
16+
channel = self.select_channel(service_client)
1917
token = self.concurrency_manager.get_token(service_client, channel, self.get_price(service_client))
2018
metadata = [
2119
("snet-payment-type", "prepaid-call"),

snet/sdk/service_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ def get_concurrency_flag(self) -> bool:
240240
def get_concurrency_token_and_channel(self) -> tuple[str, PaymentChannel]:
241241
return self.payment_strategy.get_concurrency_token_and_channel(self)
242242

243+
def get_concurrent_calls(self):
244+
return self.options.get('concurrent_calls', 1)
245+
243246
def set_concurrency_token_and_channel(self, token: str,
244247
channel: PaymentChannel) -> None:
245248
self.payment_strategy.concurrency_token = token

0 commit comments

Comments
 (0)