Skip to content

Commit

Permalink
updates for dplyr v0.5.0 compatibility and additional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gregrahn committed Jul 14, 2016
1 parent e741edd commit d9897eb
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 34 deletions.
4 changes: 2 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated by roxygen2 (4.1.1.9000): do not edit by hand
# Generated by roxygen2: do not edit by hand

export(db_analyze.SnowflakeDBConnection)
export(db_begin.SnowflakeDBConnection)
Expand All @@ -8,9 +8,9 @@ export(db_explain.SnowflakeDBConnection)
export(db_insert_into.SnowflakeDBConnection)
export(db_query_fields.SnowflakeDBConnection)
export(lahman_snowflakedb)
export(sql_translate_env.src_snowflakedb)
export(src_desc.src_snowflakedb)
export(src_snowflakedb)
export(src_translate_env.src_snowflakedb)
export(tbl.src_snowflakedb)
import(RJDBC)
import(assertthat)
Expand Down
26 changes: 18 additions & 8 deletions R/src-snowflakedb.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,19 @@ src_snowflakedb <- function(user = NULL,
requireNamespace("RJDBC", quietly = TRUE)
requireNamespace("dplyr", quietly = TRUE)

# set client metadata info
snowflakeClientInfo <- paste0('{',
'"APPLICATION": "dplyr.snowflakedb",',
'"dplyr.snowflakedb.version": "', packageVersion("dplyr.snowflakedb"), '",',
'"dplyr.version": "', packageVersion("dplyr"), '",',
'"R.version": "', R.Version()$version.string,'",',
'"R.platform": "', R.Version()$platform,'"',
'}')

# initalize the JVM and set the snowflake properties
.jinit()
.jcall("java/lang/System", "S", "setProperty", "snowflake.client.info", snowflakeClientInfo)

if (length(names(opts)) > 0) {
opts <- paste0("&",
paste(lapply(names(opts),
Expand Down Expand Up @@ -199,7 +212,7 @@ src_desc.src_snowflakedb <- function(x) {
}

#' @export
src_translate_env.src_snowflakedb <- function(x) {
sql_translate_env.src_snowflakedb <- function(x) {
dplyr::sql_variant(
dplyr::base_scalar,
dplyr::sql_translator(.parent = dplyr::base_agg,
Expand All @@ -210,7 +223,6 @@ src_translate_env.src_snowflakedb <- function(x) {
var = dplyr::sql_prefix("VAR_SAMP"),
# all = dplyr::sql_prefix("bool_and"),
# any = dplyr::sql_prefix("bool_or"),
n_distinct = function(x) dplyr::build_sql("COUNT(DISTINCT ", x, ")"),
paste = function(x, collapse) dplyr::build_sql("LISTAGG(", x, collapse, ")")
),
base_win
Expand All @@ -237,12 +249,10 @@ db_begin.SnowflakeDBConnection <- function(con, ...) {
}

#' @export
db_query_fields.SnowflakeDBConnection <- function(con, query, ...) {
# this fails when only a table name is passed in because it is single quoted
# using ident() will add a second double qoting when it is not necessary
s <- dplyr::build_sql("SELECT * FROM ", query, " LIMIT 0", con = con)
if (isTRUE(getOption("dplyr.show_sql"))) message("SQL: ", s)
names(dbGetQuery(con, s))
db_query_fields.SnowflakeDBConnection <- function(con, sql, ...) {
fields <- dplyr::build_sql("SELECT * FROM ", sql_subquery(con, sql), " LIMIT 0", con = con)
if (isTRUE(getOption("dplyr.show_sql"))) message("SQL: ", sql)
names(dbGetQuery(con, fields))
}

#' @export
Expand Down
2 changes: 1 addition & 1 deletion man/dplyr.snowflakedb.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/lahman.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/lahman_snowflakedb.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/src_snowflakedb.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

104 changes: 84 additions & 20 deletions tests/testthat/test-sql-translation.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,11 @@ test <- src_snowflakedb(account = Sys.getenv("SNOWFLAKE_ACCOUNT"),
schema = "public",
tracing = "off"))

expect_same_in_sql <- function(expr) {
expr <- substitute(expr)

sql <- translate_sql_q(list(expr))
actual <- dbGetQuery(test$con, paste0("SELECT ", sql))[[1]]

exp <- eval(expr, parent.frame())

expect_equal(actual, exp, label = deparse(substitute(expr)))
}

test_that("Simple maths is correct", {
expect_same_in_sql(1 + 2)
expect_same_in_sql(2 * 4)
expect_same_in_sql(5 / 10)
expect_same_in_sql(1 - 10)
expect_same_in_sql(5 ^ 2)
expect_same_in_sql(5 ^ 1/2)
expect_same_in_sql(100 %% 3)
expect_equal(translate_sql(1 + 2), sql("1.0 + 2.0"))
expect_equal(translate_sql(2 * 4), sql("2.0 * 4.0"))
expect_equal(translate_sql(5 ^ 2), sql("POWER(5.0, 2.0)"))
expect_equal(translate_sql(100L %% 3L), sql("100 % 3"))
})

test_that("dplyr.strict_sql = TRUE prevents auto conversion", {
Expand All @@ -60,11 +46,11 @@ test_that("dplyr.strict_sql = TRUE prevents auto conversion", {
})

test_that("Wrong number of arguments raises error", {
expect_error(translate_sql(mean(1, 2)), "Invalid number of args")
expect_error(translate_sql(mean(1, 2), window = FALSE), "Invalid number of args")
})

test_that("Named arguments generates warning", {
expect_warning(translate_sql(mean(x = 1)), "Named arguments ignored")
expect_warning(translate_sql(mean(x = 1), window = FALSE), "Named arguments ignored")
})

test_that("Subsetting always evaluated locally", {
Expand All @@ -83,3 +69,81 @@ test_that("between translated to special form (#503)", {
out <- translate_sql(between(x, 1, 2))
expect_equal(out, sql('"x" BETWEEN 1.0 AND 2.0'))
})

test_that("is.na and is.null are equivalent",{
expect_equal(translate_sql(!is.na(x)), sql('NOT(("x") IS NULL)'))
expect_equal(translate_sql(!is.null(x)), sql('NOT(("x") IS NULL)'))
})

test_that("if translation adds parens", {
expect_equal(
translate_sql(if (x) y),
sql('CASE WHEN ("x") THEN ("y") END')
)
expect_equal(
translate_sql(if (x) y else z),
sql('CASE WHEN ("x") THEN ("y") ELSE ("z") END')
)

})

test_that("pmin and pmax become min and max", {
expect_equal(translate_sql(pmin(x, y)), sql('MIN("x", "y")'))
expect_equal(translate_sql(pmax(x, y)), sql('MAX("x", "y")'))
})

# Minus -------------------------------------------------------------------

test_that("unary minus flips sign of number", {
expect_equal(translate_sql(-10L), sql("-10"))
expect_equal(translate_sql(x == -10), sql('"x" = -10.0'))
expect_equal(translate_sql(x %in% c(-1L, 0L)), sql('"x" IN (-1, 0)'))
})

test_that("unary minus wraps non-numeric expressions", {
expect_equal(translate_sql(-(1L + 2L)), sql("-(1 + 2)"))
expect_equal(translate_sql(-mean(x), window = FALSE), sql('-AVG("x")'))
})

test_that("binary minus subtracts", {
expect_equal(translate_sql(1L - 10L), sql("1 - 10"))
})

# Window functions --------------------------------------------------------

test_that("window functions without group have empty over", {
expect_equal(translate_sql(n()), sql("COUNT(*) OVER ()"))
expect_equal(translate_sql(sum(x)), sql('sum("x") OVER ()'))
})

test_that("aggregating window functions ignore order_by", {
expect_equal(
translate_sql(n(), vars_order = "x"),
sql("COUNT(*) OVER ()")
)
expect_equal(
translate_sql(sum(x), vars_order = "x"),
sql('sum("x") OVER ()')
)
})

test_that("cumulative windows warn if no order", {
expect_warning(translate_sql(cumsum(x)), "does not have explicit order")
expect_warning(translate_sql(cumsum(x), vars_order = "x"), NA)
})

test_that("ntile always casts to integer", {
expect_equal(
translate_sql(ntile(x, 10.5)),
sql('NTILE(10) OVER (ORDER BY "x")')
)
})

test_that("connection affects quoting character", {
dbiTest <- structure(list(), class = "DBITestConnection")
dbTest <- src_sql("test", con = dbiTest)
testTable <- tbl_sql("test", src = dbTest, from = "table1")

out <- select(testTable, field1)
expect_match(sql_render(out), "^SELECT `field1` AS `field1`\nFROM `table1`$")
})

0 comments on commit d9897eb

Please sign in to comment.