Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions tests/portstat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,3 +1119,74 @@ def teardown_class(cls):
os.environ["UTILITIES_UNIT_TESTING"] = "0"
os.environ["UTILITIES_UNIT_TESTING_TOPOLOGY"] = ""
remove_tmp_cnstat_file()


class TestPortstatGetDbClient(object):
"""Test the get_db_client method for caching DB connections"""

def test_get_db_client_creates_new_connection(self):
"""Test that get_db_client creates a new connection when none exists"""
from unittest import mock
from utilities_common.portstat import Portstat
from utilities_common.constants import DEFAULT_NAMESPACE

portstat = Portstat(namespace=DEFAULT_NAMESPACE, display_option='')

with mock.patch('sonic_py_common.multi_asic.connect_to_all_dbs_for_ns') as mock_connect:
mock_db_client = mock.MagicMock()
mock_connect.return_value = mock_db_client

result = portstat.get_db_client('asic0')

mock_connect.assert_called_once_with('asic0')

assert result == mock_db_client

assert portstat.db_clients['asic0'] == mock_db_client

def test_get_db_client_returns_cached_connection(self):
"""Test that get_db_client returns cached connection on subsequent calls"""
from unittest import mock
from utilities_common.portstat import Portstat
from utilities_common.constants import DEFAULT_NAMESPACE

portstat = Portstat(namespace=DEFAULT_NAMESPACE, display_option='')

with mock.patch('sonic_py_common.multi_asic.connect_to_all_dbs_for_ns') as mock_connect:
mock_db_client = mock.MagicMock()
mock_connect.return_value = mock_db_client
result1 = portstat.get_db_client('asic0')
result2 = portstat.get_db_client('asic0')
mock_connect.assert_called_once_with('asic0')

assert result1 == result2
assert result1 == mock_db_client

def test_get_db_client_multiple_namespaces(self):
"""Test that get_db_client handles multiple namespaces correctly"""
from unittest import mock
from utilities_common.portstat import Portstat
from utilities_common.constants import DEFAULT_NAMESPACE

portstat = Portstat(namespace=DEFAULT_NAMESPACE, display_option='')

with mock.patch('sonic_py_common.multi_asic.connect_to_all_dbs_for_ns') as mock_connect:
mock_db_client_asic0 = mock.MagicMock()
mock_db_client_asic1 = mock.MagicMock()

def side_effect(ns):
if ns == 'asic0':
return mock_db_client_asic0
elif ns == 'asic1':
return mock_db_client_asic1

mock_connect.side_effect = side_effect
result_asic0 = portstat.get_db_client('asic0')
result_asic1 = portstat.get_db_client('asic1')
assert mock_connect.call_count == 2
mock_connect.assert_any_call('asic0')
mock_connect.assert_any_call('asic1')
assert result_asic0 == mock_db_client_asic0
assert result_asic1 == mock_db_client_asic1
assert portstat.db_clients['asic0'] == mock_db_client_asic0
assert portstat.db_clients['asic1'] == mock_db_client_asic1
12 changes: 9 additions & 3 deletions utilities_common/portstat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sonic_py_common import multi_asic
from sonic_py_common import device_info
from swsscommon.swsscommon import SonicV2Connector, CounterTable, PortCounter
from utilities_common import constants

from utilities_common import constants
import utilities_common.multi_asic as multi_asic_util
Expand Down Expand Up @@ -166,7 +167,7 @@ def __init__(self, namespace, display_option):
if device_info.is_supervisor():
self.db = SonicV2Connector(use_unix_socket_path=False)
self.db.connect(self.db.CHASSIS_STATE_DB, False)

self.db_clients = {}
self.sorted = natsorted

def get_cnstat_dict(self):
Expand Down Expand Up @@ -370,7 +371,7 @@ def get_port_speed(self, port_name):
state_db_table_id = PORT_STATE_TABLE_PREFIX + port_name
app_db_table_id = PORT_STATUS_TABLE_PREFIX + port_name
for ns in self.multi_asic.get_ns_list_based_on_options():
self.db = multi_asic.connect_to_all_dbs_for_ns(ns)
self.db = self.get_db_client(ns)
speed = self.db.get(self.db.STATE_DB, state_db_table_id, PORT_SPEED_FIELD)
oper_status = self.db.get(self.db.APPL_DB, app_db_table_id, PORT_OPER_STATUS_FIELD)
if speed is None or speed == STATUS_NA or oper_status != "up":
Expand All @@ -379,6 +380,11 @@ def get_port_speed(self, port_name):
return int(speed)
return STATUS_NA

def get_db_client(self, ns):
if not self.db_clients.get(ns):
self.db_clients[ns] = multi_asic.connect_to_all_dbs_for_ns(ns)
return self.db_clients[ns]

def get_port_state(self, port_name):
"""
Get the port state
Expand All @@ -392,7 +398,7 @@ def get_port_state(self, port_name):

full_table_id = PORT_STATUS_TABLE_PREFIX + port_name
for ns in self.multi_asic.get_ns_list_based_on_options():
self.db = multi_asic.connect_to_all_dbs_for_ns(ns)
self.db = self.get_db_client(ns)
admin_state = self.db.get(self.db.APPL_DB, full_table_id, PORT_ADMIN_STATUS_FIELD)
oper_state = self.db.get(self.db.APPL_DB, full_table_id, PORT_OPER_STATUS_FIELD)

Expand Down
Loading