Faʻailoa vave Doodle: faʻafefea ona faauo i le R, C ++ ma fesoʻotaʻiga neural

Faʻailoa vave Doodle: faʻafefea ona faauo i le R, C ++ ma fesoʻotaʻiga neural

Ei Habr!

O le tautoulu talu ai, na talimalo ai Kaggle i se tauvaga e faavasega ai ata tusilima, Quick Draw Doodle Recognition, lea, faatasi ai ma isi, o se vaega o R-saienitisi na auai: Artem Klevtsova, Philippa Pule и Andrey Ogurtsov. Matou te le faʻamatalaina auiliili le tauvaga; ua uma ona faia i totonu lomiga lata mai.

O le taimi lenei e leʻi manuia i le faʻatoʻaga pine, ae o le tele o mea taua na maua, o lea ou te fia taʻu atu ai i le nuʻu e uiga i le tele o mea sili ona manaia ma aoga i luga o Kagle ma galuega i aso uma. Faatasi ai ma autu na talanoaina: olaga faigata e aunoa ma OpenCV, JSON parsing (o faʻataʻitaʻiga nei o loʻo suʻesuʻeina le tuʻufaʻatasia o le C++ code i tusitusiga poʻo afifi i le R faʻaaoga Rcpp), fa'avasegaina o fa'amaumauga ma fa'amautu o le fofo mulimuli. O code uma mai le fe'au i se fomu e talafeagai mo le fa'atinoina o lo'o avanoa i totonu faleteuoloa.

Faʻamatalaga:

  1. Tu'u lelei fa'amaumauga mai le CSV ile MonetDB
  2. Saunia vaega
  3. Iterators mo le la'uina o vaega mai le database
  4. Filifilia o se Fa'ata'ita'iga Fa'ata'ita'i
  5. Fa'asologa o tusitusiga
  6. Dockerization o tusitusiga
  7. Fa'aaogā le tele o GPU ile Google Cloud
  8. Nai lo o se faaiuga

1. Tu'u lelei fa'amaumauga mai le CSV i totonu o le MonetDB database

O faʻamaumauga i lenei tauvaga e le tuʻuina atu i foliga o ata ua saunia, ae i le tulaga o 340 CSV faila (tasi faila mo vasega taʻitasi) o loʻo i ai JSONs faʻatasi ai ma faʻamaufaʻailoga. O le fa'afeso'ota'i o nei manatu i laina, tatou te maua ai se ata mulimuli e 256x256 pika. E faʻapea foʻi mo faʻamaumauga taʻitasi o loʻo i ai se faʻailoga e faʻaalia ai pe saʻo le ata na iloa e le faʻavasegaina na faʻaaogaina i le taimi na aoina ai le faʻamaumauga, o se numera lua mataitusi o le atunuu o loʻo nofo ai le tusitala o le ata, o se faʻamatalaga tulaga ese, se faʻailoga taimi. ma se igoa vasega e fetaui ma le igoa faila. O se fa'amatalaga faigofie o fa'amaumauga muamua e mamafa le 7.4 GB i le fa'amaumauga ma pe a ma le 20 GB pe a uma ona tatala, o fa'amaumauga atoa pe a uma ona tatala e ave le 240 GB. Na fa'amautinoa e le au fa'atonu o fa'aliliuga uma e lua na toe faia ata tutusa, o lona uiga o le fa'asologa atoa e le toe fa'aaogaina. I soo se tulaga, o le teuina o le 50 miliona ata i faila kalafi poʻo i le tulaga o faʻasologa na vave lava ona manatu e le aoga, ma na matou filifili e tuʻufaʻatasia uma faila CSV mai le faʻamaumauga. train_simplified.zip i totonu o faʻamaumauga faʻatasi ma faʻasologa mulimuli ane o ata o le tele manaʻomia "i luga o le lele" mo vaega taʻitasi.

O se faiga fa'amaonia lelei na filifilia e pei o le DBMS MonetDB, e taʻua o se faʻatinoga mo R o se afifi MonetDBLite. O le afifi e aofia ai se faʻapipiʻi faʻapipiʻi o le database server ma faʻatagaina oe e piki saʻo le server mai se R session ma galulue faʻatasi ai iina. Fausia se faʻamaumauga ma fesoʻotaʻi i ai e faia i le tasi poloaiga:

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

Matou te manaʻomia le faia o ni laulau se lua: tasi mo faʻamaumauga uma, o le isi mo faʻamatalaga o auaunaga e uiga i faila na sii mai (e aoga pe a tupu se mea ma e tatau ona toe amata le faagasologa pe a uma ona sii mai ni faila):

Fausia laulau

if (!DBI::dbExistsTable(con, "doodles")) {
  DBI::dbCreateTable(
    con = con,
    name = "doodles",
    fields = c(
      "countrycode" = "char(2)",
      "drawing" = "text",
      "key_id" = "bigint",
      "recognized" = "bool",
      "timestamp" = "timestamp",
      "word" = "text"
    )
  )
}

if (!DBI::dbExistsTable(con, "upload_log")) {
  DBI::dbCreateTable(
    con = con,
    name = "upload_log",
    fields = c(
      "id" = "serial",
      "file_name" = "text UNIQUE",
      "uploaded" = "bool DEFAULT false"
    )
  )
}

O le auala sili ona vave e utaina ai faʻamatalaga i totonu o faʻamaumauga o le kopi saʻo o faila CSV e faʻaaoga ai le SQL - command COPY OFFSET 2 INTO tablename FROM path USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORTfea tablename - igoa laulau ma path - le ala i le faila. Aʻo galue ma le faʻamaumauga, na maua ai le faʻatinoina o le fausiaina unzip i le R e le galue saʻo ma le tele o faila mai le faʻamaumauga, o lea na matou faʻaogaina ai le faiga unzip (faʻaaogaina le parakalafa getOption("unzip")).

Galuega mo le tusitusi i le database

#' @title Извлечение и загрузка файлов
#'
#' @description
#' Извлечение CSV-файлов из ZIP-архива и загрузка их в базу данных
#'
#' @param con Объект подключения к базе данных (класс `MonetDBEmbeddedConnection`).
#' @param tablename Название таблицы в базе данных.
#' @oaram zipfile Путь к ZIP-архиву.
#' @oaram filename Имя файла внури ZIP-архива.
#' @param preprocess Функция предобработки, которая будет применена извлечённому файлу.
#'   Должна принимать один аргумент `data` (объект `data.table`).
#'
#' @return `TRUE`.
#'
upload_file <- function(con, tablename, zipfile, filename, preprocess = NULL) {
  # Проверка аргументов
  checkmate::assert_class(con, "MonetDBEmbeddedConnection")
  checkmate::assert_string(tablename)
  checkmate::assert_string(filename)
  checkmate::assert_true(DBI::dbExistsTable(con, tablename))
  checkmate::assert_file_exists(zipfile, access = "r", extension = "zip")
  checkmate::assert_function(preprocess, args = c("data"), null.ok = TRUE)

  # Извлечение файла
  path <- file.path(tempdir(), filename)
  unzip(zipfile, files = filename, exdir = tempdir(), 
        junkpaths = TRUE, unzip = getOption("unzip"))
  on.exit(unlink(file.path(path)))

  # Применяем функция предобработки
  if (!is.null(preprocess)) {
    .data <- data.table::fread(file = path)
    .data <- preprocess(data = .data)
    data.table::fwrite(x = .data, file = path, append = FALSE)
    rm(.data)
  }

  # Запрос к БД на импорт CSV
  sql <- sprintf(
    "COPY OFFSET 2 INTO %s FROM '%s' USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORT",
    tablename, path
  )
  # Выполнение запроса к БД
  DBI::dbExecute(con, sql)

  # Добавление записи об успешной загрузке в служебную таблицу
  DBI::dbExecute(con, sprintf("INSERT INTO upload_log(file_name, uploaded) VALUES('%s', true)",
                              filename))

  return(invisible(TRUE))
}

Afai e te manaʻomia le suia o le laulau aʻo leʻi tusia i le database, ua lava e pasi i le finauga preprocess galuega e fa'aliliuina ai fa'amaumauga.

Fa'ailoga mo le fa'asolo fa'asologa o fa'amaumauga i totonu o fa'amaumauga:

Tusia fa'amaumauga ile fa'amaumauga

# Список файлов для записи
files <- unzip(zipfile, list = TRUE)$Name

# Список исключений, если часть файлов уже была загружена
to_skip <- DBI::dbGetQuery(con, "SELECT file_name FROM upload_log")[[1L]]
files <- setdiff(files, to_skip)

