Pengecaman Doodle Draw Pantas: cara berkawan dengan R, C++ dan rangkaian saraf

Pengecaman Doodle Draw Pantas: cara berkawan dengan R, C++ dan rangkaian saraf

Hai Habr!

Musim luruh yang lalu, Kaggle menganjurkan pertandingan untuk mengklasifikasikan gambar lukisan tangan, Pengecaman Doodle Draw Pantas, di mana, antara lain, sepasukan saintis R mengambil bahagian: Artem Klevtsova, Pengurus Philippa ΠΈ Andrey Ogurtsov. Kami tidak akan menerangkan pertandingan secara terperinci; itu telah dilakukan dalam penerbitan terkini.

Kali ini ia tidak berjaya dengan ladang pingat, tetapi banyak pengalaman berharga diperoleh, jadi saya ingin memberitahu komuniti tentang beberapa perkara yang paling menarik dan berguna tentang Kagle dan dalam kerja seharian. Antara topik yang dibincangkan: hidup susah tanpa OpenCV, penghuraian JSON (contoh ini mengkaji integrasi kod C++ ke dalam skrip atau pakej dalam R menggunakan Rcpp), parameterisasi skrip dan dockerisasi penyelesaian akhir. Semua kod daripada mesej dalam bentuk yang sesuai untuk pelaksanaan tersedia dalam repositori.

Kandungan:

  1. Muatkan data daripada CSV ke MonetDB dengan cekap
  2. Menyediakan kumpulan
  3. Iterator untuk memunggah kelompok daripada pangkalan data
  4. Memilih Seni Bina Model
  5. Parameterisasi skrip
  6. Dokerisasi skrip
  7. Menggunakan berbilang GPU pada Google Cloud
  8. Daripada kesimpulan

1. Muatkan data daripada CSV ke dalam pangkalan data MonetDB dengan cekap

Data dalam pertandingan ini disediakan bukan dalam bentuk imej siap, tetapi dalam bentuk 340 fail CSV (satu fail untuk setiap kelas) yang mengandungi JSON dengan koordinat titik. Dengan menyambungkan titik ini dengan garisan, kami mendapat imej akhir berukuran 256x256 piksel. Juga untuk setiap rekod terdapat label yang menunjukkan sama ada gambar itu dikenali dengan betul oleh pengelas yang digunakan pada masa set data dikumpulkan, kod dua huruf negara tempat tinggal pengarang gambar, pengecam unik, cap waktu dan nama kelas yang sepadan dengan nama fail. Versi ringkas data asal mempunyai berat 7.4 GB dalam arkib dan kira-kira 20 GB selepas membongkar, data penuh selepas membongkar mengambil 240 GB. Penganjur memastikan bahawa kedua-dua versi mengeluarkan semula lukisan yang sama, bermakna versi penuh adalah berlebihan. Walau apa pun, menyimpan 50 juta imej dalam fail grafik atau dalam bentuk tatasusunan serta-merta dianggap tidak menguntungkan dan kami memutuskan untuk menggabungkan semua fail CSV daripada arkib train_simplified.zip ke dalam pangkalan data dengan generasi imej berikutnya dengan saiz yang diperlukan "dengan cepat" untuk setiap kumpulan.

Sistem yang terbukti telah dipilih sebagai DBMS MonetDB, iaitu pelaksanaan untuk R sebagai pakej MonetDBLite. Pakej ini termasuk versi terbenam pelayan pangkalan data dan membolehkan anda mengambil pelayan terus dari sesi R dan bekerja dengannya di sana. Mencipta pangkalan data dan menyambungkannya dilakukan dengan satu arahan:

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

Kami perlu membuat dua jadual: satu untuk semua data, satu lagi untuk maklumat perkhidmatan tentang fail yang dimuat turun (berguna jika berlaku masalah dan proses perlu disambung semula selepas memuat turun beberapa fail):

Mencipta jadual

if (!DBI::dbExistsTable(con, "doodles")) {
  DBI::dbCreateTable(
    con = con,
    name = "doodles",
    fields = c(
      "countrycode" = "char(2)",
      "drawing" = "text",
      "key_id" = "bigint",
      "recognized" = "bool",
      "timestamp" = "timestamp",
      "word" = "text"
    )
  )
}

if (!DBI::dbExistsTable(con, "upload_log")) {
  DBI::dbCreateTable(
    con = con,
    name = "upload_log",
    fields = c(
      "id" = "serial",
      "file_name" = "text UNIQUE",
      "uploaded" = "bool DEFAULT false"
    )
  )
}

Cara terpantas untuk memuatkan data ke dalam pangkalan data adalah dengan menyalin terus fail CSV menggunakan perintah SQL COPY OFFSET 2 INTO tablename FROM path USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORTJika tablename - nama jadual dan path - laluan ke fail. Semasa bekerja dengan arkib, didapati bahawa pelaksanaan terbina dalam unzip dalam R tidak berfungsi dengan betul dengan beberapa fail daripada arkib, jadi kami menggunakan sistem tersebut unzip (menggunakan parameter getOption("unzip")).

Fungsi untuk menulis ke pangkalan data

#' @title Π˜Π·Π²Π»Π΅Ρ‡Π΅Π½ΠΈΠ΅ ΠΈ Π·Π°Π³Ρ€ΡƒΠ·ΠΊΠ° Ρ„Π°ΠΉΠ»ΠΎΠ²
#'
#' @description
#' Π˜Π·Π²Π»Π΅Ρ‡Π΅Π½ΠΈΠ΅ CSV-Ρ„Π°ΠΉΠ»ΠΎΠ² ΠΈΠ· ZIP-Π°Ρ€Ρ…ΠΈΠ²Π° ΠΈ Π·Π°Π³Ρ€ΡƒΠ·ΠΊΠ° ΠΈΡ… Π² Π±Π°Π·Ρƒ Π΄Π°Π½Π½Ρ‹Ρ…
#'
#' @param con ΠžΠ±ΡŠΠ΅ΠΊΡ‚ ΠΏΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΡ ΠΊ Π±Π°Π·Π΅ Π΄Π°Π½Π½Ρ‹Ρ… (класс `MonetDBEmbeddedConnection`).
#' @param tablename НазваниС Ρ‚Π°Π±Π»ΠΈΡ†Ρ‹ Π² Π±Π°Π·Π΅ Π΄Π°Π½Π½Ρ‹Ρ….
#' @oaram zipfile ΠŸΡƒΡ‚ΡŒ ΠΊ ZIP-Π°Ρ€Ρ…ΠΈΠ²Ρƒ.
#' @oaram filename Имя Ρ„Π°ΠΉΠ»Π° Π²Π½ΡƒΡ€ΠΈ ZIP-Π°Ρ€Ρ…ΠΈΠ²Π°.
#' @param preprocess Ѐункция ΠΏΡ€Π΅Π΄ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ, которая Π±ΡƒΠ΄Π΅Ρ‚ ΠΏΡ€ΠΈΠΌΠ΅Π½Π΅Π½Π° ΠΈΠ·Π²Π»Π΅Ρ‡Ρ‘Π½Π½ΠΎΠΌΡƒ Ρ„Π°ΠΉΠ»Ρƒ.
#'   Π”ΠΎΠ»ΠΆΠ½Π° ΠΏΡ€ΠΈΠ½ΠΈΠΌΠ°Ρ‚ΡŒ ΠΎΠ΄ΠΈΠ½ Π°Ρ€Π³ΡƒΠΌΠ΅Π½Ρ‚ `data` (ΠΎΠ±ΡŠΠ΅ΠΊΡ‚ `data.table`).
#'
#' @return `TRUE`.
#'
upload_file <- function(con, tablename, zipfile, filename, preprocess = NULL) {
  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° Π°Ρ€Π³ΡƒΠΌΠ΅Π½Ρ‚ΠΎΠ²
  checkmate::assert_class(con, "MonetDBEmbeddedConnection")
  checkmate::assert_string(tablename)
  checkmate::assert_string(filename)
  checkmate::assert_true(DBI::dbExistsTable(con, tablename))
  checkmate::assert_file_exists(zipfile, access = "r", extension = "zip")
  checkmate::assert_function(preprocess, args = c("data"), null.ok = TRUE)

  # Π˜Π·Π²Π»Π΅Ρ‡Π΅Π½ΠΈΠ΅ Ρ„Π°ΠΉΠ»Π°
  path <- file.path(tempdir(), filename)
  unzip(zipfile, files = filename, exdir = tempdir(), 
        junkpaths = TRUE, unzip = getOption("unzip"))
  on.exit(unlink(file.path(path)))

  # ΠŸΡ€ΠΈΠΌΠ΅Π½ΡΠ΅ΠΌ функция ΠΏΡ€Π΅Π΄ΠΎΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ
  if (!is.null(preprocess)) {
    .data <- data.table::fread(file = path)
    .data <- preprocess(data = .data)
    data.table::fwrite(x = .data, file = path, append = FALSE)
    rm(.data)
  }

  # Запрос ΠΊ Π‘Π” Π½Π° ΠΈΠΌΠΏΠΎΡ€Ρ‚ CSV
  sql <- sprintf(
    "COPY OFFSET 2 INTO %s FROM '%s' USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORT",
    tablename, path
  )
  # Π’Ρ‹ΠΏΠΎΠ»Π½Π΅Π½ΠΈΠ΅ запроса ΠΊ Π‘Π”
  DBI::dbExecute(con, sql)

  # Π”ΠΎΠ±Π°Π²Π»Π΅Π½ΠΈΠ΅ записи ΠΎΠ± ΡƒΡΠΏΠ΅ΡˆΠ½ΠΎΠΉ Π·Π°Π³Ρ€ΡƒΠ·ΠΊΠ΅ Π² ΡΠ»ΡƒΠΆΠ΅Π±Π½ΡƒΡŽ Ρ‚Π°Π±Π»ΠΈΡ†Ρƒ
  DBI::dbExecute(con, sprintf("INSERT INTO upload_log(file_name, uploaded) VALUES('%s', true)",
                              filename))

  return(invisible(TRUE))
}

