Utambuzi wa Doodle wa Chora Haraka: jinsi ya kufanya urafiki na R, C++ na mitandao ya neva

Utambuzi wa Doodle wa Chora Haraka: jinsi ya kufanya urafiki na R, C++ na mitandao ya neva

Habari Habr!

Msimu wa vuli uliopita, Kaggle aliandaa shindano la kuainisha picha zilizochorwa kwa mkono, Utambuzi wa Doodle wa Chora Haraka, ambapo, miongoni mwa mengine, timu ya wanasayansi wa R ilishiriki: Artem Klevtsova, Meneja Philippa ΠΈ Andrey Ogurtsov. Hatutaelezea mashindano kwa undani; hiyo tayari imefanywa ndani uchapishaji wa hivi karibuni.

Wakati huu haikufanya kazi na kilimo cha medali, lakini uzoefu mwingi wa thamani ulipatikana, kwa hivyo ningependa kuiambia jamii kuhusu mambo kadhaa ya kupendeza na muhimu kwenye Kagle na katika kazi ya kila siku. Miongoni mwa mada zilizojadiliwa: maisha magumu bila OpenCV, uchanganuzi wa JSON (mifano hii inachunguza ujumuishaji wa nambari ya C++ kwenye hati au vifurushi katika R kwa kutumia Rcpp), parameterization ya maandishi na dockerization ya suluhisho la mwisho. Nambari zote kutoka kwa ujumbe katika fomu inayofaa kwa utekelezaji inapatikana hazina.

Yaliyomo:

  1. Pakia data kwa ufanisi kutoka CSV hadi MonetDB
  2. Kuandaa batches
  3. Viigizo vya upakuaji wa bechi kutoka kwa hifadhidata
  4. Kuchagua Usanifu wa Mfano
  5. Uwekaji vigezo vya hati
  6. Uboreshaji wa maandishi
  7. Kutumia GPU nyingi kwenye Wingu la Google
  8. Badala ya hitimisho

1. Pakia data kwa ufanisi kutoka kwa CSV hadi kwenye hifadhidata ya MonetDB

Data katika shindano hili haitolewa kwa namna ya picha zilizotengenezwa tayari, lakini katika mfumo wa faili 340 za CSV (faili moja kwa kila darasa) zilizo na JSON zilizo na viwianishi vya uhakika. Kwa kuunganisha pointi hizi na mistari, tunapata picha ya mwisho ya kupima saizi 256x256. Pia kwa kila rekodi kuna lebo inayoonyesha ikiwa picha hiyo ilitambuliwa kwa usahihi na kiainishaji kilichotumiwa wakati mkusanyiko wa data ulipokusanywa, msimbo wa herufi mbili za nchi anakoishi mwandishi wa picha hiyo, kitambulisho cha kipekee, muhuri wa muda. na jina la darasa linalolingana na jina la faili. Toleo lililorahisishwa la data asili lina uzito wa GB 7.4 kwenye kumbukumbu na takriban GB 20 baada ya kufunguliwa, data kamili baada ya kupekua inachukua hadi GB 240. Waandaaji walihakikisha kwamba matoleo yote mawili yametoa tena michoro sawa, kumaanisha kwamba toleo kamili lilikuwa halihitajiki. Kwa hali yoyote, kuhifadhi picha milioni 50 kwenye faili za picha au kwa namna ya safu mara moja ilionekana kuwa haina faida, na tuliamua kuunganisha faili zote za CSV kutoka kwenye kumbukumbu. treni_iliyorahisishwa.zip kwenye hifadhidata na kizazi kinachofuata cha picha za saizi inayohitajika "kwenye kuruka" kwa kila kundi.

Mfumo uliothibitishwa vizuri ulichaguliwa kama DBMS MonetDB, ambayo ni utekelezaji wa R kama kifurushi MonetDBLite. Kifurushi kinajumuisha toleo lililopachikwa la seva ya hifadhidata na hukuruhusu kuchukua seva moja kwa moja kutoka kwa kikao cha R na kufanya kazi nayo hapo. Kuunda hifadhidata na kuiunganisha hufanywa kwa amri moja:

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

Tutahitaji kuunda jedwali mbili: moja kwa data zote, nyingine kwa habari ya huduma kuhusu faili zilizopakuliwa (inafaa ikiwa kitu kitaenda vibaya na mchakato lazima uanzishwe tena baada ya kupakua faili kadhaa):

Kujenga meza

if (!DBI::dbExistsTable(con, "doodles")) {
  DBI::dbCreateTable(
    con = con,
    name = "doodles",
    fields = c(
      "countrycode" = "char(2)",
      "drawing" = "text",
      "key_id" = "bigint",
      "recognized" = "bool",
      "timestamp" = "timestamp",
      "word" = "text"
    )
  )
}

if (!DBI::dbExistsTable(con, "upload_log")) {
  DBI::dbCreateTable(
    con = con,
    name = "upload_log",
    fields = c(
      "id" = "serial",
      "file_name" = "text UNIQUE",
      "uploaded" = "bool DEFAULT false"
    )
  )
}

Njia ya haraka sana ya kupakia data kwenye hifadhidata ilikuwa kunakili faili za CSV moja kwa moja kwa kutumia SQL - amri COPY OFFSET 2 INTO tablename FROM path USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORTAmbapo tablename - jina la meza na path - njia ya faili. Wakati wa kufanya kazi na kumbukumbu, iligunduliwa kuwa utekelezaji uliojengwa ndani unzip katika R haifanyi kazi kwa usahihi na idadi ya faili kutoka kwenye kumbukumbu, kwa hiyo tulitumia mfumo unzip (kwa kutumia parameta getOption("unzip")).

Kazi ya kuandika kwenye hifadhidata

#' @title Π˜Π·Π²Π»Π΅Ρ‡Π΅Π½ΠΈΠ΅ ΠΈ Π·Π°Π³Ρ€ΡƒΠ·ΠΊΠ° Ρ„Π°ΠΉΠ»ΠΎΠ²
#'
#' @description
#' Π˜Π·Π²Π»Π΅Ρ‡Π΅Π½ΠΈΠ΅ CSV-Ρ„Π°ΠΉΠ»ΠΎΠ² ΠΈΠ· ZIP-Π°Ρ€Ρ…ΠΈΠ²Π° ΠΈ Π·Π°Π³Ρ€ΡƒΠ·ΠΊΠ° ΠΈΡ… Π² Π±Π°Π·Ρƒ Π΄Π°Π½Π½Ρ‹Ρ…
#'
#' @param con ΠžΠ±ΡŠΠ΅ΠΊΡ‚ ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΡ ΠΊ Π±Π°Π·Π΅ Π΄Π°Π½Π½Ρ‹Ρ… (класс `MonetDBEmbeddedConnection`).
#' @param tablename НазваниС Ρ‚Π°Π±Π»ΠΈΡ†Ρ‹ Π² Π±Π°Π·Π΅ Π΄Π°Π½Π½Ρ‹Ρ….
#' @oaram zipfile ΠŸΡƒΡ‚ΡŒ ΠΊ ZIP-Π°Ρ€Ρ…ΠΈΠ²Ρƒ.
#' @oaram filename Имя Ρ„Π°ΠΉΠ»Π° Π²Π½ΡƒΡ€ΠΈ ZIP-Π°Ρ€Ρ…ΠΈΠ²Π°.
#' @param preprocess Ѐункция ΠΏΡ€Π΅Π΄ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ, которая Π±ΡƒΠ΄Π΅Ρ‚ ΠΏΡ€ΠΈΠΌΠ΅Π½Π΅Π½Π° ΠΈΠ·Π²Π»Π΅Ρ‡Ρ‘Π½Π½ΠΎΠΌΡƒ Ρ„Π°ΠΉΠ»Ρƒ.
#'   Π”ΠΎΠ»ΠΆΠ½Π° ΠΏΡ€ΠΈΠ½ΠΈΠΌΠ°Ρ‚ΡŒ ΠΎΠ΄ΠΈΠ½ Π°Ρ€Π³ΡƒΠΌΠ΅Π½Ρ‚ `data` (ΠΎΠ±ΡŠΠ΅ΠΊΡ‚ `data.table`).
#'
#' @return `TRUE`.
#'
upload_file <- function(con, tablename, zipfile, filename, preprocess = NULL) {
  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° Π°Ρ€Π³ΡƒΠΌΠ΅Π½Ρ‚ΠΎΠ²
  checkmate::assert_class(con, "MonetDBEmbeddedConnection")
  checkmate::assert_string(tablename)
  checkmate::assert_string(filename)
  checkmate::assert_true(DBI::dbExistsTable(con, tablename))
  checkmate::assert_file_exists(zipfile, access = "r", extension = "zip")
  checkmate::assert_function(preprocess, args = c("data"), null.ok = TRUE)

  # Π˜Π·Π²Π»Π΅Ρ‡Π΅Π½ΠΈΠ΅ Ρ„Π°ΠΉΠ»Π°
  path <- file.path(tempdir(), filename)
  unzip(zipfile, files = filename, exdir = tempdir(), 
        junkpaths = TRUE, unzip = getOption("unzip"))
  on.exit(unlink(file.path(path)))

  # ΠŸΡ€ΠΈΠΌΠ΅Π½ΡΠ΅ΠΌ функция ΠΏΡ€Π΅Π΄ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ
  if (!is.null(preprocess)) {
    .data <- data.table::fread(file = path)
    .data <- preprocess(data = .data)
    data.table::fwrite(x = .data, file = path, append = FALSE)
    rm(.data)
  }

  # Запрос ΠΊ Π‘Π” Π½Π° ΠΈΠΌΠΏΠΎΡ€Ρ‚ CSV
  sql <- sprintf(
    "COPY OFFSET 2 INTO %s FROM '%s' USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORT",
    tablename, path
  )
  # Π’Ρ‹ΠΏΠΎΠ»Π½Π΅Π½ΠΈΠ΅ запроса ΠΊ Π‘Π”
  DBI::dbExecute(con, sql)

  # Π”ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΠ΅ записи ΠΎΠ± ΡƒΡΠΏΠ΅ΡˆΠ½ΠΎΠΉ Π·Π°Π³Ρ€ΡƒΠ·ΠΊΠ΅ Π² ΡΠ»ΡƒΠΆΠ΅Π±Π½ΡƒΡŽ Ρ‚Π°Π±Π»ΠΈΡ†Ρƒ
  DBI::dbExecute(con, sprintf("INSERT INTO upload_log(file_name, uploaded) VALUES('%s', true)",
                              filename))

  return(invisible(TRUE))
}

