Quick Draw Doodle Recognition: R, C++ ๋ฐ ์‹ ๊ฒฝ๋ง๊ณผ ์นœ๊ตฌ๊ฐ€ ๋˜๋Š” ๋ฐฉ๋ฒ•

Quick Draw Doodle Recognition: R, C++ ๋ฐ ์‹ ๊ฒฝ๋ง๊ณผ ์นœ๊ตฌ๊ฐ€ ๋˜๋Š” ๋ฐฉ๋ฒ•

ํ—ค์ด ํ•˜๋ธŒ๋ฅด!

์ง€๋‚œ ๊ฐ€์„ Kaggle์€ ์†์œผ๋กœ ๊ทธ๋ฆฐ โ€‹โ€‹๊ทธ๋ฆผ์„ ๋ถ„๋ฅ˜ํ•˜๋Š” ๋Œ€ํšŒ์ธ Quick Draw Doodle Recognition์„ ์ฃผ์ตœํ–ˆ์œผ๋ฉฐ, ์—ฌ๊ธฐ์—๋Š” R ๊ณผํ•™์ž ํŒ€์ด ์ฐธ์—ฌํ–ˆ์Šต๋‹ˆ๋‹ค. ์•„๋ฅดํ…œ ํด๋ ˆํ”„์ดˆ๋ฐ”, ํ•„๋ฆฌํŒŒ ๋งค๋‹ˆ์ € ะธ ์•ˆ๋“œ๋ ˆ์ด ์˜ค๊ตฌ๋ฅด์ดˆํ”„. ์šฐ๋ฆฌ๋Š” ๊ฒฝ์Ÿ์— ๋Œ€ํ•ด ์ž์„ธํžˆ ์„ค๋ช…ํ•˜์ง€ ์•Š์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ตœ๊ทผ ์ถœํŒ๋ฌผ.

์ด๋ฒˆ์—๋Š” ๋ฉ”๋‹ฌ ํŒŒ๋ฐ์ด ์ž˜ ์•ˆ ๋์ง€๋งŒ, ๊ท€์ค‘ํ•œ ๊ฒฝํ—˜์„ ๋งŽ์ด ์Œ“์•˜๊ธฐ ๋•Œ๋ฌธ์— Kagle๊ณผ ์ผ์ƒ ์—…๋ฌด์—์„œ ๊ฐ€์žฅ ํฅ๋ฏธ๋กญ๊ณ  ์œ ์šฉํ•œ ์—ฌ๋Ÿฌ ๊ฐ€์ง€ ์‚ฌํ•ญ์„ ์ปค๋ฎค๋‹ˆํ‹ฐ์— ์•Œ๋ฆฌ๊ณ  ์‹ถ์Šต๋‹ˆ๋‹ค. ๋…ผ์˜๋œ ์ฃผ์ œ ์ค‘: ์—†์ด๋Š” ์–ด๋ ค์šด ์‚ถ OpenCV, JSON ๊ตฌ๋ฌธ ๋ถ„์„(์ด ์˜ˆ์ œ์—์„œ๋Š” ๋‹ค์Œ์„ ์‚ฌ์šฉํ•˜์—ฌ C++ ์ฝ”๋“œ๋ฅผ R์˜ ์Šคํฌ๋ฆฝํŠธ ๋˜๋Š” ํŒจํ‚ค์ง€์— ํ†ตํ•ฉํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๊ฒ€์‚ฌํ•ฉ๋‹ˆ๋‹ค. RCPP), ์Šคํฌ๋ฆฝํŠธ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜ํ™” ๋ฐ ์ตœ์ข… ์†”๋ฃจ์…˜์˜ Dockerํ™”. ์‹คํ–‰์— ์ ํ•ฉํ•œ ํ˜•ํƒœ์˜ ๋ฉ”์‹œ์ง€์˜ ๋ชจ๋“  ์ฝ”๋“œ๋Š” ๋‹ค์Œ์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ €์žฅ์†Œ.

ะกะพะดะตั€ะถะฐะฝะธะต :

  1. CSV์˜ ๋ฐ์ดํ„ฐ๋ฅผ MonetDB๋กœ ํšจ์œจ์ ์œผ๋กœ ๋กœ๋“œ
  2. ๋ฐฐ์น˜ ์ค€๋น„
  3. ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์—์„œ ๋ฐฐ์น˜๋ฅผ ์–ธ๋กœ๋“œํ•˜๊ธฐ ์œ„ํ•œ ๋ฐ˜๋ณต์ž
  4. ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜ ์„ ํƒ
  5. ์Šคํฌ๋ฆฝํŠธ ๋งค๊ฐœ๋ณ€์ˆ˜ํ™”
  6. ์Šคํฌ๋ฆฝํŠธ์˜ Dockerํ™”
  7. Google Cloud์—์„œ ์—ฌ๋Ÿฌ GPU ์‚ฌ์šฉ
  8. ๋Œ€์‹  ๊ฒฐ๋ก 

1. CSV์˜ ๋ฐ์ดํ„ฐ๋ฅผ MonetDB ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค๋กœ ํšจ์œจ์ ์œผ๋กœ ๋กœ๋“œ

๋ณธ ๊ณต๋ชจ์ „์˜ ๋ฐ์ดํ„ฐ๋Š” ๊ธฐ์„ฑ ์ด๋ฏธ์ง€ ํ˜•ํƒœ๊ฐ€ ์•„๋‹Œ, ์ ์ขŒํ‘œ๊ฐ€ ํฌํ•จ๋œ JSON์„ ํฌํ•จํ•˜๋Š” 340๊ฐœ์˜ CSV ํŒŒ์ผ(ํด๋ž˜์Šค๋‹น 256๊ฐœ ํŒŒ์ผ) ํ˜•ํƒœ๋กœ ์ œ๊ณต๋ฉ๋‹ˆ๋‹ค. ์ด ์ ๋“ค์„ ์„ ์œผ๋กœ ์—ฐ๊ฒฐํ•˜๋ฉด 256x7.4 ํ”ฝ์…€ ํฌ๊ธฐ์˜ ์ตœ์ข… ์ด๋ฏธ์ง€๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ ๊ฐ ๋ ˆ์ฝ”๋“œ์—๋Š” ๋ฐ์ดํ„ฐ ์„ธํŠธ ์ˆ˜์ง‘ ์‹œ ์‚ฌ์šฉ๋œ ๋ถ„๋ฅ˜์ž๊ฐ€ ์‚ฌ์ง„์„ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์ธ์‹ํ–ˆ๋Š”์ง€ ์—ฌ๋ถ€๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ๋ ˆ์ด๋ธ”, ์‚ฌ์ง„ ์ž‘์„ฑ์ž์˜ ๊ฑฐ์ฃผ ๊ตญ๊ฐ€์— ๋Œ€ํ•œ 20์ž๋ฆฌ ์ฝ”๋“œ, ๊ณ ์œ  ์‹๋ณ„์ž, ํƒ€์ž„์Šคํƒฌํ”„๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ํŒŒ์ผ ์ด๋ฆ„๊ณผ ์ผ์น˜ํ•˜๋Š” ํด๋ž˜์Šค ์ด๋ฆ„์ž…๋‹ˆ๋‹ค. ์›๋ณธ ๋ฐ์ดํ„ฐ์˜ ๋‹จ์ˆœํ™”๋œ ๋ฒ„์ „์€ ์•„์นด์ด๋ธŒ์—์„œ 240GB์ด๊ณ  ์••์ถ•์„ ํ‘ผ ํ›„ ์•ฝ 50GB์ด๋ฉฐ, ์••์ถ•์„ ํ‘ผ ํ›„ ์ „์ฒด ๋ฐ์ดํ„ฐ๋Š” XNUMXGB๋ฅผ ์ฐจ์ง€ํ•ฉ๋‹ˆ๋‹ค. ์ฃผ์ตœ์ธก์€ ๋‘ ๋ฒ„์ „ ๋ชจ๋‘ ๋™์ผํ•œ ๊ทธ๋ฆผ์„ ์žฌํ˜„ํ–ˆ๋Š”์ง€ ํ™•์ธํ–ˆ๋Š”๋ฐ, ์ด๋Š” ์ „์ฒด ๋ฒ„์ „์ด ์ค‘๋ณต๋˜์—ˆ์Œ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค. ์–ด์จŒ๋“  XNUMX์ฒœ๋งŒ ๊ฐœ์˜ ์ด๋ฏธ์ง€๋ฅผ ๊ทธ๋ž˜ํ”ฝ ํŒŒ์ผ์ด๋‚˜ ๋ฐฐ์—ด ํ˜•ํƒœ๋กœ ์ €์žฅํ•˜๋Š” ๊ฒƒ์€ ์ฆ‰์‹œ ์ˆ˜์ต์„ฑ์ด ์—†๋Š” ๊ฒƒ์œผ๋กœ ๊ฐ„์ฃผ๋˜์–ด ์•„์นด์ด๋ธŒ์˜ ๋ชจ๋“  CSV ํŒŒ์ผ์„ ๋ณ‘ํ•ฉํ•˜๊ธฐ๋กœ ๊ฒฐ์ •ํ–ˆ์Šต๋‹ˆ๋‹ค. train_simplified.zip ๊ฐ ๋ฐฐ์น˜์— ๋Œ€ํ•ด "์ฆ‰์‹œ" ํ•„์š”ํ•œ ํฌ๊ธฐ์˜ ์ด๋ฏธ์ง€๋ฅผ ํ›„์† ์ƒ์„ฑํ•˜์—ฌ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.

๊ฒ€์ฆ๋œ ์‹œ์Šคํ…œ์ด DBMS๋กœ ์„ ํƒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๋ชจ๋„ทDB, ์ฆ‰ R์„ ํŒจํ‚ค์ง€๋กœ ๊ตฌํ˜„ํ•œ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋ชจ๋„คDBLite. ํŒจํ‚ค์ง€์—๋Š” ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์„œ๋ฒ„์˜ ๋‚ด์žฅ ๋ฒ„์ „์ด ํฌํ•จ๋˜์–ด ์žˆ์œผ๋ฉฐ R ์„ธ์…˜์—์„œ ์ง์ ‘ ์„œ๋ฒ„๋ฅผ ์„ ํƒํ•˜์—ฌ ์ž‘์—…ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์ƒ์„ฑ ๋ฐ ์—ฐ๊ฒฐ์€ ํ•˜๋‚˜์˜ ๋ช…๋ น์œผ๋กœ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค.

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

๋‘ ๊ฐœ์˜ ํ…Œ์ด๋ธ”์„ ์ƒ์„ฑํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ํ•˜๋‚˜๋Š” ๋ชจ๋“  ๋ฐ์ดํ„ฐ์šฉ์ด๊ณ  ๋‹ค๋ฅธ ํ•˜๋‚˜๋Š” ๋‹ค์šด๋กœ๋“œํ•œ ํŒŒ์ผ์— ๋Œ€ํ•œ ์„œ๋น„์Šค ์ •๋ณด์šฉ์ž…๋‹ˆ๋‹ค(๋ฌด์–ธ๊ฐ€๊ฐ€ ์ž˜๋ชป๋˜์–ด ์—ฌ๋Ÿฌ ํŒŒ์ผ์„ ๋‹ค์šด๋กœ๋“œํ•œ ํ›„ ํ”„๋กœ์„ธ์Šค๋ฅผ ๋‹ค์‹œ ์‹œ์ž‘ํ•ด์•ผ ํ•˜๋Š” ๊ฒฝ์šฐ ์œ ์šฉํ•จ).

ํ…Œ์ด๋ธ” ์ƒ์„ฑ

if (!DBI::dbExistsTable(con, "doodles")) {
  DBI::dbCreateTable(
    con = con,
    name = "doodles",
    fields = c(
      "countrycode" = "char(2)",
      "drawing" = "text",
      "key_id" = "bigint",
      "recognized" = "bool",
      "timestamp" = "timestamp",
      "word" = "text"
    )
  )
}

if (!DBI::dbExistsTable(con, "upload_log")) {
  DBI::dbCreateTable(
    con = con,
    name = "upload_log",
    fields = c(
      "id" = "serial",
      "file_name" = "text UNIQUE",
      "uploaded" = "bool DEFAULT false"
    )
  )
}

๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ๋ฐ์ดํ„ฐ๋ฅผ ๋กœ๋“œํ•˜๋Š” ๊ฐ€์žฅ ๋น ๋ฅธ ๋ฐฉ๋ฒ•์€ SQL - ๋ช…๋ น์„ ์‚ฌ์šฉํ•˜์—ฌ CSV ํŒŒ์ผ์„ ์ง์ ‘ ๋ณต์‚ฌํ•˜๋Š” ๊ฒƒ์ด์—ˆ์Šต๋‹ˆ๋‹ค. COPY OFFSET 2 INTO tablename FROM path USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORT์–ด๋””์—์„œ tablename - ํ…Œ์ด๋ธ” ์ด๋ฆ„๊ณผ path - ํŒŒ์ผ์˜ ๊ฒฝ๋กœ์ž…๋‹ˆ๋‹ค. ์•„์นด์ด๋ธŒ๋กœ ์ž‘์—…ํ•˜๋Š” ๋™์•ˆ ๋‚ด์žฅ๋œ ๊ตฌํ˜„์ด ๋ฐœ๊ฒฌ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. unzip R์—์„œ๋Š” ์•„์นด์ด๋ธŒ์˜ ์—ฌ๋Ÿฌ ํŒŒ์ผ์— ๋Œ€ํ•ด ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์ž‘๋™ํ•˜์ง€ ์•Š์œผ๋ฏ€๋กœ ์‹œ์Šคํ…œ์„ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค. unzip (๋งค๊ฐœ๋ณ€์ˆ˜ ์‚ฌ์šฉ getOption("unzip")).

๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ์“ฐ๊ธฐ ๊ธฐ๋Šฅ

#' @title ะ˜ะทะฒะปะตั‡ะตะฝะธะต ะธ ะทะฐะณั€ัƒะทะบะฐ ั„ะฐะนะปะพะฒ
#'
#' @description
#' ะ˜ะทะฒะปะตั‡ะตะฝะธะต CSV-ั„ะฐะนะปะพะฒ ะธะท ZIP-ะฐั€ั…ะธะฒะฐ ะธ ะทะฐะณั€ัƒะทะบะฐ ะธั… ะฒ ะฑะฐะทัƒ ะดะฐะฝะฝั‹ั…
#'
#' @param con ะžะฑัŠะตะบั‚ ะฟะพะดะบะปัŽั‡ะตะฝะธั ะบ ะฑะฐะทะต ะดะฐะฝะฝั‹ั… (ะบะปะฐัั `MonetDBEmbeddedConnection`).
#' @param tablename ะะฐะทะฒะฐะฝะธะต ั‚ะฐะฑะปะธั†ั‹ ะฒ ะฑะฐะทะต ะดะฐะฝะฝั‹ั….
#' @oaram zipfile ะŸัƒั‚ัŒ ะบ ZIP-ะฐั€ั…ะธะฒัƒ.
#' @oaram filename ะ˜ะผั ั„ะฐะนะปะฐ ะฒะฝัƒั€ะธ ZIP-ะฐั€ั…ะธะฒะฐ.
#' @param preprocess ะคัƒะฝะบั†ะธั ะฟั€ะตะดะพะฑั€ะฐะฑะพั‚ะบะธ, ะบะพั‚ะพั€ะฐั ะฑัƒะดะตั‚ ะฟั€ะธะผะตะฝะตะฝะฐ ะธะทะฒะปะตั‡ั‘ะฝะฝะพะผัƒ ั„ะฐะนะปัƒ.
#'   ะ”ะพะปะถะฝะฐ ะฟั€ะธะฝะธะผะฐั‚ัŒ ะพะดะธะฝ ะฐั€ะณัƒะผะตะฝั‚ `data` (ะพะฑัŠะตะบั‚ `data.table`).
#'
#' @return `TRUE`.
#'
upload_file <- function(con, tablename, zipfile, filename, preprocess = NULL) {
  # ะŸั€ะพะฒะตั€ะบะฐ ะฐั€ะณัƒะผะตะฝั‚ะพะฒ
  checkmate::assert_class(con, "MonetDBEmbeddedConnection")
  checkmate::assert_string(tablename)
  checkmate::assert_string(filename)
  checkmate::assert_true(DBI::dbExistsTable(con, tablename))
  checkmate::assert_file_exists(zipfile, access = "r", extension = "zip")
  checkmate::assert_function(preprocess, args = c("data"), null.ok = TRUE)

  # ะ˜ะทะฒะปะตั‡ะตะฝะธะต ั„ะฐะนะปะฐ
  path <- file.path(tempdir(), filename)
  unzip(zipfile, files = filename, exdir = tempdir(), 
        junkpaths = TRUE, unzip = getOption("unzip"))
  on.exit(unlink(file.path(path)))

  # ะŸั€ะธะผะตะฝัะตะผ ั„ัƒะฝะบั†ะธั ะฟั€ะตะดะพะฑั€ะฐะฑะพั‚ะบะธ
  if (!is.null(preprocess)) {
    .data <- data.table::fread(file = path)
    .data <- preprocess(data = .data)
    data.table::fwrite(x = .data, file = path, append = FALSE)
    rm(.data)
  }

  # ะ—ะฐะฟั€ะพั ะบ ะ‘ะ” ะฝะฐ ะธะผะฟะพั€ั‚ CSV
  sql <- sprintf(
    "COPY OFFSET 2 INTO %s FROM '%s' USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORT",
    tablename, path
  )
  # ะ’ั‹ะฟะพะปะฝะตะฝะธะต ะทะฐะฟั€ะพัะฐ ะบ ะ‘ะ”
  DBI::dbExecute(con, sql)

  # ะ”ะพะฑะฐะฒะปะตะฝะธะต ะทะฐะฟะธัะธ ะพะฑ ัƒัะฟะตัˆะฝะพะน ะทะฐะณั€ัƒะทะบะต ะฒ ัะปัƒะถะตะฑะฝัƒัŽ ั‚ะฐะฑะปะธั†ัƒ
  DBI::dbExecute(con, sprintf("INSERT INTO upload_log(file_name, uploaded) VALUES('%s', true)",
                              filename))

  return(invisible(TRUE))
}

๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ์“ฐ๊ธฐ ์ „์— ํ…Œ์ด๋ธ”์„ ๋ณ€ํ™˜ํ•ด์•ผ ํ•˜๋Š” ๊ฒฝ์šฐ ์ธ์ˆ˜๋ฅผ ์ „๋‹ฌํ•˜๋Š” ๊ฒƒ์œผ๋กœ ์ถฉ๋ถ„ํ•ฉ๋‹ˆ๋‹ค. preprocess ๋ฐ์ดํ„ฐ๋ฅผ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.

๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ๋ฐ์ดํ„ฐ๋ฅผ ์ˆœ์ฐจ์ ์œผ๋กœ ๋กœ๋“œํ•˜๋Š” ์ฝ”๋“œ:

๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ๋ฐ์ดํ„ฐ ์“ฐ๊ธฐ

# ะกะฟะธัะพะบ ั„ะฐะนะปะพะฒ ะดะปั ะทะฐะฟะธัะธ
files <- unzip(zipfile, list = TRUE)$Name

# ะกะฟะธัะพะบ ะธัะบะปัŽั‡ะตะฝะธะน, ะตัะปะธ ั‡ะฐัั‚ัŒ ั„ะฐะนะปะพะฒ ัƒะถะต ะฑั‹ะปะฐ ะทะฐะณั€ัƒะถะตะฝะฐ
to_skip <- DBI::dbGetQuery(con, "SELECT file_name FROM upload_log")[[1L]]
files <- setdiff(files, to_skip)

if (length(files) > 0L) {
  # ะ—ะฐะฟัƒัะบะฐะตะผ ั‚ะฐะนะผะตั€
  tictoc::tic()
  # ะŸั€ะพะณั€ะตัั ะฑะฐั€
  pb <- txtProgressBar(min = 0L, max = length(files), style = 3)
  for (i in seq_along(files)) {
    upload_file(con = con, tablename = "doodles", 
                zipfile = zipfile, filename = files[i])
    setTxtProgressBar(pb, i)
  }
  close(pb)
  # ะžัั‚ะฐะฝะฐะฒะปะธะฒะฐะตะผ ั‚ะฐะนะผะตั€
  tictoc::toc()
}

# 526.141 sec elapsed - ะบะพะฟะธั€ะพะฒะฐะฝะธะต SSD->SSD
# 558.879 sec elapsed - ะบะพะฟะธั€ะพะฒะฐะฝะธะต USB->SSD

๋ฐ์ดํ„ฐ ๋กœ๋”ฉ ์‹œ๊ฐ„์€ ์‚ฌ์šฉํ•˜๋Š” ๋“œ๋ผ์ด๋ธŒ์˜ ์†๋„ ํŠน์„ฑ์— ๋”ฐ๋ผ ๋‹ฌ๋ผ์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์šฐ๋ฆฌ์˜ ๊ฒฝ์šฐ ํ•˜๋‚˜์˜ SSD ๋‚ด์—์„œ ๋˜๋Š” ํ”Œ๋ž˜์‹œ ๋“œ๋ผ์ด๋ธŒ(์†Œ์Šค ํŒŒ์ผ)์—์„œ SSD(DB)๋กœ ์ฝ๊ณ  ์“ฐ๋Š” ๋ฐ 10๋ถ„ ๋ฏธ๋งŒ์ด ์†Œ์š”๋ฉ๋‹ˆ๋‹ค.

์ •์ˆ˜ ํด๋ž˜์Šค ๋ ˆ์ด๋ธ”๊ณผ ์ธ๋ฑ์Šค ์—ด(ORDERED INDEX) ๋ฐฐ์น˜๋ฅผ ์ƒ์„ฑํ•  ๋•Œ ๊ด€์ธก์น˜๊ฐ€ ์ƒ˜ํ”Œ๋ง๋˜๋Š” ์ค„ ๋ฒˆํ˜ธ๊ฐ€ ํฌํ•จ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

์ถ”๊ฐ€ ์—ด ๋ฐ ์ƒ‰์ธ ์ƒ์„ฑ

message("Generate lables")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD label_int int"))
invisible(DBI::dbExecute(con, "UPDATE doodles SET label_int = dense_rank() OVER (ORDER BY word) - 1"))

message("Generate row numbers")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD id serial"))
invisible(DBI::dbExecute(con, "CREATE ORDERED INDEX doodles_id_ord_idx ON doodles(id)"))

์ฆ‰์„์—์„œ ๋ฐฐ์น˜๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋ ค๋ฉด ํ…Œ์ด๋ธ”์—์„œ ์ž„์˜์˜ ํ–‰์„ ์ถ”์ถœํ•˜๋Š” ์ตœ๋Œ€ ์†๋„๋ฅผ ๋‹ฌ์„ฑํ•ด์•ผ ํ–ˆ์Šต๋‹ˆ๋‹ค. doodles. ์ด๋ฅผ ์œ„ํ•ด ์šฐ๋ฆฌ๋Š” 3๊ฐ€์ง€ ํŠธ๋ฆญ์„ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค. ์ฒซ ๋ฒˆ์งธ๋Š” ๊ด€์ฐฐ ID๋ฅผ ์ €์žฅํ•˜๋Š” ์œ ํ˜•์˜ ์ฐจ์›์„ ์ค„์ด๋Š” ๊ฒƒ์ด์—ˆ์Šต๋‹ˆ๋‹ค. ์›๋ณธ ๋ฐ์ดํ„ฐ ์„ธํŠธ์—์„œ ID๋ฅผ ์ €์žฅํ•˜๋Š” ๋ฐ ํ•„์š”ํ•œ ์œ ํ˜•์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค. bigint, ๊ทธ๋Ÿฌ๋‚˜ ๊ด€์ธก์น˜ ์ˆ˜๋ฅผ ํ†ตํ•ด ์„œ์ˆ˜์™€ ๋™์ผํ•œ ์‹๋ณ„์ž๋ฅผ ์œ ํ˜•์— ๋งž์ถœ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. int. ์ด ๊ฒฝ์šฐ ๊ฒ€์ƒ‰์ด ํ›จ์”ฌ ๋น ๋ฆ…๋‹ˆ๋‹ค. ๋‘ ๋ฒˆ์งธ ๋น„๊ฒฐ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค. ORDERED INDEX โ€” ์šฐ๋ฆฌ๋Š” ๊ฐ€๋Šฅํ•œ ๋ชจ๋“  ๋ฐฉ๋ฒ•์„ ๊ฒ€ํ† ํ•œ ํ›„ ๊ฒฝํ—˜์ ์œผ๋กœ ์ด ๊ฒฐ์ •์„ ๋‚ด๋ ธ์Šต๋‹ˆ๋‹ค. ์˜ต์…˜๋“ค. ์„ธ ๋ฒˆ์งธ๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜ํ™”๋œ ์ฟผ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด์—ˆ์Šต๋‹ˆ๋‹ค. ์ด ๋ฐฉ๋ฒ•์˜ ํ•ต์‹ฌ์€ ๋ช…๋ น์„ ํ•œ ๋ฒˆ ์‹คํ–‰ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. PREPARE ๋™์ผํ•œ ์œ ํ˜•์˜ ์ฟผ๋ฆฌ๋ฅผ ์—ฌ๋Ÿฌ ๊ฐœ ์ƒ์„ฑํ•  ๋•Œ ์ค€๋น„๋œ ํ‘œํ˜„์‹์„ ๋‚˜์ค‘์— ์‚ฌ์šฉํ•˜์ง€๋งŒ ์‹ค์ œ๋กœ๋Š” ๊ฐ„๋‹จํ•œ ์ฟผ๋ฆฌ์— ๋น„ํ•ด ์ด์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค. SELECT ํ†ต๊ณ„์˜ค์ฐจ ๋ฒ”์œ„ ๋‚ด์ธ ๊ฒƒ์œผ๋กœ ๋‚˜ํƒ€๋‚ฌ๋‹ค.

๋ฐ์ดํ„ฐ ์—…๋กœ๋“œ ํ”„๋กœ์„ธ์Šค๋Š” 450MB ์ดํ•˜์˜ RAM์„ ์†Œ๋น„ํ•ฉ๋‹ˆ๋‹ค. ์ฆ‰, ์„ค๋ช…๋œ ์ ‘๊ทผ ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•˜๋ฉด ์ผ๋ถ€ ๋‹จ์ผ ๋ณด๋“œ ์žฅ์น˜๋ฅผ ํฌํ•จํ•˜์—ฌ ๊ฑฐ์˜ ๋ชจ๋“  ์ €๊ฐ€ํ˜• ํ•˜๋“œ์›จ์–ด์—์„œ ์ˆ˜์‹ญ ๊ธฐ๊ฐ€๋ฐ”์ดํŠธ์— ๋‹ฌํ•˜๋Š” ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ์ด๋™ํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ์ ์€ ๋งค์šฐ ํ›Œ๋ฅญํ•ฉ๋‹ˆ๋‹ค.

๋‚จ์€ ๊ฒƒ์€ (์ž„์˜) ๋ฐ์ดํ„ฐ ๊ฒ€์ƒ‰ ์†๋„๋ฅผ ์ธก์ •ํ•˜๊ณ  ๋‹ค์–‘ํ•œ ํฌ๊ธฐ์˜ ๋ฐฐ์น˜๋ฅผ ์ƒ˜ํ”Œ๋งํ•  ๋•Œ ํฌ๊ธฐ ์กฐ์ •์„ ํ‰๊ฐ€ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ๋ฒค์น˜๋งˆํฌ

library(ggplot2)

set.seed(0)
# ะŸะพะดะบะปัŽั‡ะตะฝะธะต ะบ ะฑะฐะทะต ะดะฐะฝะฝั‹ั…
con <- DBI::dbConnect(MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

# ะคัƒะฝะบั†ะธั ะดะปั ะฟะพะดะณะพั‚ะพะฒะบะธ ะทะฐะฟั€ะพัะฐ ะฝะฐ ัั‚ะพั€ะพะฝะต ัะตั€ะฒะตั€ะฐ
prep_sql <- function(batch_size) {
  sql <- sprintf("PREPARE SELECT id FROM doodles WHERE id IN (%s)",
                 paste(rep("?", batch_size), collapse = ","))
  res <- DBI::dbSendQuery(con, sql)
  return(res)
}

# ะคัƒะฝะบั†ะธั ะดะปั ะธะทะฒะปะตั‡ะตะฝะธั ะดะฐะฝะฝั‹ั…
fetch_data <- function(rs, batch_size) {
  ids <- sample(seq_len(n), batch_size)
  res <- DBI::dbFetch(DBI::dbBind(rs, as.list(ids)))
  return(res)
}

# ะŸั€ะพะฒะตะดะตะฝะธะต ะทะฐะผะตั€ะฐ
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    rs <- prep_sql(batch_size)
    bench::mark(
      fetch_data(rs, batch_size),
      min_iterations = 50L
    )
  }
)
# ะŸะฐั€ะฐะผะตั‚ั€ั‹ ะฑะตะฝั‡ะผะฐั€ะบะฐ
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16   23.6ms  54.02ms  93.43ms     18.8        2.6s    49
# 2         32     38ms  84.83ms 151.55ms     11.4       4.29s    49
# 3         64   63.3ms 175.54ms 248.94ms     5.85       8.54s    50
# 4        128   83.2ms 341.52ms 496.24ms     3.00      16.69s    50
# 5        256  232.8ms 653.21ms 847.44ms     1.58      31.66s    50
# 6        512  784.6ms    1.41s    1.98s     0.740       1.1m    49
# 7       1024  681.7ms    2.72s    4.06s     0.377      2.16m    49

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Quick Draw Doodle Recognition: R, C++ ๋ฐ ์‹ ๊ฒฝ๋ง๊ณผ ์นœ๊ตฌ๊ฐ€ ๋˜๋Š” ๋ฐฉ๋ฒ•

2. ๋ฐฐ์น˜ ์ค€๋น„

์ „์ฒด ๋ฐฐ์น˜ ์ค€๋น„ ํ”„๋กœ์„ธ์Šค๋Š” ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค.

  1. ์  ์ขŒํ‘œ๊ฐ€ ํฌํ•จ๋œ ๋ฌธ์ž์—ด ๋ฒกํ„ฐ๊ฐ€ ํฌํ•จ๋œ ์—ฌ๋Ÿฌ JSON์„ ๊ตฌ๋ฌธ ๋ถ„์„ํ•ฉ๋‹ˆ๋‹ค.
  2. ํ•„์š”ํ•œ ํฌ๊ธฐ(์˜ˆ: 256ร—256 ๋˜๋Š” 128ร—128)์˜ ์ด๋ฏธ์ง€์— ์  ์ขŒํ‘œ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์ƒ‰์ƒ ์„ ์„ ๊ทธ๋ฆฝ๋‹ˆ๋‹ค.
  3. ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€๋ฅผ ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

Python ์ปค๋„ ๊ฐ„์˜ ๊ฒฝ์Ÿ์˜ ์ผํ™˜์œผ๋กœ ๋ฌธ์ œ๋Š” ์ฃผ๋กœ ๋‹ค์Œ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ•ด๊ฒฐ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. OpenCV. R์—์„œ ๊ฐ€์žฅ ๊ฐ„๋‹จํ•˜๊ณ  ๊ฐ€์žฅ ๋ถ„๋ช…ํ•œ ์œ ์‚ฌ์  ์ค‘ ํ•˜๋‚˜๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

R์—์„œ JSON์„ Tensor๋กœ ๋ณ€ํ™˜ ๊ตฌํ˜„

r_process_json_str <- function(json, line.width = 3, 
                               color = TRUE, scale = 1) {
  # ะŸะฐั€ัะธะฝะณ JSON
  coords <- jsonlite::fromJSON(json, simplifyMatrix = FALSE)
  tmp <- tempfile()
  # ะฃะดะฐะปัะตะผ ะฒั€ะตะผะตะฝะฝั‹ะน ั„ะฐะนะป ะฟะพ ะทะฐะฒะตั€ัˆะตะฝะธัŽ ั„ัƒะฝะบั†ะธะธ
  on.exit(unlink(tmp))
  png(filename = tmp, width = 256 * scale, height = 256 * scale, pointsize = 1)
  # ะŸัƒัั‚ะพะน ะณั€ะฐั„ะธะบ
  plot.new()
  # ะ ะฐะทะผะตั€ ะพะบะฝะฐ ะณั€ะฐั„ะธะบะฐ
  plot.window(xlim = c(256 * scale, 0), ylim = c(256 * scale, 0))
  # ะฆะฒะตั‚ะฐ ะปะธะฝะธะน
  cols <- if (color) rainbow(length(coords)) else "#000000"
  for (i in seq_along(coords)) {
    lines(x = coords[[i]][[1]] * scale, y = coords[[i]][[2]] * scale, 
          col = cols[i], lwd = line.width)
  }
  dev.off()
  # ะŸั€ะตะพะฑั€ะฐะทะพะฒะฐะฝะธะต ะธะทะพะฑั€ะฐะถะตะฝะธั ะฒ 3-ั… ะผะตั€ะฝั‹ะน ะผะฐััะธะฒ
  res <- png::readPNG(tmp)
  return(res)
}

