Ukuqashelwa kwe-Doodle esheshayo: ubenza kanjani ubungane ne-R, C++ kanye namanethiwekhi emizwa

Ukuqashelwa kwe-Doodle esheshayo: ubenza kanjani ubungane ne-R, C++ kanye namanethiwekhi emizwa

Sawubona Habr!

Ngekwindla edlule, u-Kaggle usingathe umncintiswano wokuhlukanisa izithombe ezidwetshwe ngesandla, Ukuqashelwa kwe-Quick Draw Doodle, lapho, phakathi kokunye, ithimba lososayensi bakwa-R libambe iqhaza: Artem Klevtsova, Umphathi wePhilippa ΠΈ U-Andrey Ogurtsov. Ngeke siwuchaze kabanzi umncintiswano; lokho sekuvele kwenziwa kuyo ukushicilelwa kwakamuva.

Kulokhu akuzange kusebenze ngokulima izindondo, kodwa kwazuzwa ulwazi oluningi olubalulekile, ngakho-ke ngingathanda ukutshela umphakathi ngezinto ezimbalwa ezithakazelisa kakhulu neziwusizo ku-Kagle nasemsebenzini wansuku zonke. Phakathi kwezihloko okuxoxwe ngazo: impilo enzima ngaphandle I-OpenCV, ukuhlukanisa kwe-JSON (lezi zibonelo zihlola ukuhlanganiswa kwekhodi ye-C++ kumaskripthi noma amaphakheji ku-R kusetshenziswa Rcpp), i-parameterization yemibhalo kanye ne-dockerization yesixazululo sokugcina. Yonke ikhodi evela kumlayezo efomini elifanele ukubulawa itholakala kulo izinqolobane.

Okuqukethwe:

  1. Layisha kahle idatha isuka ku-CSV iye ku-MonetDB
  2. Ilungiselela amaqoqo
  3. Ama-Iterators okukhipha amaqoqo kusizindalwazi
  4. Ukukhetha i-Model Architecture
  5. Ipharamitha yesikripthi
  6. I-Dockerization yemibhalo
  7. Ukusebenzisa ama-GPU amaningi ku-Google Cloud
  8. Esikhundleni isiphetho

1. Layisha ngokuyikho idatha isuka ku-CSV iye kusizindalwazi se-MonetDB

Idatha kulo mncintiswano ayinikezwanga ngendlela yezithombe esenziwe ngomumo, kodwa ingamafayela e-CSV angama-340 (ifayela elilodwa lekilasi ngalinye) aqukethe ama-JSON anamaphuzu adidiyelwe. Ngokuxhuma lawa maphuzu nemigqa, sithola isithombe sokugcina esikala amaphikseli angu-256x256. Futhi kurekhodi ngalinye kunelebula ebonisa ukuthi isithombe saqashelwa kahle yini isihlukanisi esisetshenziswe ngesikhathi kuqoqwa idathasethi, ikhodi enezinhlamvu ezimbili yezwe lokuhlala lombhali wesithombe, isihlonzi esiyingqayizivele, isitembu sesikhathi. kanye negama lekilasi elifana negama lefayela. Inguqulo eyenziwe lula yedatha yasekuqaleni inesisindo esingu-7.4 GB kungobo yomlando futhi cishe u-20 GB ngemva kokukhipha, idatha egcwele ngemva kokukhipha imithwalo ithatha u-240 GB. Abahleli baqinisekisa ukuthi zombili izinguqulo zenza kabusha imidwebo efanayo, okusho ukuthi inguqulo egcwele yayingasenamsebenzi. Kunoma ikuphi, ukugcinwa kwezithombe eziyizigidi ezingu-50 kumafayela ayingcaca noma ngendlela yokuhlela kwathathwa njengokungenanzuzo ngokushesha, futhi sanquma ukuhlanganisa wonke amafayela e-CSV asuka kungobo yomlando. train_simplified.zip kusizindalwazi esinezithombe ezilandelayo zosayizi odingekayo β€œophapheme” kuqoqo ngalinye.

Isistimu efakazelwe kahle yakhethwa njenge-DBMS I-MonetDB, okungukuthi ukuqaliswa kwe-R njengephakheji I-MonetDBLite. Iphakheji ihlanganisa inguqulo eshumekiwe yeseva yedathabhesi futhi ikuvumela ukuthi ucoshe iseva ngokuqondile kuseshini ye-R futhi usebenze nayo lapho. Ukudala i-database kanye nokuxhuma kuyo kwenziwa ngomyalo owodwa:

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

Kuzodingeka sakhe amathebula amabili: elilodwa layo yonke idatha, elinye elolwazi lwesevisi mayelana namafayela alandiwe (kuwusizo uma kukhona okungahambi kahle futhi inqubo kufanele iqalwe kabusha ngemva kokulanda amafayela ambalwa):

Ukudala amatafula

if (!DBI::dbExistsTable(con, "doodles")) {
  DBI::dbCreateTable(
    con = con,
    name = "doodles",
    fields = c(
      "countrycode" = "char(2)",
      "drawing" = "text",
      "key_id" = "bigint",
      "recognized" = "bool",
      "timestamp" = "timestamp",
      "word" = "text"
    )
  )
}

if (!DBI::dbExistsTable(con, "upload_log")) {
  DBI::dbCreateTable(
    con = con,
    name = "upload_log",
    fields = c(
      "id" = "serial",
      "file_name" = "text UNIQUE",
      "uploaded" = "bool DEFAULT false"
    )
  )
}

Indlela eshesha kakhulu yokulayisha idatha kusizindalwazi bekuwukukopisha ngokuqondile amafayela e-CSV usebenzisa i-SQL - command COPY OFFSET 2 INTO tablename FROM path USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORTkuphi tablename - Igama lethebula kanye path - indlela eya kufayela. Ngenkathi isebenza nengobo yomlando, kutholwe ukuthi ukuqaliswa okwakhelwe ngaphakathi unzip ku-R ayisebenzi kahle ngenani lamafayela asuka kungobo yomlando, ngakho sisebenzise uhlelo unzip (usebenzisa ipharamitha getOption("unzip")).

Umsebenzi wokubhala kusizindalwazi

#' @title Π˜Π·Π²Π»Π΅Ρ‡Π΅Π½ΠΈΠ΅ ΠΈ Π·Π°Π³Ρ€ΡƒΠ·ΠΊΠ° Ρ„Π°ΠΉΠ»ΠΎΠ²
#'
#' @description
#' Π˜Π·Π²Π»Π΅Ρ‡Π΅Π½ΠΈΠ΅ CSV-Ρ„Π°ΠΉΠ»ΠΎΠ² ΠΈΠ· ZIP-Π°Ρ€Ρ…ΠΈΠ²Π° ΠΈ Π·Π°Π³Ρ€ΡƒΠ·ΠΊΠ° ΠΈΡ… Π² Π±Π°Π·Ρƒ Π΄Π°Π½Π½Ρ‹Ρ…
#'
#' @param con ΠžΠ±ΡŠΠ΅ΠΊΡ‚ ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΡ ΠΊ Π±Π°Π·Π΅ Π΄Π°Π½Π½Ρ‹Ρ… (класс `MonetDBEmbeddedConnection`).
#' @param tablename НазваниС Ρ‚Π°Π±Π»ΠΈΡ†Ρ‹ Π² Π±Π°Π·Π΅ Π΄Π°Π½Π½Ρ‹Ρ….
#' @oaram zipfile ΠŸΡƒΡ‚ΡŒ ΠΊ ZIP-Π°Ρ€Ρ…ΠΈΠ²Ρƒ.
#' @oaram filename Имя Ρ„Π°ΠΉΠ»Π° Π²Π½ΡƒΡ€ΠΈ ZIP-Π°Ρ€Ρ…ΠΈΠ²Π°.
#' @param preprocess Ѐункция ΠΏΡ€Π΅Π΄ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ, которая Π±ΡƒΠ΄Π΅Ρ‚ ΠΏΡ€ΠΈΠΌΠ΅Π½Π΅Π½Π° ΠΈΠ·Π²Π»Π΅Ρ‡Ρ‘Π½Π½ΠΎΠΌΡƒ Ρ„Π°ΠΉΠ»Ρƒ.
#'   Π”ΠΎΠ»ΠΆΠ½Π° ΠΏΡ€ΠΈΠ½ΠΈΠΌΠ°Ρ‚ΡŒ ΠΎΠ΄ΠΈΠ½ Π°Ρ€Π³ΡƒΠΌΠ΅Π½Ρ‚ `data` (ΠΎΠ±ΡŠΠ΅ΠΊΡ‚ `data.table`).
#'
#' @return `TRUE`.
#'
upload_file <- function(con, tablename, zipfile, filename, preprocess = NULL) {
  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° Π°Ρ€Π³ΡƒΠΌΠ΅Π½Ρ‚ΠΎΠ²
  checkmate::assert_class(con, "MonetDBEmbeddedConnection")
  checkmate::assert_string(tablename)
  checkmate::assert_string(filename)
  checkmate::assert_true(DBI::dbExistsTable(con, tablename))
  checkmate::assert_file_exists(zipfile, access = "r", extension = "zip")
  checkmate::assert_function(preprocess, args = c("data"), null.ok = TRUE)

  # Π˜Π·Π²Π»Π΅Ρ‡Π΅Π½ΠΈΠ΅ Ρ„Π°ΠΉΠ»Π°
  path <- file.path(tempdir(), filename)
  unzip(zipfile, files = filename, exdir = tempdir(), 
        junkpaths = TRUE, unzip = getOption("unzip"))
  on.exit(unlink(file.path(path)))

  # ΠŸΡ€ΠΈΠΌΠ΅Π½ΡΠ΅ΠΌ функция ΠΏΡ€Π΅Π΄ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ
  if (!is.null(preprocess)) {
    .data <- data.table::fread(file = path)
    .data <- preprocess(data = .data)
    data.table::fwrite(x = .data, file = path, append = FALSE)
    rm(.data)
  }

  # Запрос ΠΊ Π‘Π” Π½Π° ΠΈΠΌΠΏΠΎΡ€Ρ‚ CSV
  sql <- sprintf(
    "COPY OFFSET 2 INTO %s FROM '%s' USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORT",
    tablename, path
  )
  # Π’Ρ‹ΠΏΠΎΠ»Π½Π΅Π½ΠΈΠ΅ запроса ΠΊ Π‘Π”
  DBI::dbExecute(con, sql)

  # Π”ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΠ΅ записи ΠΎΠ± ΡƒΡΠΏΠ΅ΡˆΠ½ΠΎΠΉ Π·Π°Π³Ρ€ΡƒΠ·ΠΊΠ΅ Π² ΡΠ»ΡƒΠΆΠ΅Π±Π½ΡƒΡŽ Ρ‚Π°Π±Π»ΠΈΡ†Ρƒ
  DBI::dbExecute(con, sprintf("INSERT INTO upload_log(file_name, uploaded) VALUES('%s', true)",
                              filename))

  return(invisible(TRUE))
}

