Skip to content

feat: allow loose matching on primary values after checking for unique #108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 98 additions & 12 deletions netbox_diode_plugin/api/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from django.db.models.query_utils import Q
from extras.models.customfields import CustomField

from .common import UnresolvedReference
from .plugin_utils import content_type_id, get_object_type, get_object_type_model
from .common import NON_FIELD_ERRORS, ChangeSetException, UnresolvedReference
from .plugin_utils import content_type_id, get_object_type, get_object_type_model, get_primary_value_field

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,32 +58,53 @@
),
],
"ipam.ipaddress": lambda: [
GlobalIPNetworkIPMatcher(
# matches global ip addresses when no vrf is specified
IPNetworkIPMatcher(
ip_fields=("address",),
vrf_field="vrf",
model_class=get_object_type_model("ipam.ipaddress"),
name="logical_ip_address_global_no_vrf",
global_only=True,
),
# matches ip addresses within a vrf when a vrf is specified
VRFIPNetworkIPMatcher(
ip_fields=("address",),
vrf_field="vrf",
model_class=get_object_type_model("ipam.ipaddress"),
name="logical_ip_address_within_vrf",
),
# matches loosely anything with same ip address, ignoring mask
# this is necessary because the default loose primary value match would
# not ignore the network mask.
IPNetworkIPMatcher(
ip_fields=("address",),
vrf_field="vrf",
model_class=get_object_type_model("ipam.ipaddress"),
name="loose_ip_address_match",
global_only=False,
),
],
"ipam.iprange": lambda: [
GlobalIPNetworkIPMatcher(
IPNetworkIPMatcher(
ip_fields=("start_address", "end_address"),
vrf_field="vrf",
model_class=get_object_type_model("ipam.iprange"),
name="logical_ip_range_start_end_global_no_vrf",
global_only=True,
),
VRFIPNetworkIPMatcher(
ip_fields=("start_address", "end_address"),
vrf_field="vrf",
model_class=get_object_type_model("ipam.iprange"),
name="logical_ip_range_start_end_within_vrf",
),
IPNetworkIPMatcher(
ip_fields=("start_address", "end_address"),
vrf_field="vrf",
model_class=get_object_type_model("ipam.iprange"),
name="loose_ip_range_match",
global_only=False,
),
],
"ipam.prefix": lambda: [
ObjectMatchCriteria(
Expand Down Expand Up @@ -260,6 +281,10 @@ def _get_insensitive_refs(self) -> set[str]:
refs.add(source_expr.name)
return refs

def describe(self, data: dict) -> str:
"""Returns a description of the data used to match the object."""
return ", ".join(f"{k}='{v}'" for k, v in data.items() if k in self._get_refs())

def fingerprint(self, data: dict) -> str|None:
"""
Returns a fingerprint of the data based on these criteria.
Expand Down Expand Up @@ -415,6 +440,10 @@ class CustomFieldMatcher:
custom_field: str
model_class: Type[models.Model]

def describe(self, data: dict) -> str:
"""Returns a description of the data used to match the object."""
return f"{self.custom_field}={data.get(self.custom_field)}"

def fingerprint(self, data: dict) -> str|None:
"""Fingerprint the custom field value."""
if not self.has_required_fields(data):
Expand Down Expand Up @@ -443,17 +472,29 @@ def has_required_fields(self, data: dict) -> bool:


@dataclass
class GlobalIPNetworkIPMatcher:
class IPNetworkIPMatcher:
"""A matcher that ignores the mask."""

ip_fields: tuple[str]
vrf_field: str
model_class: Type[models.Model]
name: str
global_only: bool = False

def _check_condition(self, data: dict) -> bool:
"""Check the condition for the custom field."""
return data.get(self.vrf_field, None) is None
if self.global_only:
return data.get(self.vrf_field, None) is None
return True

def describe(self, data: dict) -> str:
"""Returns a description of the data used to match the object."""
desc = f"{', '.join(f'{k}={v}' for k, v in data.items() if k in self.ip_fields)}"
if self.global_only:
desc = f"{desc} (global only)"
else:
desc = f"{desc} (loose match)"
return desc

def fingerprint(self, data: dict) -> str|None:
"""Fingerprint the custom field value."""
Expand Down Expand Up @@ -491,9 +532,10 @@ def build_queryset(self, data: dict) -> models.QuerySet:
if not self._check_condition(data):
return None

filter = {
f'{self.vrf_field}__isnull': True,
}
filter = {}
if self.global_only:
filter[f'{self.vrf_field}__isnull'] = True

for field in self.ip_fields:
value = self.ip_value(data, field)
if value is None:
Expand All @@ -515,6 +557,12 @@ def _check_condition(self, data: dict) -> bool:
"""Check the condition for the custom field."""
return data.get(self.vrf_field, None) is not None

def describe(self, data: dict) -> str:
"""Returns a description of the data used to match the object."""
desc = f"{', '.join(f'{k}={v}' for k, v in data.items() if k in self.ip_fields)}"
desc = f"{desc}, {self.vrf_field}={data.get(self.vrf_field)})"
return desc

def fingerprint(self, data: dict) -> str|None:
"""Fingerprint the custom field value."""
if not self.has_required_fields(data):
Expand Down Expand Up @@ -585,6 +633,10 @@ class AutoSlugMatcher:
slug_field: str
model_class: Type[models.Model]

def describe(self, data: dict) -> str:
"""Returns a description of the data used to match the object."""
return f"{self.slug_field}={data.get('_auto_slug', None)}"

def fingerprint(self, data: dict) -> str|None:
"""Fingerprint the custom field value."""
if not self.has_required_fields(data):
Expand Down Expand Up @@ -629,6 +681,7 @@ def get_model_matchers(model_class) -> list:
name=f"unique_custom_field_{cf.name}",
)
)
matchers += _get_model_loose_matchers(model_class)
matchers += _get_autoslug_matchers(model_class)
return matchers

Expand Down Expand Up @@ -698,6 +751,21 @@ def _get_model_matchers(model_class) -> list[ObjectMatchCriteria]:

return matchers

@lru_cache(maxsize=256)
def _get_model_loose_matchers(model_class) -> list[ObjectMatchCriteria]:
object_type = get_object_type(model_class)
# search loosely by primary value if one is defined ...
primary_value_field = get_primary_value_field(object_type)
if primary_value_field is None:
return []

return [
ObjectMatchCriteria(
model_class=model_class,
fields=(primary_value_field,),
name=f"loose_match_by_{primary_value_field}",
)
]

def _is_supported_constraint(constraint, model_class) -> bool:
if not isinstance(constraint, models.UniqueConstraint):
Expand Down Expand Up @@ -790,13 +858,31 @@ def find_existing_object(data: dict, object_type: str): # noqa: C901
Returns the object if found, otherwise None.
"""
model_class = get_object_type_model(object_type)
ambiguous_match_errors = []
for matcher in get_model_matchers(model_class):
if not matcher.has_required_fields(data):
continue
q = matcher.build_queryset(data)
if q is None:
continue
existing = q.order_by('pk').first()
if existing is not None:
return existing
try:
return q.get()
except model_class.MultipleObjectsReturned:
# at least one "loose/non-unique" matcher found multiple
# objects that could be matched... so we don't consider it "new"
# If it cannot be unambiguously matched by some other matcher,
# we will raise an ambiguity error
ambiguous_match_errors.append(f"More than one object found loosely matching {matcher.describe(data)}")
continue
except model_class.DoesNotExist:
continue
if len(ambiguous_match_errors) > 0:
ambiguous_match_errors.insert(0, "No strictly unique matches found.")
raise ChangeSetException(
f"Ambiguous match for {object_type}: could not uniquely identify object from given data",
errors={
object_type: {
NON_FIELD_ERRORS: ambiguous_match_errors
}
})
return None
91 changes: 80 additions & 11 deletions netbox_diode_plugin/api/plugin_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
"""Diode plugin helpers."""

# Generated code. DO NOT EDIT.
# Timestamp: 2025-04-13 16:50:25Z
# Timestamp: 2025-05-20 14:28:26Z

from dataclasses import dataclass
import datetime
import decimal
import logging
from dataclasses import dataclass
from functools import lru_cache
import logging
from typing import Type

import netaddr
from core.models import ObjectType as NetBoxType
from django.contrib.contenttypes.models import ContentType
from django.db import models
import netaddr
from rest_framework.exceptions import ValidationError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -987,20 +987,89 @@ def legal_fields(object_type: str|Type[models.Model]) -> frozenset[str]:
return _LEGAL_FIELDS.get(object_type, frozenset())

_OBJECT_TYPE_PRIMARY_VALUE_FIELD_MAP = {
'ipam.asn': 'asn',
'dcim.devicetype': 'model',
'circuits.circuit': 'cid',
'ipam.ipaddress': 'address',
'circuits.circuitgroup': 'name',
'circuits.circuittype': 'name',
'circuits.provider': 'name',
'circuits.provideraccount': 'name',
'circuits.providernetwork': 'name',
'circuits.virtualcircuit': 'cid',
'circuits.virtualcircuittype': 'name',
'dcim.consoleport': 'name',
'dcim.consoleserverport': 'name',
'dcim.device': 'name',
'dcim.devicebay': 'name',
'dcim.devicerole': 'name',
'dcim.devicetype': 'model',
'dcim.frontport': 'name',
'dcim.interface': 'name',
'dcim.inventoryitem': 'name',
'dcim.inventoryitemrole': 'name',
'dcim.location': 'name',
'dcim.macaddress': 'mac_address',
'dcim.manufacturer': 'name',
'dcim.modulebay': 'name',
'dcim.moduletype': 'model',
'ipam.prefix': 'prefix',
'dcim.platform': 'name',
'dcim.powerfeed': 'name',
'dcim.poweroutlet': 'name',
'dcim.powerpanel': 'name',
'dcim.powerport': 'name',
'dcim.rack': 'name',
'dcim.rackrole': 'name',
'dcim.racktype': 'model',
'circuits.virtualcircuit': 'cid',
'dcim.rearport': 'name',
'dcim.region': 'name',
'dcim.site': 'name',
'dcim.sitegroup': 'name',
'dcim.virtualchassis': 'name',
'dcim.virtualdevicecontext': 'name',
'extras.tag': 'name',
'ipam.aggregate': 'prefix',
'ipam.asn': 'asn',
'ipam.asnrange': 'name',
'ipam.fhrpgroup': 'name',
'ipam.ipaddress': 'address',
'ipam.prefix': 'prefix',
'ipam.rir': 'name',
'ipam.role': 'name',
'ipam.routetarget': 'name',
'ipam.service': 'name',
'ipam.vlan': 'name',
'ipam.vlangroup': 'name',
'ipam.vlantranslationpolicy': 'name',
'ipam.vrf': 'name',
'tenancy.contact': 'name',
'tenancy.contactgroup': 'name',
'tenancy.contactrole': 'name',
'tenancy.tenant': 'name',
'tenancy.tenantgroup': 'name',
'virtualization.cluster': 'name',
'virtualization.clustergroup': 'name',
'virtualization.clustertype': 'name',
'virtualization.virtualdisk': 'name',
'virtualization.virtualmachine': 'name',
'virtualization.vminterface': 'name',
'vpn.ikepolicy': 'name',
'vpn.ikeproposal': 'name',
'vpn.ipsecpolicy': 'name',
'vpn.ipsecprofile': 'name',
'vpn.ipsecproposal': 'name',
'vpn.l2vpn': 'name',
'vpn.tunnel': 'name',
'vpn.tunnelgroup': 'name',
'wireless.wirelesslan': 'ssid',
'wireless.wirelesslangroup': 'name',
'wireless.wirelesslink': 'ssid',
}

def get_primary_value_field(object_type: str) -> str|None:
return _OBJECT_TYPE_PRIMARY_VALUE_FIELD_MAP.get(object_type)

def get_primary_value(data: dict, object_type: str) -> str|None:
field = _OBJECT_TYPE_PRIMARY_VALUE_FIELD_MAP.get(object_type, 'name')
field = get_primary_value_field(object_type)
if field is None:
return None
return data.get(field)


Expand Down Expand Up @@ -1214,4 +1283,4 @@ def apply_format_transformations(data: dict, object_type: str):
logger.error(f"Error processing field {key} in {sanitized_object_type} with value {sanitized_val}: {e}")
raise ValidationError(f"Invalid value for field {key} in {sanitized_object_type}.")
except Exception as e:
raise ValidationError(f'Invalid value for field {key} in {object_type}')
raise ValidationError(f'Invalid value {val} for field {key} in {object_type}')
12 changes: 5 additions & 7 deletions netbox_diode_plugin/tests/test_api_diff_and_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,9 +918,8 @@ def test_generate_diff_update_ip_address(self):
self.assertEqual(ip.vrf.name, f"VRF {vrf_uuid}")
self.assertEqual(ip.status, "active")

ip2 = IPAddress.objects.get(address="254.198.174.116/24", vrf__isnull=True)
self.assertEqual(ip2.vrf, None)
self.assertEqual(ip2.status, "deprecated")
# this updated the existing ip due to a loose match and associated the vrf, so there is no "other" address ...
self.assertRaises(IPAddress.DoesNotExist, IPAddress.objects.get, address="254.198.174.116/24", vrf__isnull=True)

payload = {
"timestamp": 1,
Expand All @@ -939,9 +938,8 @@ def test_generate_diff_update_ip_address(self):
ip = IPAddress.objects.get(address="254.198.174.116", vrf__name=f"VRF {vrf_uuid}")
self.assertEqual(ip.status, "dhcp")

ip2 = IPAddress.objects.get(address="254.198.174.116/24", vrf__isnull=True)
self.assertEqual(ip2.vrf, None)
self.assertEqual(ip2.status, "deprecated")
# this updated the existing ip due to a loose match and associated the vrf, so there is no "other" address ...
self.assertRaises(IPAddress.DoesNotExist, IPAddress.objects.get, address="254.198.174.116/24", vrf__isnull=True)

def test_generate_diff_and_apply_complex_vminterface(self):
"""Test generate diff and apply and update a complex vm interface."""
Expand Down Expand Up @@ -1047,7 +1045,7 @@ def test_generate_diff_and_apply_dedupe_devicetype(self):
"role": {"name": "Device Role 1"},
"site": {"name": "Site 1"}
},
"name": "Radio0/1",
"name": "Radio0/2",
"type": "ieee802.11ac",
"enabled": True
},
Expand Down
Loading