Ikiwa unahitaji kubadilisha meza kabla ya kuiandika kwenye hifadhidata, inatosha kupitisha hoja preprocess kazi ambayo itabadilisha data.

Msimbo wa kupakia data kwa mpangilio kwenye hifadhidata:

Kuandika data kwenye hifadhidata

# Бписок Ρ„Π°ΠΉΠ»ΠΎΠ² для записи
files <- unzip(zipfile, list = TRUE)$Name

# Бписок ΠΈΡΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠΉ, Ссли Ρ‡Π°ΡΡ‚ΡŒ Ρ„Π°ΠΉΠ»ΠΎΠ² ΡƒΠΆΠ΅ Π±Ρ‹Π»Π° Π·Π°Π³Ρ€ΡƒΠΆΠ΅Π½Π°
to_skip <- DBI::dbGetQuery(con, "SELECT file_name FROM upload_log")[[1L]]
files <- setdiff(files, to_skip)

if (length(files) > 0L) {
  # ЗапускаСм Ρ‚Π°ΠΉΠΌΠ΅Ρ€
  tictoc::tic()
  # ΠŸΡ€ΠΎΠ³Ρ€Π΅ΡΡ Π±Π°Ρ€
  pb <- txtProgressBar(min = 0L, max = length(files), style = 3)
  for (i in seq_along(files)) {
    upload_file(con = con, tablename = "doodles", 
                zipfile = zipfile, filename = files[i])
    setTxtProgressBar(pb, i)
  }
  close(pb)
  # ΠžΡΡ‚Π°Π½Π°Π²Π»ΠΈΠ²Π°Π΅ΠΌ Ρ‚Π°ΠΉΠΌΠ΅Ρ€
  tictoc::toc()
}

# 526.141 sec elapsed - ΠΊΠΎΠΏΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ SSD->SSD
# 558.879 sec elapsed - ΠΊΠΎΠΏΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ USB->SSD

Muda wa kupakia data unaweza kutofautiana kulingana na sifa za kasi ya kiendeshi kilichotumiwa. Kwa upande wetu, kusoma na kuandika ndani ya SSD moja au kutoka kwa gari la flash (faili ya chanzo) hadi SSD (DB) inachukua chini ya dakika 10.

Inachukua sekunde chache zaidi kuunda safu iliyo na lebo ya darasa kamili na safu wima ya faharisi (ORDERED INDEX) na nambari za mstari ambazo uchunguzi utatolewa wakati wa kuunda vikundi:

Kuunda Safu wima na Fahirisi za Ziada

message("Generate lables")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD label_int int"))
invisible(DBI::dbExecute(con, "UPDATE doodles SET label_int = dense_rank() OVER (ORDER BY word) - 1"))

message("Generate row numbers")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD id serial"))
invisible(DBI::dbExecute(con, "CREATE ORDERED INDEX doodles_id_ord_idx ON doodles(id)"))

Ili kutatua tatizo la kuunda kundi kwenye nzi, tulihitaji kufikia kasi ya juu zaidi ya kutoa safu mlalo bila mpangilio kutoka kwa jedwali. doodles. Kwa hili tulitumia mbinu 3. Ya kwanza ilikuwa kupunguza ukubwa wa aina inayohifadhi kitambulisho cha uchunguzi. Katika seti halisi ya data, aina inayohitajika kuhifadhi kitambulisho ni bigint, lakini idadi ya uchunguzi hufanya iwezekanavyo kutoshea vitambulisho vyao, sawa na nambari ya kawaida, katika aina. int. Utafutaji katika kesi hii ni haraka sana. Ujanja wa pili ulikuwa kutumia ORDERED INDEX - tulifikia uamuzi huu kwa nguvu, baada ya kupitia yote yaliyopatikana chaguo. Ya tatu ilikuwa kutumia maswali ya vigezo. Kiini cha njia ni kutekeleza amri mara moja PREPARE na matumizi ya baadaye ya usemi ulioandaliwa wakati wa kuunda rundo la maswali ya aina moja, lakini kwa kweli kuna faida kwa kulinganisha na rahisi. SELECT iligeuka kuwa ndani ya anuwai ya makosa ya takwimu.

Mchakato wa kupakia data hautumii zaidi ya MB 450 ya RAM. Hiyo ni, mbinu iliyoelezewa hukuruhusu kusonga seti za data zenye uzito wa makumi ya gigabytes kwenye karibu vifaa vyovyote vya bajeti, pamoja na vifaa vya bodi moja, ambayo ni nzuri sana.

Kilichobaki ni kupima kasi ya kurejesha data (nasibu) na kutathmini kiwango wakati wa kuchukua sampuli za saizi tofauti:

Kiwango cha hifadhidata

library(ggplot2)

set.seed(0)
# ΠŸΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠ΅ ΠΊ Π±Π°Π·Π΅ Π΄Π°Π½Π½Ρ‹Ρ…
con <- DBI::dbConnect(MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

# Ѐункция для ΠΏΠΎΠ΄Π³ΠΎΡ‚ΠΎΠ²ΠΊΠΈ запроса Π½Π° сторонС сСрвСра
prep_sql <- function(batch_size) {
  sql <- sprintf("PREPARE SELECT id FROM doodles WHERE id IN (%s)",
                 paste(rep("?", batch_size), collapse = ","))
  res <- DBI::dbSendQuery(con, sql)
  return(res)
}

# Ѐункция для извлСчСния Π΄Π°Π½Π½Ρ‹Ρ…
fetch_data <- function(rs, batch_size) {
  ids <- sample(seq_len(n), batch_size)
  res <- DBI::dbFetch(DBI::dbBind(rs, as.list(ids)))
  return(res)
}

# ΠŸΡ€ΠΎΠ²Π΅Π΄Π΅Π½ΠΈΠ΅ Π·Π°ΠΌΠ΅Ρ€Π°
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    rs <- prep_sql(batch_size)
    bench::mark(
      fetch_data(rs, batch_size),
      min_iterations = 50L
    )
  }
)
# ΠŸΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€Ρ‹ Π±Π΅Π½Ρ‡ΠΌΠ°Ρ€ΠΊΠ°
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16   23.6ms  54.02ms  93.43ms     18.8        2.6s    49
# 2         32     38ms  84.83ms 151.55ms     11.4       4.29s    49
# 3         64   63.3ms 175.54ms 248.94ms     5.85       8.54s    50
# 4        128   83.2ms 341.52ms 496.24ms     3.00      16.69s    50
# 5        256  232.8ms 653.21ms 847.44ms     1.58      31.66s    50
# 6        512  784.6ms    1.41s    1.98s     0.740       1.1m    49
# 7       1024  681.7ms    2.72s    4.06s     0.377      2.16m    49

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Utambuzi wa Doodle wa Chora Haraka: jinsi ya kufanya urafiki na R, C++ na mitandao ya neva

2. Kuandaa makundi

Mchakato mzima wa maandalizi ya kundi lina hatua zifuatazo:

  1. Inachanganua JSON kadhaa zilizo na vekta za kamba na kuratibu za vidokezo.
  2. Kuchora mistari ya rangi kulingana na kuratibu za pointi kwenye picha ya ukubwa unaohitajika (kwa mfano, 256 Γ— 256 au 128 Γ— 128).
  3. Kubadilisha picha zinazotokana na tensor.

Kama sehemu ya shindano kati ya kernels za Python, shida ilitatuliwa kimsingi kwa kutumia OpenCV. Mojawapo ya analogi rahisi na dhahiri zaidi katika R ingeonekana kama hii:

Utekelezaji wa JSON hadi Ubadilishaji wa Tensor katika R

r_process_json_str <- function(json, line.width = 3, 
                               color = TRUE, scale = 1) {
  # ΠŸΠ°Ρ€ΡΠΈΠ½Π³ JSON
  coords <- jsonlite::fromJSON(json, simplifyMatrix = FALSE)
  tmp <- tempfile()
  # УдаляСм Π²Ρ€Π΅ΠΌΠ΅Π½Π½Ρ‹ΠΉ Ρ„Π°ΠΉΠ» ΠΏΠΎ Π·Π°Π²Π΅Ρ€ΡˆΠ΅Π½ΠΈΡŽ Ρ„ΡƒΠ½ΠΊΡ†ΠΈΠΈ
  on.exit(unlink(tmp))
  png(filename = tmp, width = 256 * scale, height = 256 * scale, pointsize = 1)
  # ΠŸΡƒΡΡ‚ΠΎΠΉ Π³Ρ€Π°Ρ„ΠΈΠΊ
  plot.new()
  # Π Π°Π·ΠΌΠ΅Ρ€ ΠΎΠΊΠ½Π° Π³Ρ€Π°Ρ„ΠΈΠΊΠ°
  plot.window(xlim = c(256 * scale, 0), ylim = c(256 * scale, 0))
  # Π¦Π²Π΅Ρ‚Π° Π»ΠΈΠ½ΠΈΠΉ
  cols <- if (color) rainbow(length(coords)) else "#000000"
  for (i in seq_along(coords)) {
    lines(x = coords[[i]][[1]] * scale, y = coords[[i]][[2]] * scale, 
          col = cols[i], lwd = line.width)
  }
  dev.off()
  # ΠŸΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ изобраТСния Π² 3-Ρ… ΠΌΠ΅Ρ€Π½Ρ‹ΠΉ массив
  res <- png::readPNG(tmp)
  return(res)
}