Uma udinga ukuguqula ithebula ngaphambi kokulibhalela ku-database, kwanele ukudlulisa ku-agumenti preprocess umsebenzi ozoguqula idatha.

Ikhodi yokulayisha idatha ngokulandelana kusizindalwazi:

Ukubhala idatha kusizindalwazi

# Бписок Ρ„Π°ΠΉΠ»ΠΎΠ² для записи
files <- unzip(zipfile, list = TRUE)$Name

# Бписок ΠΈΡΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠΉ, Ссли Ρ‡Π°ΡΡ‚ΡŒ Ρ„Π°ΠΉΠ»ΠΎΠ² ΡƒΠΆΠ΅ Π±Ρ‹Π»Π° Π·Π°Π³Ρ€ΡƒΠΆΠ΅Π½Π°
to_skip <- DBI::dbGetQuery(con, "SELECT file_name FROM upload_log")[[1L]]
files <- setdiff(files, to_skip)

if (length(files) > 0L) {
  # ЗапускаСм Ρ‚Π°ΠΉΠΌΠ΅Ρ€
  tictoc::tic()
  # ΠŸΡ€ΠΎΠ³Ρ€Π΅ΡΡ Π±Π°Ρ€
  pb <- txtProgressBar(min = 0L, max = length(files), style = 3)
  for (i in seq_along(files)) {
    upload_file(con = con, tablename = "doodles", 
                zipfile = zipfile, filename = files[i])
    setTxtProgressBar(pb, i)
  }
  close(pb)
  # ΠžΡΡ‚Π°Π½Π°Π²Π»ΠΈΠ²Π°Π΅ΠΌ Ρ‚Π°ΠΉΠΌΠ΅Ρ€
  tictoc::toc()
}

# 526.141 sec elapsed - ΠΊΠΎΠΏΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ SSD->SSD
# 558.879 sec elapsed - ΠΊΠΎΠΏΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ USB->SSD

Isikhathi sokulayisha idatha singahluka kuye ngezici zesivinini sedrayivu esetshenzisiwe. Esimweni sethu, ukufunda nokubhala ngaphakathi kwe-SSD eyodwa noma kusuka ku-flash drive (ifayela lomthombo) kuya ku-SSD (DB) kuthatha ngaphansi kwemizuzu eyi-10.

Kuthatha imizuzwana embalwa ukwakha ikholomu enelebula lekilasi eliphelele kanye nekholomu yenkomba (ORDERED INDEX) ngezinombolo zomugqa lapho okubhekwayo kuzothathwa njengesampula lapho kwakhiwa amaqoqo:

Ukudala Amakholomu Engeziwe Nezikhombo

message("Generate lables")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD label_int int"))
invisible(DBI::dbExecute(con, "UPDATE doodles SET label_int = dense_rank() OVER (ORDER BY word) - 1"))

message("Generate row numbers")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD id serial"))
invisible(DBI::dbExecute(con, "CREATE ORDERED INDEX doodles_id_ord_idx ON doodles(id)"))

Ukuze sixazulule inkinga yokudala iqoqo ngokuhamba kwesikhathi, besidinga ukuzuza isivinini esikhulu sokukhipha imigqa engahleliwe etafuleni. doodles. Kulokhu sisebenzise amaqhinga ama-3. Owokuqala bekuwukunciphisa ubukhulu bohlobo olugcina i-Observation ID. Kusethi yedatha yasekuqaleni, uhlobo oludingekayo ukuze kugcinwe i-ID bigint, kodwa inani lokubhekwa lenza kube nokwenzeka ukulingana nezihlonzi zazo, ezilingana nenombolo ye-ordinal, ohlotsheni int. Ukusesha kushesha kakhulu kulokhu. Iqhinga lesibili kwaba ukusebenzisa ORDERED INDEX - sifinyelele kulesi sinqumo ngokugunyazwa, sesidlule kukho konke okutholakalayo izinketho. Okwesithathu kwakuwukusebenzisa imibuzo enepharamitha. Ingqikithi yendlela ukwenza umyalo kanye PREPARE ngokusetshenziswa okwalandela kwenkulumo elungiselelwe lapho udala inqwaba yemibuzo yohlobo olufanayo, kodwa empeleni kunenzuzo uma kuqhathaniswa nelula SELECT kuvele ukuthi ingaphakathi kwebanga lephutha lezibalo.

Inqubo yokulayisha idatha ayidli ngaphezu kuka-450 MB we-RAM. Okusho ukuthi, indlela echaziwe ikuvumela ukuthi uhambise amasethi edatha anesisindo samashumi amagigabhayithi cishe kunoma iyiphi ihadiwe yebhajethi, okuhlanganisa namadivaysi ebhodi elilodwa, okuyinto enhle kakhulu.

Okusele nje ukukala isivinini sokubuyisa idatha (okungahleliwe) nokuhlola ukukala lapho kwenziwa amasampula amaqoqo osayizi abahlukene:

Ibhentshimakhi yesizindalwazi

library(ggplot2)

set.seed(0)
# ΠŸΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠ΅ ΠΊ Π±Π°Π·Π΅ Π΄Π°Π½Π½Ρ‹Ρ…
con <- DBI::dbConnect(MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

# Ѐункция для ΠΏΠΎΠ΄Π³ΠΎΡ‚ΠΎΠ²ΠΊΠΈ запроса Π½Π° сторонС сСрвСра
prep_sql <- function(batch_size) {
  sql <- sprintf("PREPARE SELECT id FROM doodles WHERE id IN (%s)",
                 paste(rep("?", batch_size), collapse = ","))
  res <- DBI::dbSendQuery(con, sql)
  return(res)
}

# Ѐункция для извлСчСния Π΄Π°Π½Π½Ρ‹Ρ…
fetch_data <- function(rs, batch_size) {
  ids <- sample(seq_len(n), batch_size)
  res <- DBI::dbFetch(DBI::dbBind(rs, as.list(ids)))
  return(res)
}

# ΠŸΡ€ΠΎΠ²Π΅Π΄Π΅Π½ΠΈΠ΅ Π·Π°ΠΌΠ΅Ρ€Π°
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    rs <- prep_sql(batch_size)
    bench::mark(
      fetch_data(rs, batch_size),
      min_iterations = 50L
    )
  }
)
# ΠŸΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€Ρ‹ Π±Π΅Π½Ρ‡ΠΌΠ°Ρ€ΠΊΠ°
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16   23.6ms  54.02ms  93.43ms     18.8        2.6s    49
# 2         32     38ms  84.83ms 151.55ms     11.4       4.29s    49
# 3         64   63.3ms 175.54ms 248.94ms     5.85       8.54s    50
# 4        128   83.2ms 341.52ms 496.24ms     3.00      16.69s    50
# 5        256  232.8ms 653.21ms 847.44ms     1.58      31.66s    50
# 6        512  784.6ms    1.41s    1.98s     0.740       1.1m    49
# 7       1024  681.7ms    2.72s    4.06s     0.377      2.16m    49

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Ukuqashelwa kwe-Doodle esheshayo: ubenza kanjani ubungane ne-R, C++ kanye namanethiwekhi emizwa

2. Ukulungiselela amaqoqo

Yonke inqubo yokulungiselela iqoqo iqukethe izinyathelo ezilandelayo:

  1. Ukuhlaziya ama-JSON ambalwa aqukethe ama-vectors weyunithi yezinhlamvu anezixhumanisi zamaphoyinti.
  2. Ukudweba imigqa enemibala esekelwe kuzixhumanisi zamaphuzu esithombeni sosayizi odingekayo (isibonelo, 256Γ—256 noma 128Γ—128).
  3. Ukuguqula izithombe eziwumphumela zibe i-tensor.

Njengengxenye yomncintiswano phakathi kwezinhlamvu zePython, inkinga yaxazululwa ngokuyinhloko kusetshenziswa I-OpenCV. Enye ye-analogue elula futhi esobala kakhulu ku-R ingabukeka kanje:

Isebenzisa i-JSON kuya ku-Tensor Conversion ku-R

r_process_json_str <- function(json, line.width = 3, 
                               color = TRUE, scale = 1) {
  # ΠŸΠ°Ρ€ΡΠΈΠ½Π³ JSON
  coords <- jsonlite::fromJSON(json, simplifyMatrix = FALSE)
  tmp <- tempfile()
  # УдаляСм Π²Ρ€Π΅ΠΌΠ΅Π½Π½Ρ‹ΠΉ Ρ„Π°ΠΉΠ» ΠΏΠΎ Π·Π°Π²Π΅Ρ€ΡˆΠ΅Π½ΠΈΡŽ Ρ„ΡƒΠ½ΠΊΡ†ΠΈΠΈ
  on.exit(unlink(tmp))
  png(filename = tmp, width = 256 * scale, height = 256 * scale, pointsize = 1)
  # ΠŸΡƒΡΡ‚ΠΎΠΉ Π³Ρ€Π°Ρ„ΠΈΠΊ
  plot.new()
  # Π Π°Π·ΠΌΠ΅Ρ€ ΠΎΠΊΠ½Π° Π³Ρ€Π°Ρ„ΠΈΠΊΠ°
  plot.window(xlim = c(256 * scale, 0), ylim = c(256 * scale, 0))
  # Π¦Π²Π΅Ρ‚Π° Π»ΠΈΠ½ΠΈΠΉ
  cols <- if (color) rainbow(length(coords)) else "#000000"
  for (i in seq_along(coords)) {
    lines(x = coords[[i]][[1]] * scale, y = coords[[i]][[2]] * scale, 
          col = cols[i], lwd = line.width)
  }
  dev.off()
  # ΠŸΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ изобраТСния Π² 3-Ρ… ΠΌΠ΅Ρ€Π½Ρ‹ΠΉ массив
  res <- png::readPNG(tmp)
  return(res)
}

