Skip to content

feat: allow loose matching on primary values after checking for unique #108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 98 additions & 12 deletions netbox_diode_plugin/api/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
# Copyright 2025 NetBox Labs, Inc.
"""Diode NetBox Plugin - API - Object matching utilities."""

import logging
from dataclasses import dataclass
from functools import cache, lru_cache
from typing import Type

import netaddr
from django.contrib.contenttypes.fields import ContentType
from django.core.exceptions import FieldDoesNotExist
from django.db import models
from django.db.models import F, Value
from django.db.models.fields import SlugField
from django.db.models.lookups import Exact
from django.db.models.query_utils import Q
from extras.models.customfields import CustomField

from .common import UnresolvedReference
from .plugin_utils import content_type_id, get_object_type, get_object_type_model
from .common import UnresolvedReference, ChangeSetException, NON_FIELD_ERRORS
from .plugin_utils import content_type_id, get_object_type, get_object_type_model, get_primary_value_field

Check failure on line 21 in netbox_diode_plugin/api/matcher.py

View workflow job for this annotation

GitHub Actions / tests (3.10)

Ruff (I001)

netbox_diode_plugin/api/matcher.py:5:1: I001 Import block is un-sorted or un-formatted

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,32 +58,53 @@
),
],
"ipam.ipaddress": lambda: [
GlobalIPNetworkIPMatcher(
# matches global ip addresses when no vrf is specified
IPNetworkIPMatcher(
ip_fields=("address",),
vrf_field="vrf",
model_class=get_object_type_model("ipam.ipaddress"),
name="logical_ip_address_global_no_vrf",
global_only=True,
),
# matches ip addresses within a vrf when a vrf is specified
VRFIPNetworkIPMatcher(
ip_fields=("address",),
vrf_field="vrf",
model_class=get_object_type_model("ipam.ipaddress"),
name="logical_ip_address_within_vrf",
),
# matches loosely anything with same ip address, ignoring mask
# this is necessary because the default loose primary value match would
# not ignore the network mask.
IPNetworkIPMatcher(
ip_fields=("address",),
vrf_field="vrf",
model_class=get_object_type_model("ipam.ipaddress"),
name="loose_ip_address_match",
global_only=False,
),
],
"ipam.iprange": lambda: [
GlobalIPNetworkIPMatcher(
IPNetworkIPMatcher(
ip_fields=("start_address", "end_address"),
vrf_field="vrf",
model_class=get_object_type_model("ipam.iprange"),
name="logical_ip_range_start_end_global_no_vrf",
global_only=True,
),
VRFIPNetworkIPMatcher(
ip_fields=("start_address", "end_address"),
vrf_field="vrf",
model_class=get_object_type_model("ipam.iprange"),
name="logical_ip_range_start_end_within_vrf",
),
IPNetworkIPMatcher(
ip_fields=("start_address", "end_address"),
vrf_field="vrf",
model_class=get_object_type_model("ipam.iprange"),
name="loose_ip_range_match",
global_only=False,
),
],
"ipam.prefix": lambda: [
ObjectMatchCriteria(
Expand Down Expand Up @@ -260,6 +281,10 @@
refs.add(source_expr.name)
return refs

def describe(self, data: dict) -> str:
"""Returns a description of the data used to match the object."""
return ", ".join(f"{k}='{v}'" for k, v in data.items() if k in self._get_refs())

def fingerprint(self, data: dict) -> str|None:
"""
Returns a fingerprint of the data based on these criteria.
Expand Down Expand Up @@ -415,6 +440,10 @@
custom_field: str
model_class: Type[models.Model]

def describe(self, data: dict) -> str:
"""Returns a description of the data used to match the object."""
return f"{self.custom_field}={data.get(self.custom_field)}"

def fingerprint(self, data: dict) -> str|None:
"""Fingerprint the custom field value."""
if not self.has_required_fields(data):
Expand Down Expand Up @@ -443,17 +472,29 @@


@dataclass
class GlobalIPNetworkIPMatcher:
class IPNetworkIPMatcher:
"""A matcher that ignores the mask."""

ip_fields: tuple[str]
vrf_field: str
model_class: Type[models.Model]
name: str
global_only: bool = False

def _check_condition(self, data: dict) -> bool:
"""Check the condition for the custom field."""
return data.get(self.vrf_field, None) is None
if self.global_only:
return data.get(self.vrf_field, None) is None
return True

def describe(self, data: dict) -> str:
"""Returns a description of the data used to match the object."""
desc = f"{', '.join(f'{k}={v}' for k, v in data.items() if k in self.ip_fields)}"
if self.global_only:
desc = f"{desc} (global only)"
else:
desc = f"{desc} (loose match)"
return desc

def fingerprint(self, data: dict) -> str|None:
"""Fingerprint the custom field value."""
Expand Down Expand Up @@ -491,9 +532,10 @@
if not self._check_condition(data):
return None

filter = {
f'{self.vrf_field}__isnull': True,
}
filter = {}
if self.global_only:
filter[f'{self.vrf_field}__isnull'] = True

for field in self.ip_fields:
value = self.ip_value(data, field)
if value is None:
Expand All @@ -515,6 +557,12 @@
"""Check the condition for the custom field."""
return data.get(self.vrf_field, None) is not None

def describe(self, data: dict) -> str:
"""Returns a description of the data used to match the object."""
desc = f"{', '.join(f'{k}={v}' for k, v in data.items() if k in self.ip_fields)}"
desc = f"{desc}, {self.vrf_field}={data.get(self.vrf_field)})"
return desc

def fingerprint(self, data: dict) -> str|None:
"""Fingerprint the custom field value."""
if not self.has_required_fields(data):
Expand Down Expand Up @@ -585,6 +633,10 @@
slug_field: str
model_class: Type[models.Model]

def describe(self, data: dict) -> str:
"""Returns a description of the data used to match the object."""
return f"{self.slug_field}={data.get('_auto_slug', None)}"

def fingerprint(self, data: dict) -> str|None:
"""Fingerprint the custom field value."""
if not self.has_required_fields(data):
Expand Down Expand Up @@ -629,6 +681,7 @@
name=f"unique_custom_field_{cf.name}",
)
)
matchers += _get_model_loose_matchers(model_class)
matchers += _get_autoslug_matchers(model_class)
return matchers

Expand Down Expand Up @@ -698,6 +751,21 @@

return matchers

@lru_cache(maxsize=256)
def _get_model_loose_matchers(model_class) -> list[ObjectMatchCriteria]:
object_type = get_object_type(model_class)
# search loosely by primary value if one is defined ...
primary_value_field = get_primary_value_field(object_type)
if primary_value_field is None:
return []

return [
ObjectMatchCriteria(
model_class=model_class,
fields=(primary_value_field,),
name=f"loose_match_by_{primary_value_field}",
)
]

def _is_supported_constraint(constraint, model_class) -> bool:
if not isinstance(constraint, models.UniqueConstraint):
Expand Down Expand Up @@ -790,13 +858,31 @@
Returns the object if found, otherwise None.
"""
model_class = get_object_type_model(object_type)
ambiguous_match_errors = []
for matcher in get_model_matchers(model_class):
if not matcher.has_required_fields(data):
continue
q = matcher.build_queryset(data)
if q is None:
continue
existing = q.order_by('pk').first()
if existing is not None:
return existing
try:
return q.get()
except model_class.MultipleObjectsReturned:
# at least one "loose/non-unique" matcher found multiple
# objects that could be matched... so we don't consider it "new"
# If it cannot be unambiguously matched by some other matcher,
# we will raise an ambiguity error
ambiguous_match_errors.append(f"More than one object found loosely matching {matcher.describe(data)}")
continue
except model_class.DoesNotExist:
continue
if len(ambiguous_match_errors) > 0:
ambiguous_match_errors.insert(0, "No strictly unique matches found.")
raise ChangeSetException(
f"Ambiguous match for {object_type}: could not uniquely identify object from given data",
errors={
object_type: {
NON_FIELD_ERRORS: ambiguous_match_errors
}
})
return None
96 changes: 81 additions & 15 deletions netbox_diode_plugin/api/plugin_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
"""Diode plugin helpers."""

# Generated code. DO NOT EDIT.
# Timestamp: 2025-04-13 16:50:25Z
# Timestamp: 2025-05-20 14:28:26Z

from dataclasses import dataclass
import datetime
import decimal
import logging
from dataclasses import dataclass
from functools import lru_cache
import logging
from typing import Type

import netaddr
from core.models import ObjectType as NetBoxType
from django.contrib.contenttypes.models import ContentType
from django.db import models
import netaddr
from rest_framework.exceptions import ValidationError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -987,20 +987,89 @@
return _LEGAL_FIELDS.get(object_type, frozenset())

_OBJECT_TYPE_PRIMARY_VALUE_FIELD_MAP = {
'ipam.asn': 'asn',
'dcim.devicetype': 'model',
'circuits.circuit': 'cid',
'ipam.ipaddress': 'address',
'circuits.circuitgroup': 'name',
'circuits.circuittype': 'name',
'circuits.provider': 'name',
'circuits.provideraccount': 'name',
'circuits.providernetwork': 'name',
'circuits.virtualcircuit': 'cid',
'circuits.virtualcircuittype': 'name',
'dcim.consoleport': 'name',
'dcim.consoleserverport': 'name',
'dcim.device': 'name',
'dcim.devicebay': 'name',
'dcim.devicerole': 'name',
'dcim.devicetype': 'model',
'dcim.frontport': 'name',
'dcim.interface': 'name',
'dcim.inventoryitem': 'name',
'dcim.inventoryitemrole': 'name',
'dcim.location': 'name',
'dcim.macaddress': 'mac_address',
'dcim.manufacturer': 'name',
'dcim.modulebay': 'name',
'dcim.moduletype': 'model',
'ipam.prefix': 'prefix',
'dcim.platform': 'name',
'dcim.powerfeed': 'name',
'dcim.poweroutlet': 'name',
'dcim.powerpanel': 'name',
'dcim.powerport': 'name',
'dcim.rack': 'name',
'dcim.rackrole': 'name',
'dcim.racktype': 'model',
'circuits.virtualcircuit': 'cid',
'dcim.rearport': 'name',
'dcim.region': 'name',
'dcim.site': 'name',
'dcim.sitegroup': 'name',
'dcim.virtualchassis': 'name',
'dcim.virtualdevicecontext': 'name',
'extras.tag': 'name',
'ipam.aggregate': 'prefix',
'ipam.asn': 'asn',
'ipam.asnrange': 'name',
'ipam.fhrpgroup': 'name',
'ipam.ipaddress': 'address',
'ipam.prefix': 'prefix',
'ipam.rir': 'name',
'ipam.role': 'name',
'ipam.routetarget': 'name',
'ipam.service': 'name',
'ipam.vlan': 'name',
'ipam.vlangroup': 'name',
'ipam.vlantranslationpolicy': 'name',
'ipam.vrf': 'name',
'tenancy.contact': 'name',
'tenancy.contactgroup': 'name',
'tenancy.contactrole': 'name',
'tenancy.tenant': 'name',
'tenancy.tenantgroup': 'name',
'virtualization.cluster': 'name',
'virtualization.clustergroup': 'name',
'virtualization.clustertype': 'name',
'virtualization.virtualdisk': 'name',
'virtualization.virtualmachine': 'name',
'virtualization.vminterface': 'name',
'vpn.ikepolicy': 'name',
'vpn.ikeproposal': 'name',
'vpn.ipsecpolicy': 'name',
'vpn.ipsecprofile': 'name',
'vpn.ipsecproposal': 'name',
'vpn.l2vpn': 'name',
'vpn.tunnel': 'name',
'vpn.tunnelgroup': 'name',
'wireless.wirelesslan': 'ssid',
'wireless.wirelesslangroup': 'name',
'wireless.wirelesslink': 'ssid',
}

def get_primary_value_field(object_type: str) -> str|None:
return _OBJECT_TYPE_PRIMARY_VALUE_FIELD_MAP.get(object_type)

def get_primary_value(data: dict, object_type: str) -> str|None:
field = _OBJECT_TYPE_PRIMARY_VALUE_FIELD_MAP.get(object_type, 'name')
field = get_primary_value_field(object_type)
if field is None:
return None
return data.get(field)


Expand Down Expand Up @@ -1209,9 +1278,6 @@
except ValidationError:
raise
except ValueError as e:
sanitized_object_type = object_type.replace('\n', '').replace('\r', '')
sanitized_val = str(val).replace('\n', '').replace('\r', '')
logger.error(f"Error processing field {key} in {sanitized_object_type} with value {sanitized_val}: {e}")
raise ValidationError(f"Invalid value for field {key} in {sanitized_object_type}.")
raise ValidationError(f'Invalid value {val} for field {key} in {object_type}: {e}')
except Exception as e:
raise ValidationError(f'Invalid value for field {key} in {object_type}')
raise ValidationError(f'Invalid value {val} for field {key} in {object_type}')
12 changes: 5 additions & 7 deletions netbox_diode_plugin/tests/test_api_diff_and_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,9 +918,8 @@ def test_generate_diff_update_ip_address(self):
self.assertEqual(ip.vrf.name, f"VRF {vrf_uuid}")
self.assertEqual(ip.status, "active")

ip2 = IPAddress.objects.get(address="254.198.174.116/24", vrf__isnull=True)
self.assertEqual(ip2.vrf, None)
self.assertEqual(ip2.status, "deprecated")
# this updated the existing ip due to a loose match and associated the vrf, so there is no "other" address ...
self.assertRaises(IPAddress.DoesNotExist, IPAddress.objects.get, address="254.198.174.116/24", vrf__isnull=True)

payload = {
"timestamp": 1,
Expand All @@ -939,9 +938,8 @@ def test_generate_diff_update_ip_address(self):
ip = IPAddress.objects.get(address="254.198.174.116", vrf__name=f"VRF {vrf_uuid}")
self.assertEqual(ip.status, "dhcp")

ip2 = IPAddress.objects.get(address="254.198.174.116/24", vrf__isnull=True)
self.assertEqual(ip2.vrf, None)
self.assertEqual(ip2.status, "deprecated")
# this updated the existing ip due to a loose match and associated the vrf, so there is no "other" address ...
self.assertRaises(IPAddress.DoesNotExist, IPAddress.objects.get, address="254.198.174.116/24", vrf__isnull=True)

def test_generate_diff_and_apply_complex_vminterface(self):
"""Test generate diff and apply and update a complex vm interface."""
Expand Down Expand Up @@ -1047,7 +1045,7 @@ def test_generate_diff_and_apply_dedupe_devicetype(self):
"role": {"name": "Device Role 1"},
"site": {"name": "Site 1"}
},
"name": "Radio0/1",
"name": "Radio0/2",
"type": "ieee802.11ac",
"enabled": True
},
Expand Down
Loading
Loading