r_process_json_vector <- function(x, ...) {
  res <- lapply(x, r_process_json_str, ...)
  # ะžะฑัŠะตะดะธะฝะตะฝะธะต 3-ั… ะผะตั€ะฝั‹ั… ะผะฐััะธะฒะพะฒ ะบะฐั€ั‚ะธะฝะพะบ ะฒ 4-ั… ะผะตั€ะฝั‹ะน ะฒ ั‚ะตะฝะทะพั€
  res <- do.call(abind::abind, c(res, along = 0))
  return(res)
}

๊ทธ๋ฆฌ๊ธฐ๋Š” ํ‘œ์ค€ R ๋„๊ตฌ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ˆ˜ํ–‰๋˜๋ฉฐ RAM์— ์ €์žฅ๋œ ์ž„์‹œ PNG์— ์ €์žฅ๋ฉ๋‹ˆ๋‹ค(Linux์—์„œ๋Š” ์ž„์‹œ R ๋””๋ ‰ํ„ฐ๋ฆฌ๊ฐ€ ๋””๋ ‰ํ„ฐ๋ฆฌ์— ์žˆ์Œ). /tmp, RAM์— ๋งˆ์šดํŠธ๋จ). ๊ทธ๋Ÿฐ ๋‹ค์Œ ์ด ํŒŒ์ผ์€ 0์—์„œ 1๊นŒ์ง€์˜ ์ˆซ์ž๋ฅผ ๊ฐ€์ง„ XNUMX์ฐจ์› ๋ฐฐ์—ด๋กœ ์ฝํ˜€์ง‘๋‹ˆ๋‹ค. ๋ณด๋‹ค ์ผ๋ฐ˜์ ์ธ BMP๋Š” XNUMX์ง„์ˆ˜ ์ƒ‰์ƒ ์ฝ”๋“œ๊ฐ€ ์žˆ๋Š” ์›์‹œ ๋ฐฐ์—ด๋กœ ์ฝํ˜€์ง€๊ธฐ ๋•Œ๋ฌธ์— ์ด๋Š” ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค.

๊ฒฐ๊ณผ๋ฅผ ํ…Œ์ŠคํŠธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

zip_file <- file.path("data", "train_simplified.zip")
csv_file <- "cat.csv"
unzip(zip_file, files = csv_file, exdir = tempdir(), 
      junkpaths = TRUE, unzip = getOption("unzip"))
tmp_data <- data.table::fread(file.path(tempdir(), csv_file), sep = ",", 
                              select = "drawing", nrows = 10000)
arr <- r_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Quick Draw Doodle Recognition: R, C++ ๋ฐ ์‹ ๊ฒฝ๋ง๊ณผ ์นœ๊ตฌ๊ฐ€ ๋˜๋Š” ๋ฐฉ๋ฒ•

๋ฐฐ์น˜ ์ž์ฒด๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค.

res <- r_process_json_vector(tmp_data[1:4, drawing], scale = 0.5)
str(res)
 # num [1:4, 1:128, 1:128, 1:3] 1 1 1 1 1 1 1 1 1 1 ...
 # - attr(*, "dimnames")=List of 4
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL

๋Œ€๊ทœ๋ชจ ๋ฐฐ์น˜๋ฅผ ํ˜•์„ฑํ•˜๋Š” ๋ฐ ์‹œ๊ฐ„์ด ๋„ˆ๋ฌด ์˜ค๋ž˜ ๊ฑธ๋ฆฌ๊ณ  ๊ฐ•๋ ฅํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋™๋ฃŒ์˜ ๊ฒฝํ—˜์„ ํ™œ์šฉํ•˜๊ธฐ๋กœ ๊ฒฐ์ •ํ–ˆ๊ธฐ ๋•Œ๋ฌธ์— ์ด ๊ตฌํ˜„์€ ์šฐ๋ฆฌ์—๊ฒŒ ์ฐจ์„ ์ฑ…์œผ๋กœ ๋ณด์˜€์Šต๋‹ˆ๋‹ค. OpenCV. ๊ทธ ๋‹น์‹œ์—๋Š” R์šฉ ๊ธฐ์„ฑ ํŒจํ‚ค์ง€๊ฐ€ ์—†์—ˆ์œผ๋ฏ€๋กœ(ํ˜„์žฌ๋Š” ์—†์Œ) ํ•„์š”ํ•œ ๊ธฐ๋Šฅ์˜ ์ตœ์†Œํ•œ์˜ ๊ตฌํ˜„์ด ๋‹ค์Œ์„ ์‚ฌ์šฉํ•˜์—ฌ R ์ฝ”๋“œ์— ํ†ตํ•ฉ๋œ C++๋กœ ์ž‘์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. RCPP.

๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ๋‹ค์Œ ํŒจํ‚ค์ง€์™€ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์‚ฌ์šฉ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

  1. OpenCV ์ด๋ฏธ์ง€ ์ž‘์—… ๋ฐ ์„  ๊ทธ๋ฆฌ๊ธฐ์— ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. ์‚ฌ์ „ ์„ค์น˜๋œ ์‹œ์Šคํ…œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์™€ ํ—ค๋” ํŒŒ์ผ์€ ๋ฌผ๋ก  ๋™์  ๋งํฌ๋„ ์‚ฌ์šฉ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

  2. ์—‘์Šคํ…์„œ ๋‹ค์ฐจ์› ๋ฐฐ์—ด ๋ฐ ํ…์„œ ์ž‘์—…์— ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. R ํŒจํ‚ค์ง€์— ํฌํ•จ๋œ ๋™์ผํ•œ ์ด๋ฆ„์˜ ํ—ค๋” ํŒŒ์ผ์„ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค. ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ํ–‰ ์ฃผ์š” ์ˆœ์„œ์™€ ์—ด ์ฃผ์š” ์ˆœ์„œ ๋ชจ๋‘์—์„œ ๋‹ค์ฐจ์› ๋ฐฐ์—ด๋กœ ์ž‘์—…ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  3. ndjson JSON์„ ๊ตฌ๋ฌธ ๋ถ„์„ํ•˜๊ธฐ ์œ„ํ•œ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋Š” ๋‹ค์Œ์—์„œ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. ์—‘์Šคํ…์„œ ํ”„๋กœ์ ํŠธ์— ์žˆ์œผ๋ฉด ์ž๋™์œผ๋กœ.

  4. Rcpp์Šค๋ ˆ๋“œ JSON์—์„œ ๋ฒกํ„ฐ์˜ ๋ฉ€ํ‹ฐ์Šค๋ ˆ๋“œ ์ฒ˜๋ฆฌ๋ฅผ ๊ตฌ์„ฑํ•˜๊ธฐ ์œ„ํ•œ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด ํŒจํ‚ค์ง€์—์„œ ์ œ๊ณตํ•˜๋Š” ํ—ค๋” ํŒŒ์ผ์„ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค. ๋” ์ธ๊ธฐ์žˆ๋Š” ๊ฒƒ๋ถ€ํ„ฐ Rcpp๋ณ‘๋ ฌ ๋ฌด์—‡๋ณด๋‹ค๋„ ํŒจํ‚ค์ง€์—๋Š” ๋ฃจํ”„ ์ธํ„ฐ๋ŸฝํŠธ ๋ฉ”์ปค๋‹ˆ์ฆ˜์ด ๋‚ด์žฅ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

๊ทธ๊ฒƒ์€ ์–ธ๊ธ‰๋˜์–ด์•ผ ์—‘์Šคํ…์„œ ์‹ ์˜ ์„ ๋ฌผ๋กœ ํŒ๋ช…๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๊ด‘๋ฒ”์œ„ํ•œ ๊ธฐ๋Šฅ๊ณผ ๊ณ ์„ฑ๋Šฅ์„ ๊ฐ–์ถ”๊ณ  ์žˆ๋‹ค๋Š” ์‚ฌ์‹ค ์™ธ์—๋„ ๊ฐœ๋ฐœ์ž๋Š” ๋ฐ˜์‘์ด ๋น ๋ฅด๊ณ  ์งˆ๋ฌธ์— ์‹ ์†ํ•˜๊ณ  ์ž์„ธํ•˜๊ฒŒ ๋‹ต๋ณ€ํ–ˆ์Šต๋‹ˆ๋‹ค. ์ด๋“ค์˜ ๋„์›€์œผ๋กœ OpenCV ํ–‰๋ ฌ์„ xtensor ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  3์ฐจ์› ์ด๋ฏธ์ง€ ํ…์„œ๋ฅผ ์˜ฌ๋ฐ”๋ฅธ ์ฐจ์›(๋ฐฐ์น˜ ์ž์ฒด)์˜ 4์ฐจ์› ํ…์„œ๋กœ ๊ฒฐํ•ฉํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค.

Rcpp, xtensor ๋ฐ RcppThread ํ•™์Šต ์ž๋ฃŒ

https://thecoatlessprofessor.com/programming/unofficial-rcpp-api-documentation

https://docs.opencv.org/4.0.1/d7/dbd/group__imgproc.html

https://xtensor.readthedocs.io/en/latest/

https://xtensor.readthedocs.io/en/latest/file_loading.html#loading-json-data-into-xtensor

https://cran.r-project.org/web/packages/RcppThread/vignettes/RcppThread-vignette.pdf

์‹œ์Šคํ…œ ํŒŒ์ผ๊ณผ ์‹œ์Šคํ…œ์— ์„ค์น˜๋œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์™€์˜ ๋™์  ๋งํฌ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ํŒŒ์ผ์„ ์ปดํŒŒ์ผํ•˜๊ธฐ ์œ„ํ•ด ํŒจํ‚ค์ง€์— ๊ตฌํ˜„๋œ ํ”Œ๋Ÿฌ๊ทธ์ธ ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค. RCPP. ๊ฒฝ๋กœ์™€ ํ”Œ๋ž˜๊ทธ๋ฅผ ์ž๋™์œผ๋กœ ์ฐพ๊ธฐ ์œ„ํ•ด ๋„๋ฆฌ ์‚ฌ์šฉ๋˜๋Š” Linux ์œ ํ‹ธ๋ฆฌํ‹ฐ๋ฅผ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค. ํŒจํ‚ค์ง€ ๊ตฌ์„ฑ.

OpenCV ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•œ Rcpp ํ”Œ๋Ÿฌ๊ทธ์ธ ๊ตฌํ˜„

Rcpp::registerPlugin("opencv", function() {
  # ะ’ะพะทะผะพะถะฝั‹ะต ะฝะฐะทะฒะฐะฝะธั ะฟะฐะบะตั‚ะฐ
  pkg_config_name <- c("opencv", "opencv4")
  # ะ‘ะธะฝะฐั€ะฝั‹ะน ั„ะฐะนะป ัƒั‚ะธะปะธั‚ั‹ pkg-config
  pkg_config_bin <- Sys.which("pkg-config")
  # ะŸั€ะพะฒั€ะตะบะฐ ะฝะฐะปะธั‡ะธั ัƒั‚ะธะปะธั‚ั‹ ะฒ ัะธัั‚ะตะผะต
  checkmate::assert_file_exists(pkg_config_bin, access = "x")
  # ะŸั€ะพะฒะตั€ะบะฐ ะฝะฐะปะธั‡ะธั ั„ะฐะนะปะฐ ะฝะฐัั‚ั€ะพะตะบ OpenCV ะดะปั pkg-config
  check <- sapply(pkg_config_name, 
                  function(pkg) system(paste(pkg_config_bin, pkg)))
  if (all(check != 0)) {
    stop("OpenCV config for the pkg-config not found", call. = FALSE)
  }

  pkg_config_name <- pkg_config_name[check == 0]
  list(env = list(
    PKG_CXXFLAGS = system(paste(pkg_config_bin, "--cflags", pkg_config_name), 
                          intern = TRUE),
    PKG_LIBS = system(paste(pkg_config_bin, "--libs", pkg_config_name), 
                      intern = TRUE)
  ))
})

ํ”Œ๋Ÿฌ๊ทธ์ธ ์ž‘์—…์˜ ๊ฒฐ๊ณผ๋กœ ์ปดํŒŒ์ผ ํ”„๋กœ์„ธ์Šค ์ค‘์— ๋‹ค์Œ ๊ฐ’์ด ๋Œ€์ฒด๋ฉ๋‹ˆ๋‹ค.

Rcpp:::.plugins$opencv()$env

# $PKG_CXXFLAGS
# [1] "-I/usr/include/opencv"
#
# $PKG_LIBS
# [1] "-lopencv_shape -lopencv_stitching -lopencv_superres -lopencv_videostab -lopencv_aruco -lopencv_bgsegm -lopencv_bioinspired -lopencv_ccalib -lopencv_datasets -lopencv_dpm -lopencv_face -lopencv_freetype -lopencv_fuzzy -lopencv_hdf -lopencv_line_descriptor -lopencv_optflow -lopencv_video -lopencv_plot -lopencv_reg -lopencv_saliency -lopencv_stereo -lopencv_structured_light -lopencv_phase_unwrapping -lopencv_rgbd -lopencv_viz -lopencv_surface_matching -lopencv_text -lopencv_ximgproc -lopencv_calib3d -lopencv_features2d -lopencv_flann -lopencv_xobjdetect -lopencv_objdetect -lopencv_ml -lopencv_xphoto -lopencv_highgui -lopencv_videoio -lopencv_imgcodecs -lopencv_photo -lopencv_imgproc -lopencv_core"

JSON์„ ๊ตฌ๋ฌธ ๋ถ„์„ํ•˜๊ณ  ๋ชจ๋ธ๋กœ ์ „์†กํ•˜๊ธฐ ์œ„ํ•œ ๋ฐฐ์น˜๋ฅผ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•œ ๊ตฌํ˜„ ์ฝ”๋“œ๋Š” ์Šคํฌ์ผ๋Ÿฌ ์•„๋ž˜์— ์ œ๊ณต๋ฉ๋‹ˆ๋‹ค. ๋จผ์ € ํ—ค๋” ํŒŒ์ผ์„ ๊ฒ€์ƒ‰ํ•  ๋กœ์ปฌ ํ”„๋กœ์ ํŠธ ๋””๋ ‰ํ„ฐ๋ฆฌ๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค(ndjson์— ํ•„์š”ํ•จ).

Sys.setenv("PKG_CXXFLAGS" = paste0("-I", normalizePath(file.path("src"))))

C++์—์„œ JSON์„ ํ…์„œ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ๊ตฌํ˜„

// [[Rcpp::plugins(cpp14)]]
// [[Rcpp::plugins(opencv)]]
// [[Rcpp::depends(xtensor)]]
// [[Rcpp::depends(RcppThread)]]

#include <xtensor/xjson.hpp>
#include <xtensor/xadapt.hpp>
#include <xtensor/xview.hpp>
#include <xtensor-r/rtensor.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <Rcpp.h>
#include <RcppThread.h>

// ะกะธะฝะพะฝะธะผั‹ ะดะปั ั‚ะธะฟะพะฒ
using RcppThread::parallelFor;
using json = nlohmann::json;
using points = xt::xtensor<double,2>;     // ะ˜ะทะฒะปะตั‡ั‘ะฝะฝั‹ะต ะธะท JSON ะบะพะพั€ะดะธะฝะฐั‚ั‹ ั‚ะพั‡ะตะบ
using strokes = std::vector<points>;      // ะ˜ะทะฒะปะตั‡ั‘ะฝะฝั‹ะต ะธะท JSON ะบะพะพั€ะดะธะฝะฐั‚ั‹ ั‚ะพั‡ะตะบ
using xtensor3d = xt::xtensor<double, 3>; // ะขะตะฝะทะพั€ ะดะปั ั…ั€ะฐะฝะตะฝะธั ะผะฐั‚ั€ะธั†ั‹ ะธะทะพะพะฑั€ะฐะถะตะฝะธั
using xtensor4d = xt::xtensor<double, 4>; // ะขะตะฝะทะพั€ ะดะปั ั…ั€ะฐะฝะตะฝะธั ะผะฝะพะถะตัั‚ะฒะฐ ะธะทะพะฑั€ะฐะถะตะฝะธะน
using rtensor3d = xt::rtensor<double, 3>; // ะžะฑั‘ั€ั‚ะบะฐ ะดะปั ัะบัะฟะพั€ั‚ะฐ ะฒ R
using rtensor4d = xt::rtensor<double, 4>; // ะžะฑั‘ั€ั‚ะบะฐ ะดะปั ัะบัะฟะพั€ั‚ะฐ ะฒ R

