Skip to content

[PTDT-4605] Add ability to specify relationship constraints #1969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions libs/labelbox/src/labelbox/schema/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, Tuple

from lbox.exceptions import InconsistentOntologyException

Expand Down Expand Up @@ -155,7 +155,6 @@ def add_classification(self, classification: Classification) -> None:
)
self.classifications.append(classification)


"""
The following 2 functions help to bridge the gap between the step reasoning all other tool ontologies.
"""
Expand All @@ -165,6 +164,8 @@ def tool_cls_from_type(tool_type: str):
tool_cls = map_tool_type_to_tool_cls(tool_type)
if tool_cls is not None:
return tool_cls
if tool_type == Tool.Type.RELATIONSHIP:
return RelationshipTool
return Tool


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# type: ignore

import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

from labelbox.schema.ontology import Tool

@dataclass
class RelationshipTool(Tool):
"""
A relationship tool to be added to a Project's ontology.

The "tool" parameter is automatically set to Tool.Type.RELATIONSHIP
and doesn't need to be passed during instantiation.

The "classifications" parameter holds a list of Classification objects.
This can be used to add nested classifications to a tool.

Example(s):
tool = RelationshipTool(
name = "Relationship Tool example",
constraints = [
("source_tool_feature_schema_id_1", "target_tool_feature_schema_id_1"),
("source_tool_feature_schema_id_2", "target_tool_feature_schema_id_2")
]
)
classification = Classification(
class_type = Classification.Type.TEXT,
instructions = "Classification Example")
tool.add_classification(classification)

Attributes:
tool: Tool.Type.RELATIONSHIP (automatically set)
name: (str)
required: (bool)
color: (str)
classifications: (list)
schema_id: (str)
feature_schema_id: (str)
attributes: (list)
constraints: (list of [str, str])
"""

constraints: Optional[List[Tuple[str, str]]] = None

def __init__(self, name: str, constraints: Optional[List[Tuple[str, str]]] = None, **kwargs):
super().__init__(Tool.Type.RELATIONSHIP, name, **kwargs)
if constraints is not None:
self.constraints = constraints

def __post_init__(self):
# Ensure tool type is set to RELATIONSHIP
self.tool = Tool.Type.RELATIONSHIP
super().__post_init__()

def asdict(self) -> Dict[str, Any]:
result = super().asdict()
if self.constraints is not None:
result["definition"] = { "constraints": self.constraints }
return result

def add_constraint(self, start: Tool, end: Tool) -> None:
if self.constraints is None:
self.constraints = []

# Ensure feature schema ids are set for the tools,
# the newly set ids will be changed during ontology creation
# but we need to refer to the same ids in the constraints array
# to ensure that the valid constraints are created.
if start.feature_schema_id is None:
start.feature_schema_id = str(uuid.uuid4())
if start.schema_id is None:
start.schema_id = str(uuid.uuid4())
if end.feature_schema_id is None:
end.feature_schema_id = str(uuid.uuid4())
if end.schema_id is None:
end.schema_id = str(uuid.uuid4())

self.constraints.append((start.feature_schema_id, end.feature_schema_id))

def set_constraints(self, constraints: List[Tuple[Tool, Tool]]) -> None:
self.constraints = []
for constraint in constraints:
self.add_constraint(constraint[0], constraint[1])
195 changes: 195 additions & 0 deletions libs/labelbox/tests/unit/test_unit_relationship_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import pytest
import uuid
from unittest.mock import patch

from labelbox.schema.ontology import Tool
from labelbox.schema.tool_building.relationship_tool import RelationshipTool
from labelbox.schema.tool_building.classification import Classification


def test_basic_instantiation():
tool = RelationshipTool(name="Test Relationship Tool")

assert tool.name == "Test Relationship Tool"
assert tool.tool == Tool.Type.RELATIONSHIP
assert tool.constraints is None
assert tool.required is False
assert tool.color is None
assert tool.schema_id is None
assert tool.feature_schema_id is None


def test_instantiation_with_constraints():
constraints = [
("source_id_1", "target_id_1"),
("source_id_2", "target_id_2")
]
tool = RelationshipTool(name="Test Tool", constraints=constraints)

assert tool.name == "Test Tool"
assert tool.constraints == constraints
assert len(tool.constraints) == 2

def test_post_init_sets_tool_type():
tool = RelationshipTool(name="Test Tool")
assert tool.tool == Tool.Type.RELATIONSHIP


def test_asdict_without_constraints():
tool = RelationshipTool(
name="Test Tool",
required=True,
color="#FF0000"
)

result = tool.asdict()
expected = {
"tool": "edge",
"name": "Test Tool",
"required": True,
"color": "#FF0000",
"classifications": [],
"schemaNodeId": None,
"featureSchemaId": None,
"attributes": None
}

assert result == expected

def test_asdict_with_constraints():
constraints = [("source_id", "target_id")]
tool = RelationshipTool(name="Test Tool", constraints=constraints)

result = tool.asdict()

assert "definition" in result
assert result["definition"] == {"constraints": constraints}
assert result["tool"] == "edge"
assert result["name"] == "Test Tool"