Jika anda perlu mengubah jadual sebelum menulisnya ke pangkalan data, ia sudah cukup untuk lulus dalam hujah preprocess fungsi yang akan mengubah data.

Kod untuk memuatkan data secara berurutan ke dalam pangkalan data:

Menulis data ke pangkalan data

# Бписок Ρ„Π°ΠΉΠ»ΠΎΠ² для записи
files <- unzip(zipfile, list = TRUE)$Name

# Бписок ΠΈΡΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠΉ, Ссли Ρ‡Π°ΡΡ‚ΡŒ Ρ„Π°ΠΉΠ»ΠΎΠ² ΡƒΠΆΠ΅ Π±Ρ‹Π»Π° Π·Π°Π³Ρ€ΡƒΠΆΠ΅Π½Π°
to_skip <- DBI::dbGetQuery(con, "SELECT file_name FROM upload_log")[[1L]]
files <- setdiff(files, to_skip)

if (length(files) > 0L) {
  # ЗапускаСм Ρ‚Π°ΠΉΠΌΠ΅Ρ€
  tictoc::tic()
  # ΠŸΡ€ΠΎΠ³Ρ€Π΅ΡΡ Π±Π°Ρ€
  pb <- txtProgressBar(min = 0L, max = length(files), style = 3)
  for (i in seq_along(files)) {
    upload_file(con = con, tablename = "doodles", 
                zipfile = zipfile, filename = files[i])
    setTxtProgressBar(pb, i)
  }
  close(pb)
  # ΠžΡΡ‚Π°Π½Π°Π²Π»ΠΈΠ²Π°Π΅ΠΌ Ρ‚Π°ΠΉΠΌΠ΅Ρ€
  tictoc::toc()
}

# 526.141 sec elapsed - ΠΊΠΎΠΏΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ SSD->SSD
# 558.879 sec elapsed - ΠΊΠΎΠΏΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ USB->SSD

Masa pemuatan data mungkin berbeza bergantung pada ciri kelajuan pemacu yang digunakan. Dalam kes kami, membaca dan menulis dalam satu SSD atau daripada pemacu kilat (fail sumber) kepada SSD (DB) mengambil masa kurang daripada 10 minit.

Ia mengambil masa beberapa saat lagi untuk membuat lajur dengan label kelas integer dan lajur indeks (ORDERED INDEX) dengan nombor baris yang mana pemerhatian akan dijadikan sampel semasa membuat kelompok:

Mencipta Lajur dan Indeks Tambahan

message("Generate lables")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD label_int int"))
invisible(DBI::dbExecute(con, "UPDATE doodles SET label_int = dense_rank() OVER (ORDER BY word) - 1"))

message("Generate row numbers")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD id serial"))
invisible(DBI::dbExecute(con, "CREATE ORDERED INDEX doodles_id_ord_idx ON doodles(id)"))

Untuk menyelesaikan masalah mencipta kumpulan dengan cepat, kami perlu mencapai kelajuan maksimum mengekstrak baris rawak dari jadual doodles. Untuk ini kami menggunakan 3 helah. Yang pertama adalah untuk mengurangkan dimensi jenis yang menyimpan ID pemerhatian. Dalam set data asal, jenis yang diperlukan untuk menyimpan ID ialah bigint, tetapi bilangan pemerhatian memungkinkan untuk memasukkan pengecam mereka, sama dengan nombor ordinal, ke dalam jenis int. Pencarian adalah lebih pantas dalam kes ini. Helah kedua ialah menggunakan ORDERED INDEX β€” kami membuat keputusan ini secara empirik, setelah melalui semua yang ada pilihan. Yang ketiga ialah menggunakan pertanyaan berparameter. Intipati kaedah adalah untuk melaksanakan arahan sekali PREPARE dengan penggunaan seterusnya ungkapan yang disediakan apabila mencipta sekumpulan pertanyaan daripada jenis yang sama, tetapi sebenarnya terdapat kelebihan berbanding dengan yang mudah SELECT ternyata berada dalam julat ralat statistik.

Proses memuat naik data menggunakan tidak lebih daripada 450 MB RAM. Iaitu, pendekatan yang diterangkan membolehkan anda memindahkan set data seberat berpuluh-puluh gigabait pada hampir mana-mana perkakasan bajet, termasuk beberapa peranti papan tunggal, yang cukup hebat.

Apa yang tinggal adalah untuk mengukur kelajuan mendapatkan semula data (rawak) dan menilai penskalaan apabila mengambil sampel kumpulan saiz yang berbeza:

Tanda aras pangkalan data

library(ggplot2)

set.seed(0)
# ΠŸΠΎΠ΄ΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠ΅ ΠΊ Π±Π°Π·Π΅ Π΄Π°Π½Π½Ρ‹Ρ…
con <- DBI::dbConnect(MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

# Ѐункция для ΠΏΠΎΠ΄Π³ΠΎΡ‚ΠΎΠ²ΠΊΠΈ запроса Π½Π° сторонС сСрвСра
prep_sql <- function(batch_size) {
  sql <- sprintf("PREPARE SELECT id FROM doodles WHERE id IN (%s)",
                 paste(rep("?", batch_size), collapse = ","))
  res <- DBI::dbSendQuery(con, sql)
  return(res)
}

# Ѐункция для извлСчСния Π΄Π°Π½Π½Ρ‹Ρ…
fetch_data <- function(rs, batch_size) {
  ids <- sample(seq_len(n), batch_size)
  res <- DBI::dbFetch(DBI::dbBind(rs, as.list(ids)))
  return(res)
}

# ΠŸΡ€ΠΎΠ²Π΅Π΄Π΅Π½ΠΈΠ΅ Π·Π°ΠΌΠ΅Ρ€Π°
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    rs <- prep_sql(batch_size)
    bench::mark(
      fetch_data(rs, batch_size),
      min_iterations = 50L
    )
  }
)
# ΠŸΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€Ρ‹ Π±Π΅Π½Ρ‡ΠΌΠ°Ρ€ΠΊΠ°
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16   23.6ms  54.02ms  93.43ms     18.8        2.6s    49
# 2         32     38ms  84.83ms 151.55ms     11.4       4.29s    49
# 3         64   63.3ms 175.54ms 248.94ms     5.85       8.54s    50
# 4        128   83.2ms 341.52ms 496.24ms     3.00      16.69s    50
# 5        256  232.8ms 653.21ms 847.44ms     1.58      31.66s    50
# 6        512  784.6ms    1.41s    1.98s     0.740       1.1m    49
# 7       1024  681.7ms    2.72s    4.06s     0.377      2.16m    49

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Pengecaman Doodle Draw Pantas: cara berkawan dengan R, C++ dan rangkaian saraf

2. Menyediakan kumpulan

Keseluruhan proses penyediaan kumpulan terdiri daripada langkah-langkah berikut:

  1. Menghuraikan beberapa JSON yang mengandungi vektor rentetan dengan koordinat titik.
  2. Melukis garisan berwarna berdasarkan koordinat titik pada imej saiz yang diperlukan (contohnya, 256Γ—256 atau 128Γ—128).
  3. Menukar imej yang terhasil kepada tensor.

Sebagai sebahagian daripada persaingan antara kernel Python, masalah itu diselesaikan terutamanya menggunakan OpenCV. Salah satu analog yang paling mudah dan paling jelas dalam R akan kelihatan seperti ini:

Melaksanakan Penukaran JSON kepada Tensor dalam R

r_process_json_str <- function(json, line.width = 3, 
                               color = TRUE, scale = 1) {
  # ΠŸΠ°Ρ€ΡΠΈΠ½Π³ JSON
  coords <- jsonlite::fromJSON(json, simplifyMatrix = FALSE)
  tmp <- tempfile()
  # УдаляСм Π²Ρ€Π΅ΠΌΠ΅Π½Π½Ρ‹ΠΉ Ρ„Π°ΠΉΠ» ΠΏΠΎ Π·Π°Π²Π΅Ρ€ΡˆΠ΅Π½ΠΈΡŽ Ρ„ΡƒΠ½ΠΊΡ†ΠΈΠΈ
  on.exit(unlink(tmp))
  png(filename = tmp, width = 256 * scale, height = 256 * scale, pointsize = 1)
  # ΠŸΡƒΡΡ‚ΠΎΠΉ Π³Ρ€Π°Ρ„ΠΈΠΊ
  plot.new()
  # Π Π°Π·ΠΌΠ΅Ρ€ ΠΎΠΊΠ½Π° Π³Ρ€Π°Ρ„ΠΈΠΊΠ°
  plot.window(xlim = c(256 * scale, 0), ylim = c(256 * scale, 0))
  # Π¦Π²Π΅Ρ‚Π° Π»ΠΈΠ½ΠΈΠΉ
  cols <- if (color) rainbow(length(coords)) else "#000000"
  for (i in seq_along(coords)) {
    lines(x = coords[[i]][[1]] * scale, y = coords[[i]][[2]] * scale, 
          col = cols[i], lwd = line.width)
  }
  dev.off()
  # ΠŸΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ изобраТСния Π² 3-Ρ… ΠΌΠ΅Ρ€Π½Ρ‹ΠΉ массив
  res <- png::readPNG(tmp)
  return(res)
}

r_process_json_vector <- function(x, ...) {
  res <- lapply(x, r_process_json_str, ...)
  # ОбъСдинСниС 3-Ρ… ΠΌΠ΅Ρ€Π½Ρ‹Ρ… массивов ΠΊΠ°Ρ€Ρ‚ΠΈΠ½ΠΎΠΊ Π² 4-Ρ… ΠΌΠ΅Ρ€Π½Ρ‹ΠΉ Π² Ρ‚Π΅Π½Π·ΠΎΡ€
  res <- do.call(abind::abind, c(res, along = 0))
  return(res)
}