// ะกั‚ะฐั‚ะธั‡ะตัะบะธะต ะบะพะฝัั‚ะฐะฝั‚ั‹
// ะ ะฐะทะผะตั€ ะธะทะพะฑั€ะฐะถะตะฝะธั ะฒ ะฟะธะบัะตะปัั…
const static int SIZE = 256;
// ะขะธะฟ ะปะธะฝะธะธ
// ะกะผ. https://en.wikipedia.org/wiki/Pixel_connectivity#2-dimensional
const static int LINE_TYPE = cv::LINE_4;
// ะขะพะปั‰ะธะฝะฐ ะปะธะฝะธะธ ะฒ ะฟะธะบัะตะปัั…
const static int LINE_WIDTH = 3;
// ะะปะณะพั€ะธั‚ะผ ั€ะตัะฐะนะทะฐ
// https://docs.opencv.org/3.1.0/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
const static int RESIZE_TYPE = cv::INTER_LINEAR;

// ะจะฐะฑะปะพะฝ ะดะปั ะบะพะฝะฒะตั€ั‚ะธั€ะพะฒะฐะฝะธั OpenCV-ะผะฐั‚ั€ะธั†ั‹ ะฒ ั‚ะตะฝะทะพั€
template <typename T, int NCH, typename XT=xt::xtensor<T,3,xt::layout_type::column_major>>
XT to_xt(const cv::Mat_<cv::Vec<T, NCH>>& src) {
  // ะ ะฐะทะผะตั€ะฝะพัั‚ัŒ ั†ะตะปะตะฒะพะณะพ ั‚ะตะฝะทะพั€ะฐ
  std::vector<int> shape = {src.rows, src.cols, NCH};
  // ะžะฑั‰ะตะต ะบะพะปะธั‡ะตัั‚ะฒะพ ัะปะตะผะตะฝั‚ะพะฒ ะฒ ะผะฐััะธะฒะต
  size_t size = src.total() * NCH;
  // ะŸั€ะตะพะฑั€ะฐะทะพะฒะฐะฝะธะต cv::Mat ะฒ xt::xtensor
  XT res = xt::adapt((T*) src.data, size, xt::no_ownership(), shape);
  return res;
}

// ะŸั€ะตะพะฑั€ะฐะทะพะฒะฐะฝะธะต JSON ะฒ ัะฟะธัะพะบ ะบะพะพั€ะดะธะฝะฐั‚ ั‚ะพั‡ะตะบ
strokes parse_json(const std::string& x) {
  auto j = json::parse(x);
  // ะ ะตะทัƒะปัŒั‚ะฐั‚ ะฟะฐั€ัะธะฝะณะฐ ะดะพะปะถะตะฝ ะฑั‹ั‚ัŒ ะผะฐััะธะฒะพะผ
  if (!j.is_array()) {
    throw std::runtime_error("'x' must be JSON array.");
  }
  strokes res;
  res.reserve(j.size());
  for (const auto& a: j) {
    // ะšะฐะถะดั‹ะน ัะปะตะผะตะฝั‚ ะผะฐััะธะฒะฐ ะดะพะปะถะตะฝ ะฑั‹ั‚ัŒ 2-ะผะตั€ะฝั‹ะผ ะผะฐััะธะฒะพะผ
    if (!a.is_array() || a.size() != 2) {
      throw std::runtime_error("'x' must include only 2d arrays.");
    }
    // ะ˜ะทะฒะปะตั‡ะตะฝะธะต ะฒะตะบั‚ะพั€ะฐ ั‚ะพั‡ะตะบ
    auto p = a.get<points>();
    res.push_back(p);
  }
  return res;
}

// ะžั‚ั€ะธัะพะฒะบะฐ ะปะธะฝะธะน
// ะฆะฒะตั‚ะฐ HSV
cv::Mat ocv_draw_lines(const strokes& x, bool color = true) {
  // ะ˜ัั…ะพะดะฝั‹ะน ั‚ะธะฟ ะผะฐั‚ั€ะธั†ั‹
  auto stype = color ? CV_8UC3 : CV_8UC1;
  // ะ˜ั‚ะพะณะพะฒั‹ะน ั‚ะธะฟ ะผะฐั‚ั€ะธั†ั‹
  auto dtype = color ? CV_32FC3 : CV_32FC1;
  auto bg = color ? cv::Scalar(0, 0, 255) : cv::Scalar(255);
  auto col = color ? cv::Scalar(0, 255, 220) : cv::Scalar(0);
  cv::Mat img = cv::Mat(SIZE, SIZE, stype, bg);
  // ะšะพะปะธั‡ะตัั‚ะฒะพ ะปะธะฝะธะน
  size_t n = x.size();
  for (const auto& s: x) {
    // ะšะพะปะธั‡ะตัั‚ะฒะพ ั‚ะพั‡ะตะบ ะฒ ะปะธะฝะธะธ
    size_t n_points = s.shape()[1];
    for (size_t i = 0; i < n_points - 1; ++i) {
      // ะขะพั‡ะบะฐ ะฝะฐั‡ะฐะปะฐ ัˆั‚ั€ะธั…ะฐ
      cv::Point from(s(0, i), s(1, i));
      // ะขะพั‡ะบะฐ ะพะบะพะฝั‡ะฐะฝะธั ัˆั‚ั€ะธั…ะฐ
      cv::Point to(s(0, i + 1), s(1, i + 1));
      // ะžั‚ั€ะธัะพะฒะบะฐ ะปะธะฝะธะธ
      cv::line(img, from, to, col, LINE_WIDTH, LINE_TYPE);
    }
    if (color) {
      // ะœะตะฝัะตะผ ั†ะฒะตั‚ ะปะธะฝะธะธ
      col[0] += 180 / n;
    }
  }
  if (color) {
    // ะœะตะฝัะตะผ ั†ะฒะตั‚ะพะฒะพะต ะฟั€ะตะดัั‚ะฐะฒะปะตะฝะธะต ะฝะฐ RGB
    cv::cvtColor(img, img, cv::COLOR_HSV2RGB);
  }
  // ะœะตะฝัะตะผ ั„ะพั€ะผะฐั‚ ะฟั€ะตะดัั‚ะฐะฒะปะตะฝะธั ะฝะฐ float32 ั ะดะธะฐะฟะฐะทะพะฝะพะผ [0, 1]
  img.convertTo(img, dtype, 1 / 255.0);
  return img;
}

// ะžะฑั€ะฐะฑะพั‚ะบะฐ JSON ะธ ะฟะพะปัƒั‡ะตะฝะธะต ั‚ะตะฝะทะพั€ะฐ ั ะดะฐะฝะฝั‹ะผะธ ะธะทะพะฑั€ะฐะถะตะฝะธั
xtensor3d process(const std::string& x, double scale = 1.0, bool color = true) {
  auto p = parse_json(x);
  auto img = ocv_draw_lines(p, color);
  if (scale != 1) {
    cv::Mat out;
    cv::resize(img, out, cv::Size(), scale, scale, RESIZE_TYPE);
    cv::swap(img, out);
    out.release();
  }
  xtensor3d arr = color ? to_xt<double,3>(img) : to_xt<double,1>(img);
  return arr;
}

// [[Rcpp::export]]
rtensor3d cpp_process_json_str(const std::string& x, 
                               double scale = 1.0, 
                               bool color = true) {
  xtensor3d res = process(x, scale, color);
  return res;
}

// [[Rcpp::export]]
rtensor4d cpp_process_json_vector(const std::vector<std::string>& x, 
                                  double scale = 1.0, 
                                  bool color = false) {
  size_t n = x.size();
  size_t dim = floor(SIZE * scale);
  size_t channels = color ? 3 : 1;
  xtensor4d res({n, dim, dim, channels});
  parallelFor(0, n, [&x, &res, scale, color](int i) {
    xtensor3d tmp = process(x[i], scale, color);
    auto view = xt::view(res, i, xt::all(), xt::all(), xt::all());
    view = tmp;
  });
  return res;
}

์ด ์ฝ”๋“œ๋Š” ํŒŒ์ผ์— ๋ฐฐ์น˜๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. src/cv_xt.cpp ๊ทธ๋ฆฌ๊ณ  ๋ช…๋ น์–ด๋กœ ์ปดํŒŒ์ผ Rcpp::sourceCpp(file = "src/cv_xt.cpp", env = .GlobalEnv); ์—…๋ฌด์—๋„ ํ•„์š”ํ•˜๋‹ค nlohmann/json.hpp ์œผ๋กœ ์ €์žฅ์†Œ. ์ฝ”๋“œ๋Š” ์—ฌ๋Ÿฌ ๊ธฐ๋Šฅ์œผ๋กœ ๋‚˜๋‰ฉ๋‹ˆ๋‹ค.

  • to_xt โ€” ์ด๋ฏธ์ง€ ํ–‰๋ ฌ์„ ๋ณ€ํ™˜ํ•˜๊ธฐ ์œ„ํ•œ ํ…œํ”Œ๋ฆฟ ํ•จ์ˆ˜(cv::Mat)๋ฅผ ํ…์„œ๋กœ xt::xtensor;

  • parse_json โ€” ํ•จ์ˆ˜๋Š” JSON ๋ฌธ์ž์—ด์„ ๊ตฌ๋ฌธ ๋ถ„์„ํ•˜๊ณ  ์ ์˜ ์ขŒํ‘œ๋ฅผ ์ถ”์ถœํ•˜์—ฌ ๋ฒกํ„ฐ๋กœ ์••์ถ•ํ•ฉ๋‹ˆ๋‹ค.

  • ocv_draw_lines โ€” ๊ฒฐ๊ณผ ์  ๋ฒกํ„ฐ์—์„œ ์—ฌ๋Ÿฌ ์ƒ‰์ƒ์˜ ์„ ์„ ๊ทธ๋ฆฝ๋‹ˆ๋‹ค.

  • process โ€” ์œ„์˜ ๊ธฐ๋Šฅ์„ ๊ฒฐํ•ฉํ•˜๊ณ  ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€์˜ ํฌ๊ธฐ๋ฅผ ์กฐ์ •ํ•˜๋Š” ๊ธฐ๋Šฅ๋„ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.

  • cpp_process_json_str - ํ•จ์ˆ˜์— ๋Œ€ํ•œ ๋ž˜ํผ process, ๊ฒฐ๊ณผ๋ฅผ R ๊ฐ์ฒด(๋‹ค์ฐจ์› ๋ฐฐ์—ด)๋กœ ๋‚ด๋ณด๋ƒ…๋‹ˆ๋‹ค.

  • cpp_process_json_vector - ํ•จ์ˆ˜์— ๋Œ€ํ•œ ๋ž˜ํผ cpp_process_json_str, ์ด๋ฅผ ํ†ตํ•ด ๋ฉ€ํ‹ฐ์Šค๋ ˆ๋“œ ๋ชจ๋“œ์—์„œ ๋ฌธ์ž์—ด ๋ฒกํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋‹ค์ƒ‰ ์„ ์„ ๊ทธ๋ฆฌ๊ธฐ ์œ„ํ•ด HSV ์ƒ‰์ƒ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ํ›„ RGB๋กœ ๋ณ€ํ™˜ํ•˜์˜€๋‹ค. ๊ฒฐ๊ณผ๋ฅผ ํ…Œ์ŠคํŠธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

arr <- cpp_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Quick Draw Doodle Recognition: R, C++ ๋ฐ ์‹ ๊ฒฝ๋ง๊ณผ ์นœ๊ตฌ๊ฐ€ ๋˜๋Š” ๋ฐฉ๋ฒ•
R๊ณผ C++์˜ ๊ตฌํ˜„ ์†๋„ ๋น„๊ต

res_bench <- bench::mark(
  r_process_json_str(tmp_data[4, drawing], scale = 0.5),
  cpp_process_json_str(tmp_data[4, drawing], scale = 0.5),
  check = FALSE,
  min_iterations = 100
)
# ะŸะฐั€ะฐะผะตั‚ั€ั‹ ะฑะตะฝั‡ะผะฐั€ะบะฐ
cols <- c("expression", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   expression                min     median       max `itr/sec` total_time  n_itr
#   <chr>                <bch:tm>   <bch:tm>  <bch:tm>     <dbl>   <bch:tm>  <int>
# 1 r_process_json_str     3.49ms     3.55ms    4.47ms      273.      490ms    134
# 2 cpp_process_json_str   1.94ms     2.02ms    5.32ms      489.      497ms    243

library(ggplot2)
# ะŸั€ะพะฒะตะดะตะฝะธะต ะทะฐะผะตั€ะฐ
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    .data <- tmp_data[sample(seq_len(.N), batch_size), drawing]
    bench::mark(
      r_process_json_vector(.data, scale = 0.5),
      cpp_process_json_vector(.data,  scale = 0.5),
      min_iterations = 50,
      check = FALSE
    )
  }
)

res_bench[, cols]

#    expression   batch_size      min   median      max `itr/sec` total_time n_itr
#    <chr>             <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
#  1 r                   16   50.61ms  53.34ms  54.82ms    19.1     471.13ms     9
#  2 cpp                 16    4.46ms   5.39ms   7.78ms   192.      474.09ms    91
#  3 r                   32   105.7ms 109.74ms 212.26ms     7.69        6.5s    50
#  4 cpp                 32    7.76ms  10.97ms  15.23ms    95.6     522.78ms    50
#  5 r                   64  211.41ms 226.18ms 332.65ms     3.85      12.99s    50
#  6 cpp                 64   25.09ms  27.34ms  32.04ms    36.0        1.39s    50
#  7 r                  128   534.5ms 627.92ms 659.08ms     1.61      31.03s    50
#  8 cpp                128   56.37ms  58.46ms  66.03ms    16.9        2.95s    50
#  9 r                  256     1.15s    1.18s    1.29s     0.851     58.78s    50
# 10 cpp                256  114.97ms 117.39ms 130.09ms     8.45       5.92s    50
# 11 r                  512     2.09s    2.15s    2.32s     0.463       1.8m    50
# 12 cpp                512  230.81ms  235.6ms 261.99ms     4.18      11.97s    50
# 13 r                 1024        4s    4.22s     4.4s     0.238       3.5m    50
# 14 cpp               1024  410.48ms 431.43ms 462.44ms     2.33      21.45s    50