r_process_json_vector <- function(x, ...) {
  res <- lapply(x, r_process_json_str, ...)
  # ОбъСдинСниС 3-Ρ… ΠΌΠ΅Ρ€Π½Ρ‹Ρ… массивов ΠΊΠ°Ρ€Ρ‚ΠΈΠ½ΠΎΠΊ Π² 4-Ρ… ΠΌΠ΅Ρ€Π½Ρ‹ΠΉ Π² Ρ‚Π΅Π½Π·ΠΎΡ€
  res <- do.call(abind::abind, c(res, along = 0))
  return(res)
}

Umdwebo wenziwa kusetshenziswa amathuluzi ajwayelekile we-R futhi ulondolozwe ku-PNG yesikhashana egcinwe ku-RAM (ku-Linux, izinkhombandlela zesikhashana ze-R zitholakala kuhla lwemibhalo. /tmp, ifakwe ku-RAM). Leli fayela libe selifundwa njengohlelo olunezinhlangothi ezintathu ezinezinombolo ezisukela ku-0 kuye ku-1. Lokhu kubalulekile ngoba i-BMP evamile kakhulu izofundwa ibe uhlu olungahluziwe olunamakhodi ombala we-hex.

Ake sihlole umphumela:

zip_file <- file.path("data", "train_simplified.zip")
csv_file <- "cat.csv"
unzip(zip_file, files = csv_file, exdir = tempdir(), 
      junkpaths = TRUE, unzip = getOption("unzip"))
tmp_data <- data.table::fread(file.path(tempdir(), csv_file), sep = ",", 
                              select = "drawing", nrows = 10000)
arr <- r_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Ukuqashelwa kwe-Doodle esheshayo: ubenza kanjani ubungane ne-R, C++ kanye namanethiwekhi emizwa

Iqoqo ngokwalo lizokwakhiwa kanje:

res <- r_process_json_vector(tmp_data[1:4, drawing], scale = 0.5)
str(res)
 # num [1:4, 1:128, 1:128, 1:3] 1 1 1 1 1 1 1 1 1 1 ...
 # - attr(*, "dimnames")=List of 4
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL

Lokhu kuqaliswa kwabonakala kungafaneleki kithi, njengoba ukwakhiwa kwamaqoqo amakhulu kuthatha isikhathi eside ngokungafanele, futhi sanquma ukusizakala ngolwazi lozakwethu ngokusebenzisa umtapo wezincwadi onamandla. I-OpenCV. Ngaleso sikhathi lalingekho iphakheji esenziwe ngomumo ye-R (alikho manje), ngakho-ke ukuqaliswa okuncane kokusebenza okudingekayo kwabhalwa ku-C++ ngokuhlanganiswa kukhodi engu-R kusetshenziswa. Rcpp.

Ukuxazulula inkinga, amaphakheji alandelayo nemitapo yolwazi yasetshenziswa:

  1. I-OpenCV ukusebenza ngezithombe nemigqa yokudweba. Kusetshenziswe imitapo yolwazi yesistimu efakwe ngaphambilini namafayela anhlokweni, kanye nokuxhumanisa okunamandla.

  2. xtensor ukusebenza ngama-multidimensional arrays nama-tensor. Sisebenzise amafayela anhlokweni afakwe kuphakheji engu-R yegama elifanayo. Umtapo wolwazi ikuvumela ukuthi usebenze ngama-multidimensional arrays, kokubili ngomugqa omkhulu kanye nokuhleleka okukhulu kwekholomu.

  3. ndjson yokuhlaziya i-JSON. Lo mtapo wolwazi usetshenziswa ku xtensor ngokuzenzakalelayo uma ikhona kuphrojekthi.

  4. RcppThread ukuze uhlele ukucutshungulwa kwe-vector enezintambo eziningi kusuka ku-JSON. Kusetshenziswe amafayela kanhlokweni anikezwe yile phakheji. Isuka edume kakhulu I-RcppParallel Iphakheji, phakathi kwezinye izinto, inomshini wokuphazamisa iluphu eyakhelwe ngaphakathi.

Kuyafaneleka ukuphawula lokho xtensor kwavela ukuthi i-godsend: ngaphezu kweqiniso lokuthi inokusebenza okubanzi nokusebenza okuphezulu, abathuthukisi bayo bavele basabela futhi baphendula imibuzo ngokushesha nangemininingwane. Ngosizo lwabo, kube nokwenzeka ukusebenzisa ukuguqulwa kukamatikuletsheni we-OpenCV abe ama-xtensor, kanye nendlela yokuhlanganisa ama-tensor esithombe anezinhlangothi ezi-3 abe i-tensor ene-dimensional engu-4 yobukhulu obulungile (inqwaba ngokwayo).

Izinto zokufunda i-Rcpp, i-xtensor ne-RcppThread

https://thecoatlessprofessor.com/programming/unofficial-rcpp-api-documentation

https://docs.opencv.org/4.0.1/d7/dbd/group__imgproc.html

https://xtensor.readthedocs.io/en/latest/

https://xtensor.readthedocs.io/en/latest/file_loading.html#loading-json-data-into-xtensor

https://cran.r-project.org/web/packages/RcppThread/vignettes/RcppThread-vignette.pdf

Ukuze sihlanganise amafayela asebenzisa amafayela esistimu nokuxhumana okunamandla namalabhulali afakwe ohlelweni, sisebenzise indlela ye-plugin esetshenziswe kuphakheji. Rcpp. Ukuze sithole ngokuzenzakalelayo izindlela namafulegi, sisebenzise insiza ye-Linux edumile i-pkg-config.

Ukuqaliswa kwe-plugin ye-Rcpp yokusebenzisa umtapo wezincwadi we-OpenCV

Rcpp::registerPlugin("opencv", function() {
  # Π’ΠΎΠ·ΠΌΠΎΠΆΠ½Ρ‹Π΅ названия ΠΏΠ°ΠΊΠ΅Ρ‚Π°
  pkg_config_name <- c("opencv", "opencv4")
  # Π‘ΠΈΠ½Π°Ρ€Π½Ρ‹ΠΉ Ρ„Π°ΠΉΠ» ΡƒΡ‚ΠΈΠ»ΠΈΡ‚Ρ‹ pkg-config
  pkg_config_bin <- Sys.which("pkg-config")
  # ΠŸΡ€ΠΎΠ²Ρ€Π΅ΠΊΠ° наличия ΡƒΡ‚ΠΈΠ»ΠΈΡ‚Ρ‹ Π² систСмС
  checkmate::assert_file_exists(pkg_config_bin, access = "x")
  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° наличия Ρ„Π°ΠΉΠ»Π° настроСк OpenCV для pkg-config
  check <- sapply(pkg_config_name, 
                  function(pkg) system(paste(pkg_config_bin, pkg)))
  if (all(check != 0)) {
    stop("OpenCV config for the pkg-config not found", call. = FALSE)
  }

  pkg_config_name <- pkg_config_name[check == 0]
  list(env = list(
    PKG_CXXFLAGS = system(paste(pkg_config_bin, "--cflags", pkg_config_name), 
                          intern = TRUE),
    PKG_LIBS = system(paste(pkg_config_bin, "--libs", pkg_config_name), 
                      intern = TRUE)
  ))
})

Njengomphumela wokusebenza kwe-plugin, amanani alandelayo azoshintshwa ngesikhathi senqubo yokuhlanganisa:

Rcpp:::.plugins$opencv()$env

# $PKG_CXXFLAGS
# [1] "-I/usr/include/opencv"
#
# $PKG_LIBS
# [1] "-lopencv_shape -lopencv_stitching -lopencv_superres -lopencv_videostab -lopencv_aruco -lopencv_bgsegm -lopencv_bioinspired -lopencv_ccalib -lopencv_datasets -lopencv_dpm -lopencv_face -lopencv_freetype -lopencv_fuzzy -lopencv_hdf -lopencv_line_descriptor -lopencv_optflow -lopencv_video -lopencv_plot -lopencv_reg -lopencv_saliency -lopencv_stereo -lopencv_structured_light -lopencv_phase_unwrapping -lopencv_rgbd -lopencv_viz -lopencv_surface_matching -lopencv_text -lopencv_ximgproc -lopencv_calib3d -lopencv_features2d -lopencv_flann -lopencv_xobjdetect -lopencv_objdetect -lopencv_ml -lopencv_xphoto -lopencv_highgui -lopencv_videoio -lopencv_imgcodecs -lopencv_photo -lopencv_imgproc -lopencv_core"

Ikhodi yokusebenzisa yokuhlaziya i-JSON kanye nokukhiqiza inqwaba ukuze idluliselwe kumodeli inikezwa ngaphansi kwe-spoiler. Okokuqala, engeza uhla lwemibhalo lwephrojekthi yendawo ukuze useshe amafayela kanhlokweni (adingekayo ku-ndjson):

Sys.setenv("PKG_CXXFLAGS" = paste0("-I", normalizePath(file.path("src"))))

Ukusetshenziswa kwe-JSON kuya ekuguquleni i-tensor ku-C++

// [[Rcpp::plugins(cpp14)]]
// [[Rcpp::plugins(opencv)]]
// [[Rcpp::depends(xtensor)]]
// [[Rcpp::depends(RcppThread)]]

#include <xtensor/xjson.hpp>
#include <xtensor/xadapt.hpp>
#include <xtensor/xview.hpp>
#include <xtensor-r/rtensor.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <Rcpp.h>
#include <RcppThread.h>

// Π‘ΠΈΠ½ΠΎΠ½ΠΈΠΌΡ‹ для Ρ‚ΠΈΠΏΠΎΠ²
using RcppThread::parallelFor;
using json = nlohmann::json;
using points = xt::xtensor<double,2>;     // Π˜Π·Π²Π»Π΅Ρ‡Ρ‘Π½Π½Ρ‹Π΅ ΠΈΠ· JSON ΠΊΠΎΠΎΡ€Π΄ΠΈΠ½Π°Ρ‚Ρ‹ Ρ‚ΠΎΡ‡Π΅ΠΊ
using strokes = std::vector<points>;      // Π˜Π·Π²Π»Π΅Ρ‡Ρ‘Π½Π½Ρ‹Π΅ ΠΈΠ· JSON ΠΊΠΎΠΎΡ€Π΄ΠΈΠ½Π°Ρ‚Ρ‹ Ρ‚ΠΎΡ‡Π΅ΠΊ
using xtensor3d = xt::xtensor<double, 3>; // Π’Π΅Π½Π·ΠΎΡ€ для хранСния ΠΌΠ°Ρ‚Ρ€ΠΈΡ†Ρ‹ изообраТСния
using xtensor4d = xt::xtensor<double, 4>; // Π’Π΅Π½Π·ΠΎΡ€ для хранСния мноТСства ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠΉ
using rtensor3d = xt::rtensor<double, 3>; // ΠžΠ±Ρ‘Ρ€Ρ‚ΠΊΠ° для экспорта Π² R
using rtensor4d = xt::rtensor<double, 4>; // ΠžΠ±Ρ‘Ρ€Ρ‚ΠΊΠ° для экспорта Π² R