Lukisan dilakukan menggunakan alat R standard dan disimpan ke PNG sementara yang disimpan dalam RAM (di Linux, direktori R sementara terletak dalam direktori /tmp, dipasang dalam RAM). Fail ini kemudiannya dibaca sebagai tatasusunan tiga dimensi dengan nombor antara 0 hingga 1. Ini penting kerana BMP yang lebih konvensional akan dibaca ke dalam tatasusunan mentah dengan kod warna heks.

Mari kita uji hasilnya:

zip_file <- file.path("data", "train_simplified.zip")
csv_file <- "cat.csv"
unzip(zip_file, files = csv_file, exdir = tempdir(), 
      junkpaths = TRUE, unzip = getOption("unzip"))
tmp_data <- data.table::fread(file.path(tempdir(), csv_file), sep = ",", 
                              select = "drawing", nrows = 10000)
arr <- r_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Pengecaman Doodle Draw Pantas: cara berkawan dengan R, C++ dan rangkaian saraf

Kumpulan itu sendiri akan dibentuk seperti berikut:

res <- r_process_json_vector(tmp_data[1:4, drawing], scale = 0.5)
str(res)
 # num [1:4, 1:128, 1:128, 1:3] 1 1 1 1 1 1 1 1 1 1 ...
 # - attr(*, "dimnames")=List of 4
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL

Pelaksanaan ini nampaknya tidak optimum kepada kami, memandangkan pembentukan kumpulan besar mengambil masa yang lama, dan kami memutuskan untuk memanfaatkan pengalaman rakan sekerja kami dengan menggunakan perpustakaan yang berkuasa OpenCV. Pada masa itu tidak ada pakej siap sedia untuk R (tiada sekarang), jadi pelaksanaan minimum fungsi yang diperlukan ditulis dalam C++ dengan penyepaduan ke dalam kod R menggunakan Rcpp.

Untuk menyelesaikan masalah, pakej dan perpustakaan berikut digunakan:

  1. OpenCV untuk bekerja dengan imej dan garis lukisan. Menggunakan perpustakaan sistem pra-pasang dan fail pengepala, serta pemautan dinamik.

  2. xtensor untuk bekerja dengan tatasusunan berbilang dimensi dan tensor. Kami menggunakan fail pengepala yang disertakan dalam pakej R dengan nama yang sama. Pustaka membolehkan anda bekerja dengan tatasusunan berbilang dimensi, dalam susunan utama baris dan lajur.

  3. ndjson untuk menghuraikan JSON. Perpustakaan ini digunakan dalam xtensor secara automatik jika ia terdapat dalam projek.

  4. RcppThread untuk mengatur pemprosesan berbilang benang bagi vektor daripada JSON. Menggunakan fail pengepala yang disediakan oleh pakej ini. Daripada lebih popular RcppParallel Pakej itu, antara lain, mempunyai mekanisme gangguan gelung terbina dalam.

Ia harus diperhatikan bahawa xtensor ternyata menjadi anugerah: sebagai tambahan kepada fakta bahawa ia mempunyai fungsi yang luas dan prestasi tinggi, pembangunnya ternyata agak responsif dan menjawab soalan dengan segera dan terperinci. Dengan bantuan mereka, adalah mungkin untuk melaksanakan transformasi matriks OpenCV kepada tensor xtensor, serta cara untuk menggabungkan tensor imej 3 dimensi menjadi tensor 4 dimensi bagi dimensi yang betul (batch itu sendiri).

Bahan untuk pembelajaran Rcpp, xtensor dan RcppThread

https://thecoatlessprofessor.com/programming/unofficial-rcpp-api-documentation

https://docs.opencv.org/4.0.1/d7/dbd/group__imgproc.html

https://xtensor.readthedocs.io/en/latest/

https://xtensor.readthedocs.io/en/latest/file_loading.html#loading-json-data-into-xtensor

https://cran.r-project.org/web/packages/RcppThread/vignettes/RcppThread-vignette.pdf

Untuk menyusun fail yang menggunakan fail sistem dan pautan dinamik dengan perpustakaan yang dipasang pada sistem, kami menggunakan mekanisme pemalam yang dilaksanakan dalam pakej Rcpp. Untuk mencari laluan dan bendera secara automatik, kami menggunakan utiliti Linux yang popular pkg-konfigurasi.

Pelaksanaan pemalam Rcpp untuk menggunakan perpustakaan OpenCV

Rcpp::registerPlugin("opencv", function() {
  # Π’ΠΎΠ·ΠΌΠΎΠΆΠ½Ρ‹Π΅ названия ΠΏΠ°ΠΊΠ΅Ρ‚Π°
  pkg_config_name <- c("opencv", "opencv4")
  # Π‘ΠΈΠ½Π°Ρ€Π½Ρ‹ΠΉ Ρ„Π°ΠΉΠ» ΡƒΡ‚ΠΈΠ»ΠΈΡ‚Ρ‹ pkg-config
  pkg_config_bin <- Sys.which("pkg-config")
  # ΠŸΡ€ΠΎΠ²Ρ€Π΅ΠΊΠ° наличия ΡƒΡ‚ΠΈΠ»ΠΈΡ‚Ρ‹ Π² систСмС
  checkmate::assert_file_exists(pkg_config_bin, access = "x")
  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° наличия Ρ„Π°ΠΉΠ»Π° настроСк OpenCV для pkg-config
  check <- sapply(pkg_config_name, 
                  function(pkg) system(paste(pkg_config_bin, pkg)))
  if (all(check != 0)) {
    stop("OpenCV config for the pkg-config not found", call. = FALSE)
  }

  pkg_config_name <- pkg_config_name[check == 0]
  list(env = list(
    PKG_CXXFLAGS = system(paste(pkg_config_bin, "--cflags", pkg_config_name), 
                          intern = TRUE),
    PKG_LIBS = system(paste(pkg_config_bin, "--libs", pkg_config_name), 
                      intern = TRUE)
  ))
})

Hasil daripada operasi pemalam, nilai berikut akan digantikan semasa proses penyusunan:

Rcpp:::.plugins$opencv()$env

# $PKG_CXXFLAGS
# [1] "-I/usr/include/opencv"
#
# $PKG_LIBS
# [1] "-lopencv_shape -lopencv_stitching -lopencv_superres -lopencv_videostab -lopencv_aruco -lopencv_bgsegm -lopencv_bioinspired -lopencv_ccalib -lopencv_datasets -lopencv_dpm -lopencv_face -lopencv_freetype -lopencv_fuzzy -lopencv_hdf -lopencv_line_descriptor -lopencv_optflow -lopencv_video -lopencv_plot -lopencv_reg -lopencv_saliency -lopencv_stereo -lopencv_structured_light -lopencv_phase_unwrapping -lopencv_rgbd -lopencv_viz -lopencv_surface_matching -lopencv_text -lopencv_ximgproc -lopencv_calib3d -lopencv_features2d -lopencv_flann -lopencv_xobjdetect -lopencv_objdetect -lopencv_ml -lopencv_xphoto -lopencv_highgui -lopencv_videoio -lopencv_imgcodecs -lopencv_photo -lopencv_imgproc -lopencv_core"

Kod pelaksanaan untuk menghuraikan JSON dan menjana kumpulan untuk penghantaran kepada model diberikan di bawah spoiler. Pertama, tambahkan direktori projek tempatan untuk mencari fail pengepala (diperlukan untuk ndjson):

Sys.setenv("PKG_CXXFLAGS" = paste0("-I", normalizePath(file.path("src"))))

Pelaksanaan JSON kepada penukaran tensor dalam C++

// [[Rcpp::plugins(cpp14)]]
// [[Rcpp::plugins(opencv)]]
// [[Rcpp::depends(xtensor)]]
// [[Rcpp::depends(RcppThread)]]

#include <xtensor/xjson.hpp>
#include <xtensor/xadapt.hpp>
#include <xtensor/xview.hpp>
#include <xtensor-r/rtensor.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <Rcpp.h>
#include <RcppThread.h>

// Π‘ΠΈΠ½ΠΎΠ½ΠΈΠΌΡ‹ для Ρ‚ΠΈΠΏΠΎΠ²
using RcppThread::parallelFor;
using json = nlohmann::json;
using points = xt::xtensor<double,2>;     // Π˜Π·Π²Π»Π΅Ρ‡Ρ‘Π½Π½Ρ‹Π΅ ΠΈΠ· JSON ΠΊΠΎΠΎΡ€Π΄ΠΈΠ½Π°Ρ‚Ρ‹ Ρ‚ΠΎΡ‡Π΅ΠΊ
using strokes = std::vector<points>;      // Π˜Π·Π²Π»Π΅Ρ‡Ρ‘Π½Π½Ρ‹Π΅ ΠΈΠ· JSON ΠΊΠΎΠΎΡ€Π΄ΠΈΠ½Π°Ρ‚Ρ‹ Ρ‚ΠΎΡ‡Π΅ΠΊ
using xtensor3d = xt::xtensor<double, 3>; // Π’Π΅Π½Π·ΠΎΡ€ для хранСния ΠΌΠ°Ρ‚Ρ€ΠΈΡ†Ρ‹ изообраТСния
using xtensor4d = xt::xtensor<double, 4>; // Π’Π΅Π½Π·ΠΎΡ€ для хранСния мноТСства ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠΉ
using rtensor3d = xt::rtensor<double, 3>; // ΠžΠ±Ρ‘Ρ€Ρ‚ΠΊΠ° для экспорта Π² R
using rtensor4d = xt::rtensor<double, 4>; // ΠžΠ±Ρ‘Ρ€Ρ‚ΠΊΠ° для экспорта Π² R

