ืืื ืืืจ!
ืืกืชืื ืืืืจืื ืืืจืื Kaggle ืชืืจืืช ืืกืืืื ืชืืื ืืช ืืฆืืืจืืช ืืื, Quick Draw Doodle Recognition, ืฉืื, ืืื ืืืชืจ, ืืงื ืืืง ืฆืืืช ืฉื ืืืขื ื R:
ืืคืขื ืื ืื ืืกืชืืจ ืขื ืืงืืืืช ืืืืืืช, ืืื ื ืฆืืจ ืืจืื ื ืืกืืื ืจื ืขืจื, ืื ืื ื ืจืืฆื ืืกืคืจ ืืงืืืื ืขื ืืื ืืืืืจืื ืืืขื ืืื ืื ืืืฉืืืืฉืืื ืืืืชืจ ืืงืืื ืืืขืืืื ืืืืืืืืืช. ืืื ืื ืืฉืืื ืฉื ืืื ื: ืืืื ืงืฉืื ืืื OpenCV, ื ืืชืื JSON (ืืืืืืืช ืืื ืืืื ืืช ืืช ืืฉืืืื ืฉื ืงืื C++ ืืชืื ืกืงืจืืคืืื ืื ืืืืืืช ื-R ืืืืฆืขืืช Rcpp), ืคืจืืืจืืืฆืื ืฉื ืกืงืจืืคืืื ืืืืงืจืืืฆืื ืฉื ืืคืชืจืื ืืกืืคื. ืื ืืงืื ืืืืืืขื ืืฆืืจื ืืชืืืื ืืืืฆืืข ืืืื ื
ืชืืื:
ืืขื ื ืชืื ืื ื-CSV ืืืขืืืืช ืื MonetDB ืืื ืช ืงืืืฆืืช ืืืืจืืืจืื ืืคืจืืงืช ืืฆืืืช ืืืกื ืื ืชืื ืื ืืืืจืช ืืจืืืืงืืืจืช ืืืื ืคืจืืืจืืืฆืื ืฉื ืกืงืจืืคื ืขืืืื ืฉื ืกืงืจืืคืืื ืฉืืืืฉ ืืืกืคืจ GPUs ื-Google Cloud ืืืงืื ืืกืงื ื
1. ืืขื ื ืชืื ืื ื-CSV ืืืขืืืืช ืืืกื ืื ืชืื ืื ืฉื MonetDB
ืื ืชืื ืื ืืชืืจืืช ืื ืืกืืคืงืื ืื ืืฆืืจื ืฉื ืชืืื ืืช ืืืื ืืช, ืืื ืืฆืืจื ืฉื 340 ืงืืฆื CSV (ืงืืืฅ ืืื ืืื ืืืืงื) ืืืืืืื JSONs ืขื ืงืืืืจืืื ืืืช ื ืงืืืืช. ืขื ืืื ืืืืืจ ื ืงืืืืช ืืื ืืงืืืื, ืื ื ืืงืืืื ืชืืื ื ืกืืคืืช ืืืืื 256x256 ืคืืงืกืืื. ืืื ืื, ืืื ืจืฉืืื ืืฉ ืชืืืืช ืืืฆืืื ืช ืื ืืชืืื ื ืืืืชื ืืืืื ืขื ืืื ืืืกืืื ืฉืฉืืืฉ ืืืื ืืืกืืฃ ืืขืจื ืื ืชืื ืื, ืงืื ืื ืฉืชื ืืืชืืืช ืฉื ืืืื ืช ืืืืืจืื ืฉื ืืืืจ ืืชืืื ื, ืืืื ืืืืืื, ืืืชืืช ืืื ืืฉื ืืืืงื ืืชืืื ืืฉื ืืงืืืฅ. ืืจืกื ืคืฉืืื ืฉื โโืื ืชืื ืื ืืืงืืจืืื ืฉืืงืืช ืืืจืืืื 7.4 ื'ืืื-ืืืื ืื-20 ื'ืืื-ืืืื ืืืืจ ืืคืจืืงื, ืื ืชืื ืื ืืืืืื ืืืืจ ืืคืจืืงื ืชืืคืกืื 240 ื'ืืื-ืืืื. ืืืืจืื ืื ืืืื ืฉืฉืชื ืืืจืกืืืช ืืฉืืืจื ืืช ืืืชื ืฆืืืจืื, ืืืืืจ ืืืจืกื ืืืืื ืืืืชืจืช. ืืื ืืงืจื, ืืืกืื ืฉื 50 ืืืืืื ืชืืื ืืช ืืงืืฆืื ืืจืคืืื ืื ืืฆืืจืช ืืขืจืืื ื ืืฉื ืืื ืืื ืจืืื ืืืืืื ื ืืืื ืืช ืื ืงืืฆื ื-CSV ืืืืจืืืื train_simplified.zip ืืชืื ืืกื ืื ืชืื ืื ืขื ืืืืจ ืืื ืฉื ืชืืื ืืช ืืืืื ืื ืืจืฉ "ืขื ืชื ืืขื" ืขืืืจ ืื ืืฆืืื.
ืืขืจืืช ืืืืืช ืืืื ื ืืืจื ื-DBMS MonetDB, ืืืืืจ ืืืืืฉ ืขืืืจ R ืืืืืื
con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))
ื ืฆืืจื ืืืฆืืจ ืฉืชื ืืืืืืช: ืืืช ืขืืืจ ืื ืื ืชืื ืื, ืืฉื ืืื ืขืืืจ ืืืืข ืฉืืจืืช ืขื ืงืืฆืื ืฉืืืจืื (ืฉืืืืฉื ืื ืืฉืื ืืฉืชืืฉ ืืืฉ ืืืืฉ ืืช ืืชืืืื ืืืืจ ืืืจืืช ืืกืคืจ ืงืืฆืื):
ืืฆืืจืช ืืืืืืช
if (!DBI::dbExistsTable(con, "doodles")) {
DBI::dbCreateTable(
con = con,
name = "doodles",
fields = c(
"countrycode" = "char(2)",
"drawing" = "text",
"key_id" = "bigint",
"recognized" = "bool",
"timestamp" = "timestamp",
"word" = "text"
)
)
}
if (!DBI::dbExistsTable(con, "upload_log")) {
DBI::dbCreateTable(
con = con,
name = "upload_log",
fields = c(
"id" = "serial",
"file_name" = "text UNIQUE",
"uploaded" = "bool DEFAULT false"
)
)
}
ืืืจื ืืืืืจื ืืืืชืจ ืืืขืื ื ืชืื ืื ืืืกื ืื ืชืื ืื ืืืืชื ืืืขืชืืง ืืฉืืจืืช ืงืืฆื CSV ืืืืฆืขืืช SQL - command COPY OFFSET 2 INTO tablename FROM path USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORT
ืืืคื tablename
- ืฉื ืืืื ื path
- ืื ืชืื ืืงืืืฅ. ืชืื ืืื ืขืืืื ืขื ืืืจืืืื, ืืชืืื ืื ืืืืฉืื ืืืืื ื unzip
in R ืื ืขืืื ืืืืื ืขื ืืกืคืจ ืงืืฆืื ืืืืจืืืื, ืื ืืฉืชืืฉื ื ืืืขืจืืช unzip
(ืืืืฆืขืืช ืืคืจืืืจ getOption("unzip")
).
ืคืื ืงืฆืื ืืืชืืื ืืืกื ืื ืชืื ืื
#' @title ะะทะฒะปะตัะตะฝะธะต ะธ ะทะฐะณััะทะบะฐ ัะฐะนะปะพะฒ
#'
#' @description
#' ะะทะฒะปะตัะตะฝะธะต CSV-ัะฐะนะปะพะฒ ะธะท ZIP-ะฐัั
ะธะฒะฐ ะธ ะทะฐะณััะทะบะฐ ะธั
ะฒ ะฑะฐะทั ะดะฐะฝะฝัั
#'
#' @param con ะะฑัะตะบั ะฟะพะดะบะปััะตะฝะธั ะบ ะฑะฐะทะต ะดะฐะฝะฝัั
(ะบะปะฐัั `MonetDBEmbeddedConnection`).
#' @param tablename ะะฐะทะฒะฐะฝะธะต ัะฐะฑะปะธัั ะฒ ะฑะฐะทะต ะดะฐะฝะฝัั
.
#' @oaram zipfile ะััั ะบ ZIP-ะฐัั
ะธะฒั.
#' @oaram filename ะะผั ัะฐะนะปะฐ ะฒะฝััะธ ZIP-ะฐัั
ะธะฒะฐ.
#' @param preprocess ะคัะฝะบัะธั ะฟัะตะดะพะฑัะฐะฑะพัะบะธ, ะบะพัะพัะฐั ะฑัะดะตั ะฟัะธะผะตะฝะตะฝะฐ ะธะทะฒะปะตััะฝะฝะพะผั ัะฐะนะปั.
#' ะะพะปะถะฝะฐ ะฟัะธะฝะธะผะฐัั ะพะดะธะฝ ะฐัะณัะผะตะฝั `data` (ะพะฑัะตะบั `data.table`).
#'
#' @return `TRUE`.
#'
upload_file <- function(con, tablename, zipfile, filename, preprocess = NULL) {
# ะัะพะฒะตัะบะฐ ะฐัะณัะผะตะฝัะพะฒ
checkmate::assert_class(con, "MonetDBEmbeddedConnection")
checkmate::assert_string(tablename)
checkmate::assert_string(filename)
checkmate::assert_true(DBI::dbExistsTable(con, tablename))
checkmate::assert_file_exists(zipfile, access = "r", extension = "zip")
checkmate::assert_function(preprocess, args = c("data"), null.ok = TRUE)
# ะะทะฒะปะตัะตะฝะธะต ัะฐะนะปะฐ
path <- file.path(tempdir(), filename)
unzip(zipfile, files = filename, exdir = tempdir(),
junkpaths = TRUE, unzip = getOption("unzip"))
on.exit(unlink(file.path(path)))
# ะัะธะผะตะฝัะตะผ ััะฝะบัะธั ะฟัะตะดะพะฑัะฐะฑะพัะบะธ
if (!is.null(preprocess)) {
.data <- data.table::fread(file = path)
.data <- preprocess(data = .data)
data.table::fwrite(x = .data, file = path, append = FALSE)
rm(.data)
}
# ะะฐะฟัะพั ะบ ะะ ะฝะฐ ะธะผะฟะพัั CSV
sql <- sprintf(
"COPY OFFSET 2 INTO %s FROM '%s' USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORT",
tablename, path
)
# ะัะฟะพะปะฝะตะฝะธะต ะทะฐะฟัะพัะฐ ะบ ะะ
DBI::dbExecute(con, sql)
# ะะพะฑะฐะฒะปะตะฝะธะต ะทะฐะฟะธัะธ ะพะฑ ััะฟะตัะฝะพะน ะทะฐะณััะทะบะต ะฒ ัะปัะถะตะฑะฝัั ัะฐะฑะปะธัั
DBI::dbExecute(con, sprintf("INSERT INTO upload_log(file_name, uploaded) VALUES('%s', true)",
filename))
return(invisible(TRUE))
}
ืื ืืชื ืฆืจืื ืืฉื ืืช ืืช ืืืืื ืืคื ื ืืชืืืชื ืืืกื ืื ืชืื ืื, ืื ืืกืคืืง ืืื ืืืขืืืจ ืืช ืืืจืืืื ื preprocess
ืคืื ืงืฆืื ืฉืชืฉื ื ืืช ืื ืชืื ืื.
ืงืื ืืืขืื ืช ื ืชืื ืื ืืจืฆืฃ ืืืกื ืื ืชืื ืื:
ืืชืืืช ื ืชืื ืื ืืืกื ืื ืชืื ืื
# ะกะฟะธัะพะบ ัะฐะนะปะพะฒ ะดะปั ะทะฐะฟะธัะธ
files <- unzip(zipfile, list = TRUE)$Name
# ะกะฟะธัะพะบ ะธัะบะปััะตะฝะธะน, ะตัะปะธ ัะฐััั ัะฐะนะปะพะฒ ัะถะต ะฑัะปะฐ ะทะฐะณััะถะตะฝะฐ
to_skip <- DBI::dbGetQuery(con, "SELECT file_name FROM upload_log")[[1L]]
files <- setdiff(files, to_skip)
if (length(files) > 0L) {
# ะะฐะฟััะบะฐะตะผ ัะฐะนะผะตั
tictoc::tic()
# ะัะพะณัะตัั ะฑะฐั
pb <- txtProgressBar(min = 0L, max = length(files), style = 3)
for (i in seq_along(files)) {
upload_file(con = con, tablename = "doodles",
zipfile = zipfile, filename = files[i])
setTxtProgressBar(pb, i)
}
close(pb)
# ะััะฐะฝะฐะฒะปะธะฒะฐะตะผ ัะฐะนะผะตั
tictoc::toc()
}
# 526.141 sec elapsed - ะบะพะฟะธัะพะฒะฐะฝะธะต SSD->SSD
# 558.879 sec elapsed - ะบะพะฟะธัะพะฒะฐะฝะธะต USB->SSD
ืืื ืืขืื ืช ืื ืชืื ืื ืขืฉืื ืืืฉืชื ืืช ืืืชืื ืืืืคืืื ื ืืืืืจืืช ืฉื ืืืื ื ืืืฉืืฉ. ืืืงืจื ืฉืื ื, ืงืจืืื ืืืชืืื ืืชืื SSD ืืื ืื ืืืื ื ืืืืง (ืงืืืฅ ืืงืืจ) ื-SSD (DB) ืืืงื ืคืืืช ื-10 ืืงืืช.
ืืืงื ืขืื ืืื ืฉื ืืืช ืืืฆืืจ ืขืืืื ืขื ืชืืืืช ืืืืงื ืฉืืื ืืขืืืืช ืืื ืืงืก (ORDERED INDEX
) ืขื ืืกืคืจื ืฉืืจื ืฉืืคืืื ืืืืืื ืชืฆืคืืืช ืืขืช ืืฆืืจืช ืืฆืืืช:
ืืฆืืจืช ืขืืืืืช ืืืื ืืงืก ื ืืกืคืื
message("Generate lables")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD label_int int"))
invisible(DBI::dbExecute(con, "UPDATE doodles SET label_int = dense_rank() OVER (ORDER BY word) - 1"))
message("Generate row numbers")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD id serial"))
invisible(DBI::dbExecute(con, "CREATE ORDERED INDEX doodles_id_ord_idx ON doodles(id)"))
ืืื ืืคืชืืจ ืืช ืืืขืื ืฉื ืืฆืืจืช ืืฆืืื ืชืื ืืื ืชื ืืขื, ืืืื ื ืฆืจืืืื ืืืฉืื ืืช ืืืืืจืืช ืืืจืืืช ืฉื ืืืืืฅ ืฉืืจืืช ืืงืจืืืืช ืืืืืื doodles
. ืืฉื ืื ืืฉืชืืฉื ื ื-3 ืืจืืงืื. ืืจืืฉืื ืืื ืืืคืืืช ืืช ืืืืืืืช ืฉื ืืกืื ืืืืืกื ืืช ืืืื ืืชืฆืคืืช. ืืืขืจื ืื ืชืื ืื ืืืงืืจื, ืืกืื ืื ืืจืฉ ืืืืกืื ืืืืื ืืื bigint
, ืื ืืกืคืจ ืืชืฆืคืืืช ืืืคืฉืจ ืืืชืืื ืืช ืืืืืื ืฉืืื, ืืฉืืืื ืืืกืคืจ ืืกืืืืจื, ืืกืื int
. ืืืืคืืฉ ืืื ืืจืื ืืืชืจ ืืืืจ ืืืงืจื ืื. ืืืจืืง ืืฉื ื ืืื ืืืฉืชืืฉ ORDERED INDEX
- ืืืขื ื ืืืืืื ืืื ืืืืคื ืืืคืืจื, ืืืืจ ืฉืขืืจื ื ืืช ืื ืืืคืฉืจืืืืช PREPARE
ืขื ืฉืืืืฉ ืืืืจ ืืื ืืืืืื ืืืื ืืขืช โโืืฆืืจืช ืฆืจืืจ ืฉืืืืชืืช ืืืืชื ืกืื, ืื ืืืขืฉื ืืฉ ืืชืจืื ืืืฉืืืื ืืฉืืืืชื ืคืฉืืื SELECT
ืืชืืจืจ ืื ืืฆื ืืืืื ืืฉืืืื ืืกืืืืกืืืช.
ืชืืืื ืืขืืืช ืื ืชืื ืื ืฆืืจื ืื ืืืชืจ ื-450 MB ืฉื ืืืืจืื RAM. ืืืืืจ, ืืืืฉื ืืืชืืืจืช ืืืคืฉืจืช ืืืขืืืจ ืืขืจืื ื ืชืื ืื ืืืฉืงื ืฉื ืขืฉืจืืช ืืืื-ืืืื ืืืขื ืืื ืืืืจื ืชืงืฆืืืืช, ืืืื ืืื ืืืฉืืจืื ืขื ืืื ืืืื, ืืื ืื ืืื ืื.
ืื ืื ืฉื ืืชืจ ืืื ืืืืื ืืช ืืืืจืืช ืืืืืจ ืื ืชืื ืื (ืืงืจืื) ืืืืขืจืื ืืช ืงื ื ืืืืื ืืขืช ืืืืืช ืงืืืฆืืช ืืืืืื ืฉืื ืื:
ืจืฃ ืืกื ื ืชืื ืื
library(ggplot2)
set.seed(0)
# ะะพะดะบะปััะตะฝะธะต ะบ ะฑะฐะทะต ะดะฐะฝะฝัั
con <- DBI::dbConnect(MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))
# ะคัะฝะบัะธั ะดะปั ะฟะพะดะณะพัะพะฒะบะธ ะทะฐะฟัะพัะฐ ะฝะฐ ััะพัะพะฝะต ัะตัะฒะตัะฐ
prep_sql <- function(batch_size) {
sql <- sprintf("PREPARE SELECT id FROM doodles WHERE id IN (%s)",
paste(rep("?", batch_size), collapse = ","))
res <- DBI::dbSendQuery(con, sql)
return(res)
}
# ะคัะฝะบัะธั ะดะปั ะธะทะฒะปะตัะตะฝะธั ะดะฐะฝะฝัั
fetch_data <- function(rs, batch_size) {
ids <- sample(seq_len(n), batch_size)
res <- DBI::dbFetch(DBI::dbBind(rs, as.list(ids)))
return(res)
}
# ะัะพะฒะตะดะตะฝะธะต ะทะฐะผะตัะฐ
res_bench <- bench::press(
batch_size = 2^(4:10),
{
rs <- prep_sql(batch_size)
bench::mark(
fetch_data(rs, batch_size),
min_iterations = 50L
)
}
)
# ะะฐัะฐะผะตััั ะฑะตะฝัะผะฐัะบะฐ
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]
# batch_size min median max `itr/sec` total_time n_itr
# <dbl> <bch:tm> <bch:tm> <bch:tm> <dbl> <bch:tm> <int>
# 1 16 23.6ms 54.02ms 93.43ms 18.8 2.6s 49
# 2 32 38ms 84.83ms 151.55ms 11.4 4.29s 49
# 3 64 63.3ms 175.54ms 248.94ms 5.85 8.54s 50
# 4 128 83.2ms 341.52ms 496.24ms 3.00 16.69s 50
# 5 256 232.8ms 653.21ms 847.44ms 1.58 31.66s 50
# 6 512 784.6ms 1.41s 1.98s 0.740 1.1m 49
# 7 1024 681.7ms 2.72s 4.06s 0.377 2.16m 49
ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
geom_point() +
geom_line() +
ylab("median time, s") +
theme_minimal()
DBI::dbDisconnect(con, shutdown = TRUE)
2. ืืื ืช ืื ืืช
ืื ืชืืืื ืืื ืช ืืืฆืืื ืืืจืื ืืืฉืืืื ืืืืื:
- ื ืืชืื ืืกืคืจ JSONs ืืืืืืื ืืงืืืจืื ืฉื ืืืจืืืืช ืขื ืงืืืืจืืื ืืืช ืฉื ื ืงืืืืช.
- ืฆืืืจ ืงืืืื ืฆืืขืื ืืื ืขื ืกืื ืงืืืืจืืื ืืืช ืฉื ื ืงืืืืช ืขื ืชืืื ื ืืืืื ืื ืืจืฉ (ืืืืืื, 256ร256 ืื 128ร128).
- ืืืจืช ืืชืืื ืืช ืืืชืงืืืืช ืืื ืืืจ.
ืืืืง ืืืชืืจืืช ืืื ืืจืขืื ื Python, ืืืขืื ื ืคืชืจื ืืขืืงืจ ืืืืฆืขืืช OpenCV. ืืื ืืื ืืืืื ืืคืฉืืืื ืืืืจืืจืื ืืืืชืจ ื-R ืืืจืื ืื:
ืืืืขืช ืืืจืช JSON ืืื ืืืจ ื-R
r_process_json_str <- function(json, line.width = 3,
color = TRUE, scale = 1) {
# ะะฐััะธะฝะณ JSON
coords <- jsonlite::fromJSON(json, simplifyMatrix = FALSE)
tmp <- tempfile()
# ะฃะดะฐะปัะตะผ ะฒัะตะผะตะฝะฝัะน ัะฐะนะป ะฟะพ ะทะฐะฒะตััะตะฝะธั ััะฝะบัะธะธ
on.exit(unlink(tmp))
png(filename = tmp, width = 256 * scale, height = 256 * scale, pointsize = 1)
# ะัััะพะน ะณัะฐัะธะบ
plot.new()
# ะ ะฐะทะผะตั ะพะบะฝะฐ ะณัะฐัะธะบะฐ
plot.window(xlim = c(256 * scale, 0), ylim = c(256 * scale, 0))
# ะฆะฒะตัะฐ ะปะธะฝะธะน
cols <- if (color) rainbow(length(coords)) else "#000000"
for (i in seq_along(coords)) {
lines(x = coords[[i]][[1]] * scale, y = coords[[i]][[2]] * scale,
col = cols[i], lwd = line.width)
}
dev.off()
# ะัะตะพะฑัะฐะทะพะฒะฐะฝะธะต ะธะทะพะฑัะฐะถะตะฝะธั ะฒ 3-ั
ะผะตัะฝัะน ะผะฐััะธะฒ
res <- png::readPNG(tmp)
return(res)
}
r_process_json_vector <- function(x, ...) {
res <- lapply(x, r_process_json_str, ...)
# ะะฑัะตะดะธะฝะตะฝะธะต 3-ั
ะผะตัะฝัั
ะผะฐััะธะฒะพะฒ ะบะฐััะธะฝะพะบ ะฒ 4-ั
ะผะตัะฝัะน ะฒ ัะตะฝะทะพั
res <- do.call(abind::abind, c(res, along = 0))
return(res)
}
ืืฆืืืจ ืืชืืฆืข ืืืืฆืขืืช ืืื R ืกืื ืืจืืืื ืื ืฉืืจ ื-PNG ืืื ื ืืืืืืกื ื-RAM (ื-Linux, ืกืคืจืืืช R ืืื ืืืช ื ืืฆืืืช ืืกืคืจืืื /tmp
, ืืืชืงื ื-RAM). ืงืืืฅ ืื ื ืงืจื ืืืืจ ืืื ืืืขืจื ืชืืช ืืืืื ืขื ืืกืคืจืื ืื ืขืื ืืื 0 ื-1. ืื ืืฉืื ืืืืืื ืฉ-BMP ืงืื ืื ืฆืืื ืื ืืืชืจ ืืืงืจื ืืืขืจื ืืืืื ืขื ืงืืื ืฆืืข ืืฉืืฉื.
ืืืื ื ืืืืง ืืช ืืชืืฆืื:
zip_file <- file.path("data", "train_simplified.zip")
csv_file <- "cat.csv"
unzip(zip_file, files = csv_file, exdir = tempdir(),
junkpaths = TRUE, unzip = getOption("unzip"))
tmp_data <- data.table::fread(file.path(tempdir(), csv_file), sep = ",",
select = "drawing", nrows = 10000)
arr <- r_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256 3
plot(magick::image_read(arr))
ืืืฆืืื ืขืฆืื ืชืืืืฆืจ ืืืืคื ืืื:
res <- r_process_json_vector(tmp_data[1:4, drawing], scale = 0.5)
str(res)
# num [1:4, 1:128, 1:128, 1:3] 1 1 1 1 1 1 1 1 1 1 ...
# - attr(*, "dimnames")=List of 4
# ..$ : NULL
# ..$ : NULL
# ..$ : NULL
# ..$ : NULL
ืืืฉืื ืื ื ืจืื ืื ื ืื ืืืคืืืืื, ืืืืืื ืฉืืฆืืจืช ืืฆืืืช ืืืืืืช ื ืืฉืืช ืืื ืจื ืืืืคื ืืืื ื, ืืืืืื ื ืื ืฆื ืืช ืื ืืกืืื ืฉื ืขืืืชืื ื ืขื ืืื ืฉืืืืฉ ืืกืคืจืืื ืจืืช ืขืืฆืื OpenCV. ืืืืชื ืชืงืืคื ืื ืืืืชื ืืืืื ืืืื ื ืขืืืจ R (ืืื ืืื ืืขืช), ืืืื ืืืฉืื ืืื ืืืื ืฉื ืืคืื ืงืฆืืื ืืืืช ืื ืืจืฉืช ื ืืชื ื-C++ ืขื ืืื ืืืจืฆืื ืืงืื R ืืืืฆืขืืช Rcpp.
ืืื ืืคืชืืจ ืืช ืืืขืื, ื ืขืฉื ืฉืืืืฉ ืืืืืืืช ืืืกืคืจืืืช ืืืืืช:
-
OpenCV ืืขืืืื ืขื ืชืืื ืืช ืืฆืืืจ ืงืืืื. ืืฉืชืืฉื ืืกืคืจืืืช ืืขืจืืช ืืงืืืฆื ืืืชืจืืช ืืืชืงื ืืช ืืจืืฉ, ืืื ืื ืงืืฉืืจ ืืื ืื.
-
xtensor ืืขืืืื ืขื ืืขืจืืื ืืื ืืืจืื ืจื ืืืืืืื. ืืฉืชืืฉื ื ืืงืืฆื ืืืชืจืืช ืืืืืืื ืืืืืืช R ืืืืชื ืฉื. ืืกืคืจืืื ืืืคืฉืจืช ืื ืืขืืื ืขื ืืขืจืืื ืจื ืืืืืืื, ืื ืืกืืจ ืขืืงืจื ืืฉืืจื ืืื ืืกืืจ ืขืืืื ืขืืงืจื.
-
ndjson ืื ืืชืื JSON. ืืกืคืจืืื ืืื ืืฉืืฉืช ื xtensor ืืืืคื ืืืืืืื ืื ืืื ืงืืื ืืคืจืืืงื.
-
RcppThread ืืืจืืื ืขืืืื ืืจืืื ืืืืื ืืงืืืจ ื-JSON. ืืฉืชืืฉ ืืงืืฆื ืืืืชืจืืช ืฉืกืืคืงื ืขื ืืื ืืืืื ืื. ืืคืืคืืืจื ืืืชืจ RcppParallel ืืืืืื, ืืื ืืืชืจ, ืื ืื ืื ืืคืกืงืช ืืืืื ืืืื ื.
ืืฉ ืืฆืืื ืื xtensor ืืชืืจืจ ืืืชื ื ืืฉืืื: ืื ืืกืฃ ืืขืืืื ืฉืืฉ ืื ืคืื ืงืฆืืื ืืืืช ื ืจืืืช ืืืืฆืืขืื ืืืืืื, ืืชืืจืจ ืฉืืืคืชืืื ืฉืื ืืืืืื ืืืื ืืขื ื ืขื ืฉืืืืช ืืืืืจืืช ืืืคืืจืื. ืืขืืจืชื, ื ืืชื ืืื ืืืืฉื ืืจื ืกืคืืจืืฆืืืช ืฉื ืืืจืืฆืืช OpenCV ืืื ืกืืจื xtensor, ืืื ืืจื ืืฉืื ืื ืกืืจ ืชืืื ื ืชืืช ืืืืืืช ืืื ืืืจ 3 ืืืืื ืืืื ืื ืืื (ืืืฆืืื ืขืฆืื).
ืืืืจืื ืืืืืื Rcpp, xtensor ื-RcppThread
ืืื ืืืืจ ืงืืฆืื ืืืฉืชืืฉืื ืืงืืฆื ืืขืจืืช ืืงืืฉืืจ ืืื ืื ืขื ืกืคืจืืืช ืืืืชืงื ืืช ืืืขืจืืช, ืืฉืชืืฉื ื ืืื ืื ืื ืืคืืืืื ืืืืืฉื ืืืืืื Rcpp. ืืื ืืืฆืื ืืืืคื ืืืืืืื ื ืชืืืื ืืืืืื, ืืฉืชืืฉื ื ืืืื ืขืืจ ืคืืคืืืจื ืฉื ืืื ืืงืก pkg-config.
ืืืฉืื ืชืืกืฃ Rcpp ืืฉืืืืฉ ืืกืคืจืืืช OpenCV
Rcpp::registerPlugin("opencv", function() {
# ะะพะทะผะพะถะฝัะต ะฝะฐะทะฒะฐะฝะธั ะฟะฐะบะตัะฐ
pkg_config_name <- c("opencv", "opencv4")
# ะะธะฝะฐัะฝัะน ัะฐะนะป ััะธะปะธัั pkg-config
pkg_config_bin <- Sys.which("pkg-config")
# ะัะพะฒัะตะบะฐ ะฝะฐะปะธัะธั ััะธะปะธัั ะฒ ัะธััะตะผะต
checkmate::assert_file_exists(pkg_config_bin, access = "x")
# ะัะพะฒะตัะบะฐ ะฝะฐะปะธัะธั ัะฐะนะปะฐ ะฝะฐัััะพะตะบ OpenCV ะดะปั pkg-config
check <- sapply(pkg_config_name,
function(pkg) system(paste(pkg_config_bin, pkg)))
if (all(check != 0)) {
stop("OpenCV config for the pkg-config not found", call. = FALSE)
}
pkg_config_name <- pkg_config_name[check == 0]
list(env = list(
PKG_CXXFLAGS = system(paste(pkg_config_bin, "--cflags", pkg_config_name),
intern = TRUE),
PKG_LIBS = system(paste(pkg_config_bin, "--libs", pkg_config_name),
intern = TRUE)
))
})
ืืชืืฆืื ืืคืขืืืช ืืชืืกืฃ, ืืขืจืืื ืืืืื ืืืืืคื ืืืืื ืชืืืื ืืืืืืจ:
Rcpp:::.plugins$opencv()$env
# $PKG_CXXFLAGS
# [1] "-I/usr/include/opencv"
#
# $PKG_LIBS
# [1] "-lopencv_shape -lopencv_stitching -lopencv_superres -lopencv_videostab -lopencv_aruco -lopencv_bgsegm -lopencv_bioinspired -lopencv_ccalib -lopencv_datasets -lopencv_dpm -lopencv_face -lopencv_freetype -lopencv_fuzzy -lopencv_hdf -lopencv_line_descriptor -lopencv_optflow -lopencv_video -lopencv_plot -lopencv_reg -lopencv_saliency -lopencv_stereo -lopencv_structured_light -lopencv_phase_unwrapping -lopencv_rgbd -lopencv_viz -lopencv_surface_matching -lopencv_text -lopencv_ximgproc -lopencv_calib3d -lopencv_features2d -lopencv_flann -lopencv_xobjdetect -lopencv_objdetect -lopencv_ml -lopencv_xphoto -lopencv_highgui -lopencv_videoio -lopencv_imgcodecs -lopencv_photo -lopencv_imgproc -lopencv_core"
ืงืื ืืืืฉืื ืื ืืชืื JSON ืืืฆืืจืช ืืฆืืื ืืฉืืืืจ ืืืื ื ืืชื ืชืืช ืืกืคืืืืจ. ืจืืฉืืช, ืืืกืฃ ืกืคืจืืืช ืคืจืืืงื ืืงืืืืช ืืื ืืืคืฉ ืงืืฆื ืืืชืจืืช (ืืจืืฉ ืขืืืจ ndjson):
Sys.setenv("PKG_CXXFLAGS" = paste0("-I", normalizePath(file.path("src"))))
ืืืืขืช ืืืจืช JSON ืืื ืืืจ ื-C++
// [[Rcpp::plugins(cpp14)]]
// [[Rcpp::plugins(opencv)]]
// [[Rcpp::depends(xtensor)]]
// [[Rcpp::depends(RcppThread)]]
#include <xtensor/xjson.hpp>
#include <xtensor/xadapt.hpp>
#include <xtensor/xview.hpp>
#include <xtensor-r/rtensor.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <Rcpp.h>
#include <RcppThread.h>
// ะกะธะฝะพะฝะธะผั ะดะปั ัะธะฟะพะฒ
using RcppThread::parallelFor;
using json = nlohmann::json;
using points = xt::xtensor<double,2>; // ะะทะฒะปะตััะฝะฝัะต ะธะท JSON ะบะพะพัะดะธะฝะฐัั ัะพัะตะบ
using strokes = std::vector<points>; // ะะทะฒะปะตััะฝะฝัะต ะธะท JSON ะบะพะพัะดะธะฝะฐัั ัะพัะตะบ
using xtensor3d = xt::xtensor<double, 3>; // ะขะตะฝะทะพั ะดะปั ั
ัะฐะฝะตะฝะธั ะผะฐััะธัั ะธะทะพะพะฑัะฐะถะตะฝะธั
using xtensor4d = xt::xtensor<double, 4>; // ะขะตะฝะทะพั ะดะปั ั
ัะฐะฝะตะฝะธั ะผะฝะพะถะตััะฒะฐ ะธะทะพะฑัะฐะถะตะฝะธะน
using rtensor3d = xt::rtensor<double, 3>; // ะะฑัััะบะฐ ะดะปั ัะบัะฟะพััะฐ ะฒ R
using rtensor4d = xt::rtensor<double, 4>; // ะะฑัััะบะฐ ะดะปั ัะบัะฟะพััะฐ ะฒ R
// ะกัะฐัะธัะตัะบะธะต ะบะพะฝััะฐะฝัั
// ะ ะฐะทะผะตั ะธะทะพะฑัะฐะถะตะฝะธั ะฒ ะฟะธะบัะตะปัั
const static int SIZE = 256;
// ะขะธะฟ ะปะธะฝะธะธ
// ะกะผ. https://en.wikipedia.org/wiki/Pixel_connectivity#2-dimensional
const static int LINE_TYPE = cv::LINE_4;
// ะขะพะปัะธะฝะฐ ะปะธะฝะธะธ ะฒ ะฟะธะบัะตะปัั
const static int LINE_WIDTH = 3;
// ะะปะณะพัะธัะผ ัะตัะฐะนะทะฐ
// https://docs.opencv.org/3.1.0/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
const static int RESIZE_TYPE = cv::INTER_LINEAR;
// ะจะฐะฑะปะพะฝ ะดะปั ะบะพะฝะฒะตััะธัะพะฒะฐะฝะธั OpenCV-ะผะฐััะธัั ะฒ ัะตะฝะทะพั
template <typename T, int NCH, typename XT=xt::xtensor<T,3,xt::layout_type::column_major>>
XT to_xt(const cv::Mat_<cv::Vec<T, NCH>>& src) {
// ะ ะฐะทะผะตัะฝะพััั ัะตะปะตะฒะพะณะพ ัะตะฝะทะพัะฐ
std::vector<int> shape = {src.rows, src.cols, NCH};
// ะะฑัะตะต ะบะพะปะธัะตััะฒะพ ัะปะตะผะตะฝัะพะฒ ะฒ ะผะฐััะธะฒะต
size_t size = src.total() * NCH;
// ะัะตะพะฑัะฐะทะพะฒะฐะฝะธะต cv::Mat ะฒ xt::xtensor
XT res = xt::adapt((T*) src.data, size, xt::no_ownership(), shape);
return res;
}
// ะัะตะพะฑัะฐะทะพะฒะฐะฝะธะต JSON ะฒ ัะฟะธัะพะบ ะบะพะพัะดะธะฝะฐั ัะพัะตะบ
strokes parse_json(const std::string& x) {
auto j = json::parse(x);
// ะ ะตะทัะปััะฐั ะฟะฐััะธะฝะณะฐ ะดะพะปะถะตะฝ ะฑััั ะผะฐััะธะฒะพะผ
if (!j.is_array()) {
throw std::runtime_error("'x' must be JSON array.");
}
strokes res;
res.reserve(j.size());
for (const auto& a: j) {
// ะะฐะถะดัะน ัะปะตะผะตะฝั ะผะฐััะธะฒะฐ ะดะพะปะถะตะฝ ะฑััั 2-ะผะตัะฝัะผ ะผะฐััะธะฒะพะผ
if (!a.is_array() || a.size() != 2) {
throw std::runtime_error("'x' must include only 2d arrays.");
}
// ะะทะฒะปะตัะตะฝะธะต ะฒะตะบัะพัะฐ ัะพัะตะบ
auto p = a.get<points>();
res.push_back(p);
}
return res;
}
// ะััะธัะพะฒะบะฐ ะปะธะฝะธะน
// ะฆะฒะตัะฐ HSV
cv::Mat ocv_draw_lines(const strokes& x, bool color = true) {
// ะัั
ะพะดะฝัะน ัะธะฟ ะผะฐััะธัั
auto stype = color ? CV_8UC3 : CV_8UC1;
// ะัะพะณะพะฒัะน ัะธะฟ ะผะฐััะธัั
auto dtype = color ? CV_32FC3 : CV_32FC1;
auto bg = color ? cv::Scalar(0, 0, 255) : cv::Scalar(255);
auto col = color ? cv::Scalar(0, 255, 220) : cv::Scalar(0);
cv::Mat img = cv::Mat(SIZE, SIZE, stype, bg);
// ะะพะปะธัะตััะฒะพ ะปะธะฝะธะน
size_t n = x.size();
for (const auto& s: x) {
// ะะพะปะธัะตััะฒะพ ัะพัะตะบ ะฒ ะปะธะฝะธะธ
size_t n_points = s.shape()[1];
for (size_t i = 0; i < n_points - 1; ++i) {
// ะขะพัะบะฐ ะฝะฐัะฐะปะฐ ัััะธั
ะฐ
cv::Point from(s(0, i), s(1, i));
// ะขะพัะบะฐ ะพะบะพะฝัะฐะฝะธั ัััะธั
ะฐ
cv::Point to(s(0, i + 1), s(1, i + 1));
// ะััะธัะพะฒะบะฐ ะปะธะฝะธะธ
cv::line(img, from, to, col, LINE_WIDTH, LINE_TYPE);
}
if (color) {
// ะะตะฝัะตะผ ัะฒะตั ะปะธะฝะธะธ
col[0] += 180 / n;
}
}
if (color) {
// ะะตะฝัะตะผ ัะฒะตัะพะฒะพะต ะฟัะตะดััะฐะฒะปะตะฝะธะต ะฝะฐ RGB
cv::cvtColor(img, img, cv::COLOR_HSV2RGB);
}
// ะะตะฝัะตะผ ัะพัะผะฐั ะฟัะตะดััะฐะฒะปะตะฝะธั ะฝะฐ float32 ั ะดะธะฐะฟะฐะทะพะฝะพะผ [0, 1]
img.convertTo(img, dtype, 1 / 255.0);
return img;
}
// ะะฑัะฐะฑะพัะบะฐ JSON ะธ ะฟะพะปััะตะฝะธะต ัะตะฝะทะพัะฐ ั ะดะฐะฝะฝัะผะธ ะธะทะพะฑัะฐะถะตะฝะธั
xtensor3d process(const std::string& x, double scale = 1.0, bool color = true) {
auto p = parse_json(x);
auto img = ocv_draw_lines(p, color);
if (scale != 1) {
cv::Mat out;
cv::resize(img, out, cv::Size(), scale, scale, RESIZE_TYPE);
cv::swap(img, out);
out.release();
}
xtensor3d arr = color ? to_xt<double,3>(img) : to_xt<double,1>(img);
return arr;
}
// [[Rcpp::export]]
rtensor3d cpp_process_json_str(const std::string& x,
double scale = 1.0,
bool color = true) {
xtensor3d res = process(x, scale, color);
return res;
}
// [[Rcpp::export]]
rtensor4d cpp_process_json_vector(const std::vector<std::string>& x,
double scale = 1.0,
bool color = false) {
size_t n = x.size();
size_t dim = floor(SIZE * scale);
size_t channels = color ? 3 : 1;
xtensor4d res({n, dim, dim, channels});
parallelFor(0, n, [&x, &res, scale, color](int i) {
xtensor3d tmp = process(x[i], scale, color);
auto view = xt::view(res, i, xt::all(), xt::all(), xt::all());
view = tmp;
});
return res;
}
ืืฉ ืืืงื ืืช ืืงืื ืืื ืืงืืืฅ src/cv_xt.cpp
ืืืืืืจ ืขื ืืคืงืืื Rcpp::sourceCpp(file = "src/cv_xt.cpp", env = .GlobalEnv)
; ื ืืจืฉ ืื ืืขืืืื nlohmann/json.hpp
ืฉื
-
to_xt
- ืคืื ืงืฆืื ืืชืื ืืช ืืืคืืืช ืืืจืืฆืช ืชืืื ื (cv::Mat
) ืืื ืืืจxt::xtensor
; -
parse_json
- ืืคืื ืงืฆืื ืื ืชืืช ืืืจืืืช JSON, ืืืืฆืช ืืช ืืงืืืืจืืื ืืืช ืฉื ื ืงืืืืช, ืืืจืืช ืืืชื ืืืงืืืจ; -
ocv_draw_lines
- ืืืืืงืืืจ ืืืชืงืื ืฉื ื ืงืืืืช, ืืฆืืืจ ืงืืืื ืจื ืฆืืขืื ืืื; -
process
- ืืฉืื ืืช ืืคืื ืงืฆืืืช ืืขืื ืืืืกืืฃ ืื ืืช ืืืืืืช ืืฉื ืืช ืืช ืงื ื ืืืืื ืฉื ืืชืืื ื ืืืชืงืืืช; -
cpp_process_json_str
- ืขืืืฃ ืืช ืืคืื ืงืฆืืprocess
, ืืืืืฆื ืืช ืืชืืฆืื ื-R-object (ืืขืจื ืจื-ืืืื); -
cpp_process_json_vector
- ืขืืืฃ ืืช ืืคืื ืงืฆืืcpp_process_json_str
, ืืืืคืฉืจ ืื ืืขืื ืืงืืืจ ืืืจืืืช ืืืฆื ืจืืืื ืืืืื.
ืืื ืืฆืืืจ ืงืืืื ืืจืืื ืฆืืขืื, ื ืขืฉื ืฉืืืืฉ ืืืืื ืืฆืืข HSV, ืืืืืจ ืืื ืืืจื ื-RGB. ืืืื ื ืืืืง ืืช ืืชืืฆืื:
arr <- cpp_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256 3
plot(magick::image_read(arr))
ืืฉืืืื ืฉื ืืืืจืืช ืืืืืขืืช ื-R ื-C++
res_bench <- bench::mark(
r_process_json_str(tmp_data[4, drawing], scale = 0.5),
cpp_process_json_str(tmp_data[4, drawing], scale = 0.5),
check = FALSE,
min_iterations = 100
)
# ะะฐัะฐะผะตััั ะฑะตะฝัะผะฐัะบะฐ
cols <- c("expression", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]
# expression min median max `itr/sec` total_time n_itr
# <chr> <bch:tm> <bch:tm> <bch:tm> <dbl> <bch:tm> <int>
# 1 r_process_json_str 3.49ms 3.55ms 4.47ms 273. 490ms 134
# 2 cpp_process_json_str 1.94ms 2.02ms 5.32ms 489. 497ms 243
library(ggplot2)
# ะัะพะฒะตะดะตะฝะธะต ะทะฐะผะตัะฐ
res_bench <- bench::press(
batch_size = 2^(4:10),
{
.data <- tmp_data[sample(seq_len(.N), batch_size), drawing]
bench::mark(
r_process_json_vector(.data, scale = 0.5),
cpp_process_json_vector(.data, scale = 0.5),
min_iterations = 50,
check = FALSE
)
}
)
res_bench[, cols]
# expression batch_size min median max `itr/sec` total_time n_itr
# <chr> <dbl> <bch:tm> <bch:tm> <bch:tm> <dbl> <bch:tm> <int>
# 1 r 16 50.61ms 53.34ms 54.82ms 19.1 471.13ms 9
# 2 cpp 16 4.46ms 5.39ms 7.78ms 192. 474.09ms 91
# 3 r 32 105.7ms 109.74ms 212.26ms 7.69 6.5s 50
# 4 cpp 32 7.76ms 10.97ms 15.23ms 95.6 522.78ms 50
# 5 r 64 211.41ms 226.18ms 332.65ms 3.85 12.99s 50
# 6 cpp 64 25.09ms 27.34ms 32.04ms 36.0 1.39s 50
# 7 r 128 534.5ms 627.92ms 659.08ms 1.61 31.03s 50
# 8 cpp 128 56.37ms 58.46ms 66.03ms 16.9 2.95s 50
# 9 r 256 1.15s 1.18s 1.29s 0.851 58.78s 50
# 10 cpp 256 114.97ms 117.39ms 130.09ms 8.45 5.92s 50
# 11 r 512 2.09s 2.15s 2.32s 0.463 1.8m 50
# 12 cpp 512 230.81ms 235.6ms 261.99ms 4.18 11.97s 50
# 13 r 1024 4s 4.22s 4.4s 0.238 3.5m 50
# 14 cpp 1024 410.48ms 431.43ms 462.44ms 2.33 21.45s 50
ggplot(res_bench, aes(x = factor(batch_size), y = median,
group = expression, color = expression)) +
geom_point() +
geom_line() +
ylab("median time, s") +
theme_minimal() +
scale_color_discrete(name = "", labels = c("cpp", "r")) +
theme(legend.position = "bottom")
ืืคื ืฉื ืืชื ืืจืืืช, ืืขืืืช ืืืืืจืืช ืืชืืจืจื ืืืฉืืขืืชืืช ืืืืชืจ, ืืื ื ืืชื ืืืืืืง ืืช ืงืื C++ ืขื ืืื ืืงืืืื ืืงืื R.
3. ืืืืจืืืจืื ืืคืจืืงืช ืืฆืืืช ืืืกื ืื ืชืื ืื
ื-R ืืฉ ืืื ืืืื ืจืืื ืืขืืืื ื ืชืื ืื ืฉืืชืืื ื-RAM, ืืขืื ืฉ-Python ืืชืืคืืื ืืืชืจ ืืขืืืื ื ืชืื ืื ืืืืจืืืื, ืืืืคืฉืจ ืื ืืืืฉื ืืงืืืช ืืืืืขืืืช ืืืฉืืืื ืืืืฅ ืืืืื (ืืืฉืืืื ืืืืฆืขืืช ืืืืจืื ืืืฆืื ื). ืืืืื ืงืืืกืืช ืืจืืืื ืืืช ืขืืืจื ื ืืืงืฉืจ ืืืขืื ืืืชืืืจืช ืืื ืจืฉืชืืช ืขืฆืืืืช ืขืืืงืืช ืืืืืื ืืช ืืฉืืืช ืืืจืืื ืืืจืื ืขื ืงืืจืื ืฉื ืืฉืืคืืข ืืื ืฉืื ืืืืฆืขืืช ืืืง ืงืื ืืืชืฆืคืืืช, ืื ืืื ื-ืืฆื.
ืืืกืืจืืช ืืืืื ืขืืืงื ืฉื ืืชืื ื-Python ืืฉ ืฉืืขืืจืื ืืืืืืื ืฉืืืืฉืืื ืืืืจืืืจืื ืขื ืกืื ื ืชืื ืื: ืืืืืืช, ืชืืื ืืช ืืชืืงืืืช, ืคืืจืืืื ืืื ืืจืื ืืื'. ื ืืชื ืืืฉืชืืฉ ืืืคืฉืจืืืืช ืืืื ืืช ืื ืืืชืื ืืฉืื ืืืฉืืืืช ืกืคืฆืืคืืืช. ื-R ื ืืื ืื ืฆื ืืช ืื ืืชืืื ืืช ืฉื ืกืคืจืืืช Python keras ืขื ืืืืงืื ืืืืืจืืื ืืฉืื ืื ืฉืื ืืืืฆืขืืช ืืืืืื ืืืืชื ืฉื, ืฉืืชืืจื ืคืืขืืช ืขื ืืื ืืืืืื ืืจืฉืช. ืื ืืืืจืื ืจืืื ืืืืืจ ืืจืื ื ืคืจื; ืื ืื ืจืง ืืืคืฉืจ ืื ืืืจืืฅ ืงืื Python ื-R, ืืื ืื ืืืคืฉืจ ืื ืืืขืืืจ ืืืืืืงืืื ืืื ืืคืขืืืช R ื- Python, ืชืื ืืืฆืืข ืืืืืืื ืฉื ืื ืืืจืืช ืืกืื ืืืจืืฉืืช.
ื ืคืืจื ื ืืืฆืืจื ืืืืกื ืืช ืื ืื ืชืื ืื ื-RAM ืืืืฆืขืืช MonetDBLite, ืื ืขืืืืช ื"ืจืฉืช ืืขืฆืืืช" ืชืชืืฆืข ืขื ืืื ืืงืื ืืืงืืจื ื-Python, ืื ืื ื ืจืง ืฆืจืืืื ืืืชืื ืืืืจืืืจ ืขื ืื ืชืื ืื, ืืืืืื ืฉืืื ืฉืื ืืืจ ืืืื ืืืฆื ืืื ื-R ืื ื-Python. ืืืขืฉื ืืฉ ืจืง ืฉืชื ืืจืืฉืืช ืขืืืจื: ืขืืื ืืืืืืจ ืืฆืืื ืืืืืื ืืื ืกืืคืืช ืืืฉืืืจ ืืช ืืฆืื ืืื ืืืืจืฆืืืช (ืืืืจืื ื-R ืืืืฉื ืืฆืืจื ืืคืฉืืื ืืืืชืจ ืืืืฆืขืืช ืกืืืจืืช). ืืขืืจ, ื ืืจืฉ ืืืืืจ ืืืคืืจืฉ ืืขืจืื R ืืืขืจืืื numpy ืืชืื ืืืืืจืืืจ, ืืื ืืืจืกื ืื ืืืืืช ืฉื ืืืืืื keras ืขืืฉื ืืช ืื ืืขืฆืื.
ืืืืืจืืืจ ืื ืชืื ื ืืืืื ืืชืืงืืฃ ืืชืืจืจ ืืืืงืื:
ืืืืจืืืจ ืื ืชืื ื ืืืจืื ืืืืืืช
train_generator <- function(db_connection = con,
samples_index,
num_classes = 340,
batch_size = 32,
scale = 1,
color = FALSE,
imagenet_preproc = FALSE) {
# ะัะพะฒะตัะบะฐ ะฐัะณัะผะตะฝัะพะฒ
checkmate::assert_class(con, "DBIConnection")
checkmate::assert_integerish(samples_index)
checkmate::assert_count(num_classes)
checkmate::assert_count(batch_size)
checkmate::assert_number(scale, lower = 0.001, upper = 5)
checkmate::assert_flag(color)
checkmate::assert_flag(imagenet_preproc)
# ะะตัะตะผะตัะธะฒะฐะตะผ, ััะพะฑั ะฑัะฐัั ะธ ัะดะฐะปััั ะธัะฟะพะปัะทะพะฒะฐะฝะฝัะต ะธะฝะดะตะบัั ะฑะฐััะตะน ะฟะพ ะฟะพััะดะบั
dt <- data.table::data.table(id = sample(samples_index))
# ะัะพััะฐะฒะปัะตะผ ะฝะพะผะตัะฐ ะฑะฐััะตะน
dt[, batch := (.I - 1L) %/% batch_size + 1L]
# ะััะฐะฒะปัะตะผ ัะพะปัะบะพ ะฟะพะปะฝัะต ะฑะฐััะธ ะธ ะธะฝะดะตะบัะธััะตะผ
dt <- dt[, if (.N == batch_size) .SD, keyby = batch]
# ะฃััะฐะฝะฐะฒะปะธะฒะฐะตะผ ัััััะธะบ
i <- 1
# ะะพะปะธัะตััะฒะพ ะฑะฐััะตะน
max_i <- dt[, max(batch)]
# ะะพะดะณะพัะพะฒะบะฐ ะฒััะฐะถะตะฝะธั ะดะปั ะฒัะณััะทะบะธ
sql <- sprintf(
"PREPARE SELECT drawing, label_int FROM doodles WHERE id IN (%s)",
paste(rep("?", batch_size), collapse = ",")
)
res <- DBI::dbSendQuery(con, sql)
# ะะฝะฐะปะพะณ keras::to_categorical
to_categorical <- function(x, num) {
n <- length(x)
m <- numeric(n * num)
m[x * n + seq_len(n)] <- 1
dim(m) <- c(n, num)
return(m)
}
# ะะฐะผัะบะฐะฝะธะต
function() {
# ะะฐัะธะฝะฐะตะผ ะฝะพะฒัั ัะฟะพั
ั
if (i > max_i) {
dt[, id := sample(id)]
data.table::setkey(dt, batch)
# ะกะฑัะฐััะฒะฐะตะผ ัััััะธะบ
i <<- 1
max_i <<- dt[, max(batch)]
}
# ID ะดะปั ะฒัะณััะทะบะธ ะดะฐะฝะฝัั
batch_ind <- dt[batch == i, id]
# ะัะณััะทะบะฐ ะดะฐะฝะฝัั
batch <- DBI::dbFetch(DBI::dbBind(res, as.list(batch_ind)), n = -1)
# ะฃะฒะตะปะธัะธะฒะฐะตะผ ัััััะธะบ
i <<- i + 1
# ะะฐััะธะฝะณ JSON ะธ ะฟะพะดะณะพัะพะฒะบะฐ ะผะฐััะธะฒะฐ
batch_x <- cpp_process_json_vector(batch$drawing, scale = scale, color = color)
if (imagenet_preproc) {
# ะจะบะฐะปะธัะพะฒะฐะฝะธะต c ะธะฝัะตัะฒะฐะปะฐ [0, 1] ะฝะฐ ะธะฝัะตัะฒะฐะป [-1, 1]
batch_x <- (batch_x - 0.5) * 2
}
batch_y <- to_categorical(batch$label_int, num_classes)
result <- list(batch_x, batch_y)
return(result)
}
}
ืืคืื ืงืฆืื ืืืงืืช ืืงืื ืืฉืชื ื ืขื ืืืืืจ ืืืกื ืื ืชืื ืื, ืืกืคืจ ืืฉืืจืืช ืืฉืืืืฉ, ืืกืคืจ ืืืืืงืืช, ืืืื ืืฆืืื, ืงื ื ืืืื (scale = 1
ืืชืืื ืืขืืืื ืชืืื ืืช ืฉื 256x256 ืคืืงืกืืื, scale = 0.5
- 128x128 ืคืืงืกืืื), ืืืืื ืฆืืข (color = FALSE
ืืฆืืื ืขืืืื ืืืืื ื ืืคืืจ ืืขืช ืฉืืืืฉ color = TRUE
ืื ืงื ืืฆืืืจ ืืฆืืข ืืืฉ) ืืืืืื ืขืืืื ืืงืืื ืืจืฉืชืืช ืฉืืืืฉืจื ืืจืืฉ ื-imagenet. ืืืืจืื ื ืืืฅ ืขื ืื ืช ืืฉื ืืช ืืช ืงื ื ืืืืื ืฉื ืขืจืื ืืคืืงืกืืื ืืืืจืืื [0, 1] ืืืจืืื [-1, 1], ืฉืฉืืืฉ ืืขืช ืืืืื ืืฆืืื ืฉืกืืคืง keras ืืืืืื.
ืืคืื ืงืฆืื ืืืืฆืื ืืช ืืืืื ืืืืงืช ืกืื ืืจืืืื ื, ืืืื data.table
ืขื ืืกืคืจื ืฉืืจื ืืขืืจืืื ืืืงืจืื ื samples_index
ืืืกืคืจื ืืฆืืื, ืืื ื ืืืกืคืจ ืืฆืืื ืืงืกืืืื, ืืื ืืืืื SQL ืืคืจืืงืช ื ืชืื ืื ืืืกื ืื ืชืื ืื. ืื ืืกืฃ, ืืืืจื ื ืื ืืืื ืืืืจ ืฉื ืืคืื ืงืฆืื ืืคื ืื keras::to_categorical()
. ืืฉืชืืฉื ื ืืืขื ืืื ืื ืชืื ืื ืืืืืื, ืืฉืืจื ื ืืฆื ืืืื ืืืืืืช, ืื ืฉืืืื ืืขืืื ืืืืื ืขื ืืื ืืคืจืืืจ steps_per_epoch
ืืฉืงืืจืืื ืื keras::fit_generator()
, ืืืืฆื if (i > max_i)
ืขืื ืจืง ืขืืืจ ืืืืจืืืจ ืืืืืืช.
ืืคืื ืงืฆืื ืืคื ืืืืช ืืืืืจืื ืืื ืืงืกืื ืฉื ืฉืืจืืช ืขืืืจ ืืืฆืืื ืืืื, ืจืฉืืืืช ื ืคืจืงืืช ืืืกื ืื ืชืื ืื ืืืฉืจ ืืื ื ืืืฆืืื ืขืืื, ื ืืชืื JSON (ืคืื ืงืฆืื cpp_process_json_vector()
, ืืชืื ื-C++) ืืืฆืืจืช ืืขืจืืื ืืืชืืืืื ืืชืืื ืืช. ืืืืจ ืืื ื ืืฆืจืื ืืงืืืจืื ืืืื ืขื ืชืืืืืช ืืืืงืืช, ืืขืจืืื ืขื ืขืจืื ืคืืงืกืืื ืืชืืืืืช ืืฉืืืืื ืืจืฉืืื, ืฉืืื ืขืจื ืืืืืจื. ืืื ืืืืืฅ ืืช ืืขืืืื, ืืฉืชืืฉื ื ืืืฆืืจืช ืืื ืืงืกืื ืืืืืืืช data.table
ืืฉืื ืื ืืจื ืืงืืฉืืจ - ืืื "ืฉืืื" ืืืืืืืช ืืืื ืืืืช ื ืชืื ืื ืื ืงืฉื ืืืืืื ืขืืืื ืืขืืื ืขื ืื ืืืืช ืืฉืืขืืชืืช ืฉื ื ืชืื ืื ื-R.
ืืชืืฆืืืช ืฉื ืืืืืืช ืืืืจืืช ืืืืฉื ื ืืื Core i5 ืื ืืืืงืื:
ืจืฃ ืืืืจืืืจ
library(Rcpp)
library(keras)
library(ggplot2)
source("utils/rcpp.R")
source("utils/keras_iterator.R")
con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))
ind <- seq_len(DBI::dbGetQuery(con, "SELECT count(*) FROM doodles")[[1L]])
num_classes <- DBI::dbGetQuery(con, "SELECT max(label_int) + 1 FROM doodles")[[1L]]
# ะะฝะดะตะบัั ะดะปั ะพะฑััะฐััะตะน ะฒัะฑะพัะบะธ
train_ind <- sample(ind, floor(length(ind) * 0.995))
# ะะฝะดะตะบัั ะดะปั ะฟัะพะฒะตัะพัะฝะพะน ะฒัะฑะพัะบะธ
val_ind <- ind[-train_ind]
rm(ind)
# ะะพัััะธัะธะตะฝั ะผะฐัััะฐะฑะฐ
scale <- 0.5
# ะัะพะฒะตะดะตะฝะธะต ะทะฐะผะตัะฐ
res_bench <- bench::press(
batch_size = 2^(4:10),
{
it1 <- train_generator(
db_connection = con,
samples_index = train_ind,
num_classes = num_classes,
batch_size = batch_size,
scale = scale
)
bench::mark(
it1(),
min_iterations = 50L
)
}
)
# ะะฐัะฐะผะตััั ะฑะตะฝัะผะฐัะบะฐ
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]
# batch_size min median max `itr/sec` total_time n_itr
# <dbl> <bch:tm> <bch:tm> <bch:tm> <dbl> <bch:tm> <int>
# 1 16 25ms 64.36ms 92.2ms 15.9 3.09s 49
# 2 32 48.4ms 118.13ms 197.24ms 8.17 5.88s 48
# 3 64 69.3ms 117.93ms 181.14ms 8.57 5.83s 50
# 4 128 157.2ms 240.74ms 503.87ms 3.85 12.71s 49
# 5 256 359.3ms 613.52ms 988.73ms 1.54 30.5s 47
# 6 512 884.7ms 1.53s 2.07s 0.674 1.11m 45
# 7 1024 2.7s 3.83s 5.47s 0.261 2.81m 44
ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
geom_point() +
geom_line() +
ylab("median time, s") +
theme_minimal()
DBI::dbDisconnect(con, shutdown = TRUE)
ืื ืืฉ ืื ืืืืช ืืกืคืงืช ืฉื ืืืืจืื RAM, ืืชื ืืืื ืืืืืฅ ืืจืฆืื ืืช ืืช ืคืขืืืช ืืกื ืื ืชืื ืื ืขื ืืื ืืขืืจืชื ืืืืชื ืืืืจืื RAM (32 ื'ืืื-ืืืื ืืกืคืืงืื ืืืฉืืื ืฉืื ื). ื-Linux, ืืืืืฆื ืืืชืงื ืช ืืืจืืจืช ืืืื /dev/shm
, ืชืืคืก ืขื ืืืฆืืช ืืงืืืืืช ื-RAM. ืืชื ืืืื ืืืืืืฉ ืขืื ืขื ืืื ืขืจืืื /etc/fstab
ืืื ืืงืื ืชืงืืื ืืื tmpfs /dev/shm tmpfs defaults,size=25g 0 0
. ืืงืคื ืืืชืื ืืืืืืง ืืช ืืชืืฆืื ืขื ืืื ืืคืขืืช ืืคืงืืื df -h
.
ืืืืืจืืืจ ืื ืชืื ื ืืืืงื ื ืจืื ืืจืื ืืืชืจ ืคืฉืื, ืืืืืื ืฉืืขืจื ืื ืชืื ืื ืฉื ืืืืืงื ืืชืืื ืืืืืืื ื-RAM:
ืืืืจืืืจ ืื ืชืื ื ืืืืงื
test_generator <- function(dt,
batch_size = 32,
scale = 1,
color = FALSE,
imagenet_preproc = FALSE) {
# ะัะพะฒะตัะบะฐ ะฐัะณัะผะตะฝัะพะฒ
checkmate::assert_data_table(dt)
checkmate::assert_count(batch_size)
checkmate::assert_number(scale, lower = 0.001, upper = 5)
checkmate::assert_flag(color)
checkmate::assert_flag(imagenet_preproc)
# ะัะพััะฐะฒะปัะตะผ ะฝะพะผะตัะฐ ะฑะฐััะตะน
dt[, batch := (.I - 1L) %/% batch_size + 1L]
data.table::setkey(dt, batch)
i <- 1
max_i <- dt[, max(batch)]
# ะะฐะผัะบะฐะฝะธะต
function() {
batch_x <- cpp_process_json_vector(dt[batch == i, drawing],
scale = scale, color = color)
if (imagenet_preproc) {
# ะจะบะฐะปะธัะพะฒะฐะฝะธะต c ะธะฝัะตัะฒะฐะปะฐ [0, 1] ะฝะฐ ะธะฝัะตัะฒะฐะป [-1, 1]
batch_x <- (batch_x - 0.5) * 2
}
result <- list(batch_x)
i <<- i + 1
return(result)
}
}
4. ืืืืจืช ืืจืืืืงืืืจืช ืืืืื
ืืืจืืืืงืืืจื ืืจืืฉืื ื ืฉืื ื ืขืฉื ืฉืืืืฉ ืืืืชื (batch, height, width, 3)
, ืืืืืจ, ืื ื ืืชื ืืฉื ืืช ืืช ืืกืคืจ ืืขืจืืฆืื. ืืื ืืืืื ืืื ื-Python, ืื ืืืืจื ื ืืืชืื ื ืืืฉืื ืืฉืื ื ืฉื ืืืจืืืืงืืืจื ืืื, ืืขืงืืืช ืืืืืจ ืืืงืืจื (ืืื ืื ืฉืืจื ืฉื ืืฆืืช ืืืจืกืช ื-keras):
ืืจืืืืงืืืจืช Mobilenet v1
library(keras)
top_3_categorical_accuracy <- custom_metric(
name = "top_3_categorical_accuracy",
metric_fn = function(y_true, y_pred) {
metric_top_k_categorical_accuracy(y_true, y_pred, k = 3)
}
)
layer_sep_conv_bn <- function(object,
filters,
alpha = 1,
depth_multiplier = 1,
strides = c(2, 2)) {
# NB! depth_multiplier != resolution multiplier
# https://github.com/keras-team/keras/issues/10349
layer_depthwise_conv_2d(
object = object,
kernel_size = c(3, 3),
strides = strides,
padding = "same",
depth_multiplier = depth_multiplier
) %>%
layer_batch_normalization() %>%
layer_activation_relu() %>%
layer_conv_2d(
filters = filters * alpha,
kernel_size = c(1, 1),
strides = c(1, 1)
) %>%
layer_batch_normalization() %>%
layer_activation_relu()
}
get_mobilenet_v1 <- function(input_shape = c(224, 224, 1),
num_classes = 340,
alpha = 1,
depth_multiplier = 1,
optimizer = optimizer_adam(lr = 0.002),
loss = "categorical_crossentropy",
metrics = c("categorical_crossentropy",
top_3_categorical_accuracy)) {
inputs <- layer_input(shape = input_shape)
outputs <- inputs %>%
layer_conv_2d(filters = 32, kernel_size = c(3, 3), strides = c(2, 2), padding = "same") %>%
layer_batch_normalization() %>%
layer_activation_relu() %>%
layer_sep_conv_bn(filters = 64, strides = c(1, 1)) %>%
layer_sep_conv_bn(filters = 128, strides = c(2, 2)) %>%
layer_sep_conv_bn(filters = 128, strides = c(1, 1)) %>%
layer_sep_conv_bn(filters = 256, strides = c(2, 2)) %>%
layer_sep_conv_bn(filters = 256, strides = c(1, 1)) %>%
layer_sep_conv_bn(filters = 512, strides = c(2, 2)) %>%
layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
layer_sep_conv_bn(filters = 1024, strides = c(2, 2)) %>%
layer_sep_conv_bn(filters = 1024, strides = c(1, 1)) %>%
layer_global_average_pooling_2d() %>%
layer_dense(units = num_classes) %>%
layer_activation_softmax()
model <- keras_model(
inputs = inputs,
outputs = outputs
)
model %>% compile(
optimizer = optimizer,
loss = loss,
metrics = metrics
)
return(model)
}
ืืืกืจืื ืืช ืฉื ืืืฉื ืื ืืจืืจืื. ืื ื ืจืืฆื ืืืืืง ืืจืื ืืืืื, ืืื ืืืืคื, ืื ื ืื ืจืืฆื ืืฉืืชื ืื ืืจืืืืงืืืจื ืืืืคื ืืื ื. ืื ื ืฉืืื ืืืืชื ื ืืืืืื ืืช ืืืฉืชืืฉ ืืืฉืงืืื ืฉื ืืืืื ืืืช ืฉืืืืฉืจื ืืจืืฉ ืืืืื'ื ื. ืืจืืื, ืืืืื ืืชืืขืื ืขืืจ. ืคืึผื ืงืฆึดืึธื get_config()
ืืืคืฉืจ ืื ืืงืื ืชืืืืจ ืฉื ืืืื ืืฆืืจื ืืชืืืื ืืขืจืืื (base_model_conf$layers
- ืจืฉืืืช R ืจืืืื), ืืืคืื ืงืฆืื from_config()
ืืืฆืข ืืช ืืืืจื ืืืคืืื ืืืืืืืงื ืืืื:
base_model_conf <- get_config(base_model)
base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
base_model <- from_config(base_model_conf)
ืขืืฉืื ืื ืื ืงืฉื ืืืชืื ืคืื ืงืฆืื ืืื ืืืจืกืืืช ืืื ืืืฉืื ืื ืืื ืืืืืจืื ืฉืกืืคืงื keras ืืืืื ืขื ืื ืืื ืืฉืงืืืืช ืืืืื ืืช ื-imagenet:
ืคืื ืงืฆืื ืืืขืื ืช ืืจืืืืงืืืจืืช ืืืื ืืช
get_model <- function(name = "mobilenet_v2",
input_shape = NULL,
weights = "imagenet",
pooling = "avg",
num_classes = NULL,
optimizer = keras::optimizer_adam(lr = 0.002),
loss = "categorical_crossentropy",
metrics = NULL,
color = TRUE,
compile = FALSE) {
# ะัะพะฒะตัะบะฐ ะฐัะณัะผะตะฝัะพะฒ
checkmate::assert_string(name)
checkmate::assert_integerish(input_shape, lower = 1, upper = 256, len = 3)
checkmate::assert_count(num_classes)
checkmate::assert_flag(color)
checkmate::assert_flag(compile)
# ะะพะปััะฐะตะผ ะพะฑัะตะบั ะธะท ะฟะฐะบะตัะฐ keras
model_fun <- get0(paste0("application_", name), envir = asNamespace("keras"))
# ะัะพะฒะตัะบะฐ ะฝะฐะปะธัะธั ะพะฑัะตะบัะฐ ะฒ ะฟะฐะบะตัะต
if (is.null(model_fun)) {
stop("Model ", shQuote(name), " not found.", call. = FALSE)
}
base_model <- model_fun(
input_shape = input_shape,
include_top = FALSE,
weights = weights,
pooling = pooling
)
# ะัะปะธ ะธะทะพะฑัะฐะถะตะฝะธะต ะฝะต ัะฒะตัะฝะพะต, ะผะตะฝัะตะผ ัะฐะทะผะตัะฝะพััั ะฒั
ะพะดะฐ
if (!color) {
base_model_conf <- keras::get_config(base_model)
base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
base_model <- keras::from_config(base_model_conf)
}
predictions <- keras::get_layer(base_model, "global_average_pooling2d_1")$output
predictions <- keras::layer_dense(predictions, units = num_classes, activation = "softmax")
model <- keras::keras_model(
inputs = base_model$input,
outputs = predictions
)
if (compile) {
keras::compile(
object = model,
optimizer = optimizer,
loss = loss,
metrics = metrics
)
}
return(model)
}
ืืขืช ืฉืืืืฉ ืืชืืื ืืช ืื-ืขืจืืฆืืืช, ืืื ืฉืืืืฉ ืืืฉืงืืื ืืืืื ืื ืืจืืฉ. ื ืืชื ืืชืงื ืืช ืื: ืืืืฆืขืืช ืืคืื ืงืฆืื get_weights()
ืงืื ืืช ืืฉืงืื ืืืืื ืืฆืืจื ืฉื ืจืฉืืื ืฉื ืืขืจืื R, ืฉื ื ืืช ืืืื ืฉื ืืืืื ื ืืจืืฉืื ืืจืฉืืื ืื (ืขื ืืื ื ืืืืช ืขืจืืฅ ืฆืืข ืืื ืื ืืืืฆืข ืฉื ืฉืืืฉืชื), ืืืืืจ ืืื ืืขื ืืช ืืืฉืงืืืืช ืืืืจื ืืืืื ืขื ืืคืื ืงืฆืื set_weights()
. ืืขืืื ืื ืืืกืคื ื ืืช ืืคืื ืงืฆืืื ืืืืช ืืื, ืื ืืฉืื ืื ืืืจ ืืื ืืจืืจ ืฉืืืชืจ ืคืจืืืืงืืืื ืืขืืื ืขื ืชืืื ืืช ืฆืืขืื ืืืช.
ืืืฆืขื ื ืืช ืจืื ืื ืืกืืืื ืืืืฆืขืืช mobilenet ืืจืกืืืช 1 ื-2, ืืื ืื resnet34. ืืจืืืืงืืืจืืช ืืืืจื ืืืช ืืืชืจ ืืื SE-ResNeXt ืืืคืืขื ืืืื ืืชืืจืืช ืื. ืืฆืขืจื ืื ืขืืื ืืจืฉืืชื ื ืืืฉืืืื ืืืื ืื ืืื ืืชืื ื ืืฉืื ื (ืืื ืืืืื ื ืืชืื).
5. ืคืจืืืจืืืฆืื ืฉื ืกืงืจืืคืืื
ืืืขืื ื ืืืืช, ืื ืืงืื ืืชืืืืช ืืืืืื ืชืืื ื ืืกืงืจืืคื ืืืื, ืคืจืืืจืื ืืืืฆืขืืช
doc <- '
Usage:
train_nn.R --help
train_nn.R --list-models
train_nn.R [options]
Options:
-h --help Show this message.
-l --list-models List available models.
-m --model=<model> Neural network model name [default: mobilenet_v2].
-b --batch-size=<size> Batch size [default: 32].
-s --scale-factor=<ratio> Scale factor [default: 0.5].
-c --color Use color lines [default: FALSE].
-d --db-dir=<path> Path to database directory [default: Sys.getenv("db_dir")].
-r --validate-ratio=<ratio> Validate sample ratio [default: 0.995].
-n --n-gpu=<number> Number of GPUs [default: 1].
'
args <- docopt::docopt(doc)
ืืืืื ืืืงืืคื ืืืืฆื ืืช ืืืืฉืื Rscript bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db
ืื ./bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db
, ืื ืงืืืฅ train_nn.R
ื ืืชื ืช ืืืคืขืื (ืคืงืืื ืื ืชืชืืื ืืืื ืืช ืืืืื resnet50
ืืชืืื ืืช ืืฉืืืฉื ืฆืืขืื ืืืืื 128x128 ืคืืงืกืืื, ืืกื ืื ืชืื ืื ืืืื ืืืืืช ืืืืงื ืืชืืงืืื /home/andrey/doodle_db
). ืืชื ืืืื ืืืืกืืฃ ืืจืฉืืื ืืืืจืืช ืืืืื, ืกืื ืืืคืืืืืืฆืื ืืื ืคืจืืืจ ืืืจ ืื ืืชื ืืืชืืื ืืืฉืืช. ืืชืืืื ืืื ืช ืืคืจืกืื ืืชืืจืจ ืื ืืืืจืืืืืช mobilenet_v2
ืืืืจืกื ืื ืืืืืช keras ืืฉืืืืฉ R
ืืืฉื ืื ืืคืฉืจื ืืืจื ืืฉืืขืืชืืช ื ืืกืืืื ืขื ืืืืื ืฉืื ืื ืืืฉืืืื ืืืฉืงื ืืืกืืจืชืืช ืืืชืจ ืฉื ืกืงืจืืคืืื ื-RStudio (ืื ื ืืฆืืื ืื ืืช ืืืืืื ืืืืืคื ืืคืฉืจืืช
6. ืขืืื ื ืฉื ืกืงืจืืคืืื
ืืฉืชืืฉื ื ื-Docker ืืื ืืืืืื ื ืืืืืช ืฉื ืืกืืืื ืืืืืื ืืืืืื ืืื ืืืจื ืฆืืืช ืืืคืจืืกื ืืืืจื ืืขื ื. ืืชื ืืืื ืืืชืืื ืืืืืจ ืืช ืืืื ืืื, ืฉืืื ืืืกืืช ืืจืื ืืืชืื ืช R
Docker ืืืคืฉืจ ืื ืื ืืืฆืืจ ืชืืื ืืช ืืฉืื ืืืคืก ืืื ืืืฉืชืืฉ ืืชืืื ืืช ืืืจืืช ืืืกืืก ืืืฆืืจืช ืชืืื ืืช ืืฉืื. ืืขืช ื ืืชืื ืืืคืฉืจืืืืช ืืืืื ืืช, ืืืขื ื ืืืกืงื ื ืฉืืชืงื ืช ืื ืืื ืืชืงื ืื ืฉื NVIDIA, CUDA+cuDNN ืืกืคืจืืืช Python ืืื ืืืง ืื ื ืจืื ืืืชืืื ื, ืืืืืื ื ืืงืืช ืืช ืืชืืื ื ืืจืฉืืืช ืืืกืืก tensorflow/tensorflow:1.12.0-gpu
, ืืืกืคืช ืฉื ืืช ืืืืืืช ื-R ืืืจืืฉืืช.
ืงืืืฅ ืืืืงืจ ืืกืืคื ื ืจืื ืื:
ืืืงืจืคืื
FROM tensorflow/tensorflow:1.12.0-gpu
MAINTAINER Artem Klevtsov <[email protected]>
SHELL ["/bin/bash", "-c"]
ARG LOCALE="en_US.UTF-8"
ARG APT_PKG="libopencv-dev r-base r-base-dev littler"
ARG R_BIN_PKG="futile.logger checkmate data.table rcpp rapidjsonr dbi keras jsonlite curl digest remotes"
ARG R_SRC_PKG="xtensor RcppThread docopt MonetDBLite"
ARG PY_PIP_PKG="keras"
ARG DIRS="/db /app /app/data /app/models /app/logs"
RUN source /etc/os-release &&
echo "deb https://cloud.r-project.org/bin/linux/ubuntu ${UBUNTU_CODENAME}-cran35/" > /etc/apt/sources.list.d/cran35.list &&
apt-key adv --keyserver keyserver.ubuntu.com --recv-keys E084DAB9 &&
add-apt-repository -y ppa:marutter/c2d4u3.5 &&
add-apt-repository -y ppa:timsc/opencv-3.4 &&
apt-get update &&
apt-get install -y locales &&
locale-gen ${LOCALE} &&
apt-get install -y --no-install-recommends ${APT_PKG} &&
ln -s /usr/lib/R/site-library/littler/examples/install.r /usr/local/bin/install.r &&
ln -s /usr/lib/R/site-library/littler/examples/install2.r /usr/local/bin/install2.r &&
ln -s /usr/lib/R/site-library/littler/examples/installGithub.r /usr/local/bin/installGithub.r &&
echo 'options(Ncpus = parallel::detectCores())' >> /etc/R/Rprofile.site &&
echo 'options(repos = c(CRAN = "https://cloud.r-project.org"))' >> /etc/R/Rprofile.site &&
apt-get install -y $(printf "r-cran-%s " ${R_BIN_PKG}) &&
install.r ${R_SRC_PKG} &&
pip install ${PY_PIP_PKG} &&
mkdir -p ${DIRS} &&
chmod 777 ${DIRS} &&
rm -rf /tmp/downloaded_packages/ /tmp/*.rds &&
rm -rf /var/lib/apt/lists/*
COPY utils /app/utils
COPY src /app/src
COPY tests /app/tests
COPY bin/*.R /app/
ENV DBDIR="/db"
ENV CUDA_HOME="/usr/local/cuda"
ENV PATH="/app:${PATH}"
WORKDIR /app
VOLUME /db
VOLUME /app
CMD bash
ืืืขืื ื ืืืืช, ืืืืืืืช ืืื ื ืขืฉื ืฉืืืืฉ ืืืื ืกื ืืืฉืชื ืื; ืจืื ืืชืกืจืืืื ืืืชืืืื ืืืขืชืงืื ืืชืื ืืืืืืื ืืืืื ืืืจืืื. ืฉืื ืื ื ืื ืืช ืืขืืคืช ืืคืงืืื ื /bin/bash
ืื ืืืืช ืืฉืืืืฉ ืืชืืื /etc/os-release
. ืื ื ืื ืข ืืืฆืืจื ืืฆืืื ืืช ืืจืกืช ืืขืจืืช ืืืคืขืื ืืงืื.
ืื ืืกืฃ, ื ืืชื ืกืงืจืืคื bash ืงืื ืืืืคืฉืจ ืื ืืืคืขืื ืงืื ืืืื ืจ ืขื ืคืงืืืืช ืฉืื ืืช. ืืืืืื, ืืื ืืืืืื ืืืืืช ืกืงืจืืคืืื ืืืืืื ืจืฉืชืืช ืขืฆืืืืช ืฉืืืฆืื ืืขืืจ ืืชืื ืืงืื ืืืื ืจ, ืื ืืขืืคืช ืคืงืืื ืื ืืคืื ืืืืื ืืื ืืืืจ ืคืขืืืช ืืงืื ืืืื ืจ:
ืกืงืจืืคื ืืืคืขืืช ืืืืื
#!/bin/sh
DBDIR=${PWD}/db
LOGSDIR=${PWD}/logs
MODELDIR=${PWD}/models
DATADIR=${PWD}/data
ARGS="--runtime=nvidia --rm -v ${DBDIR}:/db -v ${LOGSDIR}:/app/logs -v ${MODELDIR}:/app/models -v ${DATADIR}:/app/data"
if [ -z "$1" ]; then
CMD="Rscript /app/train_nn.R"
elif [ "$1" = "bash" ]; then
ARGS="${ARGS} -ti"
else
CMD="Rscript /app/train_nn.R $@"
fi
docker run ${ARGS} doodles-tf ${CMD}
ืื ืกืงืจืืคื bash ืื ืืืคืขื ืืื ืคืจืืืจืื, ืืกืงืจืืคื ืืืงืจื ืืชืื ืืงืื ืืืื ืจ train_nn.R
ืขื ืขืจืื ืืจืืจืช ืืืื; ืื ืืืจืืืื ื ืืืืงืื ืืจืืฉืื ืืื "bash", ืื ืืืืื ืืชืืื ืืืืคื ืืื ืืจืืงืืืื ืขื ืืขืืคืช ืคืงืืื. ืืื ืฉืืจ ืืืงืจืื, ืืขืจืืื ืฉื ืืจืืืื ืืื ืืืงืืืืื ืืืืืคืื: CMD="Rscript /app/train_nn.R $@"
.
ืจืืื ืืฆืืื ืฉืืกืคืจืืืช ืขื ื ืชืื ื ืืงืืจ ืืืกื ื ืชืื ืื, ืืื ืื ืืกืคืจืืื ืืฉืืืจืช ืืืืืื ืืืืื ืื, ืืืชืงื ืืช ืืชืื ืืงืื ืืืื ืจ ืืืืขืจืืช ืืืืจืืช, ืื ืฉืืืคืฉืจ ืื ืืืฉืช ืืชืืฆืืืช ืืกืงืจืืคืืื ืืื ืื ืืคืืืฆืืืช ืืืืชืจืืช.
7. ืฉืืืืฉ ืืืกืคืจ GPUs ื-Google Cloud
ืืื ืืืืคืืื ืื ืฉื ืืชืืจืืช ืืื ืื ืชืื ืื ืืจืืขืฉืื ืืืื (ืจืื ืืช ืชืืื ืช ืืืืชืจืช, ืฉืืืฉืืื ื[email protected] ื-ODS slack). ืืฆืืืช ืืืืืืช ืขืืืจืืช ืืืืืื ืืื, ืืืืืจ ื ืืกืืืื ืืืืฉื ืขื 1 GPU, ืืืืื ื ืืฉืืื ืืืืืืื ืฉื ืืืืื ืืืกืคืจ GPUs ืืขื ื. ืืฉืชืืฉ ื-GoogleCloud (dev/shm
.
ืืืขื ืืื ืืืืชืจ ืืื ืงืืข ืืงืื ืฉืืืจืื ืืฉืืืืฉ ืืืกืคืจ GPUs. ืจืืฉืืช, ืืืืื ื ืืฆืจ ืขื ื-CPU ืืืืฆืขืืช ืื ืื ืืงืฉืจ, ืืืืืง ืืื ื-Python:
with(tensorflow::tf$device("/cpu:0"), {
model_cpu <- get_model(
name = model_name,
input_shape = input_shape,
weights = weights,
metrics =(top_3_categorical_accuracy,
compile = FALSE
)
})
ืืื ืืืืื ืืื-ืงืืืคืืืฆืื (ืื ืืฉืื) ืืืขืชืง ืืืกืคืจ ื ืชืื ืฉื GPUs ืืืื ืื, ืืจืง ืืืืจ ืืื ืืื ืืืืืจ:
model <- keras::multi_gpu_model(model_cpu, gpus = n_gpu)
keras::compile(
object = model,
optimizer = keras::optimizer_adam(lr = 0.0004),
loss = "categorical_crossentropy",
metrics = c(top_3_categorical_accuracy)
)
ืื ื ืืชื ืืื ืืืืฉื ืืช ืืืื ืืงื ืืงืืืกืืช ืฉื ืืงืคืืช ืื ืืฉืืืืช ืืืื ืืืืจืื ื, ืืืืื ืืฉืืื ืืืืจืื ื, ืืืืื ืืงืคืื ืืืืืื ืืืืฉ ืฉื ืืืื ืืืื ืืืกืคืจ GPUs.
ืืืืืื ื ืืืจ ืืื ืฉืืืืฉ. ืื ืกืืจืืืจื, ืืืืืืื ืืช ืขืฆืื ื ืืืงืืืช ืืืื ืื ืืฉืืืจืช ืืืืื ืขื ืฉืืืช ืืื ืคืืจืืืืืืื ืืืืจ ืื ืชืงืืคื:
ืืชืงืฉืจืืืืช ืืืืจืืช
# ะจะฐะฑะปะพะฝ ะธะผะตะฝะธ ัะฐะนะปะฐ ะปะพะณะฐ
log_file_tmpl <- file.path("logs", sprintf(
"%s_%d_%dch_%s.csv",
model_name,
dim_size,
channels,
format(Sys.time(), "%Y%m%d%H%M%OS")
))
# ะจะฐะฑะปะพะฝ ะธะผะตะฝะธ ัะฐะนะปะฐ ะผะพะดะตะปะธ
model_file_tmpl <- file.path("models", sprintf(
"%s_%d_%dch_{epoch:02d}_{val_loss:.2f}.h5",
model_name,
dim_size,
channels
))
callbacks_list <- list(
keras::callback_csv_logger(
filename = log_file_tmpl
),
keras::callback_early_stopping(
monitor = "val_loss",
min_delta = 1e-4,
patience = 8,
verbose = 1,
mode = "min"
),
keras::callback_reduce_lr_on_plateau(
monitor = "val_loss",
factor = 0.5, # ัะผะตะฝััะฐะตะผ lr ะฒ 2 ัะฐะทะฐ
patience = 4,
verbose = 1,
min_delta = 1e-4,
mode = "min"
),
keras::callback_model_checkpoint(
filepath = model_file_tmpl,
monitor = "val_loss",
save_best_only = FALSE,
save_weights_only = FALSE,
mode = "min"
)
)
8. ืืืงืื ืืกืงื ื
ืขืืืื ืื ืืชืืืจื ืขื ืืกืคืจ ืืขืืืช ืฉื ืชืงืื ื ืืื:
- ะฒ keras ืืื ืคืื ืงืฆืื ืืืื ื ืืืืคืืฉ ืืืืืืื ืืืจ ืงืฆื ืืืืืื ืืืืคืืืืื (ืื ืืืื
lr_finder
ืืกืคืจืื ืืืจ.ืื); ืขื ืงืฆืช ืืืืฅ, ืืคืฉืจ ืืืขืืืจ ืืืฉืืื ืฆื ืฉืืืฉื ื-R, ืืืฉื,ืื ; - ืืชืืฆืื ืืื ืงืืื ืืงืืืืช, ืื ื ืืชื ืืื ืืืืืจ ืืช ืืืืจืืช ืืืืืื ืื ืืื ื ืืขืช ืฉืืืืฉ ืืืกืคืจ GPUs;
- ืืฉ ืืืกืจ ืืืจืืืืงืืืจืืช ืืืืจื ืืืช ืฉื ืจืฉืชืืช ืขืฆืืืืช, ืืืืืื ืืื ืฉืืืืฉืจื ืืจืืฉ ื-imagenet;
- ืืื ืืืื ืืืช ืฉื ืืืืืจ ืืื ืืฉืืขืืจื ืืืืื ืืคืืื (ืืืฉืื ืงืืกืื ืืก ืืื ืืืงืฉืชื ื
ืืืืืข ืชืืืืกืงืืืื ).
ืืืื ืืืจืื ืฉืืืืฉืืื ื ืืืื ืืืชืืจืืช ืืื:
- ืขื ืืืืจื ืืขืืช ืืกืคืง ื ืืื ืืืกืืช, ืืชื ืืืื ืืขืืื ืขื ื ืคืื ื ืชืื ืื ืืืื ืื (ืคื ืจืืื ื-RAM) ืืื ืืื. ืฉืงืืช ืคืืกืืืง ืืืืช ื ืชืื ืื ืืืกื ืืืืจืื ืขืงื ืฉืื ืื ืืืงืื ืฉื ืืืืืืช, ืื ืฉื ืื ืข ืืืขืชืงืชื, ืืืฉืืืืฉ ื ืืื, ืืืืืืืช ืฉืื ืืืขื ืชืืื ืืืืืืืช ืืช ืืืืืจืืช ืืืืืื ืืืืชืจ ืืืื ืื ืืืืื ืืืืืจืื ืื ื ืืฉืคืืช ืกืงืจืืคืืื. ืฉืืืจืช ื ืชืื ืื ืืืกื ื ืชืื ืื ืืืคืฉืจืช ืื, ืืืงืจืื ืจืืื, ืื ืืืฉืื ืืื ืขื ืืฆืืจื ืืกืืื ืืช ืื ืืขืจื ืื ืชืื ืื ืืชืื ื-RAM.
- ื ืืชื ืืืืืืฃ ืคืื ืงืฆืืืช ืืืืืืช ื-R ืืืืืจืืช ื-C++ ืืืืฆืขืืช ืืืืืื Rcpp. ืื ืื ืืกืฃ ืืฉืืืืฉ RcppThread ืื RcppParallel, ืื ื ืืงืืืื ืืืืืฉืื ืขื ืจืืืื ืืืืื ืืืฆื ืคืืืคืืจืืืช, ืื ืฉืืื ืฆืืจื ืืืงืืื ืืช ืืงืื ืืจืืช R.
- ืึฒืึดืืึธื Rcpp ื ืืชื ืืืฉืชืืฉ ืืื ืืืข ืจืฆืื ื ื-C++, ืืืื ืืืื ืื ืืจืฉ ืืชืืืจ
ืืื . ืงืืฆื ืืืชืจืืช ืืืกืคืจ ืกืคืจืืืช C ืืื ืืืืช ืืื xtensor ืืืื ื-CRAN, ืืืืืจ ื ืืฆืจืช ืชืฉืชืืช ืืืืฉืื ืคืจืืืงืืื ืืืฉืืืื ืงืื C++ ืืืื ืืืืื ืขื ืืืฆืืขืื ืืืืืื ืืชืื R. ื ืืืืช ื ืืกืคืช ืืื ืืืืฉืช ืชืืืืจ ืืื ืชื ืงืื C++ ืกืืื ื-RStudio. - ืืืงืืคื ืืืคืฉืจ ืื ืืืจืืฅ ืกืงืจืืคืืื ืขืฆืืืืื ืขื ืคืจืืืจืื. ืื ื ืื ืืฉืืืืฉ ืืฉืจืช ืืจืืืง, ืืืื. ืชืืช ืืืงืจ. ื-RStudio ืื ื ืื ืืขืจืื ืฉืขืืช ืจืืืช ืฉื ื ืืกืืืื ืืืืืื ืจืฉืชืืช ืขืฆืืืืช, ืืืชืงื ืช ื-IDE ืขื ืืฉืจืช ืขืฆืื ืื ืชืืื ืืืฆืืงืช.
- Docker ืืืืื ื ืืืืืช ืงืื ืืฉืืืืจ ืฉื ืชืืฆืืืช ืืื ืืคืชืืื ืขื ืืจืกืืืช ืฉืื ืืช ืฉื ืืขืจืืช ืืืคืขืื ืืืกืคืจืืืช, ืืื ืื ืงืืืช ืืืฆืืข ืืฉืจืชืื. ืืชื ืืืื ืืืคืขืื ืืช ืื ืฆืื ืืจ ืืืืืื ืื ืขื ืคืงืืื ืืืช ืืืื.
- Google Cloud ืืื ืืจื ืืืืืืชืืช ืืชืงืฆืื ืืืชื ืกืืช ืืืืืจื ืืงืจื, ืื ืขืืื ืืืืืจ ืชืฆืืจืืช ืืงืคืืื.
- ืืืืืช ืืืืจืืช ืฉื ืงืืขื ืงืื ืืืืืื ืืื ืฉืืืืฉืืช ืืืื, ืืืืืื ืืฉืืืื R ื-C++, ืืขื ืืืืืื ืกืคืกื - ืื ืงื ืืืื.
ืืกื ืืื ืืืืืื ืืื ืืืืชื ืืชืืืืช ืืืื ืืื ืื ื ืืืฉืืืื ืืขืืื ืืื ืืคืชืืจ ืืื ืืืืขืืืช ืฉืืืขืื.
ืืงืืจ: www.habr.com