// БтатичСскиС константы
// Π Π°Π·ΠΌΠ΅Ρ€ изобраТСния Π² пиксСлях
const static int SIZE = 256;
// Π’ΠΈΠΏ Π»ΠΈΠ½ΠΈΠΈ
// Π‘ΠΌ. https://en.wikipedia.org/wiki/Pixel_connectivity#2-dimensional
const static int LINE_TYPE = cv::LINE_4;
// Π’ΠΎΠ»Ρ‰ΠΈΠ½Π° Π»ΠΈΠ½ΠΈΠΈ Π² пиксСлях
const static int LINE_WIDTH = 3;
// Алгоритм рСсайза
// https://docs.opencv.org/3.1.0/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
const static int RESIZE_TYPE = cv::INTER_LINEAR;

// Π¨Π°Π±Π»ΠΎΠ½ для конвСртирования OpenCV-ΠΌΠ°Ρ‚Ρ€ΠΈΡ†Ρ‹ Π² Ρ‚Π΅Π½Π·ΠΎΡ€
template <typename T, int NCH, typename XT=xt::xtensor<T,3,xt::layout_type::column_major>>
XT to_xt(const cv::Mat_<cv::Vec<T, NCH>>& src) {
  // Π Π°Π·ΠΌΠ΅Ρ€Π½ΠΎΡΡ‚ΡŒ Ρ†Π΅Π»Π΅Π²ΠΎΠ³ΠΎ Ρ‚Π΅Π½Π·ΠΎΡ€Π°
  std::vector<int> shape = {src.rows, src.cols, NCH};
  // ΠžΠ±Ρ‰Π΅Π΅ количСство элСмСнтов Π² массивС
  size_t size = src.total() * NCH;
  // ΠŸΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ cv::Mat Π² xt::xtensor
  XT res = xt::adapt((T*) src.data, size, xt::no_ownership(), shape);
  return res;
}

// ΠŸΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ JSON Π² список ΠΊΠΎΠΎΡ€Π΄ΠΈΠ½Π°Ρ‚ Ρ‚ΠΎΡ‡Π΅ΠΊ
strokes parse_json(const std::string& x) {
  auto j = json::parse(x);
  // Π Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚ парсинга Π΄ΠΎΠ»ΠΆΠ΅Π½ Π±Ρ‹Ρ‚ΡŒ массивом
  if (!j.is_array()) {
    throw std::runtime_error("'x' must be JSON array.");
  }
  strokes res;
  res.reserve(j.size());
  for (const auto& a: j) {
    // ΠšΠ°ΠΆΠ΄Ρ‹ΠΉ элСмСнт массива Π΄ΠΎΠ»ΠΆΠ΅Π½ Π±Ρ‹Ρ‚ΡŒ 2-ΠΌΠ΅Ρ€Π½Ρ‹ΠΌ массивом
    if (!a.is_array() || a.size() != 2) {
      throw std::runtime_error("'x' must include only 2d arrays.");
    }
    // Π˜Π·Π²Π»Π΅Ρ‡Π΅Π½ΠΈΠ΅ Π²Π΅ΠΊΡ‚ΠΎΡ€Π° Ρ‚ΠΎΡ‡Π΅ΠΊ
    auto p = a.get<points>();
    res.push_back(p);
  }
  return res;
}

// ΠžΡ‚Ρ€ΠΈΡΠΎΠ²ΠΊΠ° Π»ΠΈΠ½ΠΈΠΉ
// Π¦Π²Π΅Ρ‚Π° HSV
cv::Mat ocv_draw_lines(const strokes& x, bool color = true) {
  // Π˜ΡΡ…ΠΎΠ΄Π½Ρ‹ΠΉ Ρ‚ΠΈΠΏ ΠΌΠ°Ρ‚Ρ€ΠΈΡ†Ρ‹
  auto stype = color ? CV_8UC3 : CV_8UC1;
  // Π˜Ρ‚ΠΎΠ³ΠΎΠ²Ρ‹ΠΉ Ρ‚ΠΈΠΏ ΠΌΠ°Ρ‚Ρ€ΠΈΡ†Ρ‹
  auto dtype = color ? CV_32FC3 : CV_32FC1;
  auto bg = color ? cv::Scalar(0, 0, 255) : cv::Scalar(255);
  auto col = color ? cv::Scalar(0, 255, 220) : cv::Scalar(0);
  cv::Mat img = cv::Mat(SIZE, SIZE, stype, bg);
  // ΠšΠΎΠ»ΠΈΡ‡Π΅ΡΡ‚Π²ΠΎ Π»ΠΈΠ½ΠΈΠΉ
  size_t n = x.size();
  for (const auto& s: x) {
    // ΠšΠΎΠ»ΠΈΡ‡Π΅ΡΡ‚Π²ΠΎ Ρ‚ΠΎΡ‡Π΅ΠΊ Π² Π»ΠΈΠ½ΠΈΠΈ
    size_t n_points = s.shape()[1];
    for (size_t i = 0; i < n_points - 1; ++i) {
      // Π’ΠΎΡ‡ΠΊΠ° Π½Π°Ρ‡Π°Π»Π° ΡˆΡ‚Ρ€ΠΈΡ…Π°
      cv::Point from(s(0, i), s(1, i));
      // Π’ΠΎΡ‡ΠΊΠ° окончания ΡˆΡ‚Ρ€ΠΈΡ…Π°
      cv::Point to(s(0, i + 1), s(1, i + 1));
      // ΠžΡ‚Ρ€ΠΈΡΠΎΠ²ΠΊΠ° Π»ΠΈΠ½ΠΈΠΈ
      cv::line(img, from, to, col, LINE_WIDTH, LINE_TYPE);
    }
    if (color) {
      // МСняСм Ρ†Π²Π΅Ρ‚ Π»ΠΈΠ½ΠΈΠΈ
      col[0] += 180 / n;
    }
  }
  if (color) {
    // МСняСм Ρ†Π²Π΅Ρ‚ΠΎΠ²ΠΎΠ΅ прСдставлСниС Π½Π° RGB
    cv::cvtColor(img, img, cv::COLOR_HSV2RGB);
  }
  // МСняСм Ρ„ΠΎΡ€ΠΌΠ°Ρ‚ прСдставлСния Π½Π° float32 с Π΄ΠΈΠ°ΠΏΠ°Π·ΠΎΠ½ΠΎΠΌ [0, 1]
  img.convertTo(img, dtype, 1 / 255.0);
  return img;
}

// ΠžΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠ° JSON ΠΈ ΠΏΠΎΠ»ΡƒΡ‡Π΅Π½ΠΈΠ΅ Ρ‚Π΅Π½Π·ΠΎΡ€Π° с Π΄Π°Π½Π½Ρ‹ΠΌΠΈ изобраТСния
xtensor3d process(const std::string& x, double scale = 1.0, bool color = true) {
  auto p = parse_json(x);
  auto img = ocv_draw_lines(p, color);
  if (scale != 1) {
    cv::Mat out;
    cv::resize(img, out, cv::Size(), scale, scale, RESIZE_TYPE);
    cv::swap(img, out);
    out.release();
  }
  xtensor3d arr = color ? to_xt<double,3>(img) : to_xt<double,1>(img);
  return arr;
}

// [[Rcpp::export]]
rtensor3d cpp_process_json_str(const std::string& x, 
                               double scale = 1.0, 
                               bool color = true) {
  xtensor3d res = process(x, scale, color);
  return res;
}

// [[Rcpp::export]]
rtensor4d cpp_process_json_vector(const std::vector<std::string>& x, 
                                  double scale = 1.0, 
                                  bool color = false) {
  size_t n = x.size();
  size_t dim = floor(SIZE * scale);
  size_t channels = color ? 3 : 1;
  xtensor4d res({n, dim, dim, channels});
  parallelFor(0, n, [&x, &res, scale, color](int i) {
    xtensor3d tmp = process(x[i], scale, color);
    auto view = xt::view(res, i, xt::all(), xt::all(), xt::all());
    view = tmp;
  });
  return res;
}

Le khodi kufanele ifakwe efayelini src/cv_xt.cpp bese uhlanganisa ngomyalo Rcpp::sourceCpp(file = "src/cv_xt.cpp", env = .GlobalEnv); futhi kuyadingeka emsebenzini nlohmann/json.hpp kusuka ku inqolobane. Ikhodi ihlukaniswe yaba imisebenzi eminingana:

  • to_xt - umsebenzi oyisifanekiso wokuguqula i-matrix yesithombe (cv::Mat) ku-tensor xt::xtensor;

  • parse_json - umsebenzi uhlaziya iyunithi yezinhlamvu ye-JSON, ikhiphe izixhumanisi zamaphoyinti, iwapakishe ku-vector;

  • ocv_draw_lines - kusuka ku-vector yamaphuzu, idweba imigqa enemibala eminingi;

  • process - ihlanganisa imisebenzi engenhla futhi yengeza amandla okukala isithombe esiwumphumela;

  • cpp_process_json_str - ukugoqa phezu komsebenzi process, ethekelisa umphumela ku-R-object (amalungu afanayo angama-multidimensional);

  • cpp_process_json_vector - ukugoqa phezu komsebenzi cpp_process_json_str, okukuvumela ukuthi ucubungule i-vector yezintambo kwimodi enezintambo eziningi.

Ukuze udwebe imigqa enemibala eminingi, kusetshenziswe imodeli yombala we-HSV, elandelwa ukuguqulwa ku-RGB. Ake sihlole umphumela:

arr <- cpp_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Ukuqashelwa kwe-Doodle esheshayo: ubenza kanjani ubungane ne-R, C++ kanye namanethiwekhi emizwa
Ukuqhathaniswa kwejubane lokusetshenziswa ku-R ne-C++