r_process_json_vector <- function(x, ...) {
  res <- lapply(x, r_process_json_str, ...)
  # ОбъСдинСниС 3-Ρ… ΠΌΠ΅Ρ€Π½Ρ‹Ρ… массивов ΠΊΠ°Ρ€Ρ‚ΠΈΠ½ΠΎΠΊ Π² 4-Ρ… ΠΌΠ΅Ρ€Π½Ρ‹ΠΉ Π² Ρ‚Π΅Π½Π·ΠΎΡ€
  res <- do.call(abind::abind, c(res, along = 0))
  return(res)
}

Uchoraji unafanywa kwa kutumia zana za kawaida za R na kuhifadhiwa kwa PNG ya muda iliyohifadhiwa kwenye RAM (kwenye Linux, saraka za muda za R ziko kwenye saraka. /tmp, imewekwa kwenye RAM). Faili hii kisha inasomwa kama safu ya pande tatu yenye nambari kuanzia 0 hadi 1. Hii ni muhimu kwa sababu BMP ya kawaida zaidi inaweza kusomwa katika safu mbichi yenye misimbo ya rangi ya hex.

Wacha tujaribu matokeo:

zip_file <- file.path("data", "train_simplified.zip")
csv_file <- "cat.csv"
unzip(zip_file, files = csv_file, exdir = tempdir(), 
      junkpaths = TRUE, unzip = getOption("unzip"))
tmp_data <- data.table::fread(file.path(tempdir(), csv_file), sep = ",", 
                              select = "drawing", nrows = 10000)
arr <- r_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Utambuzi wa Doodle wa Chora Haraka: jinsi ya kufanya urafiki na R, C++ na mitandao ya neva

Kundi lenyewe litaundwa kama ifuatavyo:

res <- r_process_json_vector(tmp_data[1:4, drawing], scale = 0.5)
str(res)
 # num [1:4, 1:128, 1:128, 1:3] 1 1 1 1 1 1 1 1 1 1 ...
 # - attr(*, "dimnames")=List of 4
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL

Utekelezaji huu ulionekana kuwa mzuri kwetu, kwani uundaji wa vikundi vikubwa huchukua muda mrefu, na tuliamua kuchukua fursa ya uzoefu wa wenzetu kwa kutumia maktaba yenye nguvu. OpenCV. Wakati huo hakukuwa na kifurushi kilichotengenezwa tayari kwa R (hakuna sasa), kwa hivyo utekelezaji mdogo wa utendakazi unaohitajika uliandikwa kwa C ++ na kuunganishwa kwa nambari ya R kwa kutumia. Rcpp.

Ili kutatua shida, vifurushi na maktaba zifuatazo zilitumiwa:

  1. OpenCV kwa kufanya kazi na picha na mistari ya kuchora. Imetumika maktaba za mfumo zilizosakinishwa awali na faili za vichwa, pamoja na kuunganisha kwa nguvu.

  2. xtensor kwa kufanya kazi na safu za multidimensional na tensor. Tulitumia faili za kichwa zilizojumuishwa kwenye kifurushi cha R cha jina moja. Maktaba hukuruhusu kufanya kazi na safu nyingi, katika safu kuu na mpangilio kuu wa safu.

  3. ndjson kwa kuchanganua JSON. Maktaba hii inatumika katika xtensor moja kwa moja ikiwa iko kwenye mradi.

  4. RcppThread kwa kuandaa usindikaji wa nyuzi nyingi za vekta kutoka JSON. Imetumia faili za kichwa zilizotolewa na kifurushi hiki. Kutoka maarufu zaidi RcppSambamba Kifurushi, kati ya mambo mengine, kina utaratibu wa kukatiza kitanzi kilichojengwa.

Ikumbukwe kwamba xtensor iligeuka kuwa godsend: pamoja na ukweli kwamba ina utendaji wa kina na utendaji wa juu, watengenezaji wake waligeuka kuwa msikivu kabisa na kujibu maswali mara moja na kwa undani. Kwa msaada wao, iliwezekana kutekeleza mabadiliko ya matiti ya OpenCV kuwa tensor ya xtensor, na pia njia ya kuchanganya tensor za picha zenye sura 3 kwenye tensor ya 4-dimensional ya mwelekeo sahihi (bechi yenyewe).

Nyenzo za kujifunzia Rcpp, xtensor na RcppThread

https://thecoatlessprofessor.com/programming/unofficial-rcpp-api-documentation

https://docs.opencv.org/4.0.1/d7/dbd/group__imgproc.html

https://xtensor.readthedocs.io/en/latest/

https://xtensor.readthedocs.io/en/latest/file_loading.html#loading-json-data-into-xtensor

https://cran.r-project.org/web/packages/RcppThread/vignettes/RcppThread-vignette.pdf

Kukusanya faili zinazotumia faili za mfumo na kuunganisha kwa nguvu na maktaba zilizosakinishwa kwenye mfumo, tulitumia utaratibu wa programu-jalizi uliotekelezwa kwenye kifurushi. Rcpp. Ili kupata njia na bendera kiotomatiki, tulitumia matumizi maarufu ya Linux pkg-config.

Utekelezaji wa programu-jalizi ya Rcpp ya kutumia maktaba ya OpenCV

Rcpp::registerPlugin("opencv", function() {
  # Π’ΠΎΠ·ΠΌΠΎΠΆΠ½Ρ‹Π΅ названия ΠΏΠ°ΠΊΠ΅Ρ‚Π°
  pkg_config_name <- c("opencv", "opencv4")
  # Π‘ΠΈΠ½Π°Ρ€Π½Ρ‹ΠΉ Ρ„Π°ΠΉΠ» ΡƒΡ‚ΠΈΠ»ΠΈΡ‚Ρ‹ pkg-config
  pkg_config_bin <- Sys.which("pkg-config")
  # ΠŸΡ€ΠΎΠ²Ρ€Π΅ΠΊΠ° наличия ΡƒΡ‚ΠΈΠ»ΠΈΡ‚Ρ‹ Π² систСмС
  checkmate::assert_file_exists(pkg_config_bin, access = "x")
  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° наличия Ρ„Π°ΠΉΠ»Π° настроСк OpenCV для pkg-config
  check <- sapply(pkg_config_name, 
                  function(pkg) system(paste(pkg_config_bin, pkg)))
  if (all(check != 0)) {
    stop("OpenCV config for the pkg-config not found", call. = FALSE)
  }

  pkg_config_name <- pkg_config_name[check == 0]
  list(env = list(
    PKG_CXXFLAGS = system(paste(pkg_config_bin, "--cflags", pkg_config_name), 
                          intern = TRUE),
    PKG_LIBS = system(paste(pkg_config_bin, "--libs", pkg_config_name), 
                      intern = TRUE)
  ))
})

Kama matokeo ya utendakazi wa programu-jalizi, maadili yafuatayo yatabadilishwa wakati wa mchakato wa ujumuishaji:

Rcpp:::.plugins$opencv()$env

# $PKG_CXXFLAGS
# [1] "-I/usr/include/opencv"
#
# $PKG_LIBS
# [1] "-lopencv_shape -lopencv_stitching -lopencv_superres -lopencv_videostab -lopencv_aruco -lopencv_bgsegm -lopencv_bioinspired -lopencv_ccalib -lopencv_datasets -lopencv_dpm -lopencv_face -lopencv_freetype -lopencv_fuzzy -lopencv_hdf -lopencv_line_descriptor -lopencv_optflow -lopencv_video -lopencv_plot -lopencv_reg -lopencv_saliency -lopencv_stereo -lopencv_structured_light -lopencv_phase_unwrapping -lopencv_rgbd -lopencv_viz -lopencv_surface_matching -lopencv_text -lopencv_ximgproc -lopencv_calib3d -lopencv_features2d -lopencv_flann -lopencv_xobjdetect -lopencv_objdetect -lopencv_ml -lopencv_xphoto -lopencv_highgui -lopencv_videoio -lopencv_imgcodecs -lopencv_photo -lopencv_imgproc -lopencv_core"

Nambari ya utekelezaji ya kuchanganua JSON na kutoa kundi la kupitisha kwa modeli imetolewa chini ya kiharibifu. Kwanza, ongeza saraka ya mradi wa ndani kutafuta faili za kichwa (zinazohitajika kwa ndjson):

Sys.setenv("PKG_CXXFLAGS" = paste0("-I", normalizePath(file.path("src"))))

Utekelezaji wa JSON hadi ubadilishaji wa tensor katika C++

// [[Rcpp::plugins(cpp14)]]
// [[Rcpp::plugins(opencv)]]
// [[Rcpp::depends(xtensor)]]
// [[Rcpp::depends(RcppThread)]]

#include <xtensor/xjson.hpp>
#include <xtensor/xadapt.hpp>
#include <xtensor/xview.hpp>
#include <xtensor-r/rtensor.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <Rcpp.h>
#include <RcppThread.h>

// Π‘ΠΈΠ½ΠΎΠ½ΠΈΠΌΡ‹ для Ρ‚ΠΈΠΏΠΎΠ²
using RcppThread::parallelFor;
using json = nlohmann::json;
using points = xt::xtensor<double,2>;     // Π˜Π·Π²Π»Π΅Ρ‡Ρ‘Π½Π½Ρ‹Π΅ ΠΈΠ· JSON ΠΊΠΎΠΎΡ€Π΄ΠΈΠ½Π°Ρ‚Ρ‹ Ρ‚ΠΎΡ‡Π΅ΠΊ
using strokes = std::vector<points>;      // Π˜Π·Π²Π»Π΅Ρ‡Ρ‘Π½Π½Ρ‹Π΅ ΠΈΠ· JSON ΠΊΠΎΠΎΡ€Π΄ΠΈΠ½Π°Ρ‚Ρ‹ Ρ‚ΠΎΡ‡Π΅ΠΊ
using xtensor3d = xt::xtensor<double, 3>; // Π’Π΅Π½Π·ΠΎΡ€ для хранСния ΠΌΠ°Ρ‚Ρ€ΠΈΡ†Ρ‹ изообраТСния
using xtensor4d = xt::xtensor<double, 4>; // Π’Π΅Π½Π·ΠΎΡ€ для хранСния мноТСства ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠΉ
using rtensor3d = xt::rtensor<double, 3>; // ΠžΠ±Ρ‘Ρ€Ρ‚ΠΊΠ° для экспорта Π² R
using rtensor4d = xt::rtensor<double, 4>; // ΠžΠ±Ρ‘Ρ€Ρ‚ΠΊΠ° для экспорта Π² R