if (length(files) > 0L) {
  # Запускаем таймер
  tictoc::tic()
  # Прогресс бар
  pb <- txtProgressBar(min = 0L, max = length(files), style = 3)
  for (i in seq_along(files)) {
    upload_file(con = con, tablename = "doodles", 
                zipfile = zipfile, filename = files[i])
    setTxtProgressBar(pb, i)
  }
  close(pb)
  # Останавливаем таймер
  tictoc::toc()
}

# 526.141 sec elapsed - копирование SSD->SSD
# 558.879 sec elapsed - копирование USB->SSD

Ole taimi ole utaina o fa'amatalaga e mafai ona fesuisuia'i e fa'atatau ile saoasaoa uiga ole ta'avale fa'aaoga. I lo matou tulaga, o le faitau ma le tusitusi i totonu o le SSD e tasi poʻo le mai le flash drive (puna faila) i le SSD (DB) e itiiti ifo i le 10 minute.

E mana'omia ni nai sekone e fai ai se koluma ma se fa'ailoga o le vasega numera ma se koluma fa'asino (ORDERED INDEX) fa'atasi ai ma numera o laina e fa'ata'ita'iina ai fa'ata'ita'iga pe a faia ni vaega:

Fausia o Koluma Faaopoopo ma Faasino Upu

message("Generate lables")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD label_int int"))
invisible(DBI::dbExecute(con, "UPDATE doodles SET label_int = dense_rank() OVER (ORDER BY word) - 1"))

message("Generate row numbers")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD id serial"))
invisible(DBI::dbExecute(con, "CREATE ORDERED INDEX doodles_id_ord_idx ON doodles(id)"))

Ina ia foia le faafitauli o le fatuina o se vaega i luga o le lele, matou te manaʻomia le ausia o le saoasaoa maualuga o le aveeseina o laina faʻafuaseʻi mai le laulau. doodles. Mo lenei mea na matou faʻaaogaina 3 togafiti. O le mea muamua o le faʻaititia o le dimensionality o le ituaiga o loʻo teuina le ID mataʻituina. I le seti faʻamaumauga muamua, o le ituaiga e manaʻomia e teu ai le ID o bigint, ae o le numera o faʻamatalaga e mafai ai ona faʻafetaui a latou faʻamatalaga, tutusa ma le numera faʻasologa, i le ituaiga int. O le sailiga e sili atu le vave i lenei tulaga. O le togafiti lona lua o le faʻaaogaina ORDERED INDEX - na matou oʻo mai i lenei faʻaiuga faʻapitoa, ina ua uma ona faʻaogaina mea uma na maua filifiliga. O le lona tolu o le faʻaaogaina o faʻataʻitaʻiga fesili. O le ute o le metotia o le faʻatinoina o le faʻatonuga tasi PREPARE faʻatasi ai ma le faʻaogaina mulimuli ane o se faʻamatalaga saunia pe a fatuina se tele o fesili o le ituaiga tutusa, ae o le mea moni o loʻo i ai se avanoa pe a faʻatusatusa i se faigofie. SELECT na aliali mai o lo'o i totonu ole va'aiga o mea sese fa'afuainumera.

O le faʻagasologa o le tuʻuina atu o faʻamatalaga e le sili atu i le 450 MB o le RAM. O lona uiga, o le auala faʻamatalaina e mafai ai ona e faʻanofoina faʻamaumauga e mamafa le sefulu o gigabytes i luga o le toetoe lava o soʻo se meafaigaluega faʻatupeina, e aofia ai nisi masini laupapa e tasi, lea e manaia tele.

Pau lava le mea o loʻo totoe o le fuaina lea o le saoasaoa o le toe maua mai (faʻafuaseʻi) faʻamaumauga ma iloilo le faʻavasegaina pe a faʻataʻitaʻiina vaega o lapopoa eseese:

Fa'ailoga fa'amaumauga

library(ggplot2)

set.seed(0)
# Подключение к базе данных
con <- DBI::dbConnect(MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

# Функция для подготовки запроса на стороне сервера
prep_sql <- function(batch_size) {
  sql <- sprintf("PREPARE SELECT id FROM doodles WHERE id IN (%s)",
                 paste(rep("?", batch_size), collapse = ","))
  res <- DBI::dbSendQuery(con, sql)
  return(res)
}

# Функция для извлечения данных
fetch_data <- function(rs, batch_size) {
  ids <- sample(seq_len(n), batch_size)
  res <- DBI::dbFetch(DBI::dbBind(rs, as.list(ids)))
  return(res)
}

# Проведение замера
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    rs <- prep_sql(batch_size)
    bench::mark(
      fetch_data(rs, batch_size),
      min_iterations = 50L
    )
  }
)
# Параметры бенчмарка
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16   23.6ms  54.02ms  93.43ms     18.8        2.6s    49
# 2         32     38ms  84.83ms 151.55ms     11.4       4.29s    49
# 3         64   63.3ms 175.54ms 248.94ms     5.85       8.54s    50
# 4        128   83.2ms 341.52ms 496.24ms     3.00      16.69s    50
# 5        256  232.8ms 653.21ms 847.44ms     1.58      31.66s    50
# 6        512  784.6ms    1.41s    1.98s     0.740       1.1m    49
# 7       1024  681.7ms    2.72s    4.06s     0.377      2.16m    49

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Faʻailoa vave Doodle: faʻafefea ona faauo i le R, C ++ ma fesoʻotaʻiga neural

2. Saunia vaega

O le faagasologa atoa o sauniuniga o vaega e aofia ai laasaga nei:

  1. Fa'asalalau nisi o JSON o lo'o i ai vectors o manoa ma fa'amaopoopo o togi.
  2. Tusi laina lanu e faʻatatau i faʻamaopoopo o togi i luga o se ata o le tele manaʻomia (mo se faʻataʻitaʻiga, 256x256 poʻo le 128x128).
  3. Fa'aliliuina ata e maua i se tensor.

I le avea ai o se vaega o le tauvaga i fatu Python, o le faafitauli na foia muamua i le faʻaaogaina OpenCV. O se tasi o analogues sili ona faigofie ma sili ona manino i le R o le a pei o lenei:

Fa'atinoina o le JSON i le Tensor Conversion i le R

r_process_json_str <- function(json, line.width = 3, 
                               color = TRUE, scale = 1) {
  # Парсинг JSON
  coords <- jsonlite::fromJSON(json, simplifyMatrix = FALSE)
  tmp <- tempfile()
  # Удаляем временный файл по завершению функции
  on.exit(unlink(tmp))
  png(filename = tmp, width = 256 * scale, height = 256 * scale, pointsize = 1)
  # Пустой график
  plot.new()
  # Размер окна графика
  plot.window(xlim = c(256 * scale, 0), ylim = c(256 * scale, 0))
  # Цвета линий
  cols <- if (color) rainbow(length(coords)) else "#000000"
  for (i in seq_along(coords)) {
    lines(x = coords[[i]][[1]] * scale, y = coords[[i]][[2]] * scale, 
          col = cols[i], lwd = line.width)
  }
  dev.off()
  # Преобразование изображения в 3-х мерный массив
  res <- png::readPNG(tmp)
  return(res)
}

r_process_json_vector <- function(x, ...) {
  res <- lapply(x, r_process_json_str, ...)
  # Объединение 3-х мерных массивов картинок в 4-х мерный в тензор
  res <- do.call(abind::abind, c(res, along = 0))
  return(res)
}

O le ata e faia e faʻaaoga ai meafaigaluega masani R ma faʻapolopolo i se PNG le tumau o loʻo teuina i le RAM (i luga o Linux, o loʻo i totonu o le lisi o lisi le tumau R. /tmp, faʻapipiʻi i le RAM). Ona faitau lea o le faila lea o se laina tolu-dimensional ma numera e amata mai i le 0 i le 1. E taua lenei mea ona o se BMP sili atu masani o le a faitau i se laina mata ma lanu lanu hex.

Se'i o tatou su'e le i'uga:

zip_file <- file.path("data", "train_simplified.zip")
csv_file <- "cat.csv"
unzip(zip_file, files = csv_file, exdir = tempdir(), 
      junkpaths = TRUE, unzip = getOption("unzip"))
tmp_data <- data.table::fread(file.path(tempdir(), csv_file), sep = ",", 
                              select = "drawing", nrows = 10000)
arr <- r_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Faʻailoa vave Doodle: faʻafefea ona faauo i le R, C ++ ma fesoʻotaʻiga neural

O le vaega lava ia o le a faia e pei ona taua i lalo:

res <- r_process_json_vector(tmp_data[1:4, drawing], scale = 0.5)
str(res)
 # num [1:4, 1:128, 1:128, 1:3] 1 1 1 1 1 1 1 1 1 1 ...
 # - attr(*, "dimnames")=List of 4
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL

O lenei faʻatinoga na foliga mai e le sili ona lelei ia i matou, talu ai o le fausiaina o vaega tetele e umi se taimi, ma na matou filifili e faʻaoga lelei le poto masani a matou paaga e ala i le faʻaogaina o se faletusi mamana. OpenCV. I lena taimi e leai se afifi saunia mo R (e leai se mea i le taimi nei), o lea la o se faʻatinoga itiiti o galuega manaʻomia na tusia i le C ++ faʻatasi ai ma le tuʻufaʻatasia i le R code faʻaaoga. Rcpp.

Ina ia foia le faafitauli, o afifi ma faletusi na faʻaaogaina:

  1. OpenCV mo le galue i ata ma tusi laina. Fa'aaogaina faletusi fa'apipi'i muamua ma faila fa'aulu, fa'apea fo'i feso'ota'iga malosi.

  2. xtensor mo le galulue faatasi ma arrays multidimensional ma tensors. Na matou faʻaogaina faila faila o loʻo aofia i le R package o le igoa tutusa. O le faletusi e mafai ai ona e galue i le tele o vaega, e le gata i le laina tele ma le koluma fa'asologa.

  3. ndjson mo le fa'avasegaina o le JSON. O lenei faletusi o loʻo faʻaaogaina i xtensor otometi pe a iai i totonu o le poloketi.

  4. RcppThread mo le faʻatulagaina o le tele o filo faʻasologa o se vector mai le JSON. Fa'aaogā faila fa'aulu ua saunia e lenei afifi. Mai sili ona lauiloa RcppParallel O le afifi, faatasi ai ma isi mea, o loʻo i ai se masini faʻalavelave faʻaogaina.

E taua le maitauina o xtensor na avea ma se atua: i le faaopoopo atu i le mea moni o loʻo i ai le tele o faʻatinoga ma le maualuga o le faʻatinoga, o ana atinaʻe na foliga mai e tali mai ma taliina fesili vave ma auiliili. Faatasi ai ma la latou fesoasoani, na mafai ai ona faʻatinoina suiga o OpenCV matrices i le xtensor tensors, faʻapea foʻi ma se auala e tuʻufaʻatasia ai ata 3-dimensional tensors i se 4-dimensional tensor o le saʻo (le vaega lava ia).

Mea e a'oa'oina ai le Rcpp, xtensor ma le RcppThread

https://thecoatlessprofessor.com/programming/unofficial-rcpp-api-documentation

https://docs.opencv.org/4.0.1/d7/dbd/group__imgproc.html

https://xtensor.readthedocs.io/en/latest/

https://xtensor.readthedocs.io/en/latest/file_loading.html#loading-json-data-into-xtensor

https://cran.r-project.org/web/packages/RcppThread/vignettes/RcppThread-vignette.pdf

Ina ia tuʻufaʻatasia faila e faʻaogaina faila faila ma fesoʻotaʻiga malosi ma faletusi faʻapipiʻi i luga o le polokalama, matou te faʻaogaina le masini faʻapipiʻi faʻatinoina i totonu o le afifi. Rcpp. Ina ia otometi ona maua auala ma fuʻa, na matou faʻaaogaina se faʻaoga lauiloa Linux pkg-config.

Fa'atinoina o le Rcpp plugin mo le fa'aogaina o le OpenCV library

Rcpp::registerPlugin("opencv", function() {
  # Возможные названия пакета
  pkg_config_name <- c("opencv", "opencv4")
  # Бинарный файл утилиты pkg-config
  pkg_config_bin <- Sys.which("pkg-config")
  # Проврека наличия утилиты в системе
  checkmate::assert_file_exists(pkg_config_bin, access = "x")
  # Проверка наличия файла настроек OpenCV для pkg-config
  check <- sapply(pkg_config_name, 
                  function(pkg) system(paste(pkg_config_bin, pkg)))
  if (all(check != 0)) {
    stop("OpenCV config for the pkg-config not found", call. = FALSE)
  }

  pkg_config_name <- pkg_config_name[check == 0]
  list(env = list(
    PKG_CXXFLAGS = system(paste(pkg_config_bin, "--cflags", pkg_config_name), 
                          intern = TRUE),
    PKG_LIBS = system(paste(pkg_config_bin, "--libs", pkg_config_name), 
                      intern = TRUE)
  ))
})

O se taunuuga o le faʻaogaina o le plugin, o tau nei o le a suitulaga i le faagasologa o le tuʻufaʻatasia:

Rcpp:::.plugins$opencv()$env

# $PKG_CXXFLAGS
# [1] "-I/usr/include/opencv"
#
# $PKG_LIBS
# [1] "-lopencv_shape -lopencv_stitching -lopencv_superres -lopencv_videostab -lopencv_aruco -lopencv_bgsegm -lopencv_bioinspired -lopencv_ccalib -lopencv_datasets -lopencv_dpm -lopencv_face -lopencv_freetype -lopencv_fuzzy -lopencv_hdf -lopencv_line_descriptor -lopencv_optflow -lopencv_video -lopencv_plot -lopencv_reg -lopencv_saliency -lopencv_stereo -lopencv_structured_light -lopencv_phase_unwrapping -lopencv_rgbd -lopencv_viz -lopencv_surface_matching -lopencv_text -lopencv_ximgproc -lopencv_calib3d -lopencv_features2d -lopencv_flann -lopencv_xobjdetect -lopencv_objdetect -lopencv_ml -lopencv_xphoto -lopencv_highgui -lopencv_videoio -lopencv_imgcodecs -lopencv_photo -lopencv_imgproc -lopencv_core"

Ole tulafono ole fa'atinoga mo le fa'avasegaina o le JSON ma le fa'atupuina o se vaega mo le tu'uina atu ile fa'ata'ita'iga o lo'o tu'uina atu i lalo ole fa'aleaga. Muamua, faʻaopoopo se lisi o galuega faʻapitonuʻu e suʻe ai faila faila (manaʻomia mo ndjson):

Sys.setenv("PKG_CXXFLAGS" = paste0("-I", normalizePath(file.path("src"))))

Faʻatinoina o le JSON i le tensor liua i C++

// [[Rcpp::plugins(cpp14)]]
// [[Rcpp::plugins(opencv)]]
// [[Rcpp::depends(xtensor)]]
// [[Rcpp::depends(RcppThread)]]

#include <xtensor/xjson.hpp>
#include <xtensor/xadapt.hpp>
#include <xtensor/xview.hpp>
#include <xtensor-r/rtensor.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <Rcpp.h>
#include <RcppThread.h>

// Синонимы для типов
using RcppThread::parallelFor;
using json = nlohmann::json;
using points = xt::xtensor<double,2>;     // Извлечённые из JSON координаты точек
using strokes = std::vector<points>;      // Извлечённые из JSON координаты точек
using xtensor3d = xt::xtensor<double, 3>; // Тензор для хранения матрицы изоображения
using xtensor4d = xt::xtensor<double, 4>; // Тензор для хранения множества изображений
using rtensor3d = xt::rtensor<double, 3>; // Обёртка для экспорта в R
using rtensor4d = xt::rtensor<double, 4>; // Обёртка для экспорта в R

// Статические константы
// Размер изображения в пикселях
const static int SIZE = 256;
// Тип линии
// См. https://en.wikipedia.org/wiki/Pixel_connectivity#2-dimensional
const static int LINE_TYPE = cv::LINE_4;
// Толщина линии в пикселях
const static int LINE_WIDTH = 3;
// Алгоритм ресайза
// https://docs.opencv.org/3.1.0/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
const static int RESIZE_TYPE = cv::INTER_LINEAR;

// Шаблон для конвертирования OpenCV-матрицы в тензор
template <typename T, int NCH, typename XT=xt::xtensor<T,3,xt::layout_type::column_major>>
XT to_xt(const cv::Mat_<cv::Vec<T, NCH>>& src) {
  // Размерность целевого тензора
  std::vector<int> shape = {src.rows, src.cols, NCH};
  // Общее количество элементов в массиве
  size_t size = src.total() * NCH;
  // Преобразование cv::Mat в xt::xtensor
  XT res = xt::adapt((T*) src.data, size, xt::no_ownership(), shape);
  return res;
}