ggplot(res_bench, aes(x = factor(batch_size), y = median, 
                      group =  expression, color = expression)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal() +
  scale_color_discrete(name = "", labels = c("cpp", "r")) +
  theme(legend.position = "bottom") 

Quick Draw Doodle Recognition: R, C++ ๋ฐ ์‹ ๊ฒฝ๋ง๊ณผ ์นœ๊ตฌ๊ฐ€ ๋˜๋Š” ๋ฐฉ๋ฒ•

๋ณด์‹œ๋‹ค์‹œํ”ผ ์†๋„ ํ–ฅ์ƒ์ด ๋งค์šฐ ์ปธ์œผ๋ฉฐ R ์ฝ”๋“œ๋ฅผ ๋ณ‘๋ ฌํ™”ํ•˜์—ฌ C++ ์ฝ”๋“œ๋ฅผ ๋”ฐ๋ผ์žก๋Š” ๊ฒƒ์€ ๋ถˆ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

3. ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์—์„œ ๋ฐฐ์น˜๋ฅผ ์–ธ๋กœ๋“œํ•˜๊ธฐ ์œ„ํ•œ ๋ฐ˜๋ณต์ž

R์€ RAM์— ๋งž๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ๊ฒƒ์œผ๋กœ ์ •ํ‰์ด ๋‚˜์žˆ๋Š” ๋ฐ˜๋ฉด, Python์€ ๋ฐ˜๋ณต์ ์ธ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ๊ฐ€ ํŠน์ง•์œผ๋กœ ์•„์›ƒ์˜ค๋ธŒ์ฝ”์–ด ๊ณ„์‚ฐ(์™ธ๋ถ€ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•œ ๊ณ„์‚ฐ)์„ ์‰ฝ๊ณ  ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์„ค๋ช…๋œ ๋ฌธ์ œ์˜ ๋งฅ๋ฝ์—์„œ ์šฐ๋ฆฌ์—๊ฒŒ ๊ณ ์ „์ ์ด๊ณ  ๊ด€๋ จ ์žˆ๋Š” ์˜ˆ๋Š” ๊ด€์ฐฐ์˜ ์ž‘์€ ๋ถ€๋ถ„ ๋˜๋Š” ๋ฏธ๋‹ˆ ๋ฐฐ์น˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ ๋‹จ๊ณ„์—์„œ ๊ฒฝ์‚ฌ๋ฅผ ๊ทผ์‚ฌํ™”ํ•˜๋Š” ๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ•์œผ๋กœ ํ›ˆ๋ จ๋œ ์‹ฌ์ธต ์‹ ๊ฒฝ๋ง์ž…๋‹ˆ๋‹ค.

Python์œผ๋กœ ์ž‘์„ฑ๋œ ๋”ฅ ๋Ÿฌ๋‹ ํ”„๋ ˆ์ž„์›Œํฌ์—๋Š” ๋ฐ์ดํ„ฐ(ํ…Œ์ด๋ธ”, ํด๋”์˜ ๊ทธ๋ฆผ, ๋ฐ”์ด๋„ˆ๋ฆฌ ํ˜•์‹ ๋“ฑ)๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๋ฐ˜๋ณต์ž๋ฅผ ๊ตฌํ˜„ํ•˜๋Š” ํŠน์ˆ˜ ํด๋ž˜์Šค๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฏธ๋ฆฌ ๋งŒ๋“ค์–ด์ง„ ์˜ต์…˜์„ ์‚ฌ์šฉํ•˜๊ฑฐ๋‚˜ ํŠน์ • ์ž‘์—…์„ ์œ„ํ•ด ์ง์ ‘ ์ž‘์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. R์—์„œ๋Š” Python ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ ๋ชจ๋“  ๊ธฐ๋Šฅ์„ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ผ€ ๋ผ์Šค ๋™์ผํ•œ ์ด๋ฆ„์˜ ํŒจํ‚ค์ง€๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋‹ค์–‘ํ•œ ๋ฐฑ์—”๋“œ๊ฐ€ ์žˆ์œผ๋ฉฐ, ์ด๋Š” ํŒจํ‚ค์ง€ ์œ„์—์„œ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค. ๋ง์ƒํ•˜๋‹ค. ํ›„์ž๋Š” ๋ณ„๋„์˜ ๊ธด ๊ธฐ์‚ฌ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด R์—์„œ Python ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์„ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ R๊ณผ Python ์„ธ์…˜ ๊ฐ„์— ๊ฐœ์ฒด๋ฅผ ์ „์†กํ•˜์—ฌ ํ•„์š”ํ•œ ๋ชจ๋“  ์œ ํ˜• ๋ณ€ํ™˜์„ ์ž๋™์œผ๋กœ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

MonetDBLite๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ RAM์— ๋ชจ๋“  ๋ฐ์ดํ„ฐ๋ฅผ ์ €์žฅํ•  ํ•„์š”์„ฑ์„ ์—†์•ด์Šต๋‹ˆ๋‹ค. ๋ชจ๋“  "์‹ ๊ฒฝ๋ง" ์ž‘์—…์€ Python์˜ ์›๋ž˜ ์ฝ”๋“œ๋กœ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค. ์ค€๋น„๋œ ๊ฒƒ์ด ์—†๊ธฐ ๋•Œ๋ฌธ์— ๋ฐ์ดํ„ฐ ์œ„์— ๋ฐ˜๋ณต์ž๋ฅผ ์ž‘์„ฑํ•˜๊ธฐ๋งŒ ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค. R์ด๋‚˜ Python์—์„œ ์ด๋Ÿฌํ•œ ์ƒํ™ฉ์— ๋Œ€๋น„ํ•ฉ๋‹ˆ๋‹ค. ์ด์— ๋Œ€ํ•œ ์š”๊ตฌ ์‚ฌํ•ญ์€ ๋ณธ์งˆ์ ์œผ๋กœ ๋‘ ๊ฐ€์ง€๋ฟ์ž…๋‹ˆ๋‹ค. ๋ฌดํ•œ ๋ฃจํ”„์—์„œ ๋ฐฐ์น˜๋ฅผ ๋ฐ˜ํ™˜ํ•˜๊ณ  ๋ฐ˜๋ณต ์‚ฌ์ด์— ์ƒํƒœ๋ฅผ ์ €์žฅํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค(R์˜ ํ›„์ž๋Š” ํด๋กœ์ €๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ€์žฅ ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์œผ๋กœ ๊ตฌํ˜„๋ฉ๋‹ˆ๋‹ค). ์ด์ „์—๋Š” ๋ฐ˜๋ณต์ž ๋‚ด์—์„œ R ๋ฐฐ์—ด์„ numpy ๋ฐฐ์—ด๋กœ ๋ช…์‹œ์ ์œผ๋กœ ๋ณ€ํ™˜ํ•ด์•ผ ํ–ˆ์ง€๋งŒ ํ˜„์žฌ ๋ฒ„์ „์˜ ํŒจํ‚ค์ง€๋Š” ์ผ€ ๋ผ์Šค ์Šค์Šค๋กœ ๊ทธ๋ ‡๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค.

ํ•™์Šต ๋ฐ ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ์˜ ๋ฐ˜๋ณต์ž๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

ํ›ˆ๋ จ ๋ฐ ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ์˜ ๋ฐ˜๋ณต์ž

train_generator <- function(db_connection = con,
                            samples_index,
                            num_classes = 340,
                            batch_size = 32,
                            scale = 1,
                            color = FALSE,
                            imagenet_preproc = FALSE) {
  # ะŸั€ะพะฒะตั€ะบะฐ ะฐั€ะณัƒะผะตะฝั‚ะพะฒ
  checkmate::assert_class(con, "DBIConnection")
  checkmate::assert_integerish(samples_index)
  checkmate::assert_count(num_classes)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # ะŸะตั€ะตะผะตัˆะธะฒะฐะตะผ, ั‡ั‚ะพะฑั‹ ะฑั€ะฐั‚ัŒ ะธ ัƒะดะฐะปัั‚ัŒ ะธัะฟะพะปัŒะทะพะฒะฐะฝะฝั‹ะต ะธะฝะดะตะบัั‹ ะฑะฐั‚ั‡ะตะน ะฟะพ ะฟะพั€ัะดะบัƒ
  dt <- data.table::data.table(id = sample(samples_index))
  # ะŸั€ะพัั‚ะฐะฒะปัะตะผ ะฝะพะผะตั€ะฐ ะฑะฐั‚ั‡ะตะน
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  # ะžัั‚ะฐะฒะปัะตะผ ั‚ะพะปัŒะบะพ ะฟะพะปะฝั‹ะต ะฑะฐั‚ั‡ะธ ะธ ะธะฝะดะตะบัะธั€ัƒะตะผ
  dt <- dt[, if (.N == batch_size) .SD, keyby = batch]
  # ะฃัั‚ะฐะฝะฐะฒะปะธะฒะฐะตะผ ัั‡ั‘ั‚ั‡ะธะบ
  i <- 1
  # ะšะพะปะธั‡ะตัั‚ะฒะพ ะฑะฐั‚ั‡ะตะน
  max_i <- dt[, max(batch)]

  # ะŸะพะดะณะพั‚ะพะฒะบะฐ ะฒั‹ั€ะฐะถะตะฝะธั ะดะปั ะฒั‹ะณั€ัƒะทะบะธ
  sql <- sprintf(
    "PREPARE SELECT drawing, label_int FROM doodles WHERE id IN (%s)",
    paste(rep("?", batch_size), collapse = ",")
  )
  res <- DBI::dbSendQuery(con, sql)

  # ะะฝะฐะปะพะณ keras::to_categorical
  to_categorical <- function(x, num) {
    n <- length(x)
    m <- numeric(n * num)
    m[x * n + seq_len(n)] <- 1
    dim(m) <- c(n, num)
    return(m)
  }

  # ะ—ะฐะผั‹ะบะฐะฝะธะต
  function() {
    # ะะฐั‡ะธะฝะฐะตะผ ะฝะพะฒัƒัŽ ัะฟะพั…ัƒ
    if (i > max_i) {
      dt[, id := sample(id)]
      data.table::setkey(dt, batch)
      # ะกะฑั€ะฐัั‹ะฒะฐะตะผ ัั‡ั‘ั‚ั‡ะธะบ
      i <<- 1
      max_i <<- dt[, max(batch)]
    }

    # ID ะดะปั ะฒั‹ะณั€ัƒะทะบะธ ะดะฐะฝะฝั‹ั…
    batch_ind <- dt[batch == i, id]
    # ะ’ั‹ะณั€ัƒะทะบะฐ ะดะฐะฝะฝั‹ั…
    batch <- DBI::dbFetch(DBI::dbBind(res, as.list(batch_ind)), n = -1)

    # ะฃะฒะตะปะธั‡ะธะฒะฐะตะผ ัั‡ั‘ั‚ั‡ะธะบ
    i <<- i + 1

    # ะŸะฐั€ัะธะฝะณ JSON ะธ ะฟะพะดะณะพั‚ะพะฒะบะฐ ะผะฐััะธะฒะฐ
    batch_x <- cpp_process_json_vector(batch$drawing, scale = scale, color = color)
    if (imagenet_preproc) {
      # ะจะบะฐะปะธั€ะพะฒะฐะฝะธะต c ะธะฝั‚ะตั€ะฒะฐะปะฐ [0, 1] ะฝะฐ ะธะฝั‚ะตั€ะฒะฐะป [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }

    batch_y <- to_categorical(batch$label_int, num_classes)
    result <- list(batch_x, batch_y)
    return(result)
  }
}

์ด ํ•จ์ˆ˜๋Š” ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์—ฐ๊ฒฐ, ์‚ฌ์šฉ๋œ ๋ผ์ธ ์ˆ˜, ํด๋ž˜์Šค ์ˆ˜, ๋ฐฐ์น˜ ํฌ๊ธฐ, ๊ทœ๋ชจ(scale = 1 256x256 ํ”ฝ์…€์˜ ๋ Œ๋”๋ง ์ด๋ฏธ์ง€์— ํ•ด๋‹นํ•˜๋ฉฐ, scale = 0.5 โ€” 128x128 ํ”ฝ์…€), ์ƒ‰์ƒ ํ‘œ์‹œ๊ธฐ(color = FALSE ์‚ฌ์šฉ๋  ๋•Œ ํšŒ์ƒ‰์กฐ๋กœ ๋ Œ๋”๋ง์„ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค. color = TRUE ๊ฐ ํš์€ ์ƒˆ๋กœ์šด ์ƒ‰์ƒ์œผ๋กœ ๊ทธ๋ ค์ง‘๋‹ˆ๋‹ค.) ๋ฐ imagenet์—์„œ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋„คํŠธ์›Œํฌ์— ๋Œ€ํ•œ ์ „์ฒ˜๋ฆฌ ํ‘œ์‹œ๊ธฐ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ํ›„์ž๋Š” ๊ฐ„๊ฒฉ [0, 1]์—์„œ ๊ฐ„๊ฒฉ [-1, 1]๊นŒ์ง€ ํ”ฝ์…€ ๊ฐ’์˜ ํฌ๊ธฐ๋ฅผ ์กฐ์ •ํ•˜๊ธฐ ์œ„ํ•ด ํ•„์š”ํ•˜๋ฉฐ, ์ด๋Š” ์ œ๊ณต๋œ ํ•™์Šต ์‹œ ์‚ฌ์šฉ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ผ€ ๋ผ์Šค ๋ชจ๋ธ.

์™ธ๋ถ€ ํ•จ์ˆ˜์—๋Š” ์ธ์ˆ˜ ์œ ํ˜• ๊ฒ€์‚ฌ, ํ…Œ์ด๋ธ”์ด ํฌํ•จ๋ฉ๋‹ˆ๋‹ค. data.table ๋ฌด์ž‘์œ„๋กœ ํ˜ผํ•ฉ๋œ ์ค„ ๋ฒˆํ˜ธ๊ฐ€ ์žˆ๋Š” samples_index ๋ฐฐ์น˜ ๋ฒˆํ˜ธ, ์นด์šดํ„ฐ ๋ฐ ์ตœ๋Œ€ ๋ฐฐ์น˜ ์ˆ˜๋Š” ๋ฌผ๋ก  ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์—์„œ ๋ฐ์ดํ„ฐ๋ฅผ ์–ธ๋กœ๋“œํ•˜๊ธฐ ์œ„ํ•œ SQL ํ‘œํ˜„์‹๋„ ํฌํ•จ๋ฉ๋‹ˆ๋‹ค. ์ถ”๊ฐ€์ ์œผ๋กœ ์šฐ๋ฆฌ๋Š” ๋‚ด๋ถ€์— ํ•จ์ˆ˜์˜ ๋น ๋ฅธ ์•„๋‚ ๋กœ๊ทธ๋ฅผ ์ •์˜ํ–ˆ์Šต๋‹ˆ๋‹ค. keras::to_categorical(). ํ›ˆ๋ จ์— ๊ฑฐ์˜ ๋ชจ๋“  ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ๊ฒ€์ฆ์„ ์œ„ํ•ด XNUMX%๋งŒ ๋‚จ๊ฒจ๋‘์—ˆ์œผ๋ฏ€๋กœ ์—ํฌํฌ ํฌ๊ธฐ๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜์— ์˜ํ•ด ์ œํ•œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. steps_per_epoch ๋ถ€๋ฅผ ๋•Œ keras::fit_generator(), ๋ฐ ์กฐ๊ฑด if (i > max_i) ์œ ํšจ์„ฑ ๊ฒ€์‚ฌ ๋ฐ˜๋ณต์ž์— ๋Œ€ํ•ด์„œ๋งŒ ์ž‘๋™ํ–ˆ์Šต๋‹ˆ๋‹ค.

๋‚ด๋ถ€ ํ•จ์ˆ˜์—์„œ๋Š” ๋‹ค์Œ ๋ฐฐ์น˜์— ๋Œ€ํ•œ ํ–‰ ์ธ๋ฑ์Šค๊ฐ€ ๊ฒ€์ƒ‰๋˜๊ณ , ๋ฐฐ์น˜ ์นด์šดํ„ฐ๊ฐ€ ์ฆ๊ฐ€ํ•˜๋ฉด์„œ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์—์„œ ๋ ˆ์ฝ”๋“œ๊ฐ€ ์–ธ๋กœ๋“œ๋˜๊ณ , JSON ๊ตฌ๋ฌธ ๋ถ„์„(ํ•จ์ˆ˜ cpp_process_json_vector(), C++๋กœ ์ž‘์„ฑ๋จ) ๋ฐ ๊ทธ๋ฆผ์— ํ•ด๋‹นํ•˜๋Š” ๋ฐฐ์—ด์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ ํด๋ž˜์Šค ๋ ˆ์ด๋ธ”์ด ์žˆ๋Š” ์›-ํ•ซ ๋ฒกํ„ฐ๊ฐ€ ์ƒ์„ฑ๋˜๊ณ  ํ”ฝ์…€ ๊ฐ’๊ณผ ๋ ˆ์ด๋ธ”์ด ์žˆ๋Š” ๋ฐฐ์—ด์ด ๋ฐ˜ํ™˜ ๊ฐ’์ธ ๋ชฉ๋ก์œผ๋กœ ๊ฒฐํ•ฉ๋ฉ๋‹ˆ๋‹ค. ์ž‘์—… ์†๋„๋ฅผ ๋†’์ด๊ธฐ ์œ„ํ•ด ํ…Œ์ด๋ธ”์— ์ธ๋ฑ์Šค ์ƒ์„ฑ์„ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค. data.table ๋งํฌ๋ฅผ ํ†ตํ•œ ์ˆ˜์ • - ์ด๋Ÿฌํ•œ ํŒจํ‚ค์ง€ "์นฉ" ์—†์ด ๋ฐ์ดํ„ฐ ํ…Œ์ด๋ธ” R์—์„œ ์ƒ๋‹นํ•œ ์–‘์˜ ๋ฐ์ดํ„ฐ๋ฅผ ํšจ๊ณผ์ ์œผ๋กœ ์ž‘์—…ํ•˜๋Š” ๊ฒƒ์€ ์ƒ์ƒํ•˜๊ธฐ ๋งค์šฐ ์–ด๋ ต์Šต๋‹ˆ๋‹ค.

Core i5 ๋…ธํŠธ๋ถ์˜ ์†๋„ ์ธก์ • ๊ฒฐ๊ณผ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

๋ฐ˜๋ณต์ž ๋ฒค์น˜๋งˆํฌ

library(Rcpp)
library(keras)
library(ggplot2)

source("utils/rcpp.R")
source("utils/keras_iterator.R")

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

ind <- seq_len(DBI::dbGetQuery(con, "SELECT count(*) FROM doodles")[[1L]])
num_classes <- DBI::dbGetQuery(con, "SELECT max(label_int) + 1 FROM doodles")[[1L]]

# ะ˜ะฝะดะตะบัั‹ ะดะปั ะพะฑัƒั‡ะฐัŽั‰ะตะน ะฒั‹ะฑะพั€ะบะธ
train_ind <- sample(ind, floor(length(ind) * 0.995))
# ะ˜ะฝะดะตะบัั‹ ะดะปั ะฟั€ะพะฒะตั€ะพั‡ะฝะพะน ะฒั‹ะฑะพั€ะบะธ
val_ind <- ind[-train_ind]
rm(ind)
# ะšะพัั„ั„ะธั†ะธะตะฝั‚ ะผะฐััˆั‚ะฐะฑะฐ
scale <- 0.5

# ะŸั€ะพะฒะตะดะตะฝะธะต ะทะฐะผะตั€ะฐ
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    it1 <- train_generator(
      db_connection = con,
      samples_index = train_ind,
      num_classes = num_classes,
      batch_size = batch_size,
      scale = scale
    )
    bench::mark(
      it1(),
      min_iterations = 50L
    )
  }
)
# ะŸะฐั€ะฐะผะตั‚ั€ั‹ ะฑะตะฝั‡ะผะฐั€ะบะฐ
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16     25ms  64.36ms   92.2ms     15.9       3.09s    49
# 2         32   48.4ms 118.13ms 197.24ms     8.17       5.88s    48
# 3         64   69.3ms 117.93ms 181.14ms     8.57       5.83s    50
# 4        128  157.2ms 240.74ms 503.87ms     3.85      12.71s    49
# 5        256  359.3ms 613.52ms 988.73ms     1.54       30.5s    47
# 6        512  884.7ms    1.53s    2.07s     0.674      1.11m    45
# 7       1024     2.7s    3.83s    5.47s     0.261      2.81m    44

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
    geom_point() +
    geom_line() +
    ylab("median time, s") +
    theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Quick Draw Doodle Recognition: R, C++ ๋ฐ ์‹ ๊ฒฝ๋ง๊ณผ ์นœ๊ตฌ๊ฐ€ ๋˜๋Š” ๋ฐฉ๋ฒ•