// БтатичСскиС константы
// Π Π°Π·ΠΌΠ΅Ρ€ изобраТСния Π² пиксСлях
const static int SIZE = 256;
// Π’ΠΈΠΏ Π»ΠΈΠ½ΠΈΠΈ
// Π‘ΠΌ. https://en.wikipedia.org/wiki/Pixel_connectivity#2-dimensional
const static int LINE_TYPE = cv::LINE_4;
// Π’ΠΎΠ»Ρ‰ΠΈΠ½Π° Π»ΠΈΠ½ΠΈΠΈ Π² пиксСлях
const static int LINE_WIDTH = 3;
// Алгоритм рСсайза
// https://docs.opencv.org/3.1.0/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
const static int RESIZE_TYPE = cv::INTER_LINEAR;

// Π¨Π°Π±Π»ΠΎΠ½ для конвСртирования OpenCV-ΠΌΠ°Ρ‚Ρ€ΠΈΡ†Ρ‹ Π² Ρ‚Π΅Π½Π·ΠΎΡ€
template <typename T, int NCH, typename XT=xt::xtensor<T,3,xt::layout_type::column_major>>
XT to_xt(const cv::Mat_<cv::Vec<T, NCH>>& src) {
  // Π Π°Π·ΠΌΠ΅Ρ€Π½ΠΎΡΡ‚ΡŒ Ρ†Π΅Π»Π΅Π²ΠΎΠ³ΠΎ Ρ‚Π΅Π½Π·ΠΎΡ€Π°
  std::vector<int> shape = {src.rows, src.cols, NCH};
  // ΠžΠ±Ρ‰Π΅Π΅ количСство элСмСнтов Π² массивС
  size_t size = src.total() * NCH;
  // ΠŸΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ cv::Mat Π² xt::xtensor
  XT res = xt::adapt((T*) src.data, size, xt::no_ownership(), shape);
  return res;
}

// ΠŸΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ JSON Π² список ΠΊΠΎΠΎΡ€Π΄ΠΈΠ½Π°Ρ‚ Ρ‚ΠΎΡ‡Π΅ΠΊ
strokes parse_json(const std::string& x) {
  auto j = json::parse(x);
  // Π Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚ парсинга Π΄ΠΎΠ»ΠΆΠ΅Π½ Π±Ρ‹Ρ‚ΡŒ массивом
  if (!j.is_array()) {
    throw std::runtime_error("'x' must be JSON array.");
  }
  strokes res;
  res.reserve(j.size());
  for (const auto& a: j) {
    // ΠšΠ°ΠΆΠ΄Ρ‹ΠΉ элСмСнт массива Π΄ΠΎΠ»ΠΆΠ΅Π½ Π±Ρ‹Ρ‚ΡŒ 2-ΠΌΠ΅Ρ€Π½Ρ‹ΠΌ массивом
    if (!a.is_array() || a.size() != 2) {
      throw std::runtime_error("'x' must include only 2d arrays.");
    }
    // Π˜Π·Π²Π»Π΅Ρ‡Π΅Π½ΠΈΠ΅ Π²Π΅ΠΊΡ‚ΠΎΡ€Π° Ρ‚ΠΎΡ‡Π΅ΠΊ
    auto p = a.get<points>();
    res.push_back(p);
  }
  return res;
}

// ΠžΡ‚Ρ€ΠΈΡΠΎΠ²ΠΊΠ° Π»ΠΈΠ½ΠΈΠΉ
// Π¦Π²Π΅Ρ‚Π° HSV
cv::Mat ocv_draw_lines(const strokes& x, bool color = true) {
  // Π˜ΡΡ…ΠΎΠ΄Π½Ρ‹ΠΉ Ρ‚ΠΈΠΏ ΠΌΠ°Ρ‚Ρ€ΠΈΡ†Ρ‹
  auto stype = color ? CV_8UC3 : CV_8UC1;
  // Π˜Ρ‚ΠΎΠ³ΠΎΠ²Ρ‹ΠΉ Ρ‚ΠΈΠΏ ΠΌΠ°Ρ‚Ρ€ΠΈΡ†Ρ‹
  auto dtype = color ? CV_32FC3 : CV_32FC1;
  auto bg = color ? cv::Scalar(0, 0, 255) : cv::Scalar(255);
  auto col = color ? cv::Scalar(0, 255, 220) : cv::Scalar(0);
  cv::Mat img = cv::Mat(SIZE, SIZE, stype, bg);
  // ΠšΠΎΠ»ΠΈΡ‡Π΅ΡΡ‚Π²ΠΎ Π»ΠΈΠ½ΠΈΠΉ
  size_t n = x.size();
  for (const auto& s: x) {
    // ΠšΠΎΠ»ΠΈΡ‡Π΅ΡΡ‚Π²ΠΎ Ρ‚ΠΎΡ‡Π΅ΠΊ Π² Π»ΠΈΠ½ΠΈΠΈ
    size_t n_points = s.shape()[1];
    for (size_t i = 0; i < n_points - 1; ++i) {
      // Π’ΠΎΡ‡ΠΊΠ° Π½Π°Ρ‡Π°Π»Π° ΡˆΡ‚Ρ€ΠΈΡ…Π°
      cv::Point from(s(0, i), s(1, i));
      // Π’ΠΎΡ‡ΠΊΠ° окончания ΡˆΡ‚Ρ€ΠΈΡ…Π°
      cv::Point to(s(0, i + 1), s(1, i + 1));
      // ΠžΡ‚Ρ€ΠΈΡΠΎΠ²ΠΊΠ° Π»ΠΈΠ½ΠΈΠΈ
      cv::line(img, from, to, col, LINE_WIDTH, LINE_TYPE);
    }
    if (color) {
      // МСняСм Ρ†Π²Π΅Ρ‚ Π»ΠΈΠ½ΠΈΠΈ
      col[0] += 180 / n;
    }
  }
  if (color) {
    // МСняСм Ρ†Π²Π΅Ρ‚ΠΎΠ²ΠΎΠ΅ прСдставлСниС Π½Π° RGB
    cv::cvtColor(img, img, cv::COLOR_HSV2RGB);
  }
  // МСняСм Ρ„ΠΎΡ€ΠΌΠ°Ρ‚ прСдставлСния Π½Π° float32 с Π΄ΠΈΠ°ΠΏΠ°Π·ΠΎΠ½ΠΎΠΌ [0, 1]
  img.convertTo(img, dtype, 1 / 255.0);
  return img;
}

// ΠžΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠ° JSON ΠΈ ΠΏΠΎΠ»ΡƒΡ‡Π΅Π½ΠΈΠ΅ Ρ‚Π΅Π½Π·ΠΎΡ€Π° с Π΄Π°Π½Π½Ρ‹ΠΌΠΈ изобраТСния
xtensor3d process(const std::string& x, double scale = 1.0, bool color = true) {
  auto p = parse_json(x);
  auto img = ocv_draw_lines(p, color);
  if (scale != 1) {
    cv::Mat out;
    cv::resize(img, out, cv::Size(), scale, scale, RESIZE_TYPE);
    cv::swap(img, out);
    out.release();
  }
  xtensor3d arr = color ? to_xt<double,3>(img) : to_xt<double,1>(img);
  return arr;
}

// [[Rcpp::export]]
rtensor3d cpp_process_json_str(const std::string& x, 
                               double scale = 1.0, 
                               bool color = true) {
  xtensor3d res = process(x, scale, color);
  return res;
}

// [[Rcpp::export]]
rtensor4d cpp_process_json_vector(const std::vector<std::string>& x, 
                                  double scale = 1.0, 
                                  bool color = false) {
  size_t n = x.size();
  size_t dim = floor(SIZE * scale);
  size_t channels = color ? 3 : 1;
  xtensor4d res({n, dim, dim, channels});
  parallelFor(0, n, [&x, &res, scale, color](int i) {
    xtensor3d tmp = process(x[i], scale, color);
    auto view = xt::view(res, i, xt::all(), xt::all(), xt::all());
    view = tmp;
  });
  return res;
}

Nambari hii inapaswa kuwekwa kwenye faili src/cv_xt.cpp na kukusanya na amri Rcpp::sourceCpp(file = "src/cv_xt.cpp", env = .GlobalEnv); pia inahitajika kwa kazi nlohmann/json.hpp ya hazina. Nambari imegawanywa katika kazi kadhaa:

  • to_xt - kazi ya kiolezo cha kubadilisha matrix ya picha (cv::Mat) kwa tensor xt::xtensor;

  • parse_json - kazi huchanganua kamba ya JSON, hutoa kuratibu za pointi, kuzipakia kwenye vector;

  • ocv_draw_lines - kutoka kwa vector inayosababisha ya pointi, huchota mistari ya rangi nyingi;

  • process - inachanganya kazi zilizo hapo juu na pia huongeza uwezo wa kuongeza picha inayosababisha;

  • cpp_process_json_str - wrapper juu ya kazi process, ambayo husafirisha matokeo kwa kitu cha R (safu ya multidimensional);

  • cpp_process_json_vector - wrapper juu ya kazi cpp_process_json_str, ambayo hukuruhusu kusindika vekta ya kamba katika hali ya nyuzi nyingi.

Ili kuchora mistari ya rangi nyingi, mtindo wa rangi ya HSV ulitumiwa, ikifuatiwa na ubadilishaji hadi RGB. Wacha tujaribu matokeo:

arr <- cpp_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Utambuzi wa Doodle wa Chora Haraka: jinsi ya kufanya urafiki na R, C++ na mitandao ya neva
Ulinganisho wa kasi ya utekelezaji katika R na C++

