Skip to content

Commit

Permalink
Merge pull request #301 from OHDSI/databricks-bulk-load
Browse files Browse the repository at this point in the history
Adding bulk load for Spark (DataBricks)
  • Loading branch information
schuemie authored Jan 22, 2025
2 parents ff12112 + 25fd106 commit 79eb5d9
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 1 deletion.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ Suggests:
odbc,
duckdb,
pool,
ParallelLogger
ParallelLogger,
AzureStor
License: Apache License
VignetteBuilder: knitr
URL: https://ohdsi.github.io/DatabaseConnector/, https://github.com/OHDSI/DatabaseConnector
Expand Down
1 change: 1 addition & 0 deletions DatabaseConnector.Rproj
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Version: 1.0
ProjectId: 9d51e576-41a3-432f-b696-8bfdc3eed676

RestoreWorkspace: No
SaveWorkspace: No
Expand Down
81 changes: 81 additions & 0 deletions R/BulkLoad.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@ checkBulkLoadCredentials <- function(connection) {
return(FALSE)
}
return(TRUE)
} else if (dbms(connection) == "spark") {
envSet <- FALSE
container <- FALSE

if (Sys.getenv("AZR_STORAGE_ACCOUNT") != "" && Sys.getenv("AZR_ACCOUNT_KEY") != "" && Sys.setenv("AZR_CONTAINER_NAME") != "") {
envSet <- TRUE
}

# List storage containers to confirm the container
# specified in the configuration exists
ensure_installed("AzureStor")
azureEndpoint <- getAzureEndpoint()
containerList <- getAzureContainerNames(azureEndpoint)

if (Sys.getenv("AZR_CONTAINER_NAME") %in% containerList) {
container <- TRUE
}

return(envSet & container)
} else {
return(FALSE)
}
Expand All @@ -72,6 +91,18 @@ getHiveSshUser <- function() {
return(if (sshUser == "") "root" else sshUser)
}

getAzureEndpoint <- function() {
azureEndpoint <- AzureStor::storage_endpoint(
paste0("https://", Sys.getenv("AZR_STORAGE_ACCOUNT"), ".dfs.core.windows.net"),
key = Sys.getenv("AZR_ACCOUNT_KEY")
)
return(azureEndpoint)
}

getAzureContainerNames <- function(azureEndpoint) {
return(names(AzureStor::list_storage_containers(azureEndpoint)))
}