์ถฉ๋ถ„ํ•œ ์–‘์˜ RAM์ด ์žˆ๋Š” ๊ฒฝ์šฐ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค๋ฅผ ๋™์ผํ•œ RAM์œผ๋กœ ์ „์†กํ•˜์—ฌ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ์ž‘์—… ์†๋„๋ฅผ ํฌ๊ฒŒ ๋†’์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(์šฐ๋ฆฌ ์ž‘์—…์—๋Š” 32GB์ด๋ฉด ์ถฉ๋ถ„ํ•ฉ๋‹ˆ๋‹ค). Linux์—์„œ๋Š” ํŒŒํ‹ฐ์…˜์ด ๊ธฐ๋ณธ์ ์œผ๋กœ ๋งˆ์šดํŠธ๋ฉ๋‹ˆ๋‹ค. /dev/shm, RAM ์šฉ๋Ÿ‰์˜ ์ตœ๋Œ€ ์ ˆ๋ฐ˜์„ ์ฐจ์ง€ํ•ฉ๋‹ˆ๋‹ค. ํŽธ์ง‘์„ ํ†ตํ•ด ๋” ๋งŽ์€ ๊ฒƒ์„ ๊ฐ•์กฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. /etc/fstab๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ธฐ๋ก์„ ์–ป์œผ๋ ค๋ฉด tmpfs /dev/shm tmpfs defaults,size=25g 0 0. ๋ฐ˜๋“œ์‹œ ์žฌ๋ถ€ํŒ… ํ›„ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•˜์—ฌ ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•˜์„ธ์š”. df -h.

ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์˜ ๋ฐ˜๋ณต์ž๋Š” ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ์„ธํŠธ๊ฐ€ RAM์— ์™„์ „ํžˆ ๋“ค์–ด๋งž๊ธฐ ๋•Œ๋ฌธ์— ํ›จ์”ฌ ๋” ๋‹จ์ˆœํ•ด ๋ณด์ž…๋‹ˆ๋‹ค.

ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์˜ ๋ฐ˜๋ณต์ž