res_bench <- bench::mark(
  r_process_json_str(tmp_data[4, drawing], scale = 0.5),
  cpp_process_json_str(tmp_data[4, drawing], scale = 0.5),
  check = FALSE,
  min_iterations = 100
)
# ΠŸΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€Ρ‹ Π±Π΅Π½Ρ‡ΠΌΠ°Ρ€ΠΊΠ°
cols <- c("expression", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   expression                min     median       max `itr/sec` total_time  n_itr
#   <chr>                <bch:tm>   <bch:tm>  <bch:tm>     <dbl>   <bch:tm>  <int>
# 1 r_process_json_str     3.49ms     3.55ms    4.47ms      273.      490ms    134
# 2 cpp_process_json_str   1.94ms     2.02ms    5.32ms      489.      497ms    243

library(ggplot2)
# ΠŸΡ€ΠΎΠ²Π΅Π΄Π΅Π½ΠΈΠ΅ Π·Π°ΠΌΠ΅Ρ€Π°
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    .data <- tmp_data[sample(seq_len(.N), batch_size), drawing]
    bench::mark(
      r_process_json_vector(.data, scale = 0.5),
      cpp_process_json_vector(.data,  scale = 0.5),
      min_iterations = 50,
      check = FALSE
    )
  }
)

res_bench[, cols]

#    expression   batch_size      min   median      max `itr/sec` total_time n_itr
#    <chr>             <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
#  1 r                   16   50.61ms  53.34ms  54.82ms    19.1     471.13ms     9
#  2 cpp                 16    4.46ms   5.39ms   7.78ms   192.      474.09ms    91
#  3 r                   32   105.7ms 109.74ms 212.26ms     7.69        6.5s    50
#  4 cpp                 32    7.76ms  10.97ms  15.23ms    95.6     522.78ms    50
#  5 r                   64  211.41ms 226.18ms 332.65ms     3.85      12.99s    50
#  6 cpp                 64   25.09ms  27.34ms  32.04ms    36.0        1.39s    50
#  7 r                  128   534.5ms 627.92ms 659.08ms     1.61      31.03s    50
#  8 cpp                128   56.37ms  58.46ms  66.03ms    16.9        2.95s    50
#  9 r                  256     1.15s    1.18s    1.29s     0.851     58.78s    50
# 10 cpp                256  114.97ms 117.39ms 130.09ms     8.45       5.92s    50
# 11 r                  512     2.09s    2.15s    2.32s     0.463       1.8m    50
# 12 cpp                512  230.81ms  235.6ms 261.99ms     4.18      11.97s    50
# 13 r                 1024        4s    4.22s     4.4s     0.238       3.5m    50
# 14 cpp               1024  410.48ms 431.43ms 462.44ms     2.33      21.45s    50

ggplot(res_bench, aes(x = factor(batch_size), y = median, 
                      group =  expression, color = expression)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal() +
  scale_color_discrete(name = "", labels = c("cpp", "r")) +
  theme(legend.position = "bottom") 

Utambuzi wa Doodle wa Chora Haraka: jinsi ya kufanya urafiki na R, C++ na mitandao ya neva

Kama unavyoona, ongezeko la kasi liligeuka kuwa muhimu sana, na haiwezekani kupata nambari ya C++ kwa kusawazisha nambari ya R.

3. Viigizo vya upakuaji wa bechi kutoka kwa hifadhidata

R ina sifa inayostahili ya usindikaji wa data ambayo inafaa kwenye RAM, wakati Python ina sifa zaidi ya usindikaji wa data mara kwa mara, hukuruhusu kutekeleza kwa urahisi na kwa asili mahesabu ya nje ya msingi (hesabu kwa kutumia kumbukumbu ya nje). Mfano wa kawaida na unaofaa kwetu katika muktadha wa tatizo lililoelezwa ni mitandao ya kina ya neva iliyofunzwa na mbinu ya mteremko wa kushuka kwa ukadiriaji wa upinde rangi kwa kila hatua kwa kutumia sehemu ndogo ya uchunguzi, au kundi-dogo.

Mifumo ya kina ya kujifunza iliyoandikwa katika Python ina madarasa maalum ambayo hutekeleza virudia kulingana na data: majedwali, picha katika folda, miundo ya binary, nk. Unaweza kutumia chaguo zilizotengenezwa tayari au kuandika yako mwenyewe kwa kazi maalum. Katika R tunaweza kuchukua fursa ya huduma zote za maktaba ya Python keri na viunga vyake mbalimbali kwa kutumia kifurushi cha jina moja, ambacho kwa upande wake hufanya kazi juu ya kifurushi sikia tena. Mwisho unastahili makala tofauti ndefu; hairuhusu tu kuendesha nambari ya Python kutoka R, lakini pia hukuruhusu kuhamisha vitu kati ya vipindi vya R na Python, ukifanya otomatiki ubadilishaji wa aina zote muhimu.

Tuliondoa hitaji la kuhifadhi data zote kwenye RAM kwa kutumia MonetDBLite, kazi yote ya "mtandao wa neva" itafanywa na nambari ya asili kwenye Python, lazima tu tuandike kiboreshaji juu ya data, kwani hakuna kitu tayari. kwa hali kama hiyo katika R au Python. Kwa kweli kuna mahitaji mawili tu kwa hiyo: lazima irudishe batches kwa kitanzi kisicho na mwisho na kuokoa hali yake kati ya marudio (ya mwisho katika R inatekelezwa kwa njia rahisi zaidi kwa kutumia kufungwa). Hapo awali, ilihitajika kubadilisha kwa uwazi safu za R kuwa safu numpy ndani ya kiboreshaji, lakini toleo la sasa la kifurushi. keri anafanya mwenyewe.

Kiboreshaji cha data ya mafunzo na uthibitishaji kiligeuka kuwa kama ifuatavyo:

Iterator kwa mafunzo na data ya uthibitishaji

train_generator <- function(db_connection = con,
                            samples_index,
                            num_classes = 340,
                            batch_size = 32,
                            scale = 1,
                            color = FALSE,
                            imagenet_preproc = FALSE) {
  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° Π°Ρ€Π³ΡƒΠΌΠ΅Π½Ρ‚ΠΎΠ²
  checkmate::assert_class(con, "DBIConnection")
  checkmate::assert_integerish(samples_index)
  checkmate::assert_count(num_classes)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # ΠŸΠ΅Ρ€Π΅ΠΌΠ΅ΡˆΠΈΠ²Π°Π΅ΠΌ, Ρ‡Ρ‚ΠΎΠ±Ρ‹ Π±Ρ€Π°Ρ‚ΡŒ ΠΈ ΡƒΠ΄Π°Π»ΡΡ‚ΡŒ ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Π½Π½Ρ‹Π΅ индСксы Π±Π°Ρ‚Ρ‡Π΅ΠΉ ΠΏΠΎ порядку
  dt <- data.table::data.table(id = sample(samples_index))
  # ΠŸΡ€ΠΎΡΡ‚Π°Π²Π»ΡΠ΅ΠΌ Π½ΠΎΠΌΠ΅Ρ€Π° Π±Π°Ρ‚Ρ‡Π΅ΠΉ
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  # ΠžΡΡ‚Π°Π²Π»ΡΠ΅ΠΌ Ρ‚ΠΎΠ»ΡŒΠΊΠΎ ΠΏΠΎΠ»Π½Ρ‹Π΅ Π±Π°Ρ‚Ρ‡ΠΈ ΠΈ индСксируСм
  dt <- dt[, if (.N == batch_size) .SD, keyby = batch]
  # УстанавливаСм счётчик
  i <- 1
  # ΠšΠΎΠ»ΠΈΡ‡Π΅ΡΡ‚Π²ΠΎ Π±Π°Ρ‚Ρ‡Π΅ΠΉ
  max_i <- dt[, max(batch)]

  # ΠŸΠΎΠ΄Π³ΠΎΡ‚ΠΎΠ²ΠΊΠ° выраТСния для Π²Ρ‹Π³Ρ€ΡƒΠ·ΠΊΠΈ
  sql <- sprintf(
    "PREPARE SELECT drawing, label_int FROM doodles WHERE id IN (%s)",
    paste(rep("?", batch_size), collapse = ",")
  )
  res <- DBI::dbSendQuery(con, sql)

  # Аналог keras::to_categorical
  to_categorical <- function(x, num) {
    n <- length(x)
    m <- numeric(n * num)
    m[x * n + seq_len(n)] <- 1
    dim(m) <- c(n, num)
    return(m)
  }

  # Π—Π°ΠΌΡ‹ΠΊΠ°Π½ΠΈΠ΅
  function() {
    # НачинаСм Π½ΠΎΠ²ΡƒΡŽ эпоху
    if (i > max_i) {
      dt[, id := sample(id)]
      data.table::setkey(dt, batch)
      # БбрасываСм счётчик
      i <<- 1
      max_i <<- dt[, max(batch)]
    }

    # ID для Π²Ρ‹Π³Ρ€ΡƒΠ·ΠΊΠΈ Π΄Π°Π½Π½Ρ‹Ρ…
    batch_ind <- dt[batch == i, id]
    # Π’Ρ‹Π³Ρ€ΡƒΠ·ΠΊΠ° Π΄Π°Π½Π½Ρ‹Ρ…
    batch <- DBI::dbFetch(DBI::dbBind(res, as.list(batch_ind)), n = -1)

    # Π£Π²Π΅Π»ΠΈΡ‡ΠΈΠ²Π°Π΅ΠΌ счётчик
    i <<- i + 1

    # ΠŸΠ°Ρ€ΡΠΈΠ½Π³ JSON ΠΈ ΠΏΠΎΠ΄Π³ΠΎΡ‚ΠΎΠ²ΠΊΠ° массива
    batch_x <- cpp_process_json_vector(batch$drawing, scale = scale, color = color)
    if (imagenet_preproc) {
      # Π¨ΠΊΠ°Π»ΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ c ΠΈΠ½Ρ‚Π΅Ρ€Π²Π°Π»Π° [0, 1] Π½Π° ΠΈΠ½Ρ‚Π΅Ρ€Π²Π°Π» [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }

    batch_y <- to_categorical(batch$label_int, num_classes)
    result <- list(batch_x, batch_y)
    return(result)
  }
}

Kazi inachukua kama pembejeo la kutofautisha na unganisho kwenye hifadhidata, nambari za mistari iliyotumiwa, idadi ya madarasa, saizi ya kundi, kiwango (scale = 1 inalingana na utoaji wa picha za saizi 256x256, scale = 0.5 - pikseli 128x128), kiashirio cha rangi (color = FALSE hubainisha utoaji kwa rangi ya kijivu inapotumika color = TRUE kila kipigo huchorwa kwa rangi mpya) na kiashirio cha kuchakata awali cha mitandao iliyofunzwa awali kwenye imagenet. Mwisho unahitajika ili kuongeza maadili ya saizi kutoka kwa muda [0, 1] hadi muda [-1, 1], ambao ulitumika wakati wa kufundisha kilichotolewa. keri mifano.

Kitendaji cha nje kina ukaguzi wa aina ya hoja, jedwali data.table na nambari za mstari zilizochanganywa bila mpangilio kutoka samples_index na nambari za kundi, kaunta na idadi ya juu zaidi ya bachi, pamoja na usemi wa SQL wa kupakua data kutoka kwa hifadhidata. Zaidi ya hayo, tulifafanua analog ya haraka ya kazi ndani keras::to_categorical(). Tulitumia karibu data yote kwa mafunzo, tukiacha nusu ya asilimia kwa uthibitisho, kwa hivyo saizi ya epoch ilipunguzwa na parameta. steps_per_epoch alipoitwa keras::fit_generator(), na hali if (i > max_i) ilifanya kazi tu kwa kiboreshaji cha uthibitishaji.

Katika utendakazi wa ndani, faharasa za safu mlalo hutolewa kwa kundi linalofuata, rekodi hupakuliwa kutoka kwa hifadhidata huku kihesabu bechi ikiongezeka, uchanganuzi wa JSON (kazi cpp_process_json_vector(), iliyoandikwa kwa C++) na kuunda safu zinazolingana na picha. Kisha vekta za moto-moja zilizo na lebo za darasa huundwa, safu zilizo na maadili ya pixel na lebo zinajumuishwa kwenye orodha, ambayo ni dhamana ya kurudi. Ili kuharakisha kazi, tulitumia uundaji wa faharisi kwenye meza data.table na marekebisho kupitia kiunga - bila "chips" hizi za kifurushi data.meza Ni ngumu sana kufikiria kufanya kazi kwa ufanisi na idadi yoyote muhimu ya data katika R.

Matokeo ya vipimo vya kasi kwenye kompyuta ndogo ya Core i5 ni kama ifuatavyo.

Benchmark ya iterator

library(Rcpp)
library(keras)
library(ggplot2)

source("utils/rcpp.R")
source("utils/keras_iterator.R")

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

ind <- seq_len(DBI::dbGetQuery(con, "SELECT count(*) FROM doodles")[[1L]])
num_classes <- DBI::dbGetQuery(con, "SELECT max(label_int) + 1 FROM doodles")[[1L]]

# Π˜Π½Π΄Π΅ΠΊΡΡ‹ для ΠΎΠ±ΡƒΡ‡Π°ΡŽΡ‰Π΅ΠΉ Π²Ρ‹Π±ΠΎΡ€ΠΊΠΈ
train_ind <- sample(ind, floor(length(ind) * 0.995))
# Π˜Π½Π΄Π΅ΠΊΡΡ‹ для ΠΏΡ€ΠΎΠ²Π΅Ρ€ΠΎΡ‡Π½ΠΎΠΉ Π²Ρ‹Π±ΠΎΡ€ΠΊΠΈ
val_ind <- ind[-train_ind]
rm(ind)
# ΠšΠΎΡΡ„Ρ„ΠΈΡ†ΠΈΠ΅Π½Ρ‚ ΠΌΠ°ΡΡˆΡ‚Π°Π±Π°
scale <- 0.5

# ΠŸΡ€ΠΎΠ²Π΅Π΄Π΅Π½ΠΈΠ΅ Π·Π°ΠΌΠ΅Ρ€Π°
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    it1 <- train_generator(
      db_connection = con,
      samples_index = train_ind,
      num_classes = num_classes,
      batch_size = batch_size,
      scale = scale
    )
    bench::mark(
      it1(),
      min_iterations = 50L
    )
  }
)
# ΠŸΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€Ρ‹ Π±Π΅Π½Ρ‡ΠΌΠ°Ρ€ΠΊΠ°
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16     25ms  64.36ms   92.2ms     15.9       3.09s    49
# 2         32   48.4ms 118.13ms 197.24ms     8.17       5.88s    48
# 3         64   69.3ms 117.93ms 181.14ms     8.57       5.83s    50
# 4        128  157.2ms 240.74ms 503.87ms     3.85      12.71s    49
# 5        256  359.3ms 613.52ms 988.73ms     1.54       30.5s    47
# 6        512  884.7ms    1.53s    2.07s     0.674      1.11m    45
# 7       1024     2.7s    3.83s    5.47s     0.261      2.81m    44

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
    geom_point() +
    geom_line() +
    ylab("median time, s") +
    theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Utambuzi wa Doodle wa Chora Haraka: jinsi ya kufanya urafiki na R, C++ na mitandao ya neva