res_bench <- bench::mark(
  r_process_json_str(tmp_data[4, drawing], scale = 0.5),
  cpp_process_json_str(tmp_data[4, drawing], scale = 0.5),
  check = FALSE,
  min_iterations = 100
)
# ΠŸΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€Ρ‹ Π±Π΅Π½Ρ‡ΠΌΠ°Ρ€ΠΊΠ°
cols <- c("expression", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   expression                min     median       max `itr/sec` total_time  n_itr
#   <chr>                <bch:tm>   <bch:tm>  <bch:tm>     <dbl>   <bch:tm>  <int>
# 1 r_process_json_str     3.49ms     3.55ms    4.47ms      273.      490ms    134
# 2 cpp_process_json_str   1.94ms     2.02ms    5.32ms      489.      497ms    243

library(ggplot2)
# ΠŸΡ€ΠΎΠ²Π΅Π΄Π΅Π½ΠΈΠ΅ Π·Π°ΠΌΠ΅Ρ€Π°
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    .data <- tmp_data[sample(seq_len(.N), batch_size), drawing]
    bench::mark(
      r_process_json_vector(.data, scale = 0.5),
      cpp_process_json_vector(.data,  scale = 0.5),
      min_iterations = 50,
      check = FALSE
    )
  }
)

res_bench[, cols]

#    expression   batch_size      min   median      max `itr/sec` total_time n_itr
#    <chr>             <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
#  1 r                   16   50.61ms  53.34ms  54.82ms    19.1     471.13ms     9
#  2 cpp                 16    4.46ms   5.39ms   7.78ms   192.      474.09ms    91
#  3 r                   32   105.7ms 109.74ms 212.26ms     7.69        6.5s    50
#  4 cpp                 32    7.76ms  10.97ms  15.23ms    95.6     522.78ms    50
#  5 r                   64  211.41ms 226.18ms 332.65ms     3.85      12.99s    50
#  6 cpp                 64   25.09ms  27.34ms  32.04ms    36.0        1.39s    50
#  7 r                  128   534.5ms 627.92ms 659.08ms     1.61      31.03s    50
#  8 cpp                128   56.37ms  58.46ms  66.03ms    16.9        2.95s    50
#  9 r                  256     1.15s    1.18s    1.29s     0.851     58.78s    50
# 10 cpp                256  114.97ms 117.39ms 130.09ms     8.45       5.92s    50
# 11 r                  512     2.09s    2.15s    2.32s     0.463       1.8m    50
# 12 cpp                512  230.81ms  235.6ms 261.99ms     4.18      11.97s    50
# 13 r                 1024        4s    4.22s     4.4s     0.238       3.5m    50
# 14 cpp               1024  410.48ms 431.43ms 462.44ms     2.33      21.45s    50

ggplot(res_bench, aes(x = factor(batch_size), y = median, 
                      group =  expression, color = expression)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal() +
  scale_color_discrete(name = "", labels = c("cpp", "r")) +
  theme(legend.position = "bottom") 

Ukuqashelwa kwe-Doodle esheshayo: ubenza kanjani ubungane ne-R, C++ kanye namanethiwekhi emizwa

Njengoba ubona, ukukhuphuka kwejubane kubonakale kubaluleke kakhulu, futhi akunakwenzeka ukuthola ikhodi ye-C++ ngokuhambisana nekhodi engu-R.

3. Ama-Iterators okukhipha amaqoqo kusizindalwazi

I-R inedumela elifaneleka kahle lokucubungula idatha elilingana ne-RAM, kuyilapho i-Python ibonakala kakhulu ngokucubungula idatha okuphindaphindiwe, okukuvumela ukuba usebenzise kalula futhi ngokwemvelo izibalo ezingaphandle kwe-core (izibalo usebenzisa imemori yangaphandle). Isibonelo sakudala nesifanelekile kithi kumongo wenkinga echaziwe amanethiwekhi ajulile e-neural aqeqeshwa indlela yokwehla kwegradient ngokulinganisa kokuthambekela esinyathelweni ngasinye kusetshenziswa ingxenye encane yokubuka, noma inqwaba encane.

Izinhlaka zokufunda ezijulile ezibhalwe nge-Python zinamakilasi akhethekile asebenzisa ama-iterators asekelwe kudatha: amathebula, izithombe kumafolda, amafomethi kanambambili, njll. Ungasebenzisa okukhethwa kukho okwenziwe ngomumo noma ubhale eyakho imisebenzi ethile. Ku-R singasebenzisa ngokunenzuzo zonke izici zomtapo wezincwadi wePython amakhamera nama-backends ahlukahlukene usebenzisa iphakheji yegama elifanayo, elisebenza phezu kwephakheji phinda. Lesi sakamuva sifanelwe isihloko eside esihlukile; ayikuvumeli kuphela ukuthi usebenzise ikhodi yePython kusuka ku-R, kodwa futhi ikuvumela ukuthi udlulise izinto phakathi kweseshini ye-R ne-Python, wenze ngokuzenzakalela zonke izinhlobo zokuguqulwa ezidingekayo.

Sisuse isidingo sokugcina yonke idatha ku-RAM ngokusebenzisa i-MonetDBLite, wonke umsebenzi "wenethiwekhi ye-neural" uzokwenziwa ngekhodi yoqobo ku-Python, kufanele sibhale i-iterator phezu kwedatha, ngoba akukho lutho olulungile. esimweni esinjalo ku-R noma ku-Python. Empeleni kunezidingo ezimbili kuphela zayo: kufanele ibuyisele amaqoqo ku-loop engapheli futhi ilondoloze isimo sayo phakathi kokuphindaphinda (okugcina ku-R kwenziwa ngendlela elula kusetshenziswa ukuvala). Ngaphambilini, bekudingeka ukuguqula ngokusobala ama-R amalungu afanayo abe ama-numpy array ngaphakathi kwe-iterator, kodwa inguqulo yamanje yephakheji. amakhamera uyazenzela.

I-iterator yokuqeqeshwa nokuqinisekisa idatha ivele kanje:

I-Iterator yokuqeqeshwa kanye nedatha yokuqinisekisa

train_generator <- function(db_connection = con,
                            samples_index,
                            num_classes = 340,
                            batch_size = 32,
                            scale = 1,
                            color = FALSE,
                            imagenet_preproc = FALSE) {
  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° Π°Ρ€Π³ΡƒΠΌΠ΅Π½Ρ‚ΠΎΠ²
  checkmate::assert_class(con, "DBIConnection")
  checkmate::assert_integerish(samples_index)
  checkmate::assert_count(num_classes)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # ΠŸΠ΅Ρ€Π΅ΠΌΠ΅ΡˆΠΈΠ²Π°Π΅ΠΌ, Ρ‡Ρ‚ΠΎΠ±Ρ‹ Π±Ρ€Π°Ρ‚ΡŒ ΠΈ ΡƒΠ΄Π°Π»ΡΡ‚ΡŒ ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Π½Π½Ρ‹Π΅ индСксы Π±Π°Ρ‚Ρ‡Π΅ΠΉ ΠΏΠΎ порядку
  dt <- data.table::data.table(id = sample(samples_index))
  # ΠŸΡ€ΠΎΡΡ‚Π°Π²Π»ΡΠ΅ΠΌ Π½ΠΎΠΌΠ΅Ρ€Π° Π±Π°Ρ‚Ρ‡Π΅ΠΉ
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  # ΠžΡΡ‚Π°Π²Π»ΡΠ΅ΠΌ Ρ‚ΠΎΠ»ΡŒΠΊΠΎ ΠΏΠΎΠ»Π½Ρ‹Π΅ Π±Π°Ρ‚Ρ‡ΠΈ ΠΈ индСксируСм
  dt <- dt[, if (.N == batch_size) .SD, keyby = batch]
  # УстанавливаСм счётчик
  i <- 1
  # ΠšΠΎΠ»ΠΈΡ‡Π΅ΡΡ‚Π²ΠΎ Π±Π°Ρ‚Ρ‡Π΅ΠΉ
  max_i <- dt[, max(batch)]

  # ΠŸΠΎΠ΄Π³ΠΎΡ‚ΠΎΠ²ΠΊΠ° выраТСния для Π²Ρ‹Π³Ρ€ΡƒΠ·ΠΊΠΈ
  sql <- sprintf(
    "PREPARE SELECT drawing, label_int FROM doodles WHERE id IN (%s)",
    paste(rep("?", batch_size), collapse = ",")
  )
  res <- DBI::dbSendQuery(con, sql)

  # Аналог keras::to_categorical
  to_categorical <- function(x, num) {
    n <- length(x)
    m <- numeric(n * num)
    m[x * n + seq_len(n)] <- 1
    dim(m) <- c(n, num)
    return(m)
  }

  # Π—Π°ΠΌΡ‹ΠΊΠ°Π½ΠΈΠ΅
  function() {
    # НачинаСм Π½ΠΎΠ²ΡƒΡŽ эпоху
    if (i > max_i) {
      dt[, id := sample(id)]
      data.table::setkey(dt, batch)
      # БбрасываСм счётчик
      i <<- 1
      max_i <<- dt[, max(batch)]
    }

    # ID для Π²Ρ‹Π³Ρ€ΡƒΠ·ΠΊΠΈ Π΄Π°Π½Π½Ρ‹Ρ…
    batch_ind <- dt[batch == i, id]
    # Π’Ρ‹Π³Ρ€ΡƒΠ·ΠΊΠ° Π΄Π°Π½Π½Ρ‹Ρ…
    batch <- DBI::dbFetch(DBI::dbBind(res, as.list(batch_ind)), n = -1)

    # Π£Π²Π΅Π»ΠΈΡ‡ΠΈΠ²Π°Π΅ΠΌ счётчик
    i <<- i + 1

    # ΠŸΠ°Ρ€ΡΠΈΠ½Π³ JSON ΠΈ ΠΏΠΎΠ΄Π³ΠΎΡ‚ΠΎΠ²ΠΊΠ° массива
    batch_x <- cpp_process_json_vector(batch$drawing, scale = scale, color = color)
    if (imagenet_preproc) {
      # Π¨ΠΊΠ°Π»ΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ c ΠΈΠ½Ρ‚Π΅Ρ€Π²Π°Π»Π° [0, 1] Π½Π° ΠΈΠ½Ρ‚Π΅Ρ€Π²Π°Π» [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }

    batch_y <- to_categorical(batch$label_int, num_classes)
    result <- list(batch_x, batch_y)
    return(result)
  }
}

Umsebenzi uthatha njengokufaka okuguquguqukayo nokuxhumana kusizindalwazi, izinombolo zemigqa esetshenzisiwe, inani lamakilasi, usayizi weqoqo, isikali (scale = 1 ihambisana nokunikeza izithombe zamaphikseli angu-256x256, scale = 0.5 β€” 128x128 pixels), inkomba yombala (color = FALSE icacisa ukunikezwa nge-grayscale uma isetshenziswa color = TRUE i-stroke ngayinye idwetshwa ngombala omusha) kanye nenkomba yokucubungula ngaphambilini yamanethiwekhi aqeqeshwe kusengaphambili ku-imagenet. Lokhu kokugcina kuyadingeka ukuze kukalwe amanani amaphikseli ukusuka ku-interval [0, 1] kuya ku-interval [-1, 1], eyayisetshenziswa lapho kuqeqeshwa abahlinzekiwe. amakhamera amamodeli.

Umsebenzi wangaphandle uqukethe ukuhlola uhlobo lwe-agumenti, ithebula data.table ngezinombolo zomugqa oxutshwe ngokungahleliwe kusuka samples_index kanye nezinombolo zeqoqo, ikhawunta kanye nenani eliphakeme lamaqoqo, kanye nesisho se-SQL sokukhipha idatha kusizindalwazi. Ukwengeza, sichaze i-analogue esheshayo yomsebenzi ngaphakathi keras::to_categorical(). Sisebenzise cishe yonke idatha ekuqeqesheni, sishiya ingxenye yephesenti ukuze iqinisekiswe, ngakho usayizi wenkathi ukhawulelwe ipharamitha. steps_per_epoch uma ebizwa keras::fit_generator(), kanye nesimo if (i > max_i) isebenzele i-iterator yokuqinisekisa kuphela.

Emsebenzini wangaphakathi, izinkomba zemigqa zibuyiswa eqeqebeni elilandelayo, amarekhodi alayishwa kusizindalwazi ngokukhula kwe-batch counter, i-JSON yokuhlukanisa (umsebenzi cpp_process_json_vector(), ibhalwe ngo-C++) futhi idale amalungu afanayo ahambisana nezithombe. Bese kwakhiwa ama-vector ashisayo anamalebula ekilasi, ama-array anamanani e-pixel namalebula ahlanganiswa abe uhlu, okuyinani lokubuyisela. Ukuze sisheshise umsebenzi, sisebenzise ukudalwa kwezinkomba kumathebula data.table kanye nokuguqulwa ngesixhumanisi - ngaphandle kwalawa "chips" wephakheji idatha.table Kunzima kakhulu ukucabanga ukusebenza ngempumelelo nganoma yiliphi inani elibalulekile ledatha ku-R.

Imiphumela yezilinganiso zejubane ku-laptop ye-Core i5 imi kanje:

Ibhentshimakhi ye-Iterator

library(Rcpp)
library(keras)
library(ggplot2)

source("utils/rcpp.R")
source("utils/keras_iterator.R")

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

ind <- seq_len(DBI::dbGetQuery(con, "SELECT count(*) FROM doodles")[[1L]])
num_classes <- DBI::dbGetQuery(con, "SELECT max(label_int) + 1 FROM doodles")[[1L]]

# Π˜Π½Π΄Π΅ΠΊΡΡ‹ для ΠΎΠ±ΡƒΡ‡Π°ΡŽΡ‰Π΅ΠΉ Π²Ρ‹Π±ΠΎΡ€ΠΊΠΈ
train_ind <- sample(ind, floor(length(ind) * 0.995))
# Π˜Π½Π΄Π΅ΠΊΡΡ‹ для ΠΏΡ€ΠΎΠ²Π΅Ρ€ΠΎΡ‡Π½ΠΎΠΉ Π²Ρ‹Π±ΠΎΡ€ΠΊΠΈ
val_ind <- ind[-train_ind]
rm(ind)
# ΠšΠΎΡΡ„Ρ„ΠΈΡ†ΠΈΠ΅Π½Ρ‚ ΠΌΠ°ΡΡˆΡ‚Π°Π±Π°
scale <- 0.5

# ΠŸΡ€ΠΎΠ²Π΅Π΄Π΅Π½ΠΈΠ΅ Π·Π°ΠΌΠ΅Ρ€Π°
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    it1 <- train_generator(
      db_connection = con,
      samples_index = train_ind,
      num_classes = num_classes,
      batch_size = batch_size,
      scale = scale
    )
    bench::mark(
      it1(),
      min_iterations = 50L
    )
  }
)
# ΠŸΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€Ρ‹ Π±Π΅Π½Ρ‡ΠΌΠ°Ρ€ΠΊΠ°
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16     25ms  64.36ms   92.2ms     15.9       3.09s    49
# 2         32   48.4ms 118.13ms 197.24ms     8.17       5.88s    48
# 3         64   69.3ms 117.93ms 181.14ms     8.57       5.83s    50
# 4        128  157.2ms 240.74ms 503.87ms     3.85      12.71s    49
# 5        256  359.3ms 613.52ms 988.73ms     1.54       30.5s    47
# 6        512  884.7ms    1.53s    2.07s     0.674      1.11m    45
# 7       1024     2.7s    3.83s    5.47s     0.261      2.81m    44

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
    geom_point() +
    geom_line() +
    ylab("median time, s") +
    theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Ukuqashelwa kwe-Doodle esheshayo: ubenza kanjani ubungane ne-R, C++ kanye namanethiwekhi emizwa