countRows <- function(connection, sqlTableName) {
sql <- "SELECT COUNT(*) FROM @table"
count <- renderTranslateQuerySql(
Expand Down Expand Up @@ -354,3 +385,53 @@ bulkLoadPostgres <- function(connection, sqlTableName, sqlFieldNames, sqlDataTyp
delta <- Sys.time() - startTime
inform(paste("Bulk load to PostgreSQL took", signif(delta, 3), attr(delta, "units")))
}

bulkLoadSpark <- function(connection, sqlTableName, data) {
ensure_installed("AzureStor")
logTrace(sprintf("Inserting %d rows into table '%s' using DataBricks bulk load", nrow(data), sqlTableName))
start <- Sys.time()

csvFileName <- tempfile("spark_insert_", fileext = ".csv")
write.csv(x = data, na = "", file = csvFileName, row.names = FALSE, quote = TRUE)
on.exit(unlink(csvFileName))

azureEndpoint <- getAzureEndpoint()
containers <- AzureStor::list_storage_containers(azureEndpoint)
targetContainer <- containers[[Sys.getenv("AZR_CONTAINER_NAME")]]
AzureStor::storage_upload(
targetContainer,
src=csvFileName,
dest=csvFileName
)

on.exit(
AzureStor::delete_storage_file(
targetContainer,
file = csvFileName,
confirm = FALSE
),
add = TRUE
)

sql <- SqlRender::loadRenderTranslateSql(
sqlFilename = "sparkCopy.sql",
packageName = "DatabaseConnector",
dbms = "spark",
sqlTableName = sqlTableName,
fileName = basename(csvFileName),
azureAccountKey = Sys.getenv("AZR_ACCOUNT_KEY"),
azureStorageAccount = Sys.getenv("AZR_STORAGE_ACCOUNT")
)

tryCatch(
{
DatabaseConnector::executeSql(connection = connection, sql = sql, reportOverallTime = FALSE)
},
error = function(e) {
abort("Error in DataBricks bulk upload. Please check DataBricks/Azure Storage access.")
}
)
delta <- Sys.time() - start
inform(paste("Bulk load to DataBricks took", signif(delta, 3), attr(delta, "units")))
}

9 changes: 9 additions & 0 deletions R/InsertTable.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ validateInt64Insert <- function() {
#' "some_aws_region", "AWS_BUCKET_NAME" = "some_bucket_name", "AWS_OBJECT_KEY" = "some_object_key",
#' "AWS_SSE_TYPE" = "server_side_encryption_type").
#'
#' Spark (DataBricks): The MPP bulk loading relies upon the AzureStor library
#' to test a connection to an Azure ADLS Gen2 storage container using Azure credentials.
#' Credentials are configured directly into the System Environment using the
#' following keys: Sys.setenv("AZR_STORAGE_ACCOUNT" =
#' "some_azure_storage_account", "AZR_ACCOUNT_KEY" = "some_secret_account_key", "AZR_CONTAINER_NAME" =
#' "some_container_name").
#'
#' PDW: The MPP bulk loading relies upon the client
#' having a Windows OS and the DWLoader exe installed, and the following permissions granted: --Grant
#' BULK Load permissions - needed at a server level USE master; GRANT ADMINISTER BULK OPERATIONS TO
Expand Down Expand Up @@ -308,6 +315,8 @@ insertTable.default <- function(connection,
bulkLoadHive(connection, sqlTableName, sqlFieldNames, data)
} else if (dbms == "postgresql") {
bulkLoadPostgres(connection, sqlTableName, sqlFieldNames, sqlDataTypes, data)
} else if (dbms == "spark") {
bulkLoadSpark(connection, sqlTableName, data)
}
} else if (useCtasHack) {
# Inserting using CTAS hack ----------------------------------------------------------------
Expand Down
34 changes: 34 additions & 0 deletions extras/TestBulkLoad.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,37 @@ all.equal(data, data2)

renderTranslateExecuteSql(connection, "DROP TABLE scratch_mschuemi.insert_test;")
disconnect(connection)


# Spark ------------------------------------------------------------------------------
# Assumes Spark (DataBricks) environmental variables have been set
options(sqlRenderTempEmulationSchema = Sys.getenv("DATABRICKS_SCRATCH_SCHEMA"))
databricksConnectionString <- paste0("jdbc:databricks://", Sys.getenv('DATABRICKS_HOST'), "/default;transportMode=http;ssl=1;AuthMech=3;httpPath=", Sys.getenv('DATABRICKS_HTTP_PATH'))
connectionDetails <- createConnectionDetails(dbms = "spark",
connectionString = databricksConnectionString,
user = "token",
password = Sys.getenv("DATABRICKS_TOKEN"))


connection <- connect(connectionDetails)
system.time(
insertTable(connection = connection,
tableName = "scratch.scratch_asena5.insert_test",
data = data,
dropTableIfExists = TRUE,
createTable = TRUE,
tempTable = FALSE,
progressBar = TRUE,
camelCaseToSnakeCase = TRUE,
bulkLoad = TRUE)
)
data2 <- querySql(connection, "SELECT * FROM scratch.scratch_asena5.insert_test;", snakeCaseToCamelCase = TRUE, integer64AsNumeric = FALSE)

data <- data[order(data$id), ]
data2 <- data2[order(data2$id), ]
row.names(data) <- NULL
row.names(data2) <- NULL
all.equal(data, data2)

renderTranslateExecuteSql(connection, "DROP TABLE scratch.scratch_asena5.insert_test;")
disconnect(connection)
10 changes: 10 additions & 0 deletions inst/sql/sql_server/sparkCopy.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
COPY INTO @sqlTableName
FROM 'abfss://@azureStorageAccount.dfs.core.windows.net/@fileName'
WITH (
CREDENTIAL (AZURE_SAS_TOKEN = '@azureAccountKey')
)
FILEFORMAT = CSV
FORMAT_OPTIONS (
'header' = 'true',
'inferSchema' = 'true'
);

0 comments on commit 79eb5d9

Please sign in to comment.