// БтатичСскиС константы
// Π Π°Π·ΠΌΠ΅Ρ€ изобраТСния Π² пиксСлях
const static int SIZE = 256;
// Π’ΠΈΠΏ Π»ΠΈΠ½ΠΈΠΈ
// Π‘ΠΌ. https://en.wikipedia.org/wiki/Pixel_connectivity#2-dimensional
const static int LINE_TYPE = cv::LINE_4;
// Π’ΠΎΠ»Ρ‰ΠΈΠ½Π° Π»ΠΈΠ½ΠΈΠΈ Π² пиксСлях
const static int LINE_WIDTH = 3;
// Алгоритм рСсайза
// https://docs.opencv.org/3.1.0/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
const static int RESIZE_TYPE = cv::INTER_LINEAR;

// Π¨Π°Π±Π»ΠΎΠ½ для конвСртирования OpenCV-ΠΌΠ°Ρ‚Ρ€ΠΈΡ†Ρ‹ Π² Ρ‚Π΅Π½Π·ΠΎΡ€
template <typename T, int NCH, typename XT=xt::xtensor<T,3,xt::layout_type::column_major>>
XT to_xt(const cv::Mat_<cv::Vec<T, NCH>>& src) {
  // Π Π°Π·ΠΌΠ΅Ρ€Π½ΠΎΡΡ‚ΡŒ Ρ†Π΅Π»Π΅Π²ΠΎΠ³ΠΎ Ρ‚Π΅Π½Π·ΠΎΡ€Π°
  std::vector<int> shape = {src.rows, src.cols, NCH};
  // ΠžΠ±Ρ‰Π΅Π΅ количСство элСмСнтов Π² массивС
  size_t size = src.total() * NCH;
  // ΠŸΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ cv::Mat Π² xt::xtensor
  XT res = xt::adapt((T*) src.data, size, xt::no_ownership(), shape);
  return res;
}

// ΠŸΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ JSON Π² список ΠΊΠΎΠΎΡ€Π΄ΠΈΠ½Π°Ρ‚ Ρ‚ΠΎΡ‡Π΅ΠΊ
strokes parse_json(const std::string& x) {
  auto j = json::parse(x);
  // Π Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚ парсинга Π΄ΠΎΠ»ΠΆΠ΅Π½ Π±Ρ‹Ρ‚ΡŒ массивом
  if (!j.is_array()) {
    throw std::runtime_error("'x' must be JSON array.");
  }
  strokes res;
  res.reserve(j.size());
  for (const auto& a: j) {
    // ΠšΠ°ΠΆΠ΄Ρ‹ΠΉ элСмСнт массива Π΄ΠΎΠ»ΠΆΠ΅Π½ Π±Ρ‹Ρ‚ΡŒ 2-ΠΌΠ΅Ρ€Π½Ρ‹ΠΌ массивом
    if (!a.is_array() || a.size() != 2) {
      throw std::runtime_error("'x' must include only 2d arrays.");
    }
    // Π˜Π·Π²Π»Π΅Ρ‡Π΅Π½ΠΈΠ΅ Π²Π΅ΠΊΡ‚ΠΎΡ€Π° Ρ‚ΠΎΡ‡Π΅ΠΊ
    auto p = a.get<points>();
    res.push_back(p);
  }
  return res;
}

// ΠžΡ‚Ρ€ΠΈΡΠΎΠ²ΠΊΠ° Π»ΠΈΠ½ΠΈΠΉ
// Π¦Π²Π΅Ρ‚Π° HSV
cv::Mat ocv_draw_lines(const strokes& x, bool color = true) {
  // Π˜ΡΡ…ΠΎΠ΄Π½Ρ‹ΠΉ Ρ‚ΠΈΠΏ ΠΌΠ°Ρ‚Ρ€ΠΈΡ†Ρ‹
  auto stype = color ? CV_8UC3 : CV_8UC1;
  // Π˜Ρ‚ΠΎΠ³ΠΎΠ²Ρ‹ΠΉ Ρ‚ΠΈΠΏ ΠΌΠ°Ρ‚Ρ€ΠΈΡ†Ρ‹
  auto dtype = color ? CV_32FC3 : CV_32FC1;
  auto bg = color ? cv::Scalar(0, 0, 255) : cv::Scalar(255);
  auto col = color ? cv::Scalar(0, 255, 220) : cv::Scalar(0);
  cv::Mat img = cv::Mat(SIZE, SIZE, stype, bg);
  // ΠšΠΎΠ»ΠΈΡ‡Π΅ΡΡ‚Π²ΠΎ Π»ΠΈΠ½ΠΈΠΉ
  size_t n = x.size();
  for (const auto& s: x) {
    // ΠšΠΎΠ»ΠΈΡ‡Π΅ΡΡ‚Π²ΠΎ Ρ‚ΠΎΡ‡Π΅ΠΊ Π² Π»ΠΈΠ½ΠΈΠΈ
    size_t n_points = s.shape()[1];
    for (size_t i = 0; i < n_points - 1; ++i) {
      // Π’ΠΎΡ‡ΠΊΠ° Π½Π°Ρ‡Π°Π»Π° ΡˆΡ‚Ρ€ΠΈΡ…Π°
      cv::Point from(s(0, i), s(1, i));
      // Π’ΠΎΡ‡ΠΊΠ° окончания ΡˆΡ‚Ρ€ΠΈΡ…Π°
      cv::Point to(s(0, i + 1), s(1, i + 1));
      // ΠžΡ‚Ρ€ΠΈΡΠΎΠ²ΠΊΠ° Π»ΠΈΠ½ΠΈΠΈ
      cv::line(img, from, to, col, LINE_WIDTH, LINE_TYPE);
    }
    if (color) {
      // МСняСм Ρ†Π²Π΅Ρ‚ Π»ΠΈΠ½ΠΈΠΈ
      col[0] += 180 / n;
    }
  }
  if (color) {
    // МСняСм Ρ†Π²Π΅Ρ‚ΠΎΠ²ΠΎΠ΅ прСдставлСниС Π½Π° RGB
    cv::cvtColor(img, img, cv::COLOR_HSV2RGB);
  }
  // МСняСм Ρ„ΠΎΡ€ΠΌΠ°Ρ‚ прСдставлСния Π½Π° float32 с Π΄ΠΈΠ°ΠΏΠ°Π·ΠΎΠ½ΠΎΠΌ [0, 1]
  img.convertTo(img, dtype, 1 / 255.0);
  return img;
}

// ΠžΠ±Ρ€Π°Π±ΠΎΡ‚ΠΊΠ° JSON ΠΈ ΠΏΠΎΠ»ΡƒΡ‡Π΅Π½ΠΈΠ΅ Ρ‚Π΅Π½Π·ΠΎΡ€Π° с Π΄Π°Π½Π½Ρ‹ΠΌΠΈ изобраТСния
xtensor3d process(const std::string& x, double scale = 1.0, bool color = true) {
  auto p = parse_json(x);
  auto img = ocv_draw_lines(p, color);
  if (scale != 1) {
    cv::Mat out;
    cv::resize(img, out, cv::Size(), scale, scale, RESIZE_TYPE);
    cv::swap(img, out);
    out.release();
  }
  xtensor3d arr = color ? to_xt<double,3>(img) : to_xt<double,1>(img);
  return arr;
}

// [[Rcpp::export]]
rtensor3d cpp_process_json_str(const std::string& x, 
                               double scale = 1.0, 
                               bool color = true) {
  xtensor3d res = process(x, scale, color);
  return res;
}

// [[Rcpp::export]]
rtensor4d cpp_process_json_vector(const std::vector<std::string>& x, 
                                  double scale = 1.0, 
                                  bool color = false) {
  size_t n = x.size();
  size_t dim = floor(SIZE * scale);
  size_t channels = color ? 3 : 1;
  xtensor4d res({n, dim, dim, channels});
  parallelFor(0, n, [&x, &res, scale, color](int i) {
    xtensor3d tmp = process(x[i], scale, color);
    auto view = xt::view(res, i, xt::all(), xt::all(), xt::all());
    view = tmp;
  });
  return res;
}

Kod ini hendaklah diletakkan dalam fail src/cv_xt.cpp dan menyusun dengan arahan Rcpp::sourceCpp(file = "src/cv_xt.cpp", env = .GlobalEnv); juga diperlukan untuk bekerja nlohmann/json.hpp daripada repositori. Kod ini dibahagikan kepada beberapa fungsi:

  • to_xt β€” fungsi templat untuk mengubah matriks imej (cv::Mat) kepada tensor xt::xtensor;

  • parse_json β€” fungsi menghuraikan rentetan JSON, mengekstrak koordinat titik, membungkusnya ke dalam vektor;

  • ocv_draw_lines β€” daripada vektor mata yang terhasil, lukis garisan pelbagai warna;

  • process β€” menggabungkan fungsi di atas dan juga menambah keupayaan untuk menskala imej yang terhasil;

  • cpp_process_json_str - pembalut atas fungsi process, yang mengeksport hasil ke objek R (tatasusunan pelbagai dimensi);

  • cpp_process_json_vector - pembalut atas fungsi cpp_process_json_str, yang membolehkan anda memproses vektor rentetan dalam mod berbilang benang.

Untuk melukis garisan berbilang warna, model warna HSV telah digunakan, diikuti dengan penukaran kepada RGB. Mari kita uji hasilnya:

arr <- cpp_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Pengecaman Doodle Draw Pantas: cara berkawan dengan R, C++ dan rangkaian saraf
Perbandingan kelajuan pelaksanaan dalam R dan C++