test_generator <- function(dt,
                           batch_size = 32,
                           scale = 1,
                           color = FALSE,
                           imagenet_preproc = FALSE) {

  # ะŸั€ะพะฒะตั€ะบะฐ ะฐั€ะณัƒะผะตะฝั‚ะพะฒ
  checkmate::assert_data_table(dt)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # ะŸั€ะพัั‚ะฐะฒะปัะตะผ ะฝะพะผะตั€ะฐ ะฑะฐั‚ั‡ะตะน
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  data.table::setkey(dt, batch)
  i <- 1
  max_i <- dt[, max(batch)]

  # ะ—ะฐะผั‹ะบะฐะฝะธะต
  function() {
    batch_x <- cpp_process_json_vector(dt[batch == i, drawing], 
                                       scale = scale, color = color)
    if (imagenet_preproc) {
      # ะจะบะฐะปะธั€ะพะฒะฐะฝะธะต c ะธะฝั‚ะตั€ะฒะฐะปะฐ [0, 1] ะฝะฐ ะธะฝั‚ะตั€ะฒะฐะป [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }
    result <- list(batch_x)
    i <<- i + 1
    return(result)
  }
}

4. ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜ ์„ ํƒ

์ฒ˜์Œ์œผ๋กœ ์‚ฌ์šฉ๋œ ์•„ํ‚คํ…์ฒ˜๋Š” ๋ชจ๋ฐ”์ผ๋„ท v1, ๊ทธ ๊ธฐ๋Šฅ์€ ๋‹ค์Œ์—์„œ ๋…ผ์˜๋ฉ๋‹ˆ๋‹ค. ์ด ๋ฉ”์‹œ์ง€. ๊ธฐ๋ณธ์œผ๋กœ ํฌํ•จ๋˜์–ด ์žˆ์–ด์š” ์ผ€ ๋ผ์Šค ๋”ฐ๋ผ์„œ R๊ณผ ๋™์ผํ•œ ์ด๋ฆ„์˜ ํŒจํ‚ค์ง€์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋‹จ์ผ ์ฑ„๋„ ์ด๋ฏธ์ง€์™€ ํ•จ๊ป˜ ์‚ฌ์šฉํ•˜๋ ค๊ณ  ํ•˜๋ฉด ์ด์ƒํ•œ ๊ฒฐ๊ณผ๊ฐ€ ๋‚˜์™”์Šต๋‹ˆ๋‹ค. ์ž…๋ ฅ ํ…์„œ๋Š” ํ•ญ์ƒ ๋‹ค์Œ ์ฐจ์›์„ ๊ฐ€์ ธ์•ผ ํ•ฉ๋‹ˆ๋‹ค. (batch, height, width, 3)์ฆ‰, ์ฑ„๋„ ์ˆ˜๋ฅผ ๋ณ€๊ฒฝํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. Python์—๋Š” ๊ทธ๋Ÿฌํ•œ ์ œํ•œ์ด ์—†์œผ๋ฏ€๋กœ ์šฐ๋ฆฌ๋Š” ์›๋ณธ ๊ธฐ์‚ฌ์— ๋”ฐ๋ผ ์ด ์•„ํ‚คํ…์ฒ˜์˜ ์ž์ฒด ๊ตฌํ˜„์„ ์„œ๋‘˜๋Ÿฌ ์ž‘์„ฑํ–ˆ์Šต๋‹ˆ๋‹ค(keras ๋ฒ„์ „์— ์žˆ๋Š” ๋“œ๋กญ์•„์›ƒ ์—†์Œ).

๋ชจ๋ฐ”์ผ๋„ท v1 ์•„ํ‚คํ…์ฒ˜

library(keras)

top_3_categorical_accuracy <- custom_metric(
    name = "top_3_categorical_accuracy",
    metric_fn = function(y_true, y_pred) {
         metric_top_k_categorical_accuracy(y_true, y_pred, k = 3)
    }
)

layer_sep_conv_bn <- function(object, 
                              filters,
                              alpha = 1,
                              depth_multiplier = 1,
                              strides = c(2, 2)) {

  # NB! depth_multiplier !=  resolution multiplier
  # https://github.com/keras-team/keras/issues/10349

  layer_depthwise_conv_2d(
    object = object,
    kernel_size = c(3, 3), 
    strides = strides,
    padding = "same",
    depth_multiplier = depth_multiplier
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() %>%
  layer_conv_2d(
    filters = filters * alpha,
    kernel_size = c(1, 1), 
    strides = c(1, 1)
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() 
}

get_mobilenet_v1 <- function(input_shape = c(224, 224, 1),
                             num_classes = 340,
                             alpha = 1,
                             depth_multiplier = 1,
                             optimizer = optimizer_adam(lr = 0.002),
                             loss = "categorical_crossentropy",
                             metrics = c("categorical_crossentropy",
                                         top_3_categorical_accuracy)) {

  inputs <- layer_input(shape = input_shape)

  outputs <- inputs %>%
    layer_conv_2d(filters = 32, kernel_size = c(3, 3), strides = c(2, 2), padding = "same") %>%
    layer_batch_normalization() %>% 
    layer_activation_relu() %>%
    layer_sep_conv_bn(filters = 64, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(1, 1)) %>%
    layer_global_average_pooling_2d() %>%
    layer_dense(units = num_classes) %>%
    layer_activation_softmax()

    model <- keras_model(
      inputs = inputs,
      outputs = outputs
    )

    model %>% compile(
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )

    return(model)
}

์ด ์ ‘๊ทผ๋ฒ•์˜ ๋‹จ์ ์€ ๋ช…๋ฐฑํ•ฉ๋‹ˆ๋‹ค. ๋งŽ์€ ๋ชจ๋ธ์„ ํ…Œ์ŠคํŠธํ•˜๊ณ  ์‹ถ์ง€๋งŒ ๋ฐ˜๋Œ€๋กœ ๊ฐ ์•„ํ‚คํ…์ฒ˜๋ฅผ ์ˆ˜๋™์œผ๋กœ ๋‹ค์‹œ ์ž‘์„ฑํ•˜๊ณ  ์‹ถ์ง€๋Š” ์•Š์Šต๋‹ˆ๋‹ค. ๋˜ํ•œ imagenet์—์„œ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ์‚ฌ์šฉํ•  ๊ธฐํšŒ๋„ ๋ฐ•ํƒˆ๋‹นํ–ˆ์Šต๋‹ˆ๋‹ค. ๋Š˜ ๊ทธ๋ ‡๋“ฏ์ด ๋ฌธ์„œ๋ฅผ ๊ณต๋ถ€ํ•˜๋Š” ๊ฒƒ์ด ๋„์›€์ด ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๊ธฐ๋Šฅ get_config() ํŽธ์ง‘์— ์ ํ•ฉํ•œ ํ˜•์‹์œผ๋กœ ๋ชจ๋ธ์— ๋Œ€ํ•œ ์„ค๋ช…์„ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(base_model_conf$layers - ์ผ๋ฐ˜ R ๋ชฉ๋ก) ๋ฐ ๊ธฐ๋Šฅ from_config() ๋ชจ๋ธ ๊ฐ์ฒด๋กœ ์—ญ๋ณ€ํ™˜์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

base_model_conf <- get_config(base_model)
base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
base_model <- from_config(base_model_conf)

์ด์ œ ์ œ๊ณต๋œ ๊ฐ’ ์ค‘ ํ•˜๋‚˜๋ฅผ ์–ป๊ธฐ ์œ„ํ•œ ๋ฒ”์šฉ ํ•จ์ˆ˜๋ฅผ ์ž‘์„ฑํ•˜๋Š” ๊ฒƒ์ด ์–ด๋ ต์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์ผ€ ๋ผ์Šค imagenet์—์„œ ํ›ˆ๋ จ๋œ ๊ฐ€์ค‘์น˜๊ฐ€ ์žˆ๊ฑฐ๋‚˜ ์—†๋Š” ๋ชจ๋ธ:

๊ธฐ์„ฑ ์•„ํ‚คํ…์ฒ˜ ๋กœ๋”ฉ ๊ธฐ๋Šฅ

get_model <- function(name = "mobilenet_v2",
                      input_shape = NULL,
                      weights = "imagenet",
                      pooling = "avg",
                      num_classes = NULL,
                      optimizer = keras::optimizer_adam(lr = 0.002),
                      loss = "categorical_crossentropy",
                      metrics = NULL,
                      color = TRUE,
                      compile = FALSE) {
  # ะŸั€ะพะฒะตั€ะบะฐ ะฐั€ะณัƒะผะตะฝั‚ะพะฒ
  checkmate::assert_string(name)
  checkmate::assert_integerish(input_shape, lower = 1, upper = 256, len = 3)
  checkmate::assert_count(num_classes)
  checkmate::assert_flag(color)
  checkmate::assert_flag(compile)

  # ะŸะพะปัƒั‡ะฐะตะผ ะพะฑัŠะตะบั‚ ะธะท ะฟะฐะบะตั‚ะฐ keras
  model_fun <- get0(paste0("application_", name), envir = asNamespace("keras"))
  # ะŸั€ะพะฒะตั€ะบะฐ ะฝะฐะปะธั‡ะธั ะพะฑัŠะตะบั‚ะฐ ะฒ ะฟะฐะบะตั‚ะต
  if (is.null(model_fun)) {
    stop("Model ", shQuote(name), " not found.", call. = FALSE)
  }

  base_model <- model_fun(
    input_shape = input_shape,
    include_top = FALSE,
    weights = weights,
    pooling = pooling
  )

  # ะ•ัะปะธ ะธะทะพะฑั€ะฐะถะตะฝะธะต ะฝะต ั†ะฒะตั‚ะฝะพะต, ะผะตะฝัะตะผ ั€ะฐะทะผะตั€ะฝะพัั‚ัŒ ะฒั…ะพะดะฐ
  if (!color) {
    base_model_conf <- keras::get_config(base_model)
    base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
    base_model <- keras::from_config(base_model_conf)
  }

  predictions <- keras::get_layer(base_model, "global_average_pooling2d_1")$output
  predictions <- keras::layer_dense(predictions, units = num_classes, activation = "softmax")
  model <- keras::keras_model(
    inputs = base_model$input,
    outputs = predictions
  )

  if (compile) {
    keras::compile(
      object = model,
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )
  }

  return(model)
}

๋‹จ์ผ ์ฑ„๋„ ์ด๋ฏธ์ง€๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๊ฐ€์ค‘์น˜๊ฐ€ ์‚ฌ์šฉ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์ด ๋ฌธ์ œ๋Š” ํ•ด๊ฒฐ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค: ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ get_weights() R ๋ฐฐ์—ด ๋ชฉ๋ก ํ˜•์‹์œผ๋กœ ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๋ฅผ ๊ฐ€์ ธ์˜ค๊ณ  ์ด ๋ชฉ๋ก์˜ ์ฒซ ๋ฒˆ์งธ ์š”์†Œ์˜ ์ฐจ์›์„ ๋ณ€๊ฒฝํ•ฉ๋‹ˆ๋‹ค(ํ•˜๋‚˜์˜ ์ƒ‰์ƒ ์ฑ„๋„์„ ์‚ฌ์šฉํ•˜๊ฑฐ๋‚˜ ์„ธ ๊ฐ€์ง€ ๋ชจ๋‘์˜ ํ‰๊ท ์„ ๊ตฌํ•จ). ๊ทธ๋Ÿฐ ๋‹ค์Œ ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ€์ค‘์น˜๋ฅผ ๋ชจ๋ธ์— ๋‹ค์‹œ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค. set_weights(). ์šฐ๋ฆฌ๋Š” ์ด ๊ธฐ๋Šฅ์„ ์ถ”๊ฐ€ํ•˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ์™œ๋ƒํ•˜๋ฉด ์ด ๋‹จ๊ณ„์—์„œ ์ปฌ๋Ÿฌ ์‚ฌ์ง„์œผ๋กœ ์ž‘์—…ํ•˜๋Š” ๊ฒƒ์ด ๋” ์ƒ์‚ฐ์ ์ด๋ผ๋Š” ๊ฒƒ์ด ์ด๋ฏธ ๋ถ„๋ช…ํ–ˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.

์šฐ๋ฆฌ๋Š” resnet1๋ฟ๋งŒ ์•„๋‹ˆ๋ผ mobilenet ๋ฒ„์ „ 2๊ณผ 34๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋Œ€๋ถ€๋ถ„์˜ ์‹คํ—˜์„ ์ˆ˜ํ–‰ํ–ˆ์Šต๋‹ˆ๋‹ค. SE-ResNeXt์™€ ๊ฐ™์€ ๋ณด๋‹ค ํ˜„๋Œ€์ ์ธ ์•„ํ‚คํ…์ฒ˜๊ฐ€ ์ด๋ฒˆ ๋Œ€ํšŒ์—์„œ ์ข‹์€ ์„ฑ์ ์„ ๊ฑฐ๋‘์—ˆ์Šต๋‹ˆ๋‹ค. ๋ถˆํ–‰ํ•˜๊ฒŒ๋„ ์šฐ๋ฆฌ๋Š” ๊ธฐ์„ฑ ๊ตฌํ˜„์„ ๋งˆ์Œ๋Œ€๋กœ ํ•  ์ˆ˜ ์—†์—ˆ๊ณ  ์ง์ ‘ ์ž‘์„ฑํ•˜์ง€๋„ ์•Š์•˜์Šต๋‹ˆ๋‹ค(๊ทธ๋Ÿฌ๋‚˜ ๋ฐ˜๋“œ์‹œ ์ž‘์„ฑํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค).

5. ์Šคํฌ๋ฆฝํŠธ ๋งค๊ฐœ๋ณ€์ˆ˜ํ™”

ํŽธ์˜์ƒ ํ›ˆ๋ จ ์‹œ์ž‘์„ ์œ„ํ•œ ๋ชจ๋“  ์ฝ”๋“œ๋Š” ๋‹จ์ผ ์Šคํฌ๋ฆฝํŠธ๋กœ ์„ค๊ณ„๋˜์—ˆ์œผ๋ฉฐ ๋‹ค์Œ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋งค๊ฐœ๋ณ€์ˆ˜ํ™”๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๋„์ฝฅ ๋‹ค์Œ๊ณผ ๊ฐ™์ด

doc <- '
Usage:
  train_nn.R --help
  train_nn.R --list-models
  train_nn.R [options]

Options:
  -h --help                   Show this message.
  -l --list-models            List available models.
  -m --model=<model>          Neural network model name [default: mobilenet_v2].
  -b --batch-size=<size>      Batch size [default: 32].
  -s --scale-factor=<ratio>   Scale factor [default: 0.5].
  -c --color                  Use color lines [default: FALSE].
  -d --db-dir=<path>          Path to database directory [default: Sys.getenv("db_dir")].
  -r --validate-ratio=<ratio> Validate sample ratio [default: 0.995].
  -n --n-gpu=<number>         Number of GPUs [default: 1].
'
args <- docopt::docopt(doc)

ํŒจํ‚ค์ง€ ๋„์ฝฅ ๊ตฌํ˜„์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค http://docopt.org/ R์˜ ๊ฒฝ์šฐ ๋„์›€์„ ๋ฐ›์•„ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ฐ„๋‹จํ•œ ๋ช…๋ น์œผ๋กœ ์Šคํฌ๋ฆฝํŠธ๊ฐ€ ์‹œ์ž‘๋ฉ๋‹ˆ๋‹ค. Rscript bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db ๋˜๋Š” ./bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db, ํŒŒ์ผ์ธ ๊ฒฝ์šฐ train_nn.R ์‹คํ–‰ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค(์ด ๋ช…๋ น์€ ๋ชจ๋ธ ํ•™์Šต์„ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค) resnet50 128x128 ํ”ฝ์…€ ํฌ๊ธฐ์˜ XNUMX์ƒ‰ ์ด๋ฏธ์ง€์˜ ๊ฒฝ์šฐ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค๊ฐ€ ํด๋”์— ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. /home/andrey/doodle_db). ํ•™์Šต ์†๋„, ์ตœ์ ํ™” ์œ ํ˜• ๋ฐ ๊ธฐํƒ€ ์‚ฌ์šฉ์ž ์ •์˜ ๊ฐ€๋Šฅํ•œ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๋ชฉ๋ก์— ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ถœํŒ๋ฌผ์„ ์ค€๋น„ํ•˜๋Š” ๊ณผ์ •์—์„œ ์•„ํ‚คํ…์ฒ˜๊ฐ€ mobilenet_v2 ํ˜„์žฌ ๋ฒ„์ „๋ถ€ํ„ฐ ์ผ€ ๋ผ์Šค R ์‚ฌ์šฉ์—์„œ ์ด๋Ÿฌ์ง€๋งˆ R ํŒจํ‚ค์ง€์—์„œ ๊ณ ๋ ค๋˜์ง€ ์•Š์€ ๋ณ€๊ฒฝ ์‚ฌํ•ญ์œผ๋กœ ์ธํ•ด ์ˆ˜์ •์„ ๊ธฐ๋‹ค๋ฆฌ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

์ด ์ ‘๊ทผ ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•˜๋ฉด RStudio์—์„œ ๋ณด๋‹ค ์ „ํ†ต์ ์ธ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹œ์ž‘ํ•˜๋Š” ๊ฒƒ๊ณผ ๋น„๊ตํ•˜์—ฌ ๋‹ค์–‘ํ•œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ์‹คํ—˜ ์†๋„๋ฅผ ํฌ๊ฒŒ ๋†’์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(๊ฐ€๋Šฅํ•œ ๋Œ€์•ˆ์œผ๋กœ ํŒจํ‚ค์ง€๋ฅผ ์–ธ๊ธ‰ํ•ฉ๋‹ˆ๋‹ค). ํŠธํ”„๋Ÿฐ). ๊ทธ๋Ÿฌ๋‚˜ ๊ฐ€์žฅ ํฐ ์žฅ์ ์€ ์ด๋ฅผ ์œ„ํ•ด RStudio๋ฅผ ์„ค์น˜ํ•˜์ง€ ์•Š๊ณ ๋„ Docker ๋˜๋Š” ๋‹จ์ˆœํžˆ ์„œ๋ฒ„์—์„œ ์Šคํฌ๋ฆฝํŠธ ์‹คํ–‰์„ ์‰ฝ๊ฒŒ ๊ด€๋ฆฌํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

6. ์Šคํฌ๋ฆฝํŠธ ๋„์ปคํ™”

์šฐ๋ฆฌ๋Š” Docker๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํŒ€ ๊ตฌ์„ฑ์› ๊ฐ„ ๋ชจ๋ธ ๊ต์œก๊ณผ ํด๋ผ์šฐ๋“œ์—์„œ์˜ ์‹ ์†ํ•œ ๋ฐฐํฌ๋ฅผ ์œ„ํ•œ ํ™˜๊ฒฝ ์ด์‹์„ฑ์„ ๋ณด์žฅํ–ˆ์Šต๋‹ˆ๋‹ค. R ํ”„๋กœ๊ทธ๋ž˜๋จธ์—๊ฒŒ๋Š” ์ƒ๋Œ€์ ์œผ๋กœ ํŠน์ดํ•œ ์ด ๋„๊ตฌ์— ์ต์ˆ™ํ•ด์ง€๊ธฐ ์‹œ์ž‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ์ผ๋ จ์˜ ์ถœํŒ๋ฌผ์ด๋‚˜ ๋น„๋””์˜ค ์ฝ”์Šค.

Docker๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์ฒ˜์Œ๋ถ€ํ„ฐ ์ž์‹ ๋งŒ์˜ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๊ณ  ๋‹ค๋ฅธ ์ด๋ฏธ์ง€๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์ž์‹ ๋งŒ์˜ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์˜ต์…˜์„ ๋ถ„์„ํ•œ ๊ฒฐ๊ณผ NVIDIA, CUDA+cuDNN ๋“œ๋ผ์ด๋ฒ„ ๋ฐ Python ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์„ค์น˜ํ•˜๋Š” ๊ฒƒ์ด ์ด๋ฏธ์ง€์—์„œ ์ƒ๋‹นํžˆ ๋งŽ์€ ๋ถ€๋ถ„์„ ์ฐจ์ง€ํ•œ๋‹ค๋Š” ๊ฒฐ๋ก ์— ๋„๋‹ฌํ–ˆ์œผ๋ฉฐ ๊ณต์‹ ์ด๋ฏธ์ง€๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์‚ฌ์šฉํ•˜๊ธฐ๋กœ ๊ฒฐ์ •ํ–ˆ์Šต๋‹ˆ๋‹ค. tensorflow/tensorflow:1.12.0-gpu, ๊ฑฐ๊ธฐ์— ํ•„์š”ํ•œ R ํŒจํ‚ค์ง€๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.

์ตœ์ข… ๋„์ปค ํŒŒ์ผ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

๋„์ปค ํŒŒ์ผ

FROM tensorflow/tensorflow:1.12.0-gpu

MAINTAINER Artem Klevtsov <[email protected]>

SHELL ["/bin/bash", "-c"]

ARG LOCALE="en_US.UTF-8"
ARG APT_PKG="libopencv-dev r-base r-base-dev littler"
ARG R_BIN_PKG="futile.logger checkmate data.table rcpp rapidjsonr dbi keras jsonlite curl digest remotes"
ARG R_SRC_PKG="xtensor RcppThread docopt MonetDBLite"
ARG PY_PIP_PKG="keras"
ARG DIRS="/db /app /app/data /app/models /app/logs"

RUN source /etc/os-release && 
    echo "deb https://cloud.r-project.org/bin/linux/ubuntu ${UBUNTU_CODENAME}-cran35/" > /etc/apt/sources.list.d/cran35.list && 
    apt-key adv --keyserver keyserver.ubuntu.com --recv-keys E084DAB9 && 
    add-apt-repository -y ppa:marutter/c2d4u3.5 && 
    add-apt-repository -y ppa:timsc/opencv-3.4 && 
    apt-get update && 
    apt-get install -y locales && 
    locale-gen ${LOCALE} && 
    apt-get install -y --no-install-recommends ${APT_PKG} && 
    ln -s /usr/lib/R/site-library/littler/examples/install.r /usr/local/bin/install.r && 
    ln -s /usr/lib/R/site-library/littler/examples/install2.r /usr/local/bin/install2.r && 
    ln -s /usr/lib/R/site-library/littler/examples/installGithub.r /usr/local/bin/installGithub.r && 
    echo 'options(Ncpus = parallel::detectCores())' >> /etc/R/Rprofile.site && 
    echo 'options(repos = c(CRAN = "https://cloud.r-project.org"))' >> /etc/R/Rprofile.site && 
    apt-get install -y $(printf "r-cran-%s " ${R_BIN_PKG}) && 
    install.r ${R_SRC_PKG} && 
    pip install ${PY_PIP_PKG} && 
    mkdir -p ${DIRS} && 
    chmod 777 ${DIRS} && 
    rm -rf /tmp/downloaded_packages/ /tmp/*.rds && 
    rm -rf /var/lib/apt/lists/*

COPY utils /app/utils
COPY src /app/src
COPY tests /app/tests
COPY bin/*.R /app/

ENV DBDIR="/db"
ENV CUDA_HOME="/usr/local/cuda"
ENV PATH="/app:${PATH}"

WORKDIR /app

VOLUME /db
VOLUME /app

CMD bash

ํŽธ์˜์ƒ ์‚ฌ์šฉ๋œ ํŒจํ‚ค์ง€๋ฅผ ๋ณ€์ˆ˜์— ๋„ฃ์—ˆ์Šต๋‹ˆ๋‹ค. ์ž‘์„ฑ๋œ ์Šคํฌ๋ฆฝํŠธ์˜ ๋Œ€๋ถ€๋ถ„์€ ์กฐ๋ฆฝ ์ค‘์— ์ปจํ…Œ์ด๋„ˆ ๋‚ด๋ถ€์— ๋ณต์‚ฌ๋ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ ๋ช…๋ น ์…ธ์„ ๋‹ค์Œ์œผ๋กœ ๋ณ€๊ฒฝํ–ˆ์Šต๋‹ˆ๋‹ค. /bin/bash ์ฝ˜ํ…์ธ  ์ด์šฉ์˜ ํŽธ์˜์„ฑ์„ ์œ„ํ•ด /etc/os-release. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ์ฝ”๋“œ์—์„œ OS ๋ฒ„์ „์„ ์ง€์ •ํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.

๋˜ํ•œ ๋‹ค์–‘ํ•œ ๋ช…๋ น์œผ๋กœ ์ปจํ…Œ์ด๋„ˆ๋ฅผ ์‹œ์ž‘ํ•  ์ˆ˜ ์žˆ๋Š” ์ž‘์€ bash ์Šคํฌ๋ฆฝํŠธ๊ฐ€ ์ž‘์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด ์ด์ „์— ์ปจํ…Œ์ด๋„ˆ ๋‚ด๋ถ€์— ๋ฐฐ์น˜๋œ ์‹ ๊ฒฝ๋ง์„ ํ›ˆ๋ จํ•˜๊ธฐ ์œ„ํ•œ ์Šคํฌ๋ฆฝํŠธ์ด๊ฑฐ๋‚˜ ์ปจํ…Œ์ด๋„ˆ ์ž‘์—…์„ ๋””๋ฒ„๊น…ํ•˜๊ณ  ๋ชจ๋‹ˆํ„ฐ๋งํ•˜๊ธฐ ์œ„ํ•œ ๋ช…๋ น ์…ธ์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ปจํ…Œ์ด๋„ˆ๋ฅผ ์‹œ์ž‘ํ•˜๋Š” ์Šคํฌ๋ฆฝํŠธ

#!/bin/sh

DBDIR=${PWD}/db
LOGSDIR=${PWD}/logs
MODELDIR=${PWD}/models
DATADIR=${PWD}/data
ARGS="--runtime=nvidia --rm -v ${DBDIR}:/db -v ${LOGSDIR}:/app/logs -v ${MODELDIR}:/app/models -v ${DATADIR}:/app/data"

if [ -z "$1" ]; then
    CMD="Rscript /app/train_nn.R"
elif [ "$1" = "bash" ]; then
    ARGS="${ARGS} -ti"
else
    CMD="Rscript /app/train_nn.R $@"
fi

docker run ${ARGS} doodles-tf ${CMD}

์ด bash ์Šคํฌ๋ฆฝํŠธ๋ฅผ ๋งค๊ฐœ๋ณ€์ˆ˜ ์—†์ด ์‹คํ–‰ํ•˜๋ฉด ์Šคํฌ๋ฆฝํŠธ๊ฐ€ ์ปจํ…Œ์ด๋„ˆ ๋‚ด๋ถ€์—์„œ ํ˜ธ์ถœ๋ฉ๋‹ˆ๋‹ค. train_nn.R ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ; ์ฒซ ๋ฒˆ์งธ ์œ„์น˜ ์ธ์ˆ˜๊ฐ€ "bash"์ด๋ฉด ์ปจํ…Œ์ด๋„ˆ๋Š” ๋ช…๋ น ์…ธ๊ณผ ๋Œ€ํ™”ํ˜•์œผ๋กœ ์‹œ์ž‘๋ฉ๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ๋ชจ๋“  ๊ฒฝ์šฐ์—๋Š” ์œ„์น˜ ์ธ์ˆ˜์˜ ๊ฐ’์ด ๋Œ€์ฒด๋ฉ๋‹ˆ๋‹ค. CMD="Rscript /app/train_nn.R $@".

์†Œ์Šค ๋ฐ์ดํ„ฐ์™€ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค๊ฐ€ ์žˆ๋Š” ๋””๋ ‰ํ„ฐ๋ฆฌ์™€ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ ์ €์žฅํ•˜๊ธฐ ์œ„ํ•œ ๋””๋ ‰ํ„ฐ๋ฆฌ๊ฐ€ ํ˜ธ์ŠคํŠธ ์‹œ์Šคํ…œ์˜ ์ปจํ…Œ์ด๋„ˆ ๋‚ด๋ถ€์— ๋งˆ์šดํŠธ๋˜์–ด ๋ถˆํ•„์š”ํ•œ ์กฐ์ž‘ ์—†์ด ์Šคํฌ๋ฆฝํŠธ ๊ฒฐ๊ณผ์— ์•ก์„ธ์Šคํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ์ ์€ ์ฃผ๋ชฉํ•  ๊ฐ€์น˜๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.

7. Google Cloud์—์„œ ์—ฌ๋Ÿฌ GPU ์‚ฌ์šฉ

๋Œ€ํšŒ์˜ ํŠน์ง• ์ค‘ ํ•˜๋‚˜๋Š” ๋งค์šฐ ์‹œ๋„๋Ÿฌ์šด ๋ฐ์ดํ„ฐ์˜€์Šต๋‹ˆ๋‹ค(ODS Slack์˜ @Leigh.plt์—์„œ ๊ฐ€์ ธ์˜จ ์ œ๋ชฉ ์‚ฌ์ง„ ์ฐธ์กฐ). ๋Œ€๊ทœ๋ชจ ๋ฐฐ์น˜๋Š” ์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” ๋ฐ ๋„์›€์ด ๋˜๋ฉฐ, 1๊ฐœ์˜ GPU๊ฐ€ ์žˆ๋Š” PC์—์„œ ์‹คํ—˜ํ•œ ํ›„ ํด๋ผ์šฐ๋“œ์˜ ์—ฌ๋Ÿฌ GPU์—์„œ ํ•™์Šต ๋ชจ๋ธ์„ ๋งˆ์Šคํ„ฐํ•˜๊ธฐ๋กœ ๊ฒฐ์ •ํ–ˆ์Šต๋‹ˆ๋‹ค. ์ค‘๊ณ  GoogleCloud(๊ธฐ๋ณธ์— ๋Œ€ํ•œ ์ข‹์€ ๊ฐ€์ด๋“œ) ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๊ตฌ์„ฑ์ด ๋‹ค์–‘ํ•˜๊ณ  ํ•ฉ๋ฆฌ์ ์ธ ๊ฐ€๊ฒฉ๊ณผ $300์˜ ๋ณด๋„ˆ์Šค๊ฐ€ ์ œ๊ณต๋ฉ๋‹ˆ๋‹ค. ์š•์‹ฌ ๋•Œ๋ฌธ์— SSD์™€ ์—„์ฒญ๋‚œ ์–‘์˜ RAM์ด ํฌํ•จ๋œ 4xV100 ์ธ์Šคํ„ด์Šค๋ฅผ ์ฃผ๋ฌธํ–ˆ๋Š”๋ฐ, ๊ทธ๊ฑด ํฐ ์‹ค์ˆ˜์˜€์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๊ธฐ๊ณ„๋Š” ๋ˆ์„ ๋นจ๋ฆฌ ์†Œ๋ชจํ•˜๋ฏ€๋กœ ์ž…์ฆ๋œ ํŒŒ์ดํ”„๋ผ์ธ ์—†์ด ์‹คํ—˜ํ•˜๋ฉด ํŒŒ์‚ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ต์œก ๋ชฉ์ ์œผ๋กœ๋Š” K80์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋Œ€์šฉ๋Ÿ‰ RAM์ด ์œ ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค. ํด๋ผ์šฐ๋“œ SSD๋Š” ์„ฑ๋Šฅ์ด ์ธ์ƒ์ ์ด์ง€ ์•Š์•˜๊ธฐ ๋•Œ๋ฌธ์— ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค๊ฐ€ ๋‹ค์Œ์œผ๋กœ ์ „์†ก๋˜์—ˆ์Šต๋‹ˆ๋‹ค. dev/shm.

๊ฐ€์žฅ ํฅ๋ฏธ๋กœ์šด ๊ฒƒ์€ ๋‹ค์ค‘ GPU ์‚ฌ์šฉ์„ ๋‹ด๋‹นํ•˜๋Š” ์ฝ”๋“œ ์กฐ๊ฐ์ž…๋‹ˆ๋‹ค. ๋จผ์ € ๋ชจ๋ธ์€ Python์—์„œ์™€ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ์ปจํ…์ŠคํŠธ ๊ด€๋ฆฌ์ž๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ CPU์—์„œ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.

with(tensorflow::tf$device("/cpu:0"), {
  model_cpu <- get_model(
    name = model_name,
    input_shape = input_shape,
    weights = weights,
    metrics =(top_3_categorical_accuracy,
    compile = FALSE
  )
})

๊ทธ๋Ÿฐ ๋‹ค์Œ ์ปดํŒŒ์ผ๋˜์ง€ ์•Š์€(์ค‘์š”) ๋ชจ๋ธ์ด ์ฃผ์–ด์ง„ ์ˆ˜์˜ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU์— ๋ณต์‚ฌ๋˜๊ณ  ๊ทธ ํ›„์—๋งŒ ์ปดํŒŒ์ผ๋ฉ๋‹ˆ๋‹ค.

model <- keras::multi_gpu_model(model_cpu, gpus = n_gpu)
keras::compile(
  object = model,
  optimizer = keras::optimizer_adam(lr = 0.0004),
  loss = "categorical_crossentropy",
  metrics = c(top_3_categorical_accuracy)
)

๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด๋ฅผ ์ œ์™ธํ•œ ๋ชจ๋“  ๋ ˆ์ด์–ด๋ฅผ ๋™๊ฒฐํ•˜๊ณ , ๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด๋ฅผ ํ›ˆ๋ จํ•˜๊ณ , ์—ฌ๋Ÿฌ GPU์— ๋Œ€ํ•ด ์ „์ฒด ๋ชจ๋ธ์„ ๋™๊ฒฐ ํ•ด์ œํ•˜๊ณ  ๋‹ค์‹œ ํ›ˆ๋ จํ•˜๋Š” ๊ณ ์ „์ ์ธ ๊ธฐ์ˆ ์€ ๊ตฌํ˜„ํ•  ์ˆ˜ ์—†์—ˆ์Šต๋‹ˆ๋‹ค.

ํ›ˆ๋ จ์€ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  ๋ชจ๋‹ˆํ„ฐ๋ง๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ํ…์„œ ๋ณด๋“œ, ๋กœ๊ทธ๋ฅผ ๊ธฐ๋กํ•˜๊ณ  ๊ฐ ์‹œ๋Œ€ ์ดํ›„์— ์œ ์ตํ•œ ์ด๋ฆ„์œผ๋กœ ๋ชจ๋ธ์„ ์ €์žฅํ•˜๋Š” ๊ฒƒ์œผ๋กœ ์ œํ•œํ•ฉ๋‹ˆ๋‹ค.

์ฝœ๋ฐฑ

# ะจะฐะฑะปะพะฝ ะธะผะตะฝะธ ั„ะฐะนะปะฐ ะปะพะณะฐ
log_file_tmpl <- file.path("logs", sprintf(
  "%s_%d_%dch_%s.csv",
  model_name,
  dim_size,
  channels,
  format(Sys.time(), "%Y%m%d%H%M%OS")
))
# ะจะฐะฑะปะพะฝ ะธะผะตะฝะธ ั„ะฐะนะปะฐ ะผะพะดะตะปะธ
model_file_tmpl <- file.path("models", sprintf(
  "%s_%d_%dch_{epoch:02d}_{val_loss:.2f}.h5",
  model_name,
  dim_size,
  channels
))

callbacks_list <- list(
  keras::callback_csv_logger(
    filename = log_file_tmpl
  ),
  keras::callback_early_stopping(
    monitor = "val_loss",
    min_delta = 1e-4,
    patience = 8,
    verbose = 1,
    mode = "min"
  ),
  keras::callback_reduce_lr_on_plateau(
    monitor = "val_loss",
    factor = 0.5, # ัƒะผะตะฝัŒัˆะฐะตะผ lr ะฒ 2 ั€ะฐะทะฐ
    patience = 4,
    verbose = 1,
    min_delta = 1e-4,
    mode = "min"
  ),
  keras::callback_model_checkpoint(
    filepath = model_file_tmpl,
    monitor = "val_loss",
    save_best_only = FALSE,
    save_weights_only = FALSE,
    mode = "min"
  )
)

8. ๊ฒฐ๋ก  ๋Œ€์‹ 

์šฐ๋ฆฌ๊ฐ€ ์ง๋ฉดํ•œ ์—ฌ๋Ÿฌ ๊ฐ€์ง€ ๋ฌธ์ œ๋Š” ์•„์ง ํ•ด๊ฒฐ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.

  • ะฒ ์ผ€ ๋ผ์Šค ์ตœ์ ์˜ ํ•™์Šต๋ฅ ์„ ์ž๋™์œผ๋กœ ๊ฒ€์ƒ‰ํ•˜๋Š” ๊ธฐ์„ฑ ๊ธฐ๋Šฅ์ด ์—†์Šต๋‹ˆ๋‹ค(์•„๋‚ ๋กœ๊ทธ lr_finder ๋„์„œ๊ด€์—์„œ fast.ai); ์•ฝ๊ฐ„์˜ ๋…ธ๋ ฅ์„ ๊ธฐ์šธ์ด๋ฉด ์ œXNUMX์ž ๊ตฌํ˜„์„ R๋กœ ์ด์‹ํ•˜๋Š” ๊ฒƒ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค. ์ด;
  • ์ด์ „ ์š”์ ์˜ ๊ฒฐ๊ณผ๋กœ ์—ฌ๋Ÿฌ GPU๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ ์˜ฌ๋ฐ”๋ฅธ ํ›ˆ๋ จ ์†๋„๋ฅผ ์„ ํƒํ•  ์ˆ˜ ์—†์—ˆ์Šต๋‹ˆ๋‹ค.
  • ํ˜„๋Œ€ ์‹ ๊ฒฝ๋ง ์•„ํ‚คํ…์ฒ˜, ํŠนํžˆ imagenet์—์„œ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ์•„ํ‚คํ…์ฒ˜๊ฐ€ ๋ถ€์กฑํ•ฉ๋‹ˆ๋‹ค.
  • ๋ฌด์ฃผ๊ธฐ ์ •์ฑ…๊ณผ ์ฐจ๋ณ„์ ์ธ ํ•™์Šต๋ฅ (์ฝ”์‚ฌ์ธ ์–ด๋‹๋ง์€ ์šฐ๋ฆฌ์˜ ์š”์ฒญ์— ๋”ฐ๋ผ ์ด๋ฃจ์–ด์กŒ์Šต๋‹ˆ๋‹ค) ๊ตฌํ˜„๊ณ ๋งˆ์›Œ. ์Šค์นด์ด๋‹จ).

์ด๋ฒˆ ๋Œ€ํšŒ์—์„œ ๋ฐฐ์šด ์œ ์šฉํ•œ ์ ์€ ๋ฌด์—‡์ž…๋‹ˆ๊นŒ?

  • ์ƒ๋Œ€์ ์œผ๋กœ ์ €์ „๋ ฅ ํ•˜๋“œ์›จ์–ด์—์„œ๋Š” ์ƒ๋‹นํ•œ ์–‘(RAM ํฌ๊ธฐ์˜ ๋ช‡ ๋ฐฐ)์˜ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฌธ์ œ ์—†์ด ์ž‘์—…ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋น„๋‹ ๋ด‰ํˆฌ ๋ฐ์ดํ„ฐ ํ…Œ์ด๋ธ” ํ…Œ์ด๋ธ” ๋ณต์‚ฌ๋ฅผ ๋ฐฉ์ง€ํ•˜๋Š” ๋‚ด๋ถ€ ์ˆ˜์ •์œผ๋กœ ์ธํ•ด ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ ˆ์•ฝํ•˜๊ณ  ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์‚ฌ์šฉํ•˜๋ฉด ๊ทธ ๊ธฐ๋Šฅ์€ ๊ฑฐ์˜ ํ•ญ์ƒ ์Šคํฌ๋ฆฝํŒ… ์–ธ์–ด์— ๋Œ€ํ•ด ์•Œ๋ ค์ง„ ๋ชจ๋“  ๋„๊ตฌ ์ค‘์—์„œ ๊ฐ€์žฅ ๋น ๋ฅธ ์†๋„๋ฅผ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ๋ฐ์ดํ„ฐ๋ฅผ ์ €์žฅํ•˜๋ฉด ๋งŽ์€ ๊ฒฝ์šฐ ์ „์ฒด ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ RAM์— ์••์ถ•ํ•  ํ•„์š”์„ฑ์— ๋Œ€ํ•ด ์ „ํ˜€ ์ƒ๊ฐํ•˜์ง€ ์•Š์•„๋„ ๋ฉ๋‹ˆ๋‹ค.
  • R์˜ ๋Š๋ฆฐ ํ•จ์ˆ˜๋Š” ํŒจํ‚ค์ง€๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ C++์˜ ๋น ๋ฅธ ํ•จ์ˆ˜๋กœ ๋Œ€์ฒด๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. RCPP. ์ถ”๊ฐ€์ ์œผ๋กœ ์‚ฌ์šฉํ•œ๋‹ค๋ฉด Rcpp์Šค๋ ˆ๋“œ ๋˜๋Š” Rcpp๋ณ‘๋ ฌ, ํฌ๋กœ์Šค ํ”Œ๋žซํผ ๋‹ค์ค‘ ์Šค๋ ˆ๋“œ ๊ตฌํ˜„์„ ์–ป๊ฒŒ ๋˜๋ฏ€๋กœ R ์ˆ˜์ค€์—์„œ ์ฝ”๋“œ๋ฅผ ๋ณ‘๋ ฌํ™”ํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.
  • ํŒจํ‚ค์ง€ RCPP C++์— ๋Œ€ํ•œ ์ง„์ง€ํ•œ ์ง€์‹ ์—†์ด๋„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ ํ•„์š”ํ•œ ์ตœ์†Œ ์ˆ˜์ค€์ด ์„ค๋ช…๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์—. ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์—ฌ๋Ÿฌ ๊ฐ€์ง€ ๋ฉ‹์ง„ C ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์— ๋Œ€ํ•œ ํ—ค๋” ํŒŒ์ผ ์—‘์Šคํ…์„œ ์ฆ‰, ์ด๋ฏธ ๋งŒ๋“ค์–ด์ง„ ๊ณ ์„ฑ๋Šฅ C++ ์ฝ”๋“œ๋ฅผ R์— ํ†ตํ•ฉํ•˜๋Š” ํ”„๋กœ์ ํŠธ ๊ตฌํ˜„์„ ์œ„ํ•œ ์ธํ”„๋ผ๊ฐ€ ํ˜•์„ฑ๋˜๊ณ  ์žˆ๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ถ”๊ฐ€์ ์ธ ํŽธ์˜์„ฑ์€ RStudio์˜ ๊ตฌ๋ฌธ ๊ฐ•์กฐ ํ‘œ์‹œ ๋ฐ ์ •์  C++ ์ฝ”๋“œ ๋ถ„์„๊ธฐ์ž…๋‹ˆ๋‹ค.
  • ๋„์ฝฅ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ž์ฒด ํฌํ•จ๋œ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ์›๊ฒฉ ์„œ๋ฒ„์—์„œ ์‚ฌ์šฉํ•˜๊ธฐ์— ํŽธ๋ฆฌํ•ฉ๋‹ˆ๋‹ค. ๋„์ปค ์•„๋ž˜. RStudio์—์„œ๋Š” ์‹ ๊ฒฝ๋ง ํ›ˆ๋ จ์œผ๋กœ ๋งŽ์€ ์‹œ๊ฐ„์˜ ์‹คํ—˜์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ์ด ๋ถˆํŽธํ•˜๊ณ , ์„œ๋ฒ„ ์ž์ฒด์— IDE๋ฅผ ์„ค์น˜ํ•˜๋Š” ๊ฒƒ์ด ํ•ญ์ƒ ์ •๋‹นํ™”๋˜๋Š” ๊ฒƒ์€ ์•„๋‹™๋‹ˆ๋‹ค.
  • Docker๋Š” ์„œ๋กœ ๋‹ค๋ฅธ ๋ฒ„์ „์˜ OS ๋ฐ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฐœ๋ฐœ์ž ๊ฐ„์˜ ์ฝ”๋“œ ์ด์‹์„ฑ๊ณผ ๊ฒฐ๊ณผ ์žฌํ˜„์„ฑ์„ ๋ณด์žฅํ•  ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ์„œ๋ฒ„์—์„œ์˜ ์‹คํ–‰ ์šฉ์ด์„ฑ์„ ๋ณด์žฅํ•ฉ๋‹ˆ๋‹ค. ๋‹จ ํ•˜๋‚˜์˜ ๋ช…๋ น์œผ๋กœ ์ „์ฒด ํ•™์Šต ํŒŒ์ดํ”„๋ผ์ธ์„ ์‹œ์ž‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • Google Cloud๋Š” ๊ณ ๊ฐ€์˜ ํ•˜๋“œ์›จ์–ด๋ฅผ ์‹คํ—˜ํ•  ์ˆ˜ ์žˆ๋Š” ์˜ˆ์‚ฐ ์นœํ™”์ ์ธ ๋ฐฉ๋ฒ•์ด์ง€๋งŒ ๊ตฌ์„ฑ์„ ์‹ ์ค‘ํ•˜๊ฒŒ ์„ ํƒํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  • ๊ฐœ๋ณ„ ์ฝ”๋“œ ์กฐ๊ฐ์˜ ์†๋„๋ฅผ ์ธก์ •ํ•˜๋Š” ๊ฒƒ์€ ํŠนํžˆ R๊ณผ C++๋ฅผ ๊ฒฐํ•ฉํ•˜๊ณ  ํŒจํ‚ค์ง€์™€ ํ•จ๊ป˜ ์‚ฌ์šฉํ•  ๋•Œ ๋งค์šฐ ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค. ๋ฒค์น˜ - ๋˜ํ•œ ๋งค์šฐ ์‰ฝ์Šต๋‹ˆ๋‹ค.

์ „๋ฐ˜์ ์œผ๋กœ ์ด๋ฒˆ ๊ฒฝํ—˜์€ ๋งค์šฐ ๋ณด๋žŒ์žˆ๋Š” ์ผ์ด์—ˆ์œผ๋ฉฐ ์ œ๊ธฐ๋œ ๋ฌธ์ œ ์ค‘ ์ผ๋ถ€๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ๊ณ„์†ํ•ด์„œ ๋…ธ๋ ฅํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

์ถœ์ฒ˜ : habr.com

์ฝ”๋ฉ˜ํŠธ๋ฅผ ์ถ”๊ฐ€