Uma unenani elanele le-RAM, ungasheshisa kakhulu ukusebenza kwe-database ngokuyidlulisela kule RAM efanayo (32 GB yanele umsebenzi wethu). Ku-Linux, ukwahlukanisa kufakwe ngokuzenzakalelayo /dev/shm, ethatha umthamo ofika kuhhafu we-RAM. Ungagqamisa okwengeziwe ngokuhlela /etc/fstabukuze uthole irekhodi like tmpfs /dev/shm tmpfs defaults,size=25g 0 0. Qiniseka ukuthi uqalisa kabusha futhi uhlole umphumela ngokusebenzisa umyalo df -h.

I-iterator yedatha yokuhlola ibonakala ilula kakhulu, njengoba idathasethi yokuhlola ilingana ngokuphelele ne-RAM:

I-Iterator yedatha yokuhlola

test_generator <- function(dt,
                           batch_size = 32,
                           scale = 1,
                           color = FALSE,
                           imagenet_preproc = FALSE) {

  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° Π°Ρ€Π³ΡƒΠΌΠ΅Π½Ρ‚ΠΎΠ²
  checkmate::assert_data_table(dt)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # ΠŸΡ€ΠΎΡΡ‚Π°Π²Π»ΡΠ΅ΠΌ Π½ΠΎΠΌΠ΅Ρ€Π° Π±Π°Ρ‚Ρ‡Π΅ΠΉ
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  data.table::setkey(dt, batch)
  i <- 1
  max_i <- dt[, max(batch)]

  # Π—Π°ΠΌΡ‹ΠΊΠ°Π½ΠΈΠ΅
  function() {
    batch_x <- cpp_process_json_vector(dt[batch == i, drawing], 
                                       scale = scale, color = color)
    if (imagenet_preproc) {
      # Π¨ΠΊΠ°Π»ΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ c ΠΈΠ½Ρ‚Π΅Ρ€Π²Π°Π»Π° [0, 1] Π½Π° ΠΈΠ½Ρ‚Π΅Ρ€Π²Π°Π» [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }
    result <- list(batch_x)
    i <<- i + 1
    return(result)
  }
}

4. Ukukhethwa kwemodeli yezakhiwo

Isakhiwo sokuqala esasetshenziswa kwaba I-mobilenet v1, izici okuxoxwa ngazo ku lokhu umyalezo. Ifakwe njengokujwayelekile amakhamera futhi, ngokufanelekile, iyatholakala kuphakheji yegama elifanayo lika-R. Kodwa lapho uzama ukuyisebenzisa ngezithombe zesiteshi esisodwa, kwavela into eyinqaba: i-tensor yokufaka kufanele ibe nobukhulu njalo. (batch, height, width, 3), okungukuthi, inani lamashaneli alinakushintshwa. Akukho mkhawulo onjalo kuPython, ngakho-ke sashesha futhi sabhala ukuqaliswa kwethu kwalesi sakhiwo, silandela isihloko sokuqala (ngaphandle kokuyeka okukunguqulo ye-keras):

I-Mobilenet v1 architecture

library(keras)

top_3_categorical_accuracy <- custom_metric(
    name = "top_3_categorical_accuracy",
    metric_fn = function(y_true, y_pred) {
         metric_top_k_categorical_accuracy(y_true, y_pred, k = 3)
    }
)

layer_sep_conv_bn <- function(object, 
                              filters,
                              alpha = 1,
                              depth_multiplier = 1,
                              strides = c(2, 2)) {

  # NB! depth_multiplier !=  resolution multiplier
  # https://github.com/keras-team/keras/issues/10349

  layer_depthwise_conv_2d(
    object = object,
    kernel_size = c(3, 3), 
    strides = strides,
    padding = "same",
    depth_multiplier = depth_multiplier
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() %>%
  layer_conv_2d(
    filters = filters * alpha,
    kernel_size = c(1, 1), 
    strides = c(1, 1)
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() 
}

get_mobilenet_v1 <- function(input_shape = c(224, 224, 1),
                             num_classes = 340,
                             alpha = 1,
                             depth_multiplier = 1,
                             optimizer = optimizer_adam(lr = 0.002),
                             loss = "categorical_crossentropy",
                             metrics = c("categorical_crossentropy",
                                         top_3_categorical_accuracy)) {

  inputs <- layer_input(shape = input_shape)

  outputs <- inputs %>%
    layer_conv_2d(filters = 32, kernel_size = c(3, 3), strides = c(2, 2), padding = "same") %>%
    layer_batch_normalization() %>% 
    layer_activation_relu() %>%
    layer_sep_conv_bn(filters = 64, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(1, 1)) %>%
    layer_global_average_pooling_2d() %>%
    layer_dense(units = num_classes) %>%
    layer_activation_softmax()

    model <- keras_model(
      inputs = inputs,
      outputs = outputs
    )

    model %>% compile(
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )

    return(model)
}

Ububi bale ndlela busobala. Ngifuna ukuhlola amamodeli amaningi, kodwa ngokuphambene nalokho, angifuni ukubhala kabusha i-architecture ngayinye ngesandla. Siphinde sancishwa ithuba lokusebenzisa izisindo zamamodeli aqeqeshwe ngaphambilini ku-imagenet. Njengokuvamile, ukufunda imibhalo kwasiza. Umsebenzi get_config() ikuvumela ukuthi uthole incazelo yemodeli ngendlela efanelekile ukuhlela (base_model_conf$layers - uhlu olujwayelekile R), kanye nomsebenzi from_config() yenza ukuguqulwa okuhlanekezelwe entweni eyimodeli:

base_model_conf <- get_config(base_model)
base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
base_model <- from_config(base_model_conf)

Manje akunzima ukubhala umsebenzi wendawo yonke ukuze uthole noma yikuphi okunikeziwe amakhamera amamodeli anesisindo noma angenawo aqeqeshwe ku-imagenet:

Umsebenzi wokulayisha izakhiwo esezilungile

get_model <- function(name = "mobilenet_v2",
                      input_shape = NULL,
                      weights = "imagenet",
                      pooling = "avg",
                      num_classes = NULL,
                      optimizer = keras::optimizer_adam(lr = 0.002),
                      loss = "categorical_crossentropy",
                      metrics = NULL,
                      color = TRUE,
                      compile = FALSE) {
  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° Π°Ρ€Π³ΡƒΠΌΠ΅Π½Ρ‚ΠΎΠ²
  checkmate::assert_string(name)
  checkmate::assert_integerish(input_shape, lower = 1, upper = 256, len = 3)
  checkmate::assert_count(num_classes)
  checkmate::assert_flag(color)
  checkmate::assert_flag(compile)

  # ΠŸΠΎΠ»ΡƒΡ‡Π°Π΅ΠΌ ΠΎΠ±ΡŠΠ΅ΠΊΡ‚ ΠΈΠ· ΠΏΠ°ΠΊΠ΅Ρ‚Π° keras
  model_fun <- get0(paste0("application_", name), envir = asNamespace("keras"))
  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° наличия ΠΎΠ±ΡŠΠ΅ΠΊΡ‚Π° Π² ΠΏΠ°ΠΊΠ΅Ρ‚Π΅
  if (is.null(model_fun)) {
    stop("Model ", shQuote(name), " not found.", call. = FALSE)
  }

  base_model <- model_fun(
    input_shape = input_shape,
    include_top = FALSE,
    weights = weights,
    pooling = pooling
  )

  # Если ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠ΅ Π½Π΅ Ρ†Π²Π΅Ρ‚Π½ΠΎΠ΅, мСняСм Ρ€Π°Π·ΠΌΠ΅Ρ€Π½ΠΎΡΡ‚ΡŒ Π²Ρ…ΠΎΠ΄Π°
  if (!color) {
    base_model_conf <- keras::get_config(base_model)
    base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
    base_model <- keras::from_config(base_model_conf)
  }

  predictions <- keras::get_layer(base_model, "global_average_pooling2d_1")$output
  predictions <- keras::layer_dense(predictions, units = num_classes, activation = "softmax")
  model <- keras::keras_model(
    inputs = base_model$input,
    outputs = predictions
  )

  if (compile) {
    keras::compile(
      object = model,
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )
  }

  return(model)
}

Uma usebenzisa izithombe zesiteshi esisodwa, azikho izisindo eziqeqeshwe kusengaphambili ezisetshenziswayo. Lokhu kungalungiswa: kusetshenziswa umsebenzi get_weights() thola izisindo eziyimodeli ngendlela yohlu lwezinhlaka ezingu-R, shintsha ubukhulu besici sokuqala salolu hlu (ngokuthatha isiteshi sombala owodwa noma ukuhlukanisa zonke ezintathu), bese ulayisha izisindo uzibuyisele kumodeli nomsebenzi. set_weights(). Asizange sengeze lokhu kusebenza, ngoba kulesi sigaba kwase kucacile ukuthi kukhiqiza kakhulu ukusebenza ngezithombe ezinemibala.

Senze ucwaningo oluningi sisebenzisa izinguqulo ze-mobilenet 1 no-2, kanye ne-resnet34. Izakhiwo zesimanje ezengeziwe ezifana ne-SE-ResNeXt zenze kahle kulo mncintiswano. Ngeshwa, besingenakho ukusebenzisa osekulungele, futhi asibhalanga okwethu (kodwa sizobhala nakanjani).

5. I-Parameterization yemibhalo

Ukuze kube lula, yonke ikhodi yokuqala ukuqeqeshwa yaklanywa njengeskripthi esisodwa, esenziwa ngepharamitha idokodo kanje:

doc <- '
Usage:
  train_nn.R --help
  train_nn.R --list-models
  train_nn.R [options]

Options:
  -h --help                   Show this message.
  -l --list-models            List available models.
  -m --model=<model>          Neural network model name [default: mobilenet_v2].
  -b --batch-size=<size>      Batch size [default: 32].
  -s --scale-factor=<ratio>   Scale factor [default: 0.5].
  -c --color                  Use color lines [default: FALSE].
  -d --db-dir=<path>          Path to database directory [default: Sys.getenv("db_dir")].
  -r --validate-ratio=<ratio> Validate sample ratio [default: 0.995].
  -n --n-gpu=<number>         Number of GPUs [default: 1].
'
args <- docopt::docopt(doc)

Iphakheji idokodo imele ukuqaliswa http://docopt.org/ ye-R. Ngosizo lwayo, imibhalo yethulwa ngemiyalo elula efana Rscript bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db noma ./bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db, uma ifayela train_nn.R iyasebenziseka (lo myalo uzoqala ukuqeqesha imodeli resnet50 ezithombeni ezinemibala emithathu ezikala amaphikseli angu-128x128, isizindalwazi kufanele sibe kufolda /home/andrey/doodle_db). Ungangeza isivinini sokufunda, uhlobo lwe-optimizer, nanoma yimiphi eminye imingcele ongayenza ohlwini. Ngesikhathi sokulungiselela ukushicilelwa, kwavela ukuthi i-architecture mobilenet_v2 kusukela enguqulweni yamanje amakhamera ku-R ukusetshenziswa ayikwazi ngenxa yezinguquko ezinganakwa ephaketheni lika-R, silinde ukuthi balilungise.

Le ndlela yenze kwaba nokwenzeka ukusheshisa kakhulu izivivinyo ngamamodeli ahlukene uma kuqhathaniswa nokwethulwa okungokwesiko kwemibhalo ku-RStudio (siphawula iphakheji njengenye indlela engenzeka. ama-tfrun). Kodwa inzuzo enkulu yikhono lokuphatha kalula ukwethulwa kwemibhalo ku-Docker noma kuseva nje, ngaphandle kokufaka i-RStudio yalokhu.

6. I-Dockerization yemibhalo

Sisebenzise i-Docker ukuze siqinisekise ukuphatheka kwemvelo kumamodeli okuqeqesha phakathi kwamalungu eqembu kanye nokuthunyelwa ngokushesha emafini. Ungaqala ukujwayelana naleli thuluzi, elingajwayelekile kumhleli we-R, nge lokhu uchungechunge lokushicilelwe noma isifundo sevidiyo.

I-Docker ikuvumela ukuthi udale izithombe zakho kusuka ekuqaleni futhi usebenzise ezinye izithombe njengesisekelo sokudala esakho. Lapho sihlaziya izinketho ezitholakalayo, safinyelela esiphethweni sokuthi ukufaka abashayeli be-NVIDIA, CUDA+cuDNN kanye nemitapo yolwazi yePython kuyingxenye ecacile yesithombe, futhi sanquma ukuthatha isithombe esisemthethweni njengesisekelo. tensorflow/tensorflow:1.12.0-gpu, wengeza amaphakheji we-R adingekayo lapho.

Ifayela lokugcina le-docker lalibukeka kanje:

I-Dockerfile

FROM tensorflow/tensorflow:1.12.0-gpu

MAINTAINER Artem Klevtsov <[email protected]>

SHELL ["/bin/bash", "-c"]

ARG LOCALE="en_US.UTF-8"
ARG APT_PKG="libopencv-dev r-base r-base-dev littler"
ARG R_BIN_PKG="futile.logger checkmate data.table rcpp rapidjsonr dbi keras jsonlite curl digest remotes"
ARG R_SRC_PKG="xtensor RcppThread docopt MonetDBLite"
ARG PY_PIP_PKG="keras"
ARG DIRS="/db /app /app/data /app/models /app/logs"

RUN source /etc/os-release && 
    echo "deb https://cloud.r-project.org/bin/linux/ubuntu ${UBUNTU_CODENAME}-cran35/" > /etc/apt/sources.list.d/cran35.list && 
    apt-key adv --keyserver keyserver.ubuntu.com --recv-keys E084DAB9 && 
    add-apt-repository -y ppa:marutter/c2d4u3.5 && 
    add-apt-repository -y ppa:timsc/opencv-3.4 && 
    apt-get update && 
    apt-get install -y locales && 
    locale-gen ${LOCALE} && 
    apt-get install -y --no-install-recommends ${APT_PKG} && 
    ln -s /usr/lib/R/site-library/littler/examples/install.r /usr/local/bin/install.r && 
    ln -s /usr/lib/R/site-library/littler/examples/install2.r /usr/local/bin/install2.r && 
    ln -s /usr/lib/R/site-library/littler/examples/installGithub.r /usr/local/bin/installGithub.r && 
    echo 'options(Ncpus = parallel::detectCores())' >> /etc/R/Rprofile.site && 
    echo 'options(repos = c(CRAN = "https://cloud.r-project.org"))' >> /etc/R/Rprofile.site && 
    apt-get install -y $(printf "r-cran-%s " ${R_BIN_PKG}) && 
    install.r ${R_SRC_PKG} && 
    pip install ${PY_PIP_PKG} && 
    mkdir -p ${DIRS} && 
    chmod 777 ${DIRS} && 
    rm -rf /tmp/downloaded_packages/ /tmp/*.rds && 
    rm -rf /var/lib/apt/lists/*

COPY utils /app/utils
COPY src /app/src
COPY tests /app/tests
COPY bin/*.R /app/

ENV DBDIR="/db"
ENV CUDA_HOME="/usr/local/cuda"
ENV PATH="/app:${PATH}"

WORKDIR /app

VOLUME /db
VOLUME /app

CMD bash

Ukuze kube lula, amaphakheji asetshenzisiwe afakwa ezintweni eziguquguqukayo; inqwaba yemibhalo ebhaliwe ikopishwa ngaphakathi kweziqukathi ngesikhathi sokuhlanganiswa. Siphinde sashintsha igobolondo lomyalo kwaba /bin/bash ukuze kube lula ukusetshenziswa kokuqukethwe /etc/os-release. Lokhu kugweme isidingo sokucacisa inguqulo ye-OS kukhodi.

Ukwengeza, iskripthi esincane se-bash sabhalwa esikuvumela ukuthi uqalise isitsha esinemiyalo ehlukahlukene. Isibonelo, lokhu kungaba imibhalo yokuqeqesha amanethiwekhi e-neural aye afakwa ngaphambilini ngaphakathi kwesiqukathi, noma igobolondo lomyalo lokususa iphutha nokuqapha ukusebenza kwesiqukathi:

Isikripthi sokuvula isiqukathi

#!/bin/sh

DBDIR=${PWD}/db
LOGSDIR=${PWD}/logs
MODELDIR=${PWD}/models
DATADIR=${PWD}/data
ARGS="--runtime=nvidia --rm -v ${DBDIR}:/db -v ${LOGSDIR}:/app/logs -v ${MODELDIR}:/app/models -v ${DATADIR}:/app/data"

if [ -z "$1" ]; then
    CMD="Rscript /app/train_nn.R"
elif [ "$1" = "bash" ]; then
    ARGS="${ARGS} -ti"
else
    CMD="Rscript /app/train_nn.R $@"
fi

docker run ${ARGS} doodles-tf ${CMD}

Uma lesi sikripthi se-bash siqhutshwa ngaphandle kwamapharamitha, iskripthi sizobizwa ngaphakathi kwesitsha train_nn.R ngamavelu azenzakalelayo; uma i-agumenti yendawo yokuqala ithi "bash", isiqukathi sizoqala ngokuhlanganyela negobolondo lomyalo. Kuzo zonke ezinye izimo, amanani ama-agumenti ezikhundla athathelwa indawo: CMD="Rscript /app/train_nn.R $@".

Kuyaphawuleka ukuthi izinkomba ezinedatha yomthombo kanye nedathabheyisi, kanye nenkomba yokugcina amamodeli aqeqeshiwe, zifakwe ngaphakathi kwesiqukathi kusuka kusistimu yokusingatha, okukuvumela ukuthi ufinyelele imiphumela yeskripthi ngaphandle kokukhwabanisa okungadingekile.

7. Ukusebenzisa ama-GPU amaningi ku-Google Cloud

Esinye sezici zomncintiswano kwakuyidatha enomsindo kakhulu (bona isithombe sesihloko, esibolekwe [email protected] ku-ODS slack). Amaqoqo amakhulu asiza ukulwa nalokhu, futhi ngemva kokuhlola ku-PC ene-GPU engu-1, sinqume ukwenza amamodeli okuqeqesha kuma-GPU amaningana efwini. I-GoogleCloud esetshenzisiwe (umhlahlandlela omuhle wezinto eziyisisekelo) ngenxa yokukhethwa okukhulu kokucushwa okutholakalayo, amanani aphusile kanye nebhonasi engu-$300. Ngenxa yokuhaha, nga-oda isibonelo se-4xV100 nge-SSD kanye nethoni ye-RAM, futhi lokho kwaba iphutha elikhulu. Umshini onjalo udla imali ngokushesha; ungahamba uyozama ngaphandle kwepayipi elifakazelwe. Ngezinjongo zemfundo, kungcono ukuthatha i-K80. Kodwa inani elikhulu le-RAM lafika kahle - i-SSD yefu ayizange ihlabe umxhwele ngokusebenza kwayo, ngakho-ke i-database yadluliselwa dev/shm.

Okuthakaselayo kakhulu ucezu lwekhodi olunesibopho sokusebenzisa ama-GPU amaningi. Okokuqala, imodeli idalwe ku-CPU kusetshenziswa umphathi wokuqukethwe, njengakuPython:

with(tensorflow::tf$device("/cpu:0"), {
  model_cpu <- get_model(
    name = model_name,
    input_shape = input_shape,
    weights = weights,
    metrics =(top_3_categorical_accuracy,
    compile = FALSE
  )
})

Khona-ke imodeli engahlanganisiwe (lokhu kubalulekile) ikopishelwa enanini elinikeziwe lama-GPU atholakalayo, futhi ngemva kwalokho ihlanganiswa:

model <- keras::multi_gpu_model(model_cpu, gpus = n_gpu)
keras::compile(
  object = model,
  optimizer = keras::optimizer_adam(lr = 0.0004),
  loss = "categorical_crossentropy",
  metrics = c(top_3_categorical_accuracy)
)

Indlela yakudala yokufriza zonke izendlalelo ngaphandle kweyokugcina, ukuqeqesha isendlalelo sokugcina, ukungaqandi futhi ukuqeqesha kabusha yonke imodeli yama-GPU amaningana ayikwazanga ukusetshenziswa.

Ukuqeqeshwa kwaqashwa ngaphandle kokusetshenziswa. tensorboard, sizikhawulele ekurekhodeni amalogi nokulondoloza amamodeli anamagama anolwazi ngemva kwenkathi ngayinye:

Ama-callback

# Π¨Π°Π±Π»ΠΎΠ½ ΠΈΠΌΠ΅Π½ΠΈ Ρ„Π°ΠΉΠ»Π° Π»ΠΎΠ³Π°
log_file_tmpl <- file.path("logs", sprintf(
  "%s_%d_%dch_%s.csv",
  model_name,
  dim_size,
  channels,
  format(Sys.time(), "%Y%m%d%H%M%OS")
))
# Π¨Π°Π±Π»ΠΎΠ½ ΠΈΠΌΠ΅Π½ΠΈ Ρ„Π°ΠΉΠ»Π° ΠΌΠΎΠ΄Π΅Π»ΠΈ
model_file_tmpl <- file.path("models", sprintf(
  "%s_%d_%dch_{epoch:02d}_{val_loss:.2f}.h5",
  model_name,
  dim_size,
  channels
))

callbacks_list <- list(
  keras::callback_csv_logger(
    filename = log_file_tmpl
  ),
  keras::callback_early_stopping(
    monitor = "val_loss",
    min_delta = 1e-4,
    patience = 8,
    verbose = 1,
    mode = "min"
  ),
  keras::callback_reduce_lr_on_plateau(
    monitor = "val_loss",
    factor = 0.5, # ΡƒΠΌΠ΅Π½ΡŒΡˆΠ°Π΅ΠΌ lr Π² 2 Ρ€Π°Π·Π°
    patience = 4,
    verbose = 1,
    min_delta = 1e-4,
    mode = "min"
  ),
  keras::callback_model_checkpoint(
    filepath = model_file_tmpl,
    monitor = "val_loss",
    save_best_only = FALSE,
    save_weights_only = FALSE,
    mode = "min"
  )
)

8. Esikhundleni sesiphetho

Inqwaba yezinkinga esihlangabezane nazo azikakaqedwa:

  • Π² amakhamera awukho umsebenzi owenziwe ngomumo wokucinga ngokuzenzakalelayo izinga lokufunda elilungile (i-analogue lr_finder emtapweni wezincwadi fast.ai); Ngomzamo othile, kungenzeka ukufaka ukuqaliswa kwenkampani yangaphandle ku-R, isibonelo, lokhu;
  • njengomphumela wephuzu langaphambilini, akukwazanga ukukhetha isivinini esilungile sokuqeqesha lapho usebenzisa ama-GPU amaningana;
  • kukhona ukushoda kwezakhiwo zenethiwekhi ye-neural yesimanje, ikakhulukazi lezo eziqeqeshwe kusengaphambili ku-imagenet;
  • akukho nqubomgomo yomjikelezo owodwa kanye namazinga okufunda abandlululayo (i-cosine annealing ibiyisicelo sethu kwenziwe, ngiyabonga skydan).

Yiziphi izinto eziwusizo ezifundwe kulo mncintiswano:

  • Kuzingxenyekazi zekhompuyutha ezinamandla aphansi, ungasebenza ngamavolumu ahloniphekile (izikhathi eziningi ubukhulu be-RAM) ngaphandle kobuhlungu. Isikhwama sepulasitiki idatha.table igcina inkumbulo ngenxa yokuguqulwa kwasendaweni kwamathebula, okugwema ukuwakopisha, futhi lapho esetshenziswa kahle, amakhono ayo cishe abonisa isivinini esikhulu kunazo zonke phakathi kwawo wonke amathuluzi esiwazi ngezilimi zokubhala. Ukugcina idatha kusizindalwazi kukuvumela, ezimweni eziningi, ukuthi ungacabangi nhlobo mayelana nesidingo sokuminyanisa yonke idathasethi ku-RAM.
  • Imisebenzi enensayo ku-R ingathathelwa indawo esheshayo ku-C++ kusetshenziswa iphakheji Rcpp. Uma ngaphezu kokusebenzisa RcppThread noma I-RcppParallel, sithola ukuqaliswa okunemicu eminingi ye-cross-platform, ngakho-ke asikho isidingo sokufanisa ikhodi ezingeni lika-R.
  • Iphakheji Rcpp ingasetshenziswa ngaphandle kolwazi olunzulu lwe-C++, ubuncane obudingekayo buchaziwe lapha. Amafayela anhlokweni enani lama-C-labhulali afanayo xtensor etholakala ku-CRAN, okungukuthi, kwakhiwa ingqalasizinda ukuze kuqaliswe amaphrojekthi ahlanganisa ikhodi ye-C++ eyenziwe ngomumo ibe yi-R. Ukufaneleka okwengeziwe ukugqanyiswa kwe-syntax kanye nokuhlaziya ikhodi ye-C++ emile ku-RStudio.
  • idokodo ikuvumela ukuthi uqalise izikripthi eziqukethwe ngokwazo ngamapharamitha. Lokhu kulungele ukusetshenziswa kuseva ekude, kuhl. ngaphansi kwe-docker. Ku-RStudio, akulungile ukwenza amahora amaningi okuhlola ngokuqeqesha amanethiwekhi e-neural, futhi ukufaka i-IDE kuseva ngokwayo akulungile ngaso sonke isikhathi.
  • I-Docker iqinisekisa ukuphatheka kwekhodi kanye nokuphindaphindeka kwemiphumela phakathi konjiniyela abanezinguqulo ezahlukene ze-OS nemitapo yolwazi, kanye nokusebenza kalula kumaseva. Ungakwazi ukwethula lonke ipayipi lokuqeqesha ngomyalo owodwa nje.
  • I-Google Cloud iyindlela evumelana nesabelomali yokuhlola ihadiwe ebizayo, kodwa udinga ukukhetha ukulungiselelwa ngokucophelela.
  • Ukulinganisa ijubane lezingcezu zekhodi ngayinye kuyasiza kakhulu, ikakhulukazi uma uhlanganisa i-R ne-C++, kanye nephakheji. ibhentshi - futhi kulula kakhulu.

Sekukonke lokhu okwenziwayo kube nomvuzo omkhulu futhi siyaqhubeka nokusebenzela ukuxazulula ezinye zezingqinamba eziphakanyiswe.

Source: www.habr.com

Engeza amazwana