def test_add_constraint_to_empty_constraints():
tool = RelationshipTool(name="Test Tool")
start_tool = Tool(Tool.Type.BBOX, "Start Tool")
end_tool = Tool(Tool.Type.POLYGON, "End Tool")

with patch('uuid.uuid4') as mock_uuid:
mock_uuid.return_value.hex = "test-uuid"
tool.add_constraint(start_tool, end_tool)

assert tool.constraints is not None
assert len(tool.constraints) == 1
assert start_tool.feature_schema_id is not None
assert start_tool.schema_id is not None
assert end_tool.feature_schema_id is not None
assert end_tool.schema_id is not None


def test_add_constraint_to_existing_constraints():
existing_constraints = [("existing_source", "existing_target")]
tool = RelationshipTool(name="Test Tool", constraints=existing_constraints)

start_tool = Tool(Tool.Type.BBOX, "Start Tool")
end_tool = Tool(Tool.Type.POLYGON, "End Tool")

tool.add_constraint(start_tool, end_tool)

assert len(tool.constraints) == 2
assert tool.constraints[0] == ("existing_source", "existing_target")
assert tool.constraints[1] == (start_tool.feature_schema_id, end_tool.feature_schema_id)


def test_add_constraint_preserves_existing_ids():
tool = RelationshipTool(name="Test Tool")
start_tool_feature_schema_id = "start_tool_feature_schema_id"
start_tool_schema_id = "start_tool_schema_id"
start_tool = Tool(Tool.Type.BBOX, "Start Tool", feature_schema_id=start_tool_feature_schema_id, schema_id=start_tool_schema_id)
end_tool_feature_schema_id = "end_tool_feature_schema_id"
end_tool_schema_id = "end_tool_schema_id"
end_tool = Tool(Tool.Type.POLYGON, "End Tool", feature_schema_id=end_tool_feature_schema_id, schema_id=end_tool_schema_id)

tool.add_constraint(start_tool, end_tool)

assert start_tool.feature_schema_id == start_tool_feature_schema_id
assert start_tool.schema_id == start_tool_schema_id
assert end_tool.feature_schema_id == end_tool_feature_schema_id
assert end_tool.schema_id == end_tool_schema_id
assert tool.constraints == [(start_tool_feature_schema_id, end_tool_feature_schema_id)]


def test_set_constraints():
tool = RelationshipTool(name="Test Tool")

start_tool1 = Tool(Tool.Type.BBOX, "Start Tool 1")
end_tool1 = Tool(Tool.Type.POLYGON, "End Tool 1")
start_tool2 = Tool(Tool.Type.POINT, "Start Tool 2")
end_tool2 = Tool(Tool.Type.LINE, "End Tool 2")

tool.set_constraints([
(start_tool1, end_tool1),
(start_tool2, end_tool2)
])

assert len(tool.constraints) == 2
assert tool.constraints[0] == (start_tool1.feature_schema_id, end_tool1.feature_schema_id)
assert tool.constraints[1] == (start_tool2.feature_schema_id, end_tool2.feature_schema_id)


def test_set_constraints_replaces_existing():
existing_constraints = [("old_source", "old_target")]
tool = RelationshipTool(name="Test Tool", constraints=existing_constraints)

start_tool = Tool(Tool.Type.BBOX, "Start Tool")
end_tool = Tool(Tool.Type.POLYGON, "End Tool")

tool.set_constraints([(start_tool, end_tool)])

assert len(tool.constraints) == 1
assert tool.constraints[0] != ("old_source", "old_target")
assert tool.constraints[0] == (start_tool.feature_schema_id, end_tool.feature_schema_id)


def test_uuid_generation_in_add_constraint():
tool = RelationshipTool(name="Test Tool")

start_tool = Tool(Tool.Type.BBOX, "Start Tool")
end_tool = Tool(Tool.Type.POLYGON, "End Tool")

# Ensure tools don't have IDs initially
assert start_tool.feature_schema_id is None
assert start_tool.schema_id is None
assert end_tool.feature_schema_id is None
assert end_tool.schema_id is None

tool.add_constraint(start_tool, end_tool)

# Check that UUIDs were generated
assert start_tool.feature_schema_id is not None
assert start_tool.schema_id is not None
assert end_tool.feature_schema_id is not None
assert end_tool.schema_id is not None

# Check that they are valid UUID strings
uuid.UUID(start_tool.feature_schema_id) # Will raise ValueError if invalid
uuid.UUID(start_tool.schema_id)
uuid.UUID(end_tool.feature_schema_id)
uuid.UUID(end_tool.schema_id)


def test_constraints_in_asdict():
tool = RelationshipTool(name="Test Tool")

start_tool = Tool(Tool.Type.BBOX, "Start Tool")
end_tool = Tool(Tool.Type.POLYGON, "End Tool")

tool.add_constraint(start_tool, end_tool)

result = tool.asdict()

assert "definition" in result
assert "constraints" in result["definition"]
assert len(result["definition"]["constraints"]) == 1
assert result["definition"]["constraints"][0] == (
start_tool.feature_schema_id,
end_tool.feature_schema_id
)
Loading