res_bench <- bench::mark(
  r_process_json_str(tmp_data[4, drawing], scale = 0.5),
  cpp_process_json_str(tmp_data[4, drawing], scale = 0.5),
  check = FALSE,
  min_iterations = 100
)
# ΠŸΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€Ρ‹ Π±Π΅Π½Ρ‡ΠΌΠ°Ρ€ΠΊΠ°
cols <- c("expression", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   expression                min     median       max `itr/sec` total_time  n_itr
#   <chr>                <bch:tm>   <bch:tm>  <bch:tm>     <dbl>   <bch:tm>  <int>
# 1 r_process_json_str     3.49ms     3.55ms    4.47ms      273.      490ms    134
# 2 cpp_process_json_str   1.94ms     2.02ms    5.32ms      489.      497ms    243

library(ggplot2)
# ΠŸΡ€ΠΎΠ²Π΅Π΄Π΅Π½ΠΈΠ΅ Π·Π°ΠΌΠ΅Ρ€Π°
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    .data <- tmp_data[sample(seq_len(.N), batch_size), drawing]
    bench::mark(
      r_process_json_vector(.data, scale = 0.5),
      cpp_process_json_vector(.data,  scale = 0.5),
      min_iterations = 50,
      check = FALSE
    )
  }
)

res_bench[, cols]

#    expression   batch_size      min   median      max `itr/sec` total_time n_itr
#    <chr>             <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
#  1 r                   16   50.61ms  53.34ms  54.82ms    19.1     471.13ms     9
#  2 cpp                 16    4.46ms   5.39ms   7.78ms   192.      474.09ms    91
#  3 r                   32   105.7ms 109.74ms 212.26ms     7.69        6.5s    50
#  4 cpp                 32    7.76ms  10.97ms  15.23ms    95.6     522.78ms    50
#  5 r                   64  211.41ms 226.18ms 332.65ms     3.85      12.99s    50
#  6 cpp                 64   25.09ms  27.34ms  32.04ms    36.0        1.39s    50
#  7 r                  128   534.5ms 627.92ms 659.08ms     1.61      31.03s    50
#  8 cpp                128   56.37ms  58.46ms  66.03ms    16.9        2.95s    50
#  9 r                  256     1.15s    1.18s    1.29s     0.851     58.78s    50
# 10 cpp                256  114.97ms 117.39ms 130.09ms     8.45       5.92s    50
# 11 r                  512     2.09s    2.15s    2.32s     0.463       1.8m    50
# 12 cpp                512  230.81ms  235.6ms 261.99ms     4.18      11.97s    50
# 13 r                 1024        4s    4.22s     4.4s     0.238       3.5m    50
# 14 cpp               1024  410.48ms 431.43ms 462.44ms     2.33      21.45s    50

ggplot(res_bench, aes(x = factor(batch_size), y = median, 
                      group =  expression, color = expression)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal() +
  scale_color_discrete(name = "", labels = c("cpp", "r")) +
  theme(legend.position = "bottom") 

Pengecaman Doodle Draw Pantas: cara berkawan dengan R, C++ dan rangkaian saraf

Seperti yang anda lihat, peningkatan kelajuan ternyata sangat ketara, dan tidak mungkin untuk mengejar kod C++ dengan menyelaraskan kod R.

3. Iterator untuk memunggah kelompok daripada pangkalan data

R mempunyai reputasi yang baik untuk memproses data yang sesuai dengan RAM, manakala Python lebih dicirikan oleh pemprosesan data berulang, membolehkan anda dengan mudah dan semulajadi melaksanakan pengiraan luar teras (pengiraan menggunakan memori luaran). Contoh klasik dan relevan untuk kita dalam konteks masalah yang diterangkan ialah rangkaian saraf dalam yang dilatih oleh kaedah penurunan kecerunan dengan anggaran kecerunan pada setiap langkah menggunakan sebahagian kecil pemerhatian, atau kelompok mini.

Rangka kerja pembelajaran mendalam yang ditulis dalam Python mempunyai kelas khas yang melaksanakan iterator berdasarkan data: jadual, gambar dalam folder, format binari, dsb. Anda boleh menggunakan pilihan sedia atau menulis sendiri untuk tugasan tertentu. Dalam R kita boleh memanfaatkan semua ciri perpustakaan Python keras dengan pelbagai bahagian belakang menggunakan pakej dengan nama yang sama, yang seterusnya berfungsi di atas pakej berunding. Yang terakhir ini layak mendapat artikel panjang yang berasingan; ia bukan sahaja membenarkan anda menjalankan kod Python daripada R, tetapi juga membolehkan anda memindahkan objek antara sesi R dan Python, secara automatik melaksanakan semua penukaran jenis yang diperlukan.

Kami menyingkirkan keperluan untuk menyimpan semua data dalam RAM dengan menggunakan MonetDBite, semua kerja "rangkaian saraf" akan dilakukan oleh kod asal dalam Python, kami hanya perlu menulis iterator ke atas data, kerana tiada apa yang sedia untuk situasi sedemikian sama ada dalam R atau Python. Pada asasnya hanya terdapat dua keperluan untuknya: ia mesti mengembalikan kelompok dalam gelung yang tidak berkesudahan dan menyimpan keadaannya antara lelaran (yang terakhir dalam R dilaksanakan dengan cara paling mudah menggunakan penutupan). Sebelum ini, ia dikehendaki menukar tatasusunan R secara eksplisit kepada tatasusunan numpy di dalam iterator, tetapi versi semasa pakej keras melakukannya sendiri.

Iterator untuk data latihan dan pengesahan ternyata seperti berikut:

Iterator untuk data latihan dan pengesahan

train_generator <- function(db_connection = con,
                            samples_index,
                            num_classes = 340,
                            batch_size = 32,
                            scale = 1,
                            color = FALSE,
                            imagenet_preproc = FALSE) {
  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° Π°Ρ€Π³ΡƒΠΌΠ΅Π½Ρ‚ΠΎΠ²
  checkmate::assert_class(con, "DBIConnection")
  checkmate::assert_integerish(samples_index)
  checkmate::assert_count(num_classes)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # ΠŸΠ΅Ρ€Π΅ΠΌΠ΅ΡˆΠΈΠ²Π°Π΅ΠΌ, Ρ‡Ρ‚ΠΎΠ±Ρ‹ Π±Ρ€Π°Ρ‚ΡŒ ΠΈ ΡƒΠ΄Π°Π»ΡΡ‚ΡŒ ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Π½Π½Ρ‹Π΅ индСксы Π±Π°Ρ‚Ρ‡Π΅ΠΉ ΠΏΠΎ порядку
  dt <- data.table::data.table(id = sample(samples_index))
  # ΠŸΡ€ΠΎΡΡ‚Π°Π²Π»ΡΠ΅ΠΌ Π½ΠΎΠΌΠ΅Ρ€Π° Π±Π°Ρ‚Ρ‡Π΅ΠΉ
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  # ΠžΡΡ‚Π°Π²Π»ΡΠ΅ΠΌ Ρ‚ΠΎΠ»ΡŒΠΊΠΎ ΠΏΠΎΠ»Π½Ρ‹Π΅ Π±Π°Ρ‚Ρ‡ΠΈ ΠΈ индСксируСм
  dt <- dt[, if (.N == batch_size) .SD, keyby = batch]
  # УстанавливаСм счётчик
  i <- 1
  # ΠšΠΎΠ»ΠΈΡ‡Π΅ΡΡ‚Π²ΠΎ Π±Π°Ρ‚Ρ‡Π΅ΠΉ
  max_i <- dt[, max(batch)]

  # ΠŸΠΎΠ΄Π³ΠΎΡ‚ΠΎΠ²ΠΊΠ° выраТСния для Π²Ρ‹Π³Ρ€ΡƒΠ·ΠΊΠΈ
  sql <- sprintf(
    "PREPARE SELECT drawing, label_int FROM doodles WHERE id IN (%s)",
    paste(rep("?", batch_size), collapse = ",")
  )
  res <- DBI::dbSendQuery(con, sql)

  # Аналог keras::to_categorical
  to_categorical <- function(x, num) {
    n <- length(x)
    m <- numeric(n * num)
    m[x * n + seq_len(n)] <- 1
    dim(m) <- c(n, num)
    return(m)
  }

  # Π—Π°ΠΌΡ‹ΠΊΠ°Π½ΠΈΠ΅
  function() {
    # НачинаСм Π½ΠΎΠ²ΡƒΡŽ эпоху
    if (i > max_i) {
      dt[, id := sample(id)]
      data.table::setkey(dt, batch)
      # БбрасываСм счётчик
      i <<- 1
      max_i <<- dt[, max(batch)]
    }

    # ID для Π²Ρ‹Π³Ρ€ΡƒΠ·ΠΊΠΈ Π΄Π°Π½Π½Ρ‹Ρ…
    batch_ind <- dt[batch == i, id]
    # Π’Ρ‹Π³Ρ€ΡƒΠ·ΠΊΠ° Π΄Π°Π½Π½Ρ‹Ρ…
    batch <- DBI::dbFetch(DBI::dbBind(res, as.list(batch_ind)), n = -1)

    # Π£Π²Π΅Π»ΠΈΡ‡ΠΈΠ²Π°Π΅ΠΌ счётчик
    i <<- i + 1

    # ΠŸΠ°Ρ€ΡΠΈΠ½Π³ JSON ΠΈ ΠΏΠΎΠ΄Π³ΠΎΡ‚ΠΎΠ²ΠΊΠ° массива
    batch_x <- cpp_process_json_vector(batch$drawing, scale = scale, color = color)
    if (imagenet_preproc) {
      # Π¨ΠΊΠ°Π»ΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ c ΠΈΠ½Ρ‚Π΅Ρ€Π²Π°Π»Π° [0, 1] Π½Π° ΠΈΠ½Ρ‚Π΅Ρ€Π²Π°Π» [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }

    batch_y <- to_categorical(batch$label_int, num_classes)
    result <- list(batch_x, batch_y)
    return(result)
  }
}

Fungsi ini mengambil sebagai input pembolehubah dengan sambungan ke pangkalan data, bilangan baris yang digunakan, bilangan kelas, saiz kelompok, skala (scale = 1 sepadan dengan memaparkan imej 256x256 piksel, scale = 0.5 β€” 128x128 piksel), penunjuk warna (color = FALSE menentukan pemaparan dalam skala kelabu apabila digunakan color = TRUE setiap lejang dilukis dalam warna baharu) dan penunjuk prapemprosesan untuk rangkaian pra-latihan pada imagenet. Yang terakhir ini diperlukan untuk menskalakan nilai piksel dari selang [0, 1] kepada selang [-1, 1], yang digunakan semasa melatih yang dibekalkan. keras model.

Fungsi luaran mengandungi pemeriksaan jenis argumen, jadual data.table dengan nombor baris bercampur rawak daripada samples_index dan nombor kelompok, bilangan pembilang dan maksimum kelompok, serta ungkapan SQL untuk memunggah data daripada pangkalan data. Selain itu, kami menentukan analog pantas fungsi di dalamnya keras::to_categorical(). Kami menggunakan hampir semua data untuk latihan, meninggalkan setengah peratus untuk pengesahan, jadi saiz zaman dihadkan oleh parameter steps_per_epoch apabila dipanggil keras::fit_generator(), dan syaratnya if (i > max_i) hanya berfungsi untuk lelaran pengesahan.

Dalam fungsi dalaman, indeks baris diambil untuk kumpulan seterusnya, rekod dipunggah daripada pangkalan data dengan pembilang kelompok meningkat, penghuraian JSON (fungsi cpp_process_json_vector(), ditulis dalam C++) dan mencipta tatasusunan yang sepadan dengan gambar. Kemudian vektor satu panas dengan label kelas dicipta, tatasusunan dengan nilai piksel dan label digabungkan ke dalam senarai, yang merupakan nilai pulangan. Untuk mempercepatkan kerja, kami menggunakan penciptaan indeks dalam jadual data.table dan pengubahsuaian melalui pautan - tanpa pakej "cip" ini data.tabel Agak sukar untuk membayangkan bekerja dengan berkesan dengan sebarang jumlah data yang ketara dalam R.

Keputusan pengukuran kelajuan pada komputer riba Core i5 adalah seperti berikut:

Penanda aras iterator

library(Rcpp)
library(keras)
library(ggplot2)

source("utils/rcpp.R")
source("utils/keras_iterator.R")

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

ind <- seq_len(DBI::dbGetQuery(con, "SELECT count(*) FROM doodles")[[1L]])
num_classes <- DBI::dbGetQuery(con, "SELECT max(label_int) + 1 FROM doodles")[[1L]]

# Π˜Π½Π΄Π΅ΠΊΡΡ‹ для ΠΎΠ±ΡƒΡ‡Π°ΡŽΡ‰Π΅ΠΉ Π²Ρ‹Π±ΠΎΡ€ΠΊΠΈ
train_ind <- sample(ind, floor(length(ind) * 0.995))
# Π˜Π½Π΄Π΅ΠΊΡΡ‹ для ΠΏΡ€ΠΎΠ²Π΅Ρ€ΠΎΡ‡Π½ΠΎΠΉ Π²Ρ‹Π±ΠΎΡ€ΠΊΠΈ
val_ind <- ind[-train_ind]
rm(ind)
# ΠšΠΎΡΡ„Ρ„ΠΈΡ†ΠΈΠ΅Π½Ρ‚ ΠΌΠ°ΡΡˆΡ‚Π°Π±Π°
scale <- 0.5

# ΠŸΡ€ΠΎΠ²Π΅Π΄Π΅Π½ΠΈΠ΅ Π·Π°ΠΌΠ΅Ρ€Π°
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    it1 <- train_generator(
      db_connection = con,
      samples_index = train_ind,
      num_classes = num_classes,
      batch_size = batch_size,
      scale = scale
    )
    bench::mark(
      it1(),
      min_iterations = 50L
    )
  }
)
# ΠŸΠ°Ρ€Π°ΠΌΠ΅Ρ‚Ρ€Ρ‹ Π±Π΅Π½Ρ‡ΠΌΠ°Ρ€ΠΊΠ°
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16     25ms  64.36ms   92.2ms     15.9       3.09s    49
# 2         32   48.4ms 118.13ms 197.24ms     8.17       5.88s    48
# 3         64   69.3ms 117.93ms 181.14ms     8.57       5.83s    50
# 4        128  157.2ms 240.74ms 503.87ms     3.85      12.71s    49
# 5        256  359.3ms 613.52ms 988.73ms     1.54       30.5s    47
# 6        512  884.7ms    1.53s    2.07s     0.674      1.11m    45
# 7       1024     2.7s    3.83s    5.47s     0.261      2.81m    44

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
    geom_point() +
    geom_line() +
    ylab("median time, s") +
    theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Pengecaman Doodle Draw Pantas: cara berkawan dengan R, C++ dan rangkaian saraf

