diff --git a/packages/datacommons-api/datacommons_api/api_cli.py b/packages/datacommons-api/datacommons_api/api_cli.py index 3b50566..7e2a7a5 100644 --- a/packages/datacommons-api/datacommons_api/api_cli.py +++ b/packages/datacommons-api/datacommons_api/api_cli.py @@ -16,9 +16,10 @@ import uvicorn from datacommons_api.app import app -from datacommons_api.core.config import initialize_config +from datacommons_api.core.config import get_config, initialize_config from datacommons_api.core.logging import get_logger, setup_logging -from datacommons_db.session import initialize_db +from datacommons_db.session import get_session, initialize_db +from datacommons_api.services.graph_service import GraphService setup_logging() logger = get_logger(__name__) @@ -72,3 +73,42 @@ def start( port=port, reload=reload, ) + + +@api.command() +@click.option("--gcp-project-id", help="GCP project id.", required=True) +@click.option( + "--gcp-spanner-instance-id", help="GCP Spanner instance id.", required=True +) +@click.option( + "--gcp-spanner-database-name", help="GCP Spanner database name.", required=True +) +@click.option("--yes", is_flag=True, help="Skip confirmation prompt.") +def drop_tables( + gcp_project_id: str, + gcp_spanner_instance_id: str, + gcp_spanner_database_name: str, + yes: bool, +): + """Drop Node and Edge tables from the graph database.""" + # TODO: Refactor this method to only drop the data from the tables, not the tables themselves. + if not yes: + click.confirm( + "Are you sure you want to drop the Node and Edge tables?", abort=True + ) + + logger.info("Dropping Node and Edge tables from the graph database") + initialize_config( + gcp_project_id=gcp_project_id, + gcp_spanner_instance_id=gcp_spanner_instance_id, + gcp_spanner_database_name=gcp_spanner_database_name, + ) + config = get_config() + db = get_session( + config.GCP_PROJECT_ID, + config.GCP_SPANNER_INSTANCE_ID, + config.GCP_SPANNER_DATABASE_NAME, + ) + graph_service = GraphService(db) + graph_service.drop_tables() + logger.info("Successfully dropped Node and Edge tables") diff --git a/packages/datacommons-api/datacommons_api/core/config.py b/packages/datacommons-api/datacommons_api/core/config.py index a73810b..3fbba81 100644 --- a/packages/datacommons-api/datacommons_api/core/config.py +++ b/packages/datacommons-api/datacommons_api/core/config.py @@ -65,7 +65,7 @@ def validate_config_or_exit(config: Config) -> None: # Ensure GCP Spanner is configured for var in REQUIRED_ENV_VARS: if not getattr(config, var): - logger.error("Environment variable %s must be set", var) + logger.error("Config variable %s must be set", var) sys.exit(1) diff --git a/packages/datacommons-api/datacommons_api/endpoints/routers/node_router.py b/packages/datacommons-api/datacommons_api/endpoints/routers/node_router.py index 41f036c..02d12f8 100644 --- a/packages/datacommons-api/datacommons_api/endpoints/routers/node_router.py +++ b/packages/datacommons-api/datacommons_api/endpoints/routers/node_router.py @@ -29,7 +29,13 @@ # JSON-LD endpoint -@router.get("/nodes/", response_model=JSONLDDocument, response_model_exclude_none=True) +@router.get("/nodes", response_model=JSONLDDocument, response_model_exclude_none=True) +@router.get( + "/nodes/", + response_model=JSONLDDocument, + response_model_exclude_none=True, + include_in_schema=False, +) def get_nodes( limit: int = DEFAULT_NODE_FETCH_LIMIT, type_filter: Annotated[ @@ -44,7 +50,13 @@ def get_nodes( return graph_service.get_graph_nodes(limit=limit, type_filter=type_filter) -@router.post("/nodes/", response_model=UpdateResponse, response_model_exclude_none=True) +@router.post("/nodes", response_model=UpdateResponse, response_model_exclude_none=True) +@router.post( + "/nodes/", + response_model=UpdateResponse, + response_model_exclude_none=True, + include_in_schema=False, +) def insert_nodes( jsonld: JSONLDDocument, graph_service: Annotated[GraphService, Depends(with_graph_service)] = None, diff --git a/packages/datacommons-api/datacommons_api/services/graph_service.py b/packages/datacommons-api/datacommons_api/services/graph_service.py index 941f2e6..0b0bf88 100644 --- a/packages/datacommons-api/datacommons_api/services/graph_service.py +++ b/packages/datacommons-api/datacommons_api/services/graph_service.py @@ -13,17 +13,19 @@ # limitations under the License. # Standard library imports +import base64 import logging - +import traceback +from google.cloud import spanner +from google.cloud.spanner_v1 import database from sqlalchemy import text from sqlalchemy.orm import Session, joinedload -# Third-party imports +from datacommons_api.core.config import get_config from datacommons_api.core.constants import DEFAULT_NODE_FETCH_LIMIT -from datacommons_db.models.edge import EdgeModel - -# Local application imports -from datacommons_db.models.node import NodeModel +from datacommons_db.models.edge import EdgeModel, EDGE_TABLE_NAME +from datacommons_db.models.node import NodeModel, NODE_TABLE_NAME +from datacommons_db.models.edge import OBJECT_VALUE_MAX_LENGTH from datacommons_schema.models.jsonld import ( GraphNode, GraphNodePropertyValue, @@ -33,6 +35,12 @@ # Configure logging logger = logging.getLogger(__name__) +# Silence OpenTelemetry warnings/errors (Spanner client integration triggers these) +logging.getLogger("opentelemetry.metrics._internal").setLevel(logging.ERROR) +logging.getLogger("opentelemetry.sdk.metrics._internal.export").setLevel( + logging.CRITICAL +) + class GraphServiceError(Exception): """ @@ -71,12 +79,22 @@ def create_node_model(graph_node: GraphNode) -> NodeModel: types = [types] types = [t for t in types if t is not None] + # Remove all CURIE namespaces before storing the node id + subject_id = strip_namespace(graph_node.id) + types = [strip_namespace(t) for t in types] return NodeModel( - subject_id=graph_node.id, + subject_id=subject_id, types=types, ) +def strip_namespace(id: str) -> str: + """ + Strip all CURIE namespaces from an id. + """ + return id.split(":")[-1] + + def create_edge_model( subject_id: str, predicate: str, @@ -90,7 +108,8 @@ def create_edge_model( Args: subject_id: The ID of the source node predicate: The edge predicate - value_data: The edge value + object_id: The ID of the target node + object_value: The edge value - A string literal - A GraphNode provenance: The ID of a node that is the provenance of the edge @@ -101,17 +120,17 @@ def create_edge_model( """ # Handle lists of values by creating multiple edges edge = EdgeModel( - object_id=object_id, - predicate=predicate, - subject_id=subject_id, + object_id=strip_namespace(object_id), + predicate=strip_namespace(predicate), + subject_id=strip_namespace(subject_id), ) if provenance: - edge.provenance = provenance + edge.provenance = strip_namespace(provenance) if object_value: - edge.object_value = object_value + edge.object_value = strip_namespace(object_value) if object_id else object_value if object_value and not object_id: # If the edge value is a string, use the subject id as the object id - edge.object_id = subject_id + edge.object_id = strip_namespace(subject_id) if not object_id and not object_value: message = f"Missing object_id or object_value for edge {subject_id} {predicate}" raise GraphServiceError(message) @@ -182,11 +201,17 @@ def node_model_to_graph_node(node: NodeModel) -> GraphNode: edge_groups[edge.predicate] = [] property_value = {} - # If the edge has a literal value, add it to the property value - if edge.object_value: + + if edge.object_bytes: + # If the edge has bytes, decode them and add them to the property value + property_value["@value"] = base64.b64decode(edge.object_bytes).decode( + "utf-8" + ) + elif edge.object_value: + # If the edge has a literal value, add it to the property value property_value["@value"] = edge.object_value - # If the edge has an object id, add it to the property value else: + # If the edge has an object id, add it to the property value property_value["@id"] = edge.object_id # If the edge has provenance, add it to the property value @@ -203,6 +228,152 @@ def node_model_to_graph_node(node: NodeModel) -> GraphNode: return GraphNode(**graph_node_properties) +def coerce_edge_val_for_db_write(e: EdgeModel, col: str) -> str | None: + """ + Coerces and truncates edge values to comply with Spanner index limits. + Args: + e: The EdgeModel instance containing raw data. + col: The target database column name. + Returns: + - For 'object_value': A UTF-8 string truncated to 4096 bytes (safe-decoded). + - For 'object_bytes': A Base64-encoded representation of the model's 'object_value'. + - For other columns: The raw attribute value from the model. + """ + if col not in ("object_value", "object_bytes"): + return getattr(e, col) + + val = getattr(e, "object_value") + if not val: + return None + val_bytes = str(val).encode("utf-8") + + # A Spanner index key incorporates both the indexed columns AND the Primary Key. + # Max index key length is 8192 bytes total. The Primary Keys can swallow up to 4096 bytes easily. + # So we must restrict object_value to 4096 bytes to guarantee the total key size is < 8192 bytes. + if col == "object_value": + if len(val_bytes) > OBJECT_VALUE_MAX_LENGTH: + # Slice to exactly OBJECT_VALUE_MAX_LENGTH bytes, dropping fragmented chars gracefully + # TODO: To avoid hash index collisions, we should use a deterministic hash of the object_value + # and store that along with the truncated value. + val_truncated = val_bytes[:OBJECT_VALUE_MAX_LENGTH].decode( + "utf-8", errors="ignore" + ) + return val_truncated + return val + elif col == "object_bytes": + if len(val_bytes) > OBJECT_VALUE_MAX_LENGTH: + return base64.b64encode(val_bytes).decode("utf-8") + return None + + +def get_node_models(jsonld: JSONLDDocument) -> list[NodeModel]: + """ + Converts a JSON-LD document into a list of NodeModel instances with their outgoing edges loaded. + """ + node_models = [] + for graph_node in jsonld.graph: + node_model = create_node_model(graph_node) + node_model.outgoing_edges = extract_edges_from_node(graph_node) + node_models.append(node_model) + return node_models + + +def get_node_model_batches( + node_models: list[NodeModel], batch_size: int = 1000 +) -> list[list[NodeModel]]: + """ + Splits a list of NodeModel instances into batches of nodes and edges. + + Args: + node_models: List of NodeModel instances + batch_size: Maximum number of nodes and edges per batch + + Returns: + List of batches of nodes and edges + """ + node_batches: list[list[NodeModel]] = [] + current_batch: list[NodeModel] = [] + current_batch_len = 0 + for node_model in node_models: + node_len = len(node_model.outgoing_edges) + 1 + + # If the node itself is larger than the batch_size, add it as its own batch + if node_len >= batch_size: + if current_batch: + node_batches.append(current_batch) + current_batch = [] + current_batch_len = 0 + node_batches.append([node_model]) + continue + + # Add node and its edges to the current batch + if current_batch_len + node_len <= batch_size: + current_batch.append(node_model) + current_batch_len += node_len + else: + # If the current batch is full, add it to the list of batches + node_batches.append(current_batch) + current_batch = [node_model] + current_batch_len = node_len + + # Add the last batch if it's not empty + if current_batch: + node_batches.append(current_batch) + return node_batches + + +def insert_node_models_batch( + node_models: list[NodeModel], spanner_batch: database.BatchCheckout +): + """ + Inserts a batch of NodeModel instances into the database using Spanner API. + + Args: + node_models: List of NodeModel instances + spanner_batch: Spanner batch to insert into + + Returns: + None + """ + # Get the column names from the NodeModel and EdgeModel + node_columns = tuple(c.name for c in NodeModel.__table__.columns) + edge_columns = tuple( + c.name + for c in EdgeModel.__table__.columns + if c.name != "object_value_tokenlist" + ) + + # Insert nodes into the database + spanner_batch.insert_or_update( + table=NODE_TABLE_NAME, + columns=node_columns, + values=[tuple(getattr(n, col) for col in node_columns) for n in node_models], + ) + + # Delete existing edges for these nodes using a KeyRange prefix + keyset = spanner.KeySet( + ranges=[ + spanner.KeyRange(start_closed=[n.subject_id], end_closed=[n.subject_id]) + for n in node_models + ] + ) + spanner_batch.delete(table=EDGE_TABLE_NAME, keyset=keyset) + + # Insert the new edges + for node_model in node_models: + # Skip if there are no edges to avoid empty insert errors + if not node_model.outgoing_edges: + continue + spanner_batch.insert_or_update( + table=EDGE_TABLE_NAME, + columns=edge_columns, + values=[ + tuple(coerce_edge_val_for_db_write(e, col) for col in edge_columns) + for e in node_model.outgoing_edges + ], + ) + + class GraphService: """ Service for managing graph database operations. @@ -219,7 +390,14 @@ def __init__(self, session: Session): session: SQLAlchemy session for database operations """ self.session = session - logger.info("Initialized GraphService with new session") + + config = get_config() + spanner_client = spanner.Client(project=config.GCP_PROJECT_ID) + instance = spanner_client.instance(config.GCP_SPANNER_INSTANCE_ID) + self.spanner_database = instance.database(config.GCP_SPANNER_DATABASE_NAME) + + # Silence Spanner client INFO logs + self.spanner_database.logger.setLevel(logging.WARNING) def get_graph_nodes( self, @@ -290,40 +468,63 @@ def _get_nodes_with_outgoing_edges( logger.debug("Retrieved %d nodes with outgoing edges", len(nodes)) return nodes - def insert_graph_nodes(self, jsonld: JSONLDDocument) -> None: + def insert_graph_nodes( + self, jsonld: JSONLDDocument, batch_size: int = 1000 + ) -> None: """ - Insert nodes and edges from a JSON-LD document into the database. - - Raises an exception if the node already exists. + Inserts nodes and edges from a JSON-LD document into the database using Spanner API. - This method processes the JSON-LD document, creating NodeModel and EdgeModel - instances for each node and its edges. It handles both literal values and - references to other nodes, preserving provenance information. + Updates the nodes and edges if they already exist. Args: jsonld: The JSON-LD document containing nodes and edges to insert """ - nodes: list[NodeModel] = [] - edges: list[EdgeModel] = [] - logger.info("Inserting %d nodes from JSON-LD document", len(jsonld.graph)) + # Convert JSON-LD to NodeModels + node_models = get_node_models(jsonld) + node_model_batches = get_node_model_batches(node_models, batch_size) + total_edges = sum(len(node_model.outgoing_edges) for node_model in node_models) - # Process each node in the graph - for graph_node in jsonld.graph: - # Create node model - node_model = create_node_model(graph_node) - nodes.append(node_model) - - # Extract and create edge models - node_edges = extract_edges_from_node(graph_node) - edges.extend(node_edges) + logger.info( + "Inserting %d nodes and %d edges in %d batch(es) to Spanner", + len(node_models), + total_edges, + len(node_model_batches), + ) - logger.info("Inserting %d nodes and %d edges", len(nodes), len(edges)) + # Insert nodes and edges in batches + # TODO(dwnoble): this insert may fail if a node in an earlier batch references a node in a later batch. + # Also may fail if a node references a node that is in a remote knowledge graph + # Possible solution: Insert all nodes first, then insert all edges in a second pass. + success_count = 0 + try: + for node_model_batch in node_model_batches: + with self.spanner_database.batch() as spanner_batch: + insert_node_models_batch(node_model_batch, spanner_batch) + success_count += len(node_model_batch) + except Exception as e: + error_message = f"Failed to insert nodes and edges to Spanner after {success_count}/{len(node_models)} nodes inserted" + logger.exception(error_message) + raise GraphServiceError(error_message) from e - # Add all nodes and edges to the session - self.session.add_all(nodes) - self.session.add_all(edges) + logger.info( + "Successfully committed %d nodes and %d edges to Spanner", + success_count, + total_edges, + ) - # Commit the transaction + def drop_tables(self) -> None: + """ + Delete Node and Edge tables from the graph database. + """ + logger.info("Dropping index EdgeByObjectValue") + query = "DROP INDEX EdgeByObjectValue" + self.session.execute(text(query)) + logger.info("Dropping table %s", EDGE_TABLE_NAME) + query = f"DROP TABLE {EDGE_TABLE_NAME}" + self.session.execute(text(query)) + logger.info("Dropping table %s", NODE_TABLE_NAME) + query = f"DROP TABLE {NODE_TABLE_NAME}" + self.session.execute(text(query)) self.session.commit() - logger.info("Successfully committed all nodes and edges to database") + logger.info("Successfully dropped Node and Edge tables") diff --git a/packages/datacommons-api/datacommons_api/services/graph_service_test.py b/packages/datacommons-api/datacommons_api/services/graph_service_test.py new file mode 100644 index 0000000..fa9ed8f --- /dev/null +++ b/packages/datacommons-api/datacommons_api/services/graph_service_test.py @@ -0,0 +1,176 @@ +import pytest +from unittest.mock import MagicMock, patch, call + +from sqlalchemy.orm import Session +from google.cloud import spanner +from datacommons_api.core.config import Config +from datacommons_api.services.graph_service import ( + GraphService, + GraphServiceError, + get_node_model_batches, +) +from datacommons_db.models.node import NodeModel +from datacommons_db.models.edge import EdgeModel +from datacommons_schema.models.jsonld import JSONLDDocument, GraphNode + + +def test_get_node_model_batches(): + node1 = NodeModel(subject_id="n1", types=["T1"]) + node1.outgoing_edges = [ + EdgeModel(subject_id="n1", predicate="p", object_id=f"o{i}") for i in range(5) + ] + + node2 = NodeModel(subject_id="n2", types=["T1"]) + node2.outgoing_edges = [ + EdgeModel(subject_id="n2", predicate="p", object_id=f"o{i}") for i in range(5) + ] + + node3 = NodeModel(subject_id="n3", types=["T1"]) + node3.outgoing_edges = [ + EdgeModel(subject_id="n3", predicate="p", object_id=f"o{i}") for i in range(5) + ] + + # 6 items per node + # batch size 10 means 10 items max. n1 = 6 items -> batch 0. n2 = 6 items -> batch 1. n3 = 6 items -> batch 2. + batches = get_node_model_batches([node1, node2, node3], batch_size=10) + assert len(batches) == 3 + assert batches[0] == [node1] + assert batches[1] == [node2] + assert batches[2] == [node3] + + # test a node larger than the batch size (6 items > batch size 5) + batches = get_node_model_batches([node1, node2], batch_size=5) + assert len(batches) == 2 + assert batches[0] == [node1] + assert batches[1] == [node2] + + # Test batch size 12. n1 + n2 = 12 items -> batch 0. n3 = 6 items -> batch 1. + batches = get_node_model_batches([node1, node2, node3], batch_size=12) + assert len(batches) == 2 + assert batches[0] == [node1, node2] + assert batches[1] == [node3] + + +@pytest.fixture +def mock_session(): + return MagicMock(spec=Session) + + +@pytest.fixture +def mock_config(): + with patch("datacommons_api.services.graph_service.get_config") as mock: + mock_config_instance = MagicMock(spec=Config) + mock_config_instance.GCP_PROJECT_ID = "test-project" + mock_config_instance.GCP_SPANNER_INSTANCE_ID = "test-instance" + mock_config_instance.GCP_SPANNER_DATABASE_NAME = "test-db" + mock.return_value = mock_config_instance + yield mock + + +@pytest.fixture +def mock_spanner_client(): + with patch("datacommons_api.services.graph_service.spanner.Client") as mock: + mock_client_instance = MagicMock() + mock_instance = MagicMock() + mock_database = MagicMock() + + mock_client_instance.instance.return_value = mock_instance + mock_instance.database.return_value = mock_database + + mock.return_value = mock_client_instance + yield mock_client_instance + + +@pytest.fixture +def graph_service(mock_session, mock_config, mock_spanner_client): + return GraphService(session=mock_session) + + +def test_init(mock_session, mock_config, mock_spanner_client): + service = GraphService(session=mock_session) + assert service.session == mock_session + mock_spanner_client.instance.assert_called_once_with("test-instance") + mock_spanner_client.instance.return_value.database.assert_called_once_with( + "test-db" + ) + + +def test_get_graph_nodes(graph_service, mock_session): + # Setup mock data + mock_node = NodeModel(subject_id="test_node", types=["TestType"]) + mock_edge = EdgeModel( + subject_id="test_node", predicate="test_predicate", object_id="test_target" + ) + mock_node.outgoing_edges = [mock_edge] + + # Mock the query chain + mock_query = MagicMock() + mock_query.options.return_value.limit.return_value.all.return_value = [mock_node] + # Handle type filter + mock_query.filter.return_value.params.return_value.options.return_value.limit.return_value.all.return_value = [ + mock_node + ] + mock_session.query.return_value = mock_query + + # Test without filter + result = graph_service.get_graph_nodes(limit=10) + + # Verify + assert isinstance(result, JSONLDDocument) + assert len(result.graph) == 1 + assert result.graph[0].id == "test_node" + assert result.graph[0].type == ["TestType"] + assert result.graph[0].model_dump(by_alias=True, exclude_none=True)[ + "test_predicate" + ] == {"@id": "test_target"} + + # Test with filter + result = graph_service.get_graph_nodes(limit=10, type_filter=["TestType"]) + assert isinstance(result, JSONLDDocument) + assert len(result.graph) == 1 + + +def test_insert_graph_nodes(graph_service, mock_session, mock_spanner_client): + # Setup mock data for JSONLD + graph_node = GraphNode( + **{ + "@id": "test_node", + "@type": ["TestType"], + "test_predicate": {"@id": "test_target"}, + } + ) + mock_jsonld = JSONLDDocument( + context={"test": "http://test.com/"}, graph=[graph_node] + ) + + mock_batch = MagicMock() + mock_database = mock_spanner_client.instance.return_value.database.return_value + mock_database.batch.return_value.__enter__.return_value = mock_batch + + # Test + graph_service.insert_graph_nodes(mock_jsonld) + + # Verify + assert mock_batch.insert_or_update.call_count == 2 + mock_batch.delete.assert_called_once() + + +def test_insert_graph_nodes_error(graph_service, mock_spanner_client): + # Setup mock data that triggers an error + mock_jsonld = JSONLDDocument( + context={}, graph=[GraphNode(**{"@id": "n1", "@type": "t1"})] + ) + + mock_database = mock_spanner_client.instance.return_value.database.return_value + mock_database.batch.side_effect = Exception("Spanner Error") + + with pytest.raises(GraphServiceError) as exc_info: + graph_service.insert_graph_nodes(mock_jsonld) + + assert "Failed to insert nodes and edges to Spanner" in str(exc_info.value) + + +def test_drop_tables(graph_service, mock_session): + graph_service.drop_tables() + assert mock_session.execute.call_count == 3 + mock_session.commit.assert_called_once() diff --git a/packages/datacommons-api/test_node_batches.py b/packages/datacommons-api/test_node_batches.py new file mode 100644 index 0000000..240b1e8 --- /dev/null +++ b/packages/datacommons-api/test_node_batches.py @@ -0,0 +1,43 @@ +from datacommons_db.models.node import NodeModel +from datacommons_db.models.edge import EdgeModel +from datacommons_api.services.graph_service import get_node_model_batches + + +def test_get_node_model_batches_bug(): + node1 = NodeModel(subject_id="node1", types=["TypeA"]) + # 5 edges + 1 node = 6 items + node1.outgoing_edges = [ + EdgeModel(subject_id="node1", predicate="p", object_id=f"obj{i}") + for i in range(5) + ] + + node2 = NodeModel(subject_id="node2", types=["TypeA"]) + # 5 edges + 1 node = 6 items + node2.outgoing_edges = [ + EdgeModel(subject_id="node2", predicate="p", object_id=f"obj{i}") + for i in range(5) + ] + + node3 = NodeModel(subject_id="node3", types=["TypeA"]) + # 5 edges + 1 node = 6 items + node3.outgoing_edges = [ + EdgeModel(subject_id="node3", predicate="p", object_id=f"obj{i}") + for i in range(5) + ] + + # Total items = 18. Let's set batch size to 10. + # Node 1 (6 items) -> Batch 1 + # Node 2 (6 items) -> 6 + 6 = 12 > 10. So it hits the else block. Node 2 is skipped. + batches = get_node_model_batches([node1, node2, node3], batch_size=10) + + print(f"Number of batches: {len(batches)}") + for i, batch in enumerate(batches): + print(f"Batch {i}: {[n.subject_id for n in batch]}") + + all_nodes_in_batches = [n for batch in batches for n in batch] + print(f"Total nodes returned: {len(all_nodes_in_batches)}") + print(f"Expected: 3, Actual: {len(all_nodes_in_batches)}") + + +if __name__ == "__main__": + test_get_node_model_batches_bug() diff --git a/packages/datacommons-db/datacommons_db/models/edge.py b/packages/datacommons-db/datacommons_db/models/edge.py index 25e85eb..9aeedb5 100644 --- a/packages/datacommons-db/datacommons_db/models/edge.py +++ b/packages/datacommons-db/datacommons_db/models/edge.py @@ -20,18 +20,23 @@ from datacommons_db.models.base import Base +EDGE_TABLE_NAME = "Edge" +OBJECT_VALUE_MAX_LENGTH = 4096 + + class EdgeModel(Base): """ Represents an edge in the graph. """ - __tablename__ = "Edge" + __tablename__ = EDGE_TABLE_NAME subject_id = sa.Column( String(1024), sa.ForeignKey("Node.subject_id"), primary_key=True ) predicate = sa.Column(String(1024), primary_key=True) object_id = sa.Column(String(1024), primary_key=True) - object_value = sa.Column(Text(), nullable=True) + object_value = sa.Column(String(OBJECT_VALUE_MAX_LENGTH), nullable=True) + object_bytes = sa.Column(sa.LargeBinary(), nullable=True) object_hash = sa.Column(String(64), primary_key=True, nullable=True) provenance = sa.Column(String(1024), primary_key=True, nullable=True) # Use deferred to avoid loading the node data into memory diff --git a/packages/datacommons-db/datacommons_db/models/node.py b/packages/datacommons-db/datacommons_db/models/node.py index 053367d..ece650e 100644 --- a/packages/datacommons-db/datacommons_db/models/node.py +++ b/packages/datacommons-db/datacommons_db/models/node.py @@ -19,13 +19,15 @@ from datacommons_db.models.base import Base +NODE_TABLE_NAME = "Node" + class NodeModel(Base): """ Represents a node in the graph. """ - __tablename__ = "Node" + __tablename__ = NODE_TABLE_NAME subject_id = sa.Column(String(1024), primary_key=True, autoincrement=False) name = sa.Column(Text(), nullable=True) types = sa.Column(ARRAY(String(1024)), nullable=True) diff --git a/packages/datacommons-db/datacommons_db/models/observation.py b/packages/datacommons-db/datacommons_db/models/observation.py index 9bfec2f..3a6981a 100644 --- a/packages/datacommons-db/datacommons_db/models/observation.py +++ b/packages/datacommons-db/datacommons_db/models/observation.py @@ -18,13 +18,15 @@ from datacommons_db.models.base import Base +OBSERVATION_TABLE_NAME = "Observation" + class ObservationModel(Base): """ Represents a statistical observation of a variable. """ - __tablename__ = "Observation" + __tablename__ = OBSERVATION_TABLE_NAME variable_measured = sa.Column(String(1024), nullable=False, primary_key=True) observation_about = sa.Column(String(1024), nullable=False, primary_key=True) diff --git a/packages/datacommons-db/pyproject.toml b/packages/datacommons-db/pyproject.toml index b81dd58..9a9aa07 100644 --- a/packages/datacommons-db/pyproject.toml +++ b/packages/datacommons-db/pyproject.toml @@ -18,6 +18,7 @@ classifiers = [ dependencies = [ "sqlalchemy", "sqlalchemy-spanner", + "google-cloud-spanner", "setuptools<=80.0.0", # Pin version to <=80 to avoid https://stackoverflow.com/questions/76043689/pkg-resources-is-deprecated-as-an-api ] diff --git a/uv.lock b/uv.lock index 2dca153..4a33497 100644 --- a/uv.lock +++ b/uv.lock @@ -453,6 +453,7 @@ requires-dist = [ name = "datacommons-db" source = { editable = "packages/datacommons-db" } dependencies = [ + { name = "google-cloud-spanner" }, { name = "setuptools" }, { name = "sqlalchemy" }, { name = "sqlalchemy-spanner" }, @@ -460,6 +461,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "google-cloud-spanner" }, { name = "setuptools", specifier = "<=80.0.0" }, { name = "sqlalchemy" }, { name = "sqlalchemy-spanner" },