Ikiwa una kiasi cha kutosha cha RAM, unaweza kuharakisha uendeshaji wa hifadhidata kwa kuihamisha kwa RAM hii sawa (GB 32 inatosha kwa kazi yetu). Katika Linux, kizigeu kimewekwa na chaguo-msingi /dev/shm, inachukua hadi nusu ya uwezo wa RAM. Unaweza kuangazia zaidi kwa kuhariri /etc/fstabkupata rekodi kama tmpfs /dev/shm tmpfs defaults,size=25g 0 0. Hakikisha kuwasha upya na uangalie matokeo kwa kuendesha amri df -h.

Kirudishi cha data ya jaribio kinaonekana rahisi zaidi, kwani mkusanyiko wa data wa jaribio unatoshea kabisa kwenye RAM:

Iterator kwa data ya majaribio

test_generator <- function(dt,
                           batch_size = 32,
                           scale = 1,
                           color = FALSE,
                           imagenet_preproc = FALSE) {

  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° Π°Ρ€Π³ΡƒΠΌΠ΅Π½Ρ‚ΠΎΠ²
  checkmate::assert_data_table(dt)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # ΠŸΡ€ΠΎΡΡ‚Π°Π²Π»ΡΠ΅ΠΌ Π½ΠΎΠΌΠ΅Ρ€Π° Π±Π°Ρ‚Ρ‡Π΅ΠΉ
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  data.table::setkey(dt, batch)
  i <- 1
  max_i <- dt[, max(batch)]

  # Π—Π°ΠΌΡ‹ΠΊΠ°Π½ΠΈΠ΅
  function() {
    batch_x <- cpp_process_json_vector(dt[batch == i, drawing], 
                                       scale = scale, color = color)
    if (imagenet_preproc) {
      # Π¨ΠΊΠ°Π»ΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ c ΠΈΠ½Ρ‚Π΅Ρ€Π²Π°Π»Π° [0, 1] Π½Π° ΠΈΠ½Ρ‚Π΅Ρ€Π²Π°Π» [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }
    result <- list(batch_x)
    i <<- i + 1
    return(result)
  }
}

4. Uchaguzi wa usanifu wa mfano

Usanifu wa kwanza uliotumika ulikuwa mtandao wa simu v1, vipengele vyake ambavyo vinajadiliwa katika hii ujumbe. Imejumuishwa kama kawaida keri na, ipasavyo, inapatikana kwenye kifurushi cha jina moja kwa R. Lakini wakati wa kujaribu kuitumia na picha za kituo kimoja, jambo la kushangaza liliibuka: kiboreshaji cha pembejeo lazima kiwe na kipimo kila wakati. (batch, height, width, 3), yaani, idadi ya vituo haiwezi kubadilishwa. Hakuna kizuizi kama hicho katika Python, kwa hivyo tulikimbia na kuandika utekelezaji wetu wenyewe wa usanifu huu, kufuatia nakala asili (bila kuacha ambayo iko kwenye toleo la keras):

Usanifu wa Mobilenet v1

library(keras)

top_3_categorical_accuracy <- custom_metric(
    name = "top_3_categorical_accuracy",
    metric_fn = function(y_true, y_pred) {
         metric_top_k_categorical_accuracy(y_true, y_pred, k = 3)
    }
)

layer_sep_conv_bn <- function(object, 
                              filters,
                              alpha = 1,
                              depth_multiplier = 1,
                              strides = c(2, 2)) {

  # NB! depth_multiplier !=  resolution multiplier
  # https://github.com/keras-team/keras/issues/10349

  layer_depthwise_conv_2d(
    object = object,
    kernel_size = c(3, 3), 
    strides = strides,
    padding = "same",
    depth_multiplier = depth_multiplier
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() %>%
  layer_conv_2d(
    filters = filters * alpha,
    kernel_size = c(1, 1), 
    strides = c(1, 1)
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() 
}

get_mobilenet_v1 <- function(input_shape = c(224, 224, 1),
                             num_classes = 340,
                             alpha = 1,
                             depth_multiplier = 1,
                             optimizer = optimizer_adam(lr = 0.002),
                             loss = "categorical_crossentropy",
                             metrics = c("categorical_crossentropy",
                                         top_3_categorical_accuracy)) {

  inputs <- layer_input(shape = input_shape)

  outputs <- inputs %>%
    layer_conv_2d(filters = 32, kernel_size = c(3, 3), strides = c(2, 2), padding = "same") %>%
    layer_batch_normalization() %>% 
    layer_activation_relu() %>%
    layer_sep_conv_bn(filters = 64, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(1, 1)) %>%
    layer_global_average_pooling_2d() %>%
    layer_dense(units = num_classes) %>%
    layer_activation_softmax()

    model <- keras_model(
      inputs = inputs,
      outputs = outputs
    )

    model %>% compile(
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )

    return(model)
}

Ubaya wa njia hii ni dhahiri. Ninataka kupima mifano mingi, lakini kinyume chake, sitaki kuandika upya kila usanifu kwa mikono. Pia tulinyimwa fursa ya kutumia uzani wa miundo iliyofunzwa awali kwenye imagenet. Kama kawaida, kusoma nyaraka kulisaidia. Kazi get_config() hukuruhusu kupata maelezo ya modeli katika fomu inayofaa kwa uhariri (base_model_conf$layers - orodha ya kawaida ya R), na kazi from_config() hufanya ubadilishaji wa nyuma kuwa kitu cha mfano:

base_model_conf <- get_config(base_model)
base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
base_model <- from_config(base_model_conf)

Sasa si vigumu kuandika kazi ya ulimwengu wote kupata yoyote ya iliyotolewa keri mifano iliyo na au isiyo na uzani iliyofunzwa kwenye imagenet:

Kazi ya kupakia usanifu tayari

get_model <- function(name = "mobilenet_v2",
                      input_shape = NULL,
                      weights = "imagenet",
                      pooling = "avg",
                      num_classes = NULL,
                      optimizer = keras::optimizer_adam(lr = 0.002),
                      loss = "categorical_crossentropy",
                      metrics = NULL,
                      color = TRUE,
                      compile = FALSE) {
  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° Π°Ρ€Π³ΡƒΠΌΠ΅Π½Ρ‚ΠΎΠ²
  checkmate::assert_string(name)
  checkmate::assert_integerish(input_shape, lower = 1, upper = 256, len = 3)
  checkmate::assert_count(num_classes)
  checkmate::assert_flag(color)
  checkmate::assert_flag(compile)

  # ΠŸΠΎΠ»ΡƒΡ‡Π°Π΅ΠΌ ΠΎΠ±ΡŠΠ΅ΠΊΡ‚ ΠΈΠ· ΠΏΠ°ΠΊΠ΅Ρ‚Π° keras
  model_fun <- get0(paste0("application_", name), envir = asNamespace("keras"))
  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° наличия ΠΎΠ±ΡŠΠ΅ΠΊΡ‚Π° Π² ΠΏΠ°ΠΊΠ΅Ρ‚Π΅
  if (is.null(model_fun)) {
    stop("Model ", shQuote(name), " not found.", call. = FALSE)
  }

  base_model <- model_fun(
    input_shape = input_shape,
    include_top = FALSE,
    weights = weights,
    pooling = pooling
  )

  # Если ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠ΅ Π½Π΅ Ρ†Π²Π΅Ρ‚Π½ΠΎΠ΅, мСняСм Ρ€Π°Π·ΠΌΠ΅Ρ€Π½ΠΎΡΡ‚ΡŒ Π²Ρ…ΠΎΠ΄Π°
  if (!color) {
    base_model_conf <- keras::get_config(base_model)
    base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
    base_model <- keras::from_config(base_model_conf)
  }

  predictions <- keras::get_layer(base_model, "global_average_pooling2d_1")$output
  predictions <- keras::layer_dense(predictions, units = num_classes, activation = "softmax")
  model <- keras::keras_model(
    inputs = base_model$input,
    outputs = predictions
  )

  if (compile) {
    keras::compile(
      object = model,
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )
  }

  return(model)
}

Unapotumia picha za idhaa moja, hakuna uzani uliofundishwa mapema unaotumika. Hii inaweza kusasishwa: kwa kutumia kitendakazi get_weights() pata uzani wa mfano katika mfumo wa orodha ya safu za R, badilisha kipimo cha kipengee cha kwanza cha orodha hii (kwa kuchukua chaneli moja ya rangi au wastani zote tatu), kisha upakie uzani nyuma kwenye mfano na kazi. set_weights(). Hatukuongeza utendaji huu, kwa sababu katika hatua hii tayari ilikuwa wazi kuwa ilikuwa na tija zaidi kufanya kazi na picha za rangi.

Tulifanya majaribio mengi kwa kutumia matoleo ya mobilenet 1 na 2, pamoja na resnet34. Usanifu wa kisasa zaidi kama vile SE-ResNeXt ulifanya vyema katika shindano hili. Kwa bahati mbaya, hatukuwa na utekelezaji ulioandaliwa tayari, na hatukuandika yetu (lakini hakika tutaandika).

5. Parameterization ya maandiko

Kwa urahisi, msimbo wote wa kuanza mafunzo uliundwa kama hati moja, iliyoainishwa kwa kutumia hati kama ifuatavyo:

doc <- '
Usage:
  train_nn.R --help
  train_nn.R --list-models
  train_nn.R [options]

Options:
  -h --help                   Show this message.
  -l --list-models            List available models.
  -m --model=<model>          Neural network model name [default: mobilenet_v2].
  -b --batch-size=<size>      Batch size [default: 32].
  -s --scale-factor=<ratio>   Scale factor [default: 0.5].
  -c --color                  Use color lines [default: FALSE].
  -d --db-dir=<path>          Path to database directory [default: Sys.getenv("db_dir")].
  -r --validate-ratio=<ratio> Validate sample ratio [default: 0.995].
  -n --n-gpu=<number>         Number of GPUs [default: 1].
'
args <- docopt::docopt(doc)

Ufungaji hati inawakilisha utekelezaji http://docopt.org/ kwa R. Kwa msaada wake, hati zinazinduliwa na amri rahisi kama Rscript bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db au ./bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db, ikiwa faili train_nn.R inaweza kutekelezwa (amri hii itaanza kufunza mfano resnet50 kwenye picha za rangi tatu zenye ukubwa wa saizi 128x128, hifadhidata lazima iwe kwenye folda. /home/andrey/doodle_db) Unaweza kuongeza kasi ya kujifunza, aina ya kiboreshaji, na vigezo vingine vyovyote vinavyoweza kubinafsishwa kwenye orodha. Katika mchakato wa kuandaa uchapishaji, ikawa kwamba usanifu mobilenet_v2 kutoka kwa toleo la sasa keri katika matumizi ya R hawezi kwa sababu ya mabadiliko ambayo hayajazingatiwa kwenye kifurushi cha R, tunangojea warekebishe.

Njia hii ilifanya iwezekane kuharakisha majaribio na mifano tofauti ikilinganishwa na uzinduzi wa kitamaduni wa hati katika RStudio (tunaona kifurushi kama njia mbadala inayowezekana. tfruns) Lakini faida kuu ni uwezo wa kusimamia kwa urahisi uzinduzi wa hati katika Docker au tu kwenye seva, bila kusakinisha RStudio kwa hili.

6. Dockerization ya scripts

Tulitumia Docker ili kuhakikisha kubebeka kwa mazingira kwa miundo ya mafunzo kati ya washiriki wa timu na kwa usambazaji wa haraka kwenye wingu. Unaweza kuanza kufahamiana na zana hii, ambayo sio ya kawaida kwa programu ya R, na hii mfululizo wa machapisho au kozi ya video.

Docker hukuruhusu kuunda picha zako mwenyewe kutoka mwanzo na kutumia picha zingine kama msingi wa kuunda yako mwenyewe. Wakati wa kuchambua chaguzi zinazopatikana, tulifikia hitimisho kwamba kusanikisha viendeshaji vya NVIDIA, CUDA+cuDNN na maktaba za Python ni sehemu kubwa ya picha, na tuliamua kuchukua picha rasmi kama msingi. tensorflow/tensorflow:1.12.0-gpu, na kuongeza vifurushi muhimu vya R hapo.

Faili ya mwisho ya docker ilionekana kama hii:

Dockerfile

FROM tensorflow/tensorflow:1.12.0-gpu

MAINTAINER Artem Klevtsov <[email protected]>

SHELL ["/bin/bash", "-c"]

ARG LOCALE="en_US.UTF-8"
ARG APT_PKG="libopencv-dev r-base r-base-dev littler"
ARG R_BIN_PKG="futile.logger checkmate data.table rcpp rapidjsonr dbi keras jsonlite curl digest remotes"
ARG R_SRC_PKG="xtensor RcppThread docopt MonetDBLite"
ARG PY_PIP_PKG="keras"
ARG DIRS="/db /app /app/data /app/models /app/logs"

RUN source /etc/os-release && 
    echo "deb https://cloud.r-project.org/bin/linux/ubuntu ${UBUNTU_CODENAME}-cran35/" > /etc/apt/sources.list.d/cran35.list && 
    apt-key adv --keyserver keyserver.ubuntu.com --recv-keys E084DAB9 && 
    add-apt-repository -y ppa:marutter/c2d4u3.5 && 
    add-apt-repository -y ppa:timsc/opencv-3.4 && 
    apt-get update && 
    apt-get install -y locales && 
    locale-gen ${LOCALE} && 
    apt-get install -y --no-install-recommends ${APT_PKG} && 
    ln -s /usr/lib/R/site-library/littler/examples/install.r /usr/local/bin/install.r && 
    ln -s /usr/lib/R/site-library/littler/examples/install2.r /usr/local/bin/install2.r && 
    ln -s /usr/lib/R/site-library/littler/examples/installGithub.r /usr/local/bin/installGithub.r && 
    echo 'options(Ncpus = parallel::detectCores())' >> /etc/R/Rprofile.site && 
    echo 'options(repos = c(CRAN = "https://cloud.r-project.org"))' >> /etc/R/Rprofile.site && 
    apt-get install -y $(printf "r-cran-%s " ${R_BIN_PKG}) && 
    install.r ${R_SRC_PKG} && 
    pip install ${PY_PIP_PKG} && 
    mkdir -p ${DIRS} && 
    chmod 777 ${DIRS} && 
    rm -rf /tmp/downloaded_packages/ /tmp/*.rds && 
    rm -rf /var/lib/apt/lists/*

COPY utils /app/utils
COPY src /app/src
COPY tests /app/tests
COPY bin/*.R /app/

ENV DBDIR="/db"
ENV CUDA_HOME="/usr/local/cuda"
ENV PATH="/app:${PATH}"

WORKDIR /app

VOLUME /db
VOLUME /app

CMD bash

Kwa urahisi, vifurushi vilivyotumiwa viliwekwa katika vigezo; wingi wa maandishi yaliyoandikwa yanakiliwa ndani ya vyombo wakati wa mkusanyiko. Pia tulibadilisha ganda la amri kuwa /bin/bash kwa urahisi wa matumizi ya yaliyomo /etc/os-release. Hii iliepusha hitaji la kutaja toleo la OS kwenye msimbo.

Zaidi ya hayo, hati ndogo ya bash iliandikwa ambayo inakuwezesha kuzindua chombo na amri mbalimbali. Kwa mfano, haya yanaweza kuwa hati za mafunzo ya mitandao ya neural ambayo iliwekwa hapo awali ndani ya kontena, au ganda la amri kwa utatuzi na ufuatiliaji wa utendakazi wa kontena:

Hati ya kuzindua kontena

#!/bin/sh

DBDIR=${PWD}/db
LOGSDIR=${PWD}/logs
MODELDIR=${PWD}/models
DATADIR=${PWD}/data
ARGS="--runtime=nvidia --rm -v ${DBDIR}:/db -v ${LOGSDIR}:/app/logs -v ${MODELDIR}:/app/models -v ${DATADIR}:/app/data"

if [ -z "$1" ]; then
    CMD="Rscript /app/train_nn.R"
elif [ "$1" = "bash" ]; then
    ARGS="${ARGS} -ti"
else
    CMD="Rscript /app/train_nn.R $@"
fi

docker run ${ARGS} doodles-tf ${CMD}

Ikiwa hati hii ya bash itaendeshwa bila vigezo, hati itaitwa ndani ya chombo train_nn.R na maadili ya msingi; ikiwa hoja ya kwanza ya msimamo ni "bash", basi chombo kitaanza maingiliano na ganda la amri. Katika visa vingine vyote, maadili ya hoja za msimamo hubadilishwa: CMD="Rscript /app/train_nn.R $@".

Inafaa kumbuka kuwa saraka zilizo na data ya chanzo na hifadhidata, pamoja na saraka ya kuhifadhi mifano iliyofunzwa, imewekwa ndani ya chombo kutoka kwa mfumo wa mwenyeji, ambayo hukuruhusu kupata matokeo ya hati bila ujanja usio wa lazima.

7. Kutumia GPU nyingi kwenye Wingu la Google

Moja ya vipengele vya shindano hilo ilikuwa data yenye kelele sana (angalia picha ya kichwa, iliyokopwa kutoka kwa @Leigh.plt kutoka kwa ODS slack). Kundi kubwa husaidia kukabiliana na hali hii, na baada ya majaribio kwenye Kompyuta yenye 1 GPU, tuliamua kuboresha miundo ya mafunzo kwenye GPU kadhaa kwenye wingu. Umetumia GoogleCloud (mwongozo mzuri wa mambo ya msingi) kutokana na uteuzi mkubwa wa usanidi unaopatikana, bei nzuri na bonasi ya $300. Kwa uchoyo, niliamuru mfano wa 4xV100 na SSD na tani ya RAM, na hiyo ilikuwa kosa kubwa. Mashine kama hiyo hula pesa haraka; unaweza kwenda kuvunja majaribio bila bomba iliyothibitishwa. Kwa madhumuni ya kielimu, ni bora kuchukua K80. Lakini kiasi kikubwa cha RAM kilikuja kwa manufaa - SSD ya wingu haikuvutia na utendaji wake, kwa hivyo hifadhidata ilihamishiwa dev/shm.

La kufurahisha zaidi ni kipande cha msimbo kinachohusika na kutumia GPU nyingi. Kwanza, mfano huundwa kwenye CPU kwa kutumia meneja wa muktadha, kama vile Python:

with(tensorflow::tf$device("/cpu:0"), {
  model_cpu <- get_model(
    name = model_name,
    input_shape = input_shape,
    weights = weights,
    metrics =(top_3_categorical_accuracy,
    compile = FALSE
  )
})

Kisha mfano ambao haujajumuishwa (hii ni muhimu) unakiliwa kwa idadi fulani ya GPU zinazopatikana, na tu baada ya hapo inakusanywa:

model <- keras::multi_gpu_model(model_cpu, gpus = n_gpu)
keras::compile(
  object = model,
  optimizer = keras::optimizer_adam(lr = 0.0004),
  loss = "categorical_crossentropy",
  metrics = c(top_3_categorical_accuracy)
)

Mbinu ya kitamaduni ya kufungia tabaka zote isipokuwa ile ya mwisho, kufundisha safu ya mwisho, kufungia na kurejesha muundo mzima kwa GPU kadhaa haikuweza kutekelezwa.

Mafunzo yalifuatiliwa bila matumizi. tensorboard, tukijiwekea kikomo kwa kurekodi kumbukumbu na kuhifadhi miundo yenye majina ya taarifa baada ya kila enzi:

Simu za nyuma

# Π¨Π°Π±Π»ΠΎΠ½ ΠΈΠΌΠ΅Π½ΠΈ Ρ„Π°ΠΉΠ»Π° Π»ΠΎΠ³Π°
log_file_tmpl <- file.path("logs", sprintf(
  "%s_%d_%dch_%s.csv",
  model_name,
  dim_size,
  channels,
  format(Sys.time(), "%Y%m%d%H%M%OS")
))
# Π¨Π°Π±Π»ΠΎΠ½ ΠΈΠΌΠ΅Π½ΠΈ Ρ„Π°ΠΉΠ»Π° ΠΌΠΎΠ΄Π΅Π»ΠΈ
model_file_tmpl <- file.path("models", sprintf(
  "%s_%d_%dch_{epoch:02d}_{val_loss:.2f}.h5",
  model_name,
  dim_size,
  channels
))

callbacks_list <- list(
  keras::callback_csv_logger(
    filename = log_file_tmpl
  ),
  keras::callback_early_stopping(
    monitor = "val_loss",
    min_delta = 1e-4,
    patience = 8,
    verbose = 1,
    mode = "min"
  ),
  keras::callback_reduce_lr_on_plateau(
    monitor = "val_loss",
    factor = 0.5, # ΡƒΠΌΠ΅Π½ΡŒΡˆΠ°Π΅ΠΌ lr Π² 2 Ρ€Π°Π·Π°
    patience = 4,
    verbose = 1,
    min_delta = 1e-4,
    mode = "min"
  ),
  keras::callback_model_checkpoint(
    filepath = model_file_tmpl,
    monitor = "val_loss",
    save_best_only = FALSE,
    save_weights_only = FALSE,
    mode = "min"
  )
)

8. Badala ya hitimisho

Shida kadhaa ambazo tumekutana nazo bado hazijatatuliwa:

  • Π² keri hakuna kitendakazi kilichotengenezwa tayari cha kutafuta kiotomatiki kiwango bora cha ujifunzaji (analog lr_finder kwenye maktaba haraka.ai); Kwa juhudi fulani, inawezekana kuweka utekelezwaji wa mtu wa tatu kwa R, kwa mfano, hii;
  • kama matokeo ya hatua iliyotangulia, haikuwezekana kuchagua kasi sahihi ya mafunzo wakati wa kutumia GPU kadhaa;
  • kuna ukosefu wa usanifu wa kisasa wa mtandao wa neva, hasa wale waliofunzwa awali kwenye imagenet;
  • hakuna sera ya mzunguko na viwango vya kibaguzi vya kujifunza (cosine annealing ilikuwa ombi letu kutekelezwa, Asante skydan).

Ni mambo gani muhimu ambayo yamejifunza kutoka kwa shindano hili:

  • Kwenye maunzi yenye nguvu kidogo, unaweza kufanya kazi na kiasi cha data kinachofaa (mara nyingi saizi ya RAM) bila maumivu. Mfuko wa plastiki data.meza huhifadhi kumbukumbu kwa sababu ya urekebishaji wa majedwali ya mahali, ambayo huepuka kuziiga, na inapotumiwa kwa usahihi, uwezo wake karibu kila wakati unaonyesha kasi ya juu kati ya zana zote zinazojulikana kwetu kwa lugha za uandishi. Kuhifadhi data katika hifadhidata hukuruhusu, katika hali nyingi, kutofikiria hata kidogo juu ya hitaji la kubana hifadhidata nzima kwenye RAM.
  • Vitendaji vya polepole katika R vinaweza kubadilishwa na vya haraka katika C++ kwa kutumia kifurushi Rcpp. Ikiwa ni pamoja na matumizi RcppThread au RcppSambamba, tunapata utekelezwaji wa nyuzi nyingi kwenye jukwaa, kwa hivyo hakuna haja ya kusawazisha msimbo katika kiwango cha R.
  • Kifurushi Rcpp inaweza kutumika bila ufahamu mkubwa wa C++, kiwango cha chini kinachohitajika kimeainishwa hapa. Faili za kichwa kwa idadi ya maktaba nzuri za C kama vile xtensor inapatikana kwenye CRAN, yaani, miundombinu inaundwa kwa ajili ya utekelezaji wa miradi inayounganisha msimbo wa C++ wa utendaji wa juu ulio tayari kuwa R. Urahisi zaidi ni kuangazia sintaksia na kichanganuzi tuli cha msimbo wa C++ katika RStudio.
  • hati hukuruhusu kuendesha hati zinazojitosheleza zenye vigezo. Hii ni rahisi kutumia kwenye seva ya mbali, incl. chini ya docker. Katika RStudio, ni ngumu kufanya majaribio ya masaa mengi na mafunzo ya mitandao ya neva, na kusakinisha IDE kwenye seva yenyewe sio haki kila wakati.
  • Docker inahakikisha kubebeka kwa msimbo na kunakili matokeo kati ya wasanidi programu wenye matoleo tofauti ya Mfumo wa Uendeshaji na maktaba, pamoja na urahisi wa utekelezaji kwenye seva. Unaweza kuzindua bomba zima la mafunzo kwa amri moja tu.
  • Wingu la Google ni njia rafiki ya bajeti ya kufanya majaribio kwenye maunzi ghali, lakini unahitaji kuchagua usanidi kwa uangalifu.
  • Kupima kasi ya vipande vya nambari ya mtu binafsi ni muhimu sana, haswa wakati wa kuchanganya R na C ++, na kifurushi. benchi - pia rahisi sana.

Kwa ujumla uzoefu huu ulikuwa wa kuridhisha sana na tunaendelea kujitahidi kutatua baadhi ya masuala yaliyoibuliwa.

Chanzo: mapenzi.com

Kuongeza maoni