Jika anda mempunyai jumlah RAM yang mencukupi, anda boleh mempercepatkan operasi pangkalan data dengan serius dengan memindahkannya ke RAM yang sama ini (32 GB sudah cukup untuk tugas kami). Di Linux, partition dipasang secara lalai /dev/shm, menduduki sehingga separuh kapasiti RAM. Anda boleh menyerlahkan lagi dengan mengedit /etc/fstabuntuk mendapatkan rekod seperti tmpfs /dev/shm tmpfs defaults,size=25g 0 0. Pastikan anda but semula dan semak hasilnya dengan menjalankan arahan df -h.

Peulang untuk data ujian kelihatan lebih mudah, kerana set data ujian sesuai sepenuhnya dengan RAM:

Iterator untuk data ujian

test_generator <- function(dt,
                           batch_size = 32,
                           scale = 1,
                           color = FALSE,
                           imagenet_preproc = FALSE) {

  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° Π°Ρ€Π³ΡƒΠΌΠ΅Π½Ρ‚ΠΎΠ²
  checkmate::assert_data_table(dt)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # ΠŸΡ€ΠΎΡΡ‚Π°Π²Π»ΡΠ΅ΠΌ Π½ΠΎΠΌΠ΅Ρ€Π° Π±Π°Ρ‚Ρ‡Π΅ΠΉ
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  data.table::setkey(dt, batch)
  i <- 1
  max_i <- dt[, max(batch)]

  # Π—Π°ΠΌΡ‹ΠΊΠ°Π½ΠΈΠ΅
  function() {
    batch_x <- cpp_process_json_vector(dt[batch == i, drawing], 
                                       scale = scale, color = color)
    if (imagenet_preproc) {
      # Π¨ΠΊΠ°Π»ΠΈΡ€ΠΎΠ²Π°Π½ΠΈΠ΅ c ΠΈΠ½Ρ‚Π΅Ρ€Π²Π°Π»Π° [0, 1] Π½Π° ΠΈΠ½Ρ‚Π΅Ρ€Π²Π°Π» [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }
    result <- list(batch_x)
    i <<- i + 1
    return(result)
  }
}

4. Pemilihan seni bina model

Seni bina pertama yang digunakan ialah mobilenet v1, ciri yang dibincangkan dalam ini mesej. Ia disertakan sebagai standard keras dan, oleh itu, tersedia dalam pakej dengan nama yang sama untuk R. Tetapi apabila cuba menggunakannya dengan imej saluran tunggal, satu perkara yang pelik ternyata: tensor input mesti sentiasa mempunyai dimensi (batch, height, width, 3), iaitu bilangan saluran tidak boleh diubah. Tiada batasan sedemikian dalam Python, jadi kami tergesa-gesa dan menulis pelaksanaan seni bina ini sendiri, mengikut artikel asal (tanpa keciciran dalam versi keras):

Seni bina Mobilenet v1

library(keras)

top_3_categorical_accuracy <- custom_metric(
    name = "top_3_categorical_accuracy",
    metric_fn = function(y_true, y_pred) {
         metric_top_k_categorical_accuracy(y_true, y_pred, k = 3)
    }
)

layer_sep_conv_bn <- function(object, 
                              filters,
                              alpha = 1,
                              depth_multiplier = 1,
                              strides = c(2, 2)) {

  # NB! depth_multiplier !=  resolution multiplier
  # https://github.com/keras-team/keras/issues/10349

  layer_depthwise_conv_2d(
    object = object,
    kernel_size = c(3, 3), 
    strides = strides,
    padding = "same",
    depth_multiplier = depth_multiplier
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() %>%
  layer_conv_2d(
    filters = filters * alpha,
    kernel_size = c(1, 1), 
    strides = c(1, 1)
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() 
}

get_mobilenet_v1 <- function(input_shape = c(224, 224, 1),
                             num_classes = 340,
                             alpha = 1,
                             depth_multiplier = 1,
                             optimizer = optimizer_adam(lr = 0.002),
                             loss = "categorical_crossentropy",
                             metrics = c("categorical_crossentropy",
                                         top_3_categorical_accuracy)) {

  inputs <- layer_input(shape = input_shape)

  outputs <- inputs %>%
    layer_conv_2d(filters = 32, kernel_size = c(3, 3), strides = c(2, 2), padding = "same") %>%
    layer_batch_normalization() %>% 
    layer_activation_relu() %>%
    layer_sep_conv_bn(filters = 64, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(1, 1)) %>%
    layer_global_average_pooling_2d() %>%
    layer_dense(units = num_classes) %>%
    layer_activation_softmax()

    model <- keras_model(
      inputs = inputs,
      outputs = outputs
    )

    model %>% compile(
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )

    return(model)
}

Kelemahan pendekatan ini adalah jelas. Saya ingin menguji banyak model, tetapi sebaliknya, saya tidak mahu menulis semula setiap seni bina secara manual. Kami juga telah kehilangan peluang untuk menggunakan berat model yang telah dilatih di imagenet. Seperti biasa, mengkaji dokumentasi membantu. Fungsi get_config() membolehkan anda mendapatkan penerangan tentang model dalam bentuk yang sesuai untuk diedit (base_model_conf$layers - senarai R biasa), dan fungsi from_config() melakukan penukaran terbalik kepada objek model:

base_model_conf <- get_config(base_model)
base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
base_model <- from_config(base_model_conf)

Kini tidak sukar untuk menulis fungsi universal untuk mendapatkan mana-mana yang dibekalkan keras model dengan atau tanpa pemberat yang dilatih pada imagenet:

Fungsi untuk memuatkan seni bina siap sedia

get_model <- function(name = "mobilenet_v2",
                      input_shape = NULL,
                      weights = "imagenet",
                      pooling = "avg",
                      num_classes = NULL,
                      optimizer = keras::optimizer_adam(lr = 0.002),
                      loss = "categorical_crossentropy",
                      metrics = NULL,
                      color = TRUE,
                      compile = FALSE) {
  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° Π°Ρ€Π³ΡƒΠΌΠ΅Π½Ρ‚ΠΎΠ²
  checkmate::assert_string(name)
  checkmate::assert_integerish(input_shape, lower = 1, upper = 256, len = 3)
  checkmate::assert_count(num_classes)
  checkmate::assert_flag(color)
  checkmate::assert_flag(compile)

  # ΠŸΠΎΠ»ΡƒΡ‡Π°Π΅ΠΌ ΠΎΠ±ΡŠΠ΅ΠΊΡ‚ ΠΈΠ· ΠΏΠ°ΠΊΠ΅Ρ‚Π° keras
  model_fun <- get0(paste0("application_", name), envir = asNamespace("keras"))
  # ΠŸΡ€ΠΎΠ²Π΅Ρ€ΠΊΠ° наличия ΠΎΠ±ΡŠΠ΅ΠΊΡ‚Π° Π² ΠΏΠ°ΠΊΠ΅Ρ‚Π΅
  if (is.null(model_fun)) {
    stop("Model ", shQuote(name), " not found.", call. = FALSE)
  }

  base_model <- model_fun(
    input_shape = input_shape,
    include_top = FALSE,
    weights = weights,
    pooling = pooling
  )

  # Если ΠΈΠ·ΠΎΠ±Ρ€Π°ΠΆΠ΅Π½ΠΈΠ΅ Π½Π΅ Ρ†Π²Π΅Ρ‚Π½ΠΎΠ΅, мСняСм Ρ€Π°Π·ΠΌΠ΅Ρ€Π½ΠΎΡΡ‚ΡŒ Π²Ρ…ΠΎΠ΄Π°
  if (!color) {
    base_model_conf <- keras::get_config(base_model)
    base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
    base_model <- keras::from_config(base_model_conf)
  }

  predictions <- keras::get_layer(base_model, "global_average_pooling2d_1")$output
  predictions <- keras::layer_dense(predictions, units = num_classes, activation = "softmax")
  model <- keras::keras_model(
    inputs = base_model$input,
    outputs = predictions
  )

  if (compile) {
    keras::compile(
      object = model,
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )
  }

  return(model)
}

Apabila menggunakan imej saluran tunggal, tiada pemberat terlatih digunakan. Ini boleh diperbaiki: menggunakan fungsi get_weights() dapatkan pemberat model dalam bentuk senarai tatasusunan R, tukar dimensi elemen pertama senarai ini (dengan mengambil satu saluran warna atau purata ketiga-tiganya), dan kemudian muatkan pemberat semula ke dalam model dengan fungsi set_weights(). Kami tidak pernah menambah fungsi ini, kerana pada peringkat ini sudah jelas bahawa ia adalah lebih produktif untuk bekerja dengan gambar berwarna.

Kami menjalankan kebanyakan percubaan menggunakan mobilenet versi 1 dan 2, serta resnet34. Seni bina yang lebih moden seperti SE-ResNeXt menunjukkan prestasi yang baik dalam pertandingan ini. Malangnya, kami tidak mempunyai pelaksanaan siap sedia untuk kami gunakan, dan kami tidak menulis sendiri (tetapi kami pasti akan menulis).

5. Parameterisasi skrip

Untuk kemudahan, semua kod untuk memulakan latihan direka bentuk sebagai skrip tunggal, menggunakan parameter docopt seperti berikut:

doc <- '
Usage:
  train_nn.R --help
  train_nn.R --list-models
  train_nn.R [options]

Options:
  -h --help                   Show this message.
  -l --list-models            List available models.
  -m --model=<model>          Neural network model name [default: mobilenet_v2].
  -b --batch-size=<size>      Batch size [default: 32].
  -s --scale-factor=<ratio>   Scale factor [default: 0.5].
  -c --color                  Use color lines [default: FALSE].
  -d --db-dir=<path>          Path to database directory [default: Sys.getenv("db_dir")].
  -r --validate-ratio=<ratio> Validate sample ratio [default: 0.995].
  -n --n-gpu=<number>         Number of GPUs [default: 1].
'
args <- docopt::docopt(doc)

Pakej docopt mewakili pelaksanaan http://docopt.org/ untuk R. Dengan bantuannya, skrip dilancarkan dengan arahan mudah seperti Rscript bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db atau ./bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db, jika fail train_nn.R boleh laku (arahan ini akan mula melatih model resnet50 pada imej tiga warna berukuran 128x128 piksel, pangkalan data mesti terletak dalam folder /home/andrey/doodle_db). Anda boleh menambah kelajuan pembelajaran, jenis pengoptimum dan sebarang parameter lain yang boleh disesuaikan pada senarai. Dalam proses penyediaan penerbitan, ternyata seni bina mobilenet_v2 daripada versi semasa keras dalam penggunaan R tidak boleh disebabkan perubahan tidak diambil kira dalam pakej R, kami sedang menunggu mereka untuk membetulkannya.

Pendekatan ini memungkinkan untuk mempercepatkan eksperimen dengan model yang berbeza dengan ketara berbanding dengan pelancaran skrip yang lebih tradisional dalam RStudio (kami perhatikan pakej sebagai alternatif yang mungkin tfruns). Tetapi kelebihan utama ialah keupayaan untuk menguruskan pelancaran skrip dengan mudah di Docker atau hanya pada pelayan, tanpa memasang RStudio untuk ini.

6. Dockerization skrip

Kami menggunakan Docker untuk memastikan kemudahalihan persekitaran untuk model latihan antara ahli pasukan dan untuk penggunaan pantas dalam awan. Anda boleh mula berkenalan dengan alat ini, yang agak luar biasa untuk pengaturcara R, dengan ini siri penerbitan atau kursus video.

Docker membolehkan anda membuat imej anda sendiri dari awal dan menggunakan imej lain sebagai asas untuk mencipta imej anda sendiri. Apabila menganalisis pilihan yang tersedia, kami membuat kesimpulan bahawa memasang NVIDIA, pemacu CUDA+cuDNN dan perpustakaan Python adalah bahagian imej yang agak besar, dan kami memutuskan untuk mengambil imej rasmi sebagai asas. tensorflow/tensorflow:1.12.0-gpu, menambah pakej R yang diperlukan di sana.

Fail docker terakhir kelihatan seperti ini:

Dockerfile

FROM tensorflow/tensorflow:1.12.0-gpu

MAINTAINER Artem Klevtsov <[email protected]>

SHELL ["/bin/bash", "-c"]

ARG LOCALE="en_US.UTF-8"
ARG APT_PKG="libopencv-dev r-base r-base-dev littler"
ARG R_BIN_PKG="futile.logger checkmate data.table rcpp rapidjsonr dbi keras jsonlite curl digest remotes"
ARG R_SRC_PKG="xtensor RcppThread docopt MonetDBLite"
ARG PY_PIP_PKG="keras"
ARG DIRS="/db /app /app/data /app/models /app/logs"

RUN source /etc/os-release && 
    echo "deb https://cloud.r-project.org/bin/linux/ubuntu ${UBUNTU_CODENAME}-cran35/" > /etc/apt/sources.list.d/cran35.list && 
    apt-key adv --keyserver keyserver.ubuntu.com --recv-keys E084DAB9 && 
    add-apt-repository -y ppa:marutter/c2d4u3.5 && 
    add-apt-repository -y ppa:timsc/opencv-3.4 && 
    apt-get update && 
    apt-get install -y locales && 
    locale-gen ${LOCALE} && 
    apt-get install -y --no-install-recommends ${APT_PKG} && 
    ln -s /usr/lib/R/site-library/littler/examples/install.r /usr/local/bin/install.r && 
    ln -s /usr/lib/R/site-library/littler/examples/install2.r /usr/local/bin/install2.r && 
    ln -s /usr/lib/R/site-library/littler/examples/installGithub.r /usr/local/bin/installGithub.r && 
    echo 'options(Ncpus = parallel::detectCores())' >> /etc/R/Rprofile.site && 
    echo 'options(repos = c(CRAN = "https://cloud.r-project.org"))' >> /etc/R/Rprofile.site && 
    apt-get install -y $(printf "r-cran-%s " ${R_BIN_PKG}) && 
    install.r ${R_SRC_PKG} && 
    pip install ${PY_PIP_PKG} && 
    mkdir -p ${DIRS} && 
    chmod 777 ${DIRS} && 
    rm -rf /tmp/downloaded_packages/ /tmp/*.rds && 
    rm -rf /var/lib/apt/lists/*

COPY utils /app/utils
COPY src /app/src
COPY tests /app/tests
COPY bin/*.R /app/

ENV DBDIR="/db"
ENV CUDA_HOME="/usr/local/cuda"
ENV PATH="/app:${PATH}"

WORKDIR /app

VOLUME /db
VOLUME /app

CMD bash

Untuk kemudahan, pakej yang digunakan telah dimasukkan ke dalam pembolehubah; sebahagian besar skrip bertulis disalin di dalam bekas semasa pemasangan. Kami juga menukar shell arahan kepada /bin/bash untuk kemudahan penggunaan kandungan /etc/os-release. Ini mengelakkan keperluan untuk menentukan versi OS dalam kod.

Selain itu, skrip bash kecil telah ditulis yang membolehkan anda melancarkan bekas dengan pelbagai arahan. Sebagai contoh, ini boleh menjadi skrip untuk melatih rangkaian saraf yang sebelum ini diletakkan di dalam bekas, atau shell arahan untuk nyahpepijat dan memantau operasi bekas:

Skrip untuk melancarkan bekas

#!/bin/sh

DBDIR=${PWD}/db
LOGSDIR=${PWD}/logs
MODELDIR=${PWD}/models
DATADIR=${PWD}/data
ARGS="--runtime=nvidia --rm -v ${DBDIR}:/db -v ${LOGSDIR}:/app/logs -v ${MODELDIR}:/app/models -v ${DATADIR}:/app/data"

if [ -z "$1" ]; then
    CMD="Rscript /app/train_nn.R"
elif [ "$1" = "bash" ]; then
    ARGS="${ARGS} -ti"
else
    CMD="Rscript /app/train_nn.R $@"
fi

docker run ${ARGS} doodles-tf ${CMD}

Jika skrip bash ini dijalankan tanpa parameter, skrip akan dipanggil di dalam bekas train_nn.R dengan nilai lalai; jika hujah kedudukan pertama ialah "bash", maka bekas akan bermula secara interaktif dengan shell arahan. Dalam semua kes lain, nilai hujah kedudukan digantikan: CMD="Rscript /app/train_nn.R $@".

Perlu diingat bahawa direktori dengan data sumber dan pangkalan data, serta direktori untuk menyimpan model terlatih, dipasang di dalam bekas dari sistem hos, yang membolehkan anda mengakses hasil skrip tanpa manipulasi yang tidak perlu.

7. Menggunakan berbilang GPU pada Google Cloud

Salah satu ciri pertandingan adalah data yang sangat bising (lihat gambar tajuk, dipinjam daripada @Leigh.plt dari ODS slack). Kumpulan besar membantu memerangi ini, dan selepas percubaan pada PC dengan 1 GPU, kami memutuskan untuk menguasai model latihan pada beberapa GPU dalam awan. GoogleCloud yang digunakan (panduan yang baik kepada asas) disebabkan oleh banyak pilihan konfigurasi yang tersedia, harga yang berpatutan dan bonus $300. Kerana tamak, saya memesan contoh 4xV100 dengan SSD dan satu tan RAM, dan itu adalah kesilapan besar. Mesin sedemikian memakan wang dengan cepat; anda boleh gagal bereksperimen tanpa saluran paip yang terbukti. Untuk tujuan pendidikan, lebih baik mengambil K80. Tetapi jumlah RAM yang besar berguna - SSD awan tidak menarik perhatian dengan prestasinya, jadi pangkalan data dipindahkan ke dev/shm.

Yang paling menarik ialah serpihan kod yang bertanggungjawab untuk menggunakan berbilang GPU. Pertama, model dibuat pada CPU menggunakan pengurus konteks, sama seperti dalam Python:

with(tensorflow::tf$device("/cpu:0"), {
  model_cpu <- get_model(
    name = model_name,
    input_shape = input_shape,
    weights = weights,
    metrics =(top_3_categorical_accuracy,
    compile = FALSE
  )
})

Kemudian model yang tidak tersusun (ini penting) disalin ke beberapa GPU yang tersedia, dan hanya selepas itu ia disusun:

model <- keras::multi_gpu_model(model_cpu, gpus = n_gpu)
keras::compile(
  object = model,
  optimizer = keras::optimizer_adam(lr = 0.0004),
  loss = "categorical_crossentropy",
  metrics = c(top_3_categorical_accuracy)
)

Teknik klasik membekukan semua lapisan kecuali yang terakhir, melatih lapisan terakhir, menyahbeku dan melatih semula keseluruhan model untuk beberapa GPU tidak dapat dilaksanakan.

Latihan dipantau tanpa digunakan. papan tensor, mengehadkan diri kami untuk merekodkan log dan menyimpan model dengan nama bermaklumat selepas setiap zaman:

Panggilan balik

# Π¨Π°Π±Π»ΠΎΠ½ ΠΈΠΌΠ΅Π½ΠΈ Ρ„Π°ΠΉΠ»Π° Π»ΠΎΠ³Π°
log_file_tmpl <- file.path("logs", sprintf(
  "%s_%d_%dch_%s.csv",
  model_name,
  dim_size,
  channels,
  format(Sys.time(), "%Y%m%d%H%M%OS")
))
# Π¨Π°Π±Π»ΠΎΠ½ ΠΈΠΌΠ΅Π½ΠΈ Ρ„Π°ΠΉΠ»Π° ΠΌΠΎΠ΄Π΅Π»ΠΈ
model_file_tmpl <- file.path("models", sprintf(
  "%s_%d_%dch_{epoch:02d}_{val_loss:.2f}.h5",
  model_name,
  dim_size,
  channels
))

callbacks_list <- list(
  keras::callback_csv_logger(
    filename = log_file_tmpl
  ),
  keras::callback_early_stopping(
    monitor = "val_loss",
    min_delta = 1e-4,
    patience = 8,
    verbose = 1,
    mode = "min"
  ),
  keras::callback_reduce_lr_on_plateau(
    monitor = "val_loss",
    factor = 0.5, # ΡƒΠΌΠ΅Π½ΡŒΡˆΠ°Π΅ΠΌ lr Π² 2 Ρ€Π°Π·Π°
    patience = 4,
    verbose = 1,
    min_delta = 1e-4,
    mode = "min"
  ),
  keras::callback_model_checkpoint(
    filepath = model_file_tmpl,
    monitor = "val_loss",
    save_best_only = FALSE,
    save_weights_only = FALSE,
    mode = "min"
  )
)

8. Daripada kesimpulan

Beberapa masalah yang kami hadapi masih belum dapat diatasi:

  • Π² keras tiada fungsi sedia untuk mencari kadar pembelajaran optimum secara automatik (analog lr_finder di perpustakaan pantas.ai); Dengan sedikit usaha, adalah mungkin untuk memindahkan pelaksanaan pihak ketiga ke R, sebagai contoh, ini;
  • akibat daripada perkara sebelumnya, adalah tidak mungkin untuk memilih kelajuan latihan yang betul apabila menggunakan beberapa GPU;
  • terdapat kekurangan seni bina rangkaian saraf moden, terutamanya yang telah dilatih di imagenet;
  • tiada satu kitaran dasar dan kadar pembelajaran diskriminatif (penyepuhlindapan kosinus adalah atas permintaan kami dilaksanakan, terima kasih skeydan).

Apakah perkara berguna yang dipelajari daripada pertandingan ini:

  • Pada perkakasan kuasa yang agak rendah, anda boleh bekerja dengan volum data yang baik (berkali-kali ganda saiz RAM) tanpa rasa sakit. Beg plastik data.tabel menjimatkan memori kerana pengubahsuaian jadual di tempat, yang mengelak daripada menyalinnya, dan apabila digunakan dengan betul, keupayaannya hampir selalu menunjukkan kelajuan tertinggi antara semua alat yang kami ketahui untuk bahasa skrip. Menyimpan data dalam pangkalan data membolehkan anda, dalam banyak kes, tidak memikirkan sama sekali tentang keperluan untuk memerah keseluruhan dataset ke dalam RAM.
  • Fungsi perlahan dalam R boleh digantikan dengan fungsi pantas dalam C++ menggunakan pakej Rcpp. Jika di samping menggunakan RcppThread atau RcppParallel, kami mendapat pelaksanaan berbilang benang merentas platform, jadi tidak perlu menyamakan kod pada tahap R.
  • Pakej Rcpp boleh digunakan tanpa pengetahuan serius tentang C++, minimum yang diperlukan digariskan di sini. Fail pengepala untuk beberapa perpustakaan C yang hebat seperti xtensor tersedia pada CRAN, iaitu, infrastruktur sedang dibentuk untuk pelaksanaan projek yang mengintegrasikan kod C++ berprestasi tinggi siap sedia ke dalam R. Kemudahan tambahan ialah penonjolan sintaks dan penganalisis kod C++ statik dalam RStudio.
  • docopt membolehkan anda menjalankan skrip serba lengkap dengan parameter. Ini mudah digunakan pada pelayan jauh, termasuk. di bawah buruh pelabuhan. Dalam RStudio, adalah menyusahkan untuk menjalankan banyak jam eksperimen dengan melatih rangkaian saraf, dan memasang IDE pada pelayan itu sendiri tidak selalu wajar.
  • Docker memastikan kemudahalihan kod dan kebolehulangan hasil antara pembangun dengan versi OS dan perpustakaan yang berbeza, serta kemudahan pelaksanaan pada pelayan. Anda boleh melancarkan keseluruhan saluran latihan dengan hanya satu arahan.
  • Google Cloud ialah cara mesra bajet untuk mencuba perkakasan yang mahal, tetapi anda perlu memilih konfigurasi dengan berhati-hati.
  • Mengukur kelajuan serpihan kod individu sangat berguna, terutamanya apabila menggabungkan R dan C++, dan dengan pakej bangku - juga sangat mudah.

Secara keseluruhannya pengalaman ini sangat bermanfaat dan kami terus berusaha untuk menyelesaikan beberapa isu yang dibangkitkan.

Sumber: www.habr.com

Tambah komen