// Преобразование JSON в список координат точек
strokes parse_json(const std::string& x) {
  auto j = json::parse(x);
  // Результат парсинга должен быть массивом
  if (!j.is_array()) {
    throw std::runtime_error("'x' must be JSON array.");
  }
  strokes res;
  res.reserve(j.size());
  for (const auto& a: j) {
    // Каждый элемент массива должен быть 2-мерным массивом
    if (!a.is_array() || a.size() != 2) {
      throw std::runtime_error("'x' must include only 2d arrays.");
    }
    // Извлечение вектора точек
    auto p = a.get<points>();
    res.push_back(p);
  }
  return res;
}

// Отрисовка линий
// Цвета HSV
cv::Mat ocv_draw_lines(const strokes& x, bool color = true) {
  // Исходный тип матрицы
  auto stype = color ? CV_8UC3 : CV_8UC1;
  // Итоговый тип матрицы
  auto dtype = color ? CV_32FC3 : CV_32FC1;
  auto bg = color ? cv::Scalar(0, 0, 255) : cv::Scalar(255);
  auto col = color ? cv::Scalar(0, 255, 220) : cv::Scalar(0);
  cv::Mat img = cv::Mat(SIZE, SIZE, stype, bg);
  // Количество линий
  size_t n = x.size();
  for (const auto& s: x) {
    // Количество точек в линии
    size_t n_points = s.shape()[1];
    for (size_t i = 0; i < n_points - 1; ++i) {
      // Точка начала штриха
      cv::Point from(s(0, i), s(1, i));
      // Точка окончания штриха
      cv::Point to(s(0, i + 1), s(1, i + 1));
      // Отрисовка линии
      cv::line(img, from, to, col, LINE_WIDTH, LINE_TYPE);
    }
    if (color) {
      // Меняем цвет линии
      col[0] += 180 / n;
    }
  }
  if (color) {
    // Меняем цветовое представление на RGB
    cv::cvtColor(img, img, cv::COLOR_HSV2RGB);
  }
  // Меняем формат представления на float32 с диапазоном [0, 1]
  img.convertTo(img, dtype, 1 / 255.0);
  return img;
}

// Обработка JSON и получение тензора с данными изображения
xtensor3d process(const std::string& x, double scale = 1.0, bool color = true) {
  auto p = parse_json(x);
  auto img = ocv_draw_lines(p, color);
  if (scale != 1) {
    cv::Mat out;
    cv::resize(img, out, cv::Size(), scale, scale, RESIZE_TYPE);
    cv::swap(img, out);
    out.release();
  }
  xtensor3d arr = color ? to_xt<double,3>(img) : to_xt<double,1>(img);
  return arr;
}

// [[Rcpp::export]]
rtensor3d cpp_process_json_str(const std::string& x, 
                               double scale = 1.0, 
                               bool color = true) {
  xtensor3d res = process(x, scale, color);
  return res;
}

// [[Rcpp::export]]
rtensor4d cpp_process_json_vector(const std::vector<std::string>& x, 
                                  double scale = 1.0, 
                                  bool color = false) {
  size_t n = x.size();
  size_t dim = floor(SIZE * scale);
  size_t channels = color ? 3 : 1;
  xtensor4d res({n, dim, dim, channels});
  parallelFor(0, n, [&x, &res, scale, color](int i) {
    xtensor3d tmp = process(x[i], scale, color);
    auto view = xt::view(res, i, xt::all(), xt::all(), xt::all());
    view = tmp;
  });
  return res;
}

O lenei code e tatau ona tuʻu i totonu o le faila src/cv_xt.cpp ma tuufaatasia ma le poloaiga Rcpp::sourceCpp(file = "src/cv_xt.cpp", env = .GlobalEnv); e manaʻomia foʻi mo galuega nlohmann/json.hpp mai faleteuoloa. Ua vaevaeina le code i ni nai galuega:

  • to_xt - o se galuega faʻataʻitaʻi mo le suia o se ata matrix (cv::Mat) i se tensor xt::xtensor;

  • parse_json - o le galuega e faʻapipiʻi se manoa JSON, faʻapipiʻi faʻamaopoopo o togi, faʻapipiʻi i totonu o se vector;

  • ocv_draw_lines - mai le fua o le vector o togi, tusi laina lanu-lanu;

  • process - tuʻufaʻatasia galuega o loʻo i luga ma faʻaopoopoina foi le gafatia e fua ai le ata e maua;

  • cpp_process_json_str - afifi i luga o le galuega process, lea e auina atu i fafo le taunuuga i se R-meafaitino (faasologa tele);

  • cpp_process_json_vector - afifi i luga o le galuega cpp_process_json_str, lea e fa'atagaina ai oe e fa'agasolo se ve'a manoa i le fa'asologa o filo fa'atele.

Ina ia tusia laina lanu-lanu, na faʻaaogaina le ata lanu HSV, sosoo ai ma le liua i le RGB. Se'i o tatou su'e le i'uga:

arr <- cpp_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Faʻailoa vave Doodle: faʻafefea ona faauo i le R, C ++ ma fesoʻotaʻiga neural
Faʻatusatusaga o le saoasaoa o faʻatinoga i R ma C ++

res_bench <- bench::mark(
  r_process_json_str(tmp_data[4, drawing], scale = 0.5),
  cpp_process_json_str(tmp_data[4, drawing], scale = 0.5),
  check = FALSE,
  min_iterations = 100
)
# Параметры бенчмарка
cols <- c("expression", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   expression                min     median       max `itr/sec` total_time  n_itr
#   <chr>                <bch:tm>   <bch:tm>  <bch:tm>     <dbl>   <bch:tm>  <int>
# 1 r_process_json_str     3.49ms     3.55ms    4.47ms      273.      490ms    134
# 2 cpp_process_json_str   1.94ms     2.02ms    5.32ms      489.      497ms    243

library(ggplot2)
# Проведение замера
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    .data <- tmp_data[sample(seq_len(.N), batch_size), drawing]
    bench::mark(
      r_process_json_vector(.data, scale = 0.5),
      cpp_process_json_vector(.data,  scale = 0.5),
      min_iterations = 50,
      check = FALSE
    )
  }
)

res_bench[, cols]

#    expression   batch_size      min   median      max `itr/sec` total_time n_itr
#    <chr>             <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
#  1 r                   16   50.61ms  53.34ms  54.82ms    19.1     471.13ms     9
#  2 cpp                 16    4.46ms   5.39ms   7.78ms   192.      474.09ms    91
#  3 r                   32   105.7ms 109.74ms 212.26ms     7.69        6.5s    50
#  4 cpp                 32    7.76ms  10.97ms  15.23ms    95.6     522.78ms    50
#  5 r                   64  211.41ms 226.18ms 332.65ms     3.85      12.99s    50
#  6 cpp                 64   25.09ms  27.34ms  32.04ms    36.0        1.39s    50
#  7 r                  128   534.5ms 627.92ms 659.08ms     1.61      31.03s    50
#  8 cpp                128   56.37ms  58.46ms  66.03ms    16.9        2.95s    50
#  9 r                  256     1.15s    1.18s    1.29s     0.851     58.78s    50
# 10 cpp                256  114.97ms 117.39ms 130.09ms     8.45       5.92s    50
# 11 r                  512     2.09s    2.15s    2.32s     0.463       1.8m    50
# 12 cpp                512  230.81ms  235.6ms 261.99ms     4.18      11.97s    50
# 13 r                 1024        4s    4.22s     4.4s     0.238       3.5m    50
# 14 cpp               1024  410.48ms 431.43ms 462.44ms     2.33      21.45s    50

ggplot(res_bench, aes(x = factor(batch_size), y = median, 
                      group =  expression, color = expression)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal() +
  scale_color_discrete(name = "", labels = c("cpp", "r")) +
  theme(legend.position = "bottom") 

Faʻailoa vave Doodle: faʻafefea ona faauo i le R, C ++ ma fesoʻotaʻiga neural

E pei ona mafai ona e vaʻaia, o le saoasaoa o le faʻatupulaia na matua taua tele, ma e le mafai ona maua le C ++ code e ala i le faʻatusatusaina o le R code.

3. Iterators mo le la'uina o vaega mai le database

R o loʻo i ai se igoa taʻutaʻua mo le faʻaogaina o faʻamaumauga e fetaui ma le RAM, aʻo le Python e sili atu ona faʻaalia i le faʻaogaina o faʻamatalaga faʻamatalaga, e mafai ai ona e faigofie ma faʻapitoa le faʻaogaina o faʻatusatusaga (faʻatatau e faʻaaoga ai mafaufauga i fafo). O se faʻataʻitaʻiga masani ma talafeagai mo i tatou i le tulaga o le faʻafitauli o loʻo faʻamatalaina o fesoʻotaʻiga neural loloto na aʻoaʻoina e le auala faʻasolosolo faʻasolosolo ma le faʻatusatusaina o le faʻasolosolo i laasaga taʻitasi e faʻaaoga ai se vaega itiiti o faʻamatalaga, poʻo le mini-batch.

O faʻavae aʻoaʻoga loloto o loʻo tusia i le Python ei ai vasega faʻapitoa e faʻatino ai le faʻataʻitaʻiga e faʻavae i luga o faʻamaumauga: laulau, ata i totonu o faila, faʻasologa binary, ma isi. I le R e mafai ona tatou faʻaogaina uma foliga o le faletusi Python faigata faʻatasi ai ma ona pito i tua eseese e faʻaaoga ai le afifi o le igoa lava e tasi, lea e galue i luga o le afifi toe faʻamatala. O le vaega mulimuli e tatau ona i ai se isi tusitusiga umi; e le gata ina faʻatagaina oe e taʻavale le code Python mai le R, ae faʻatagaina foi oe e faʻafeiloaʻi mea i le va o R ma Python sauniga, faʻapipiʻi otometi uma ituaiga suiga talafeagai.

Na matou faʻaumatia le manaʻoga e teu uma faʻamatalaga i le RAM e ala i le faʻaaogaina o le MonetDBite, o galuega uma "neural network" o le a faia e le uluai code i le Python, e tatau lava ona matou tusia se faʻamatalaga i luga o faʻamaumauga, talu ai e leai se mea ua saunia. mo sea tulaga i le R poʻo le Python. E na'o le lua lava mana'oga mo lea mea: e tatau ona toe fa'afo'i vaega i se matasele e le gata ma fa'asaoina lona tulaga i le va o fa'asologa (o le mea mulimuli i le R o lo'o fa'atinoina i se auala sili ona faigofie e fa'aoga tapuni). I le taimi muamua, sa manaʻomia le faʻaliliu manino o R arrays i numpy arrays i totonu o le tagata faʻataʻitaʻi, ae o le taimi nei o le afifi. faigata faia e ia lava.

O le suʻesuʻega mo aʻoaʻoga ma faʻamaumauga faʻamaonia na faʻaalia e faapea:

Iterator mo aʻoaʻoga ma faʻamaumauga faʻamaonia

train_generator <- function(db_connection = con,
                            samples_index,
                            num_classes = 340,
                            batch_size = 32,
                            scale = 1,
                            color = FALSE,
                            imagenet_preproc = FALSE) {
  # Проверка аргументов
  checkmate::assert_class(con, "DBIConnection")
  checkmate::assert_integerish(samples_index)
  checkmate::assert_count(num_classes)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # Перемешиваем, чтобы брать и удалять использованные индексы батчей по порядку
  dt <- data.table::data.table(id = sample(samples_index))
  # Проставляем номера батчей
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  # Оставляем только полные батчи и индексируем
  dt <- dt[, if (.N == batch_size) .SD, keyby = batch]
  # Устанавливаем счётчик
  i <- 1
  # Количество батчей
  max_i <- dt[, max(batch)]

  # Подготовка выражения для выгрузки
  sql <- sprintf(
    "PREPARE SELECT drawing, label_int FROM doodles WHERE id IN (%s)",
    paste(rep("?", batch_size), collapse = ",")
  )
  res <- DBI::dbSendQuery(con, sql)

  # Аналог keras::to_categorical
  to_categorical <- function(x, num) {
    n <- length(x)
    m <- numeric(n * num)
    m[x * n + seq_len(n)] <- 1
    dim(m) <- c(n, num)
    return(m)
  }

  # Замыкание
  function() {
    # Начинаем новую эпоху
    if (i > max_i) {
      dt[, id := sample(id)]
      data.table::setkey(dt, batch)
      # Сбрасываем счётчик
      i <<- 1
      max_i <<- dt[, max(batch)]
    }

    # ID для выгрузки данных
    batch_ind <- dt[batch == i, id]
    # Выгрузка данных
    batch <- DBI::dbFetch(DBI::dbBind(res, as.list(batch_ind)), n = -1)

    # Увеличиваем счётчик
    i <<- i + 1

    # Парсинг JSON и подготовка массива
    batch_x <- cpp_process_json_vector(batch$drawing, scale = scale, color = color)
    if (imagenet_preproc) {
      # Шкалирование c интервала [0, 1] на интервал [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }

    batch_y <- to_categorical(batch$label_int, num_classes)
    result <- list(batch_x, batch_y)
    return(result)
  }
}

O le galuega e ave e fai ma faʻaoga se fesuiaiga ma se fesoʻotaʻiga i le database, numera o laina faʻaaogaina, numera o vasega, lapoʻa lapoa, fua (scale = 1 e fetaui ma le fa'aliliuina o ata o le 256x256 pixels, scale = 0.5 — 128x128 pika), fa'ailoga lanu (color = FALSE fa'amaoti le fa'aliliuina i le fa'aefuefu pe'ā fa'aoga color = TRUE ta'i ta'itasi e tosoina i se lanu fou) ma se fa'ailoga muamua mo feso'ota'iga na mua'i a'oa'oina ile imagenet. O le mea mulimuli e manaʻomia ina ia fuaina ai tau pika mai le vaeluaga [0, 1] i le vaeluaga [-1, 1], lea na faʻaaogaina pe a aʻoaʻoina le tuʻuina atu. faigata faʻataʻitaʻiga.

O le galuega i fafo o loʻo i ai le siakiina o ituaiga finauga, se laulau data.table fa'atasi ai ma numera laina fa'afefiloi mai samples_index ma numera o faʻaputuga, faʻataʻitaʻiga ma numera maualuga o vaega, faʻapea foʻi ma se faʻamatalaga SQL mo le laʻuina o faʻamaumauga mai le faʻamaumauga. E le gata i lea, na matou faʻamatalaina se analogue vave o le galuega i totonu keras::to_categorical(). Na matou faʻaaogaina toetoe lava o faʻamatalaga uma mo aʻoaʻoga, ma tuʻu ai le afa pasene mo le faʻamaonia, o lea na faʻatapulaʻaina ai le tele o taimi e le parakalafa. steps_per_epoch pe a valaauina keras::fit_generator(), ma le tulaga if (i > max_i) na'o le galue mo le fa'amaonia le su'esu'e.

I totonu o le galuega i totonu, e toe maua mai laina fa'asino igoa mo le isi vaega, fa'amaumauga e la'u mai le fa'amaumauga ma le fa'aputuga fa'aopoopo fa'aopoopo, JSON parsing (galuega cpp_process_json_vector(), tusia i le C++) ma le fatuina o faʻasologa e fetaui ma ata. Ona faia lea o vete vevela tasi ma igoa o vasega, faʻapipiʻi faʻatasi ma tau pika ma faʻailoga e tuʻufaʻatasia i se lisi, o le tau toe foʻi mai. Ina ia faʻavavevave galuega, matou faʻaaogaina le fausiaina o faʻasino igoa i laulau data.table ma suiga e ala i le so'otaga - e aunoa ma nei afifi "chips" fa'amaumauga. laulau E faigata tele ona mafaufau i le galue lelei ma soʻo se aofaiga tele o faʻamaumauga i le R.

O faʻaiʻuga o fua faʻavave i luga o le komepiuta Core i5 e faʻapea:

Toerator benchmark

library(Rcpp)
library(keras)
library(ggplot2)

source("utils/rcpp.R")
source("utils/keras_iterator.R")

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

ind <- seq_len(DBI::dbGetQuery(con, "SELECT count(*) FROM doodles")[[1L]])
num_classes <- DBI::dbGetQuery(con, "SELECT max(label_int) + 1 FROM doodles")[[1L]]

# Индексы для обучающей выборки
train_ind <- sample(ind, floor(length(ind) * 0.995))
# Индексы для проверочной выборки
val_ind <- ind[-train_ind]
rm(ind)
# Коэффициент масштаба
scale <- 0.5

# Проведение замера
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    it1 <- train_generator(
      db_connection = con,
      samples_index = train_ind,
      num_classes = num_classes,
      batch_size = batch_size,
      scale = scale
    )
    bench::mark(
      it1(),
      min_iterations = 50L
    )
  }
)
# Параметры бенчмарка
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16     25ms  64.36ms   92.2ms     15.9       3.09s    49
# 2         32   48.4ms 118.13ms 197.24ms     8.17       5.88s    48
# 3         64   69.3ms 117.93ms 181.14ms     8.57       5.83s    50
# 4        128  157.2ms 240.74ms 503.87ms     3.85      12.71s    49
# 5        256  359.3ms 613.52ms 988.73ms     1.54       30.5s    47
# 6        512  884.7ms    1.53s    2.07s     0.674      1.11m    45
# 7       1024     2.7s    3.83s    5.47s     0.261      2.81m    44

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
    geom_point() +
    geom_line() +
    ylab("median time, s") +
    theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Faʻailoa vave Doodle: faʻafefea ona faauo i le R, C ++ ma fesoʻotaʻiga neural

Afai e lava le aofaʻi o le RAM, e mafai ona e faʻavavevaveina le faʻaogaina o le database e ala i le tuʻuina atu i lenei lava RAM (32 GB e lava mo la matou galuega). I Linux, o le vaeluaga o loʻo faʻapipiʻiina e ala i le faaletonu /dev/shm, nofoia e oo atu i le afa o le gafatia o le RAM. E mafai ona e fa'ailoga atili e ala i le fa'asa'oina /etc/fstabia maua se faamaumauga e pei o tmpfs /dev/shm tmpfs defaults,size=25g 0 0. Ia mautinoa e toe faʻafou ma siaki le taunuʻuga e ala i le faʻatinoina o le poloaiga df -h.

O le suʻesuʻe mo faʻamatalaga suʻega e foliga sili atu ona faigofie, talu ai o faʻamaumauga o suʻega e fetaui lelei i le RAM:

Fa'amatalaga mo fa'amatalaga su'ega

test_generator <- function(dt,
                           batch_size = 32,
                           scale = 1,
                           color = FALSE,
                           imagenet_preproc = FALSE) {

  # Проверка аргументов
  checkmate::assert_data_table(dt)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # Проставляем номера батчей
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  data.table::setkey(dt, batch)
  i <- 1
  max_i <- dt[, max(batch)]

  # Замыкание
  function() {
    batch_x <- cpp_process_json_vector(dt[batch == i, drawing], 
                                       scale = scale, color = color)
    if (imagenet_preproc) {
      # Шкалирование c интервала [0, 1] на интервал [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }
    result <- list(batch_x)
    i <<- i + 1
    return(result)
  }
}

4. Filifiliga o faʻataʻitaʻiga faʻataʻitaʻiga

O le fausaga muamua na faʻaaogaina mobilenet v1, o uiga ia o loʻo talanoaina i lenei savali. O loʻo aofia ai e pei o tulaga masani faigata ma, e tusa ai, o loʻo maua i totonu o le afifi o le igoa lava e tasi mo R. Ae a taumafai e faʻaaogaina i ata e tasi-auala, o se mea uiga ese na aliali mai: o le tensor input e tatau lava ona i ai le fua. (batch, height, width, 3), o lona uiga, e le mafai ona suia le numera o alalaupapa. E leai se faʻatapulaʻaina i le Python, o lea na matou faanatinati ai ma tusia a matou lava faʻatinoga o lenei fausaga, mulimuli i le uluaʻi tusiga (e aunoa ma le pa'ū o loʻo i totonu o le keras version):

Mobilenet v1 fausaga

library(keras)

top_3_categorical_accuracy <- custom_metric(
    name = "top_3_categorical_accuracy",
    metric_fn = function(y_true, y_pred) {
         metric_top_k_categorical_accuracy(y_true, y_pred, k = 3)
    }
)

layer_sep_conv_bn <- function(object, 
                              filters,
                              alpha = 1,
                              depth_multiplier = 1,
                              strides = c(2, 2)) {

  # NB! depth_multiplier !=  resolution multiplier
  # https://github.com/keras-team/keras/issues/10349

  layer_depthwise_conv_2d(
    object = object,
    kernel_size = c(3, 3), 
    strides = strides,
    padding = "same",
    depth_multiplier = depth_multiplier
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() %>%
  layer_conv_2d(
    filters = filters * alpha,
    kernel_size = c(1, 1), 
    strides = c(1, 1)
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() 
}

get_mobilenet_v1 <- function(input_shape = c(224, 224, 1),
                             num_classes = 340,
                             alpha = 1,
                             depth_multiplier = 1,
                             optimizer = optimizer_adam(lr = 0.002),
                             loss = "categorical_crossentropy",
                             metrics = c("categorical_crossentropy",
                                         top_3_categorical_accuracy)) {

  inputs <- layer_input(shape = input_shape)

  outputs <- inputs %>%
    layer_conv_2d(filters = 32, kernel_size = c(3, 3), strides = c(2, 2), padding = "same") %>%
    layer_batch_normalization() %>% 
    layer_activation_relu() %>%
    layer_sep_conv_bn(filters = 64, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(1, 1)) %>%
    layer_global_average_pooling_2d() %>%
    layer_dense(units = num_classes) %>%
    layer_activation_softmax()

    model <- keras_model(
      inputs = inputs,
      outputs = outputs
    )

    model %>% compile(
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )

    return(model)
}

O le le lelei o lenei auala e iloagofie. Ou te manaʻo e faʻataʻitaʻi le tele o faʻataʻitaʻiga, ae i se isi itu, ou te le manaʻo e toe tusi tusi taʻitasi ma le lima. Na le maua foi le avanoa e faʻaaoga ai le mamafa o faʻataʻitaʻiga na aʻoaʻoina muamua ile imagenet. E pei ona masani ai, o le suʻesuʻeina o faʻamaumauga na fesoasoani. Galuega get_config() fa'atagaina oe e maua se fa'amatalaga o le fa'ata'ita'iga i se fomu e talafeagai mo le fa'asa'oina (base_model_conf$layers - se lisi masani R), ma le galuega from_config() fa'atino le suiga i tua i se mea fa'ata'ita'i:

base_model_conf <- get_config(base_model)
base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
base_model <- from_config(base_model_conf)

Ole taimi nei e le faigata ona tusia se galuega lautele e maua ai soʻo se mea e tuʻuina atu faigata fa'ata'ita'iga o lo'o i ai pe leai ni mamafa ua a'oa'oina ile imagenet:

Galuega mo le utaina o fausaga ua saunia

get_model <- function(name = "mobilenet_v2",
                      input_shape = NULL,
                      weights = "imagenet",
                      pooling = "avg",
                      num_classes = NULL,
                      optimizer = keras::optimizer_adam(lr = 0.002),
                      loss = "categorical_crossentropy",
                      metrics = NULL,
                      color = TRUE,
                      compile = FALSE) {
  # Проверка аргументов
  checkmate::assert_string(name)
  checkmate::assert_integerish(input_shape, lower = 1, upper = 256, len = 3)
  checkmate::assert_count(num_classes)
  checkmate::assert_flag(color)
  checkmate::assert_flag(compile)

  # Получаем объект из пакета keras
  model_fun <- get0(paste0("application_", name), envir = asNamespace("keras"))
  # Проверка наличия объекта в пакете
  if (is.null(model_fun)) {
    stop("Model ", shQuote(name), " not found.", call. = FALSE)
  }

  base_model <- model_fun(
    input_shape = input_shape,
    include_top = FALSE,
    weights = weights,
    pooling = pooling
  )

  # Если изображение не цветное, меняем размерность входа
  if (!color) {
    base_model_conf <- keras::get_config(base_model)
    base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
    base_model <- keras::from_config(base_model_conf)
  }

  predictions <- keras::get_layer(base_model, "global_average_pooling2d_1")$output
  predictions <- keras::layer_dense(predictions, units = num_classes, activation = "softmax")
  model <- keras::keras_model(
    inputs = base_model$input,
    outputs = predictions
  )

  if (compile) {
    keras::compile(
      object = model,
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )
  }

  return(model)
}

A fa'aogaina ata ala ta'itasi, e le'o fa'aaogaina ni fua fa'aa'oa'oina. E mafai ona faʻaleleia lenei mea: faʻaaoga le galuega get_weights() maua le mamafa faʻataʻitaʻiga i le tulaga o se lisi o R arrays, sui le fua o le elemene muamua o lenei lisi (e ala i le ave o le tasi lanu lanu poʻo le averesi o mea uma e tolu), ona toe utaina lea o le mamafa i totonu o le faʻataʻitaʻiga ma le galuega. set_weights(). Matou te leʻi faʻaopoopoina lenei faʻatinoga, aua i lenei laʻasaga ua uma ona manino e sili atu ona aoga le galue ma ata lanu.

Na matou faia le tele o faʻataʻitaʻiga e faʻaaoga ai le mobilenet versions 1 ma le 2, faʻapea foʻi ma le resnet34. O fale fa'aonaponei fa'aonaponei e pei o le SE-ResNeXt na fa'atino lelei i lenei tauvaga. Ae paga lea, e leʻi iai ni a matou faʻatinoga ua saunia, ma matou te leʻi tusia a matou lava mea (ae matou te mautinoa lava e tusi).

5. Parameterization o tusitusiga

Mo le faʻaogaina, o tulafono uma mo le amataina o aʻoaʻoga na mamanuina e avea o se tusitusiga e tasi, faʻaogaina le faʻaogaina tusipasi e pei ona taʻua i lalo:

doc <- '
Usage:
  train_nn.R --help
  train_nn.R --list-models
  train_nn.R [options]

Options:
  -h --help                   Show this message.
  -l --list-models            List available models.
  -m --model=<model>          Neural network model name [default: mobilenet_v2].
  -b --batch-size=<size>      Batch size [default: 32].
  -s --scale-factor=<ratio>   Scale factor [default: 0.5].
  -c --color                  Use color lines [default: FALSE].
  -d --db-dir=<path>          Path to database directory [default: Sys.getenv("db_dir")].
  -r --validate-ratio=<ratio> Validate sample ratio [default: 0.995].
  -n --n-gpu=<number>         Number of GPUs [default: 1].
'
args <- docopt::docopt(doc)

Faʻapipiʻi tusipasi e fai ma sui o le faatinoga http://docopt.org/ mo R. Faatasi ai ma lana fesoasoani, o loʻo faʻalauiloaina faʻamaumauga i ni faʻatonuga faigofie e pei o Rscript bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db poʻo ./bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db, pe afai faila train_nn.R e mafai ona faʻatinoina (o lenei poloaiga o le a amata aʻoaʻoina le faʻataʻitaʻiga resnet50 i luga o ata e tolu-lanu e fua 128x128 pixels, e tatau ona tuʻu le faʻamaumauga i totonu o le pusa /home/andrey/doodle_db). E mafai ona e fa'aopoopoina le saoasaoa o le a'oa'oina, optimizer type, ma so'o se isi fa'ailoga fa'apitoa i le lisi. I le faagasologa o le saunia o le lomiga, na foliga mai o le tusiata fale mobilenet_v2 mai le lomiga o iai nei faigata i le R faʻaaogaina le mafai ona o suiga e le o amanaia i le R package, matou te faʻatali mo latou e toe faʻaleleia.

O lenei faiga na mafai ai ona faʻavavevave faʻataʻitaʻiga ma faʻataʻitaʻiga eseese pe a faʻatusatusa i le faʻalauiloaina masani o tusitusiga ile RStudio (matou te matauina le afifi o se isi mea e mafai. tfruns). Ae o le aoga sili o le mafai lea ona faigofie ona faʻatautaia le faʻalauiloaina o tusitusiga i Docker pe naʻo luga ole server, e aunoa ma le faʻapipiʻiina o le RStudio mo lenei mea.

6. Fa'amauina o tusitusiga

Na matou faʻaaogaina Docker e faʻamautinoa ai le feaveaʻi o le siosiomaga mo faʻataʻitaʻiga aʻoaʻoga i le va o tagata o le 'au ma mo le faʻapipiʻiina vave i le ao. E mafai ona e amata faamasani i lenei meafaigaluega, lea e le masani ai mo se R programmer, ma lenei faasologa o lomiga po o vasega vitio.

Docker e faʻatagaina oe e faia uma au lava ata mai le sasa ma faʻaoga isi ata e fai ma faʻavae mo le fatuina o oe lava. Pe a suʻesuʻeina avanoa avanoa, na matou oʻo mai i le faaiuga o le faʻapipiʻiina o le NVIDIA, CUDA + cuDNN avetaʻavale ma faletusi Python o se vaega tele o le ata, ma na matou filifili e ave le ata aloaia e fai ma faavae. tensorflow/tensorflow:1.12.0-gpu, faʻaopoopo i ai mea manaʻomia R packages.

O le faila faila mulimuli e pei o lenei:

faila faila

FROM tensorflow/tensorflow:1.12.0-gpu

MAINTAINER Artem Klevtsov <[email protected]>

SHELL ["/bin/bash", "-c"]

ARG LOCALE="en_US.UTF-8"
ARG APT_PKG="libopencv-dev r-base r-base-dev littler"
ARG R_BIN_PKG="futile.logger checkmate data.table rcpp rapidjsonr dbi keras jsonlite curl digest remotes"
ARG R_SRC_PKG="xtensor RcppThread docopt MonetDBLite"
ARG PY_PIP_PKG="keras"
ARG DIRS="/db /app /app/data /app/models /app/logs"

RUN source /etc/os-release && 
    echo "deb https://cloud.r-project.org/bin/linux/ubuntu ${UBUNTU_CODENAME}-cran35/" > /etc/apt/sources.list.d/cran35.list && 
    apt-key adv --keyserver keyserver.ubuntu.com --recv-keys E084DAB9 && 
    add-apt-repository -y ppa:marutter/c2d4u3.5 && 
    add-apt-repository -y ppa:timsc/opencv-3.4 && 
    apt-get update && 
    apt-get install -y locales && 
    locale-gen ${LOCALE} && 
    apt-get install -y --no-install-recommends ${APT_PKG} && 
    ln -s /usr/lib/R/site-library/littler/examples/install.r /usr/local/bin/install.r && 
    ln -s /usr/lib/R/site-library/littler/examples/install2.r /usr/local/bin/install2.r && 
    ln -s /usr/lib/R/site-library/littler/examples/installGithub.r /usr/local/bin/installGithub.r && 
    echo 'options(Ncpus = parallel::detectCores())' >> /etc/R/Rprofile.site && 
    echo 'options(repos = c(CRAN = "https://cloud.r-project.org"))' >> /etc/R/Rprofile.site && 
    apt-get install -y $(printf "r-cran-%s " ${R_BIN_PKG}) && 
    install.r ${R_SRC_PKG} && 
    pip install ${PY_PIP_PKG} && 
    mkdir -p ${DIRS} && 
    chmod 777 ${DIRS} && 
    rm -rf /tmp/downloaded_packages/ /tmp/*.rds && 
    rm -rf /var/lib/apt/lists/*

COPY utils /app/utils
COPY src /app/src
COPY tests /app/tests
COPY bin/*.R /app/

ENV DBDIR="/db"
ENV CUDA_HOME="/usr/local/cuda"
ENV PATH="/app:${PATH}"

WORKDIR /app

VOLUME /db
VOLUME /app

CMD bash

Mo le faigofie, o afifi na faʻaaogaina na tuʻuina i ni fesuiaiga; o le tele o tusitusiga tusitusi e kopiina i totonu o pusa i le taimi o le faʻapotopotoga. Na matou suia foi le atigi poloaiga i /bin/bash mo le faigofie o le faʻaogaina o mea /etc/os-release. O lenei mea na aloese ai mai le manaʻoga e faʻamaonia le OS version i le code.

E le gata i lea, o se tamai tusi bash na tusia e mafai ai ona e faʻalauiloa se atigipusa ma ni tulafono eseese. Mo se faʻataʻitaʻiga, e mafai ona avea nei ma tusitusiga mo le aʻoaʻoina o fesoʻotaʻiga neural na tuʻuina muamua i totonu o le koneteina, poʻo se atigi faʻatonuga mo le faʻapipiʻiina ma le mataʻituina o le gaioiga o le pusa:

Fa'amatalaga e tatala ai le koneteina

#!/bin/sh

DBDIR=${PWD}/db
LOGSDIR=${PWD}/logs
MODELDIR=${PWD}/models
DATADIR=${PWD}/data
ARGS="--runtime=nvidia --rm -v ${DBDIR}:/db -v ${LOGSDIR}:/app/logs -v ${MODELDIR}:/app/models -v ${DATADIR}:/app/data"

if [ -z "$1" ]; then
    CMD="Rscript /app/train_nn.R"
elif [ "$1" = "bash" ]; then
    ARGS="${ARGS} -ti"
else
    CMD="Rscript /app/train_nn.R $@"
fi

docker run ${ARGS} doodles-tf ${CMD}

Afai o lenei bash script e faʻatautaia e aunoa ma ni faʻamaufaʻailoga, o le a valaʻau le tusitusiga i totonu o le pusa train_nn.R fa'atasi ai ma tau fa'aletonu; pe afai o le finauga muamua o le "bash", ona amata lea o le koneteina fegalegaleai ma se atigi poloaiga. I isi tulaga uma, o le tau o finauga tulaga e suitulaga: CMD="Rscript /app/train_nn.R $@".

E taua le maitauina o faʻamaumauga ma faʻamaumauga faʻamaumauga ma faʻamaumauga, faʻapea foʻi ma le lisi mo le faʻasaoina o faʻataʻitaʻiga aʻoaʻoina, o loʻo faʻapipiʻiina i totonu o le koneteina mai le polokalama talimalo, lea e mafai ai ona e mauaina iʻuga o tusitusiga e aunoa ma ni togafiti le manaʻomia.

7. Fa'aaogaina o le tele o GPU ile Google Cloud

O se tasi o vaega o le tauvaga o le pisapisao tele o faʻamatalaga (vaai i le ulutala ata, nono mai @Leigh.plt mai le ODS slack). O vaega tetele e fesoasoani e faʻafefe ai lenei mea, ma ina ua maeʻa faʻataʻitaʻiga i luga o se PC ma le 1 GPU, na matou filifili e faʻatautaia faʻataʻitaʻiga aʻoaʻoga i luga o le tele o GPU i le ao. Fa'aaoga GoogleCloud (taiala lelei i mea faavae) ona o le tele o filifiliga o fetuutuunaiga avanoa, tau talafeagai ma ponesi $300. Ona o le matapeʻapeʻa, na ou faʻatonuina se faʻataʻitaʻiga 4xV100 ma se SSD ma se tone o le RAM, ma o se mea sese tele. O sea masini e 'ai vave tupe; e mafai ona e alu fa'ata'ita'i e aunoa ma se paipa fa'amaonia. Mo faamoemoega faaleaoaoga, e sili atu le ave o le K80. Ae o le tele o le RAM na faʻaaogaina - o le ao SSD e leʻi faʻaalia i lana faʻatinoga, o lea na faʻafeiloaʻi ai le database i. dev/shm.

O le mea e sili ona fiafia i ai o le vaega o le code e nafa ma le faʻaaogaina o le tele o GPU. Muamua, o le faʻataʻitaʻiga e faia i luga o le PPU e faʻaaoga ai se pule o tala, pei o le Python:

with(tensorflow::tf$device("/cpu:0"), {
  model_cpu <- get_model(
    name = model_name,
    input_shape = input_shape,
    weights = weights,
    metrics =(top_3_categorical_accuracy,
    compile = FALSE
  )
})

Ona kopiina lea o le faʻataʻitaʻiga e leʻi tuʻufaʻatasia (e taua lenei) i se numera tuʻuina atu o GPU avanoa, ma naʻo le maeʻa ona tuʻufaʻatasia:

model <- keras::multi_gpu_model(model_cpu, gpus = n_gpu)
keras::compile(
  object = model,
  optimizer = keras::optimizer_adam(lr = 0.0004),
  loss = "categorical_crossentropy",
  metrics = c(top_3_categorical_accuracy)
)

O le auala masani o le faʻamaisaina o laulau uma sei vagana ai le mea mulimuli, aʻoaʻoina le vaega mulimuli, faʻamalo ma toe aʻoaʻoina le faʻataʻitaʻiga atoa mo le tele o GPU e leʻi mafai ona faʻatinoina.

Sa mataituina aoaoga e aunoa ma le faaaogaina. laupapa tensorboard, fa'atapula'aina i matou i le fa'amauina o ogalaau ma fa'asaoina fa'ata'ita'iga ma igoa fa'amatalaga pe a uma vaitau ta'itasi:

Toe valaau

# Шаблон имени файла лога
log_file_tmpl <- file.path("logs", sprintf(
  "%s_%d_%dch_%s.csv",
  model_name,
  dim_size,
  channels,
  format(Sys.time(), "%Y%m%d%H%M%OS")
))
# Шаблон имени файла модели
model_file_tmpl <- file.path("models", sprintf(
  "%s_%d_%dch_{epoch:02d}_{val_loss:.2f}.h5",
  model_name,
  dim_size,
  channels
))

callbacks_list <- list(
  keras::callback_csv_logger(
    filename = log_file_tmpl
  ),
  keras::callback_early_stopping(
    monitor = "val_loss",
    min_delta = 1e-4,
    patience = 8,
    verbose = 1,
    mode = "min"
  ),
  keras::callback_reduce_lr_on_plateau(
    monitor = "val_loss",
    factor = 0.5, # уменьшаем lr в 2 раза
    patience = 4,
    verbose = 1,
    min_delta = 1e-4,
    mode = "min"
  ),
  keras::callback_model_checkpoint(
    filepath = model_file_tmpl,
    monitor = "val_loss",
    save_best_only = FALSE,
    save_weights_only = FALSE,
    mode = "min"
  )
)

8. Nai lo o se faaiuga

O le tele o faʻafitauli na matou feagai e leʻi foia:

  • в faigata e leai se galuega ua saunia mo le otometi ona suʻeina le maualuga o le aʻoaʻoga (analogue lr_finder i le faletusi vave.ai); Faatasi ai ma ni taumafaiga, e mafai ona tuʻuina atu faʻatinoga lona tolu i le R, mo se faʻataʻitaʻiga, lenei;
  • ona o se taunuuga o le vaega muamua, e le mafai ona filifili le sao sao aʻoaʻoga saoasaoa pe a faʻaaogaina le tele o GPU;
  • o loʻo i ai le leai o ni fausaga faʻaonaponei neural network, aemaise lava i latou na muaʻi aʻoaʻoina ile imagenet;
  • e leai se faiga fa'ata'amilosaga e tasi ma fua fa'atatau o a'oa'oga (kosine annealing sa i la matou talosaga faatino, Faafetai skeydan).

O a mea aoga na aʻoaʻoina mai lenei tauvaga:

  • I luga o meafaigaluega maualalo-malosi, e mafai ona e galue ma le lelei (tele taimi le tele o le RAM) voluma o faʻamaumauga e aunoa ma se tiga. taga palasitika fa'amaumauga. laulau fa'asaoina le manatua ona o suiga i totonu o laulau, e 'alofia ai le kopiina, ma a fa'aoga sa'o, e toetoe lava a fa'aalia i taimi uma ona gafatia le saoasaoa maualuga i meafaigaluega uma ua tatou iloa mo gagana tusitusi. Faʻasaoina faʻamaumauga i totonu o se faʻamaumauga e mafai ai e oe, i le tele o tulaga, aua le mafaufau i le manaʻoga e faʻapipiʻi le faʻamaumauga atoa i le RAM.
  • O galuega fa'agesegese ile R e mafai ona sui i mea vave ile C ++ fa'aoga le afifi Rcpp. Afai e faaopoopo i le faaaogaina RcppThread poʻo RcppParallel, matou te maua faʻasalalauga faʻapipiʻi tele-filo, o lea e leai se manaʻoga e faʻatusatusa le code i le R level.
  • afifi Rcpp e mafai ona faʻaaogaina e aunoa ma le malamalama tele i le C ++, o loʻo faʻamatalaina le pito maualalo manaʻomia iinei. Fa'aulu faila mo le tele o faletusi manaia C pei xtensor o loʻo avanoa ile CRAN, o lona uiga, o loʻo faʻatulagaina se atinaʻe mo le faʻatinoina o galuega faatino e tuʻufaʻatasia ai le C ++ code maualuga i le R. Fa'aopoopo le fa'aogaina o le fa'ailoga fa'asologa ma se su'esu'ega fa'ailoga C++ i le RStudio.
  • tusipasi e fa'atagaina oe e fa'atautaia fa'amaumauga a le tagata lava ia ma fa'amaufa'ailoga. E faigofie lenei mea mo le faʻaogaina i luga o se server mamao, e aofia ai. i lalo o le fagafao. I RStudio, e le faigofie le faia o le tele o itula o faʻataʻitaʻiga ma aʻoaʻoga neural networks, ma le faʻapipiʻiina o le IDE i luga o le 'auʻaunaga lava ia e le faʻamaonia i taimi uma.
  • E faʻamautinoa e Docker le faʻaogaina o le code ma le toe faʻaleleia o taunuʻuga i le va o tagata atiaʻe ma ituaiga eseese o le OS ma faletusi, faʻapea foʻi ma le faigofie o le faʻatinoina i luga o sapalai. E mafai ona e fa'alauiloa le paipa a'oa'oga atoa ile na'o le tasi le fa'atonuga.
  • Google Cloud ose auala fa'atatau tupe e fa'ata'ita'i ai i masini taugata, ae e mana'omia lou filifilia ma le fa'aeteete.
  • O le fuaina o le saoasaoa o vaega taʻitasi code e aoga tele, aemaise lava pe a tuʻufaʻatasia R ma C ++, ma faʻatasi ai ma le afifi. nofoa - faigofie tele foi.

I le aotelega o lenei aafiaga sa matua tauia ma o loo faaauau pea ona matou galulue e foia nisi o mataupu na laga.

puna: www.habr.com

Faaopoopo i ai se faamatalaga