diff --git a/R/dplyr.R b/R/dplyr.R index dd1129355..7d4fd7c7e 100644 --- a/R/dplyr.R +++ b/R/dplyr.R @@ -80,9 +80,11 @@ rows_check_y_unmatched <- dplyr$rows_check_y_unmatched rows_df_in_place <- dplyr$rows_df_in_place rowwise_df <- dplyr$rowwise_df slice_rows <- dplyr$slice_rows +stop_join <- dplyr$stop_join summarise_build <- dplyr$summarise_build summarise_cols <- dplyr$summarise_cols summarise_deprecate_variable_size <- dplyr$summarise_deprecate_variable_size the <- dplyr$the tick_if_needed <- dplyr$tick_if_needed +warn_join <- dplyr$warn_join warn_join_cross_by <- dplyr$warn_join_cross_by diff --git a/R/full_join.R b/R/full_join.R index 40578e11a..12d6469f3 100644 --- a/R/full_join.R +++ b/R/full_join.R @@ -17,7 +17,7 @@ full_join.duckplyr_df <- function(x, y, by = NULL, copy = FALSE, suffix = c(".x" "No implicit cross joins for {.code full_join()}" = is_cross_by(by), "{.arg multiple} not supported" = !identical(multiple, "all"), { - out <- rel_join_impl(x, y, by, "full", na_matches, suffix, keep, error_call) + out <- rel_join_impl(x, y, by, "full", na_matches, suffix, keep, relationship, error_call) return(out) } ) diff --git a/R/inner_join.R b/R/inner_join.R index 29bc68699..45e1bd058 100644 --- a/R/inner_join.R +++ b/R/inner_join.R @@ -19,7 +19,7 @@ inner_join.duckplyr_df <- function(x, y, by = NULL, copy = FALSE, suffix = c(".x "{.arg multiple} not supported" = !identical(multiple, "all"), "{.arg unmatched} not supported" = !identical(unmatched, "drop"), { - out <- rel_join_impl(x, y, by, "inner", na_matches, suffix, keep, error_call) + out <- rel_join_impl(x, y, by, "inner", na_matches, suffix, keep, relationship, error_call) return(out) } ) diff --git a/R/join.R b/R/join.R index 42dde6927..ceaf65581 100644 --- a/R/join.R +++ b/R/join.R @@ -6,6 +6,7 @@ rel_join_impl <- function( na_matches, suffix = c(".x", ".y"), keep = NULL, + relationship = NULL, error_call = caller_env() ) { mutating <- !(join %in% c("semi", "anti")) @@ -25,6 +26,10 @@ rel_join_impl <- function( by <- as_join_by(by, error_call = error_call) } + if (mutating) { + check_relationship(relationship, x, y, by, error_call = error_call) + } + x_by <- by$x y_by <- by$y x_rel <- duckdb_rel_from_df(x) @@ -136,3 +141,62 @@ rel_join_impl <- function( return(out) } + +check_relationship <- function(relationship, x, y, by, error_call) { + if (is_null(relationship)) { + # FIXME: Determine behavior based on option + if (!is_key(x, by$x) && !is_key(y, by$y)) { + warn_join( + message = c( + "Detected an unexpected many-to-many relationship between `x` and `y`.", + i = paste0( + "If a many-to-many relationship is expected, ", + "set `relationship = \"many-to-many\"` to silence this warning." + ) + ), + class = "dplyr_warning_join_relationship_many_to_many", + call = error_call + ) + } + return() + } + + if (relationship %in% c("one-to-many", "one-to-one")) { + if (!is_key(x, by$x)) { + stop_join( + message = c( + glue("Each row in `{x_name}` must match at most 1 row in `{y_name}`."), + ), + class = paste0("dplyr_error_join_relationship_", gsub("-", "_", relationship)), + call = error_call + ) + } + } + + if (relationship %in% c("many-to-one", "one-to-one")) { + if (!is_key(y, by$y)) { + stop_join( + message = c( + glue("Each row in `{y_name}` must match at most 1 row in `{x_name}`."), + ), + class = paste0("dplyr_error_join_relationship_", gsub("-", "_", relationship)), + call = error_call + ) + } + } +} + +is_key <- function(x, cols) { + local_options(duckdb.materialize_message = FALSE) + + rows <- + x %>% + # FIXME: Why does this materialize + # as_duckplyr_tibble() %>% + summarize(.by = c(!!!syms(cols)), `___n` = n()) %>% + filter(`___n` > 1L) %>% + head(1L) %>% + nrow() + + rows == 0 +} diff --git a/R/left_join.R b/R/left_join.R index f171a8054..565b3f375 100644 --- a/R/left_join.R +++ b/R/left_join.R @@ -20,7 +20,7 @@ left_join.duckplyr_df <- function(x, y, by = NULL, copy = FALSE, suffix = c(".x" "{.arg multiple} not supported" = !identical(multiple, "all"), "{.arg unmatched} not supported" = !identical(unmatched, "drop"), { - out <- rel_join_impl(x, y, by, "left", na_matches, suffix, keep, error_call) + out <- rel_join_impl(x, y, by, "left", na_matches, suffix, keep, relationship, error_call) return(out) } ) diff --git a/R/relational-duckdb.R b/R/relational-duckdb.R index 95a301cb3..74ee1d14c 100644 --- a/R/relational-duckdb.R +++ b/R/relational-duckdb.R @@ -33,7 +33,7 @@ duckplyr_macros <- c( # "___divide" = "(x, y) AS CASE WHEN y = 0 THEN CASE WHEN x = 0 THEN CAST('NaN' AS double) WHEN x > 0 THEN CAST('+Infinity' AS double) ELSE CAST('-Infinity' AS double) END ELSE CAST(x AS double) / y END", # - "is.na" = "(x) AS (x IS NULL)", + "is.na" = "(x) AS (x IS NULL OR isnan(x))", "n" = "() AS CAST(COUNT(*) AS int32)", # "___log10" = "(x) AS CASE WHEN x < 0 THEN CAST('NaN' AS double) WHEN x = 0 THEN CAST('-Inf' AS double) ELSE log10(x) END", @@ -181,6 +181,14 @@ check_df_for_rel <- function(df, call = caller_env()) { if (!identical(df_attrib, roundtrip_attrib)) { cli::cli_abort("Attributes are lost during conversion. Affected column: {.var {names(df)[[i]]}}.", call = call) } + # Always check roundtrip for timestamp columns + # duckdb uses microsecond precision only, this is in some cases + # less than R does + if (inherits(df[[i]], "POSIXct")) { + if (!identical(df[[i]], roundtrip[[i]])) { + cli::cli_abort("Imperfect roundtrip. Affected column: {.var {names(df)[[i]]}}.", call = call) + } + } } } diff --git a/R/right_join.R b/R/right_join.R index 38ed32b76..5d97298a3 100644 --- a/R/right_join.R +++ b/R/right_join.R @@ -19,7 +19,7 @@ right_join.duckplyr_df <- function(x, y, by = NULL, copy = FALSE, suffix = c(".x "{.arg multiple} not supported" = !identical(multiple, "all"), "{.arg unmatched} not supported" = !identical(unmatched, "drop"), { - out <- rel_join_impl(x, y, by, "right", na_matches, suffix, keep, error_call) + out <- rel_join_impl(x, y, by, "right", na_matches, suffix, keep, relationship, error_call) return(out) } ) diff --git a/tests/testthat/_snaps/dplyr-join-rows.md b/tests/testthat/_snaps/dplyr-join-rows.md index 07d31a943..f478fd51f 100644 --- a/tests/testthat/_snaps/dplyr-join-rows.md +++ b/tests/testthat/_snaps/dplyr-join-rows.md @@ -65,6 +65,39 @@ ! Each row in `x` must match at most 1 row in `y`. i Row 1 of `x` matches multiple rows in `y`. +# join_rows() gives meaningful many-to-many warnings + + Code + join_rows(c(1, 1), c(1, 1)) + Condition + Warning: + Detected an unexpected many-to-many relationship between `x` and `y`. + i Row 1 of `x` matches multiple rows in `y`. + i Row 1 of `y` matches multiple rows in `x`. + i If a many-to-many relationship is expected, set `relationship = "many-to-many"` to silence this warning. + Output + $x + [1] 1 1 2 2 + + $y + [1] 1 2 1 2 + + +--- + + Code + duckplyr_left_join(df, df, by = join_by(x)) + Condition + Warning in `duckplyr_left_join()`: + Detected an unexpected many-to-many relationship between `x` and `y`. + i If a many-to-many relationship is expected, set `relationship = "many-to-many"` to silence this warning. + Output + x + 1 1 + 2 1 + 3 1 + 4 1 + # join_rows() gives meaningful error message on unmatched rows Code diff --git a/tests/testthat/_snaps/dplyr-join.md b/tests/testthat/_snaps/dplyr-join.md index 73111ec5a..28ebdbb85 100644 --- a/tests/testthat/_snaps/dplyr-join.md +++ b/tests/testthat/_snaps/dplyr-join.md @@ -50,6 +50,15 @@ Error: ! `na_matches` must be one of "na" or "never", not "foo". +# mutating joins trigger many-to-many warning + + Code + out <- duckplyr_left_join(df, df, join_by(x)) + Condition + Warning in `duckplyr_left_join()`: + Detected an unexpected many-to-many relationship between `x` and `y`. + i If a many-to-many relationship is expected, set `relationship = "many-to-many"` to silence this warning. + # mutating joins compute common columns Code diff --git a/tests/testthat/test-dplyr-join-rows.R b/tests/testthat/test-dplyr-join-rows.R index 3cd8ff956..cd7edf110 100644 --- a/tests/testthat/test-dplyr-join-rows.R +++ b/tests/testthat/test-dplyr-join-rows.R @@ -197,7 +197,6 @@ test_that("join_rows() gives meaningful many-to-one errors", { }) test_that("join_rows() gives meaningful many-to-many warnings", { - skip("TODO duckdb") expect_snapshot({ join_rows(c(1, 1), c(1, 1)) }) diff --git a/tests/testthat/test-dplyr-join.R b/tests/testthat/test-dplyr-join.R index 60b8355ef..f66be8a5b 100644 --- a/tests/testthat/test-dplyr-join.R +++ b/tests/testthat/test-dplyr-join.R @@ -360,7 +360,6 @@ test_that("join_filter() validates arguments", { }) test_that("mutating joins trigger many-to-many warning", { - skip("TODO duckdb") df <- tibble(x = c(1, 1)) expect_snapshot(out <- duckplyr_left_join(df, df, join_by(x))) }) diff --git a/tools/00-funs.R b/tools/00-funs.R index 2ecf6094e..b1b3fe493 100644 --- a/tools/00-funs.R +++ b/tools/00-funs.R @@ -237,14 +237,10 @@ duckplyr_tests <- head(n = -1, list( NULL ), "test-join-rows.R" = c( - "join_rows() gives meaningful many-to-many warnings", NULL ), "test-join.R" = c( - "mutating joins trigger multiple match warning", - "mutating joins don't trigger multiple match warning when called indirectly", - - "mutating joins trigger many-to-many warning", + # FIXME: How to detect an indirect call? "mutating joins don't trigger many-to-many warning when called indirectly", NULL ),