Dali nga Draw Doodle Recognition: kung giunsa ang pagpakighigala sa R, C ++ ug neural network

Dali nga Draw Doodle Recognition: kung giunsa ang pagpakighigala sa R, C ++ ug neural network

Hoy Habr!

Sa miaging tingdagdag, si Kaggle nag-host sa usa ka kompetisyon sa pagklasipikar sa mga hulagway nga gidrowing sa kamot, Quick Draw Doodle Recognition, diin, ug uban pa, usa ka grupo sa mga R-scientist ang miapil: Artem Klevtsova, Philippa Manager и Andrey Ogurtsov. Dili namon ihulagway ang kompetisyon sa detalye; nahimo na kana sa bag-o nga publikasyon.

Niining higayona wala kini molampos sa pagpanguma sa medalya, apan daghang bililhong kasinatian ang naangkon, mao nga gusto nakong isulti sa komunidad ang pipila sa labing makaiikag ug mapuslanong mga butang sa Kagle ug sa adlaw-adlaw nga trabaho. Lakip sa mga hilisgutan nga gihisgutan: lisud nga kinabuhi nga wala OpenCV, JSON parsing (kini nga mga pananglitan nagsusi sa paghiusa sa C++ code ngadto sa mga script o mga pakete sa R ​​gamit ang Rcpp), parameterization sa mga script ug dockerization sa katapusang solusyon. Ang tanan nga code gikan sa mensahe sa usa ka porma nga angay alang sa pagpatuman anaa sa mga tipiganan.

Mga Kaundan:

  1. Epektibo nga pagkarga sa datos gikan sa CSV ngadto sa MonetDB
  2. Pag-andam sa mga batch
  3. Mga iterator alang sa pagdiskarga sa mga batch gikan sa database
  4. Pagpili sa usa ka Modelong Arkitektura
  5. Parameterization sa script
  6. Dockerization sa mga script
  7. Paggamit sa daghang mga GPU sa Google Cloud
  8. Kay sa usa ka konklusyon

1. Epektibo nga pagkarga sa datos gikan sa CSV ngadto sa MonetDB database

Ang datos sa kini nga kompetisyon gihatag dili sa porma sa andam nga mga imahe, apan sa porma sa 340 nga mga file sa CSV (usa ka file alang sa matag klase) nga adunay mga JSON nga adunay mga koordinasyon sa punto. Pinaagi sa pagkonektar niini nga mga punto sa mga linya, makakuha kami usa ka katapusang imahe nga adunay sukod nga 256x256 pixels. Usab alang sa matag rekord adunay usa ka label nga nagpakita kung ang hulagway husto nga giila sa classifier nga gigamit sa panahon nga ang dataset nakolekta, usa ka duha ka letra nga code sa nasud nga pinuy-anan sa tagsulat sa hulagway, usa ka talagsaon nga identifier, usa ka timestamp ug usa ka ngalan sa klase nga mohaum sa ngalan sa file. Ang usa ka gipasimple nga bersyon sa orihinal nga datos adunay gibug-aton nga 7.4 GB sa archive ug gibana-bana nga 20 GB pagkahuman sa pag-unpack, ang tibuuk nga datos pagkahuman sa pag-unpack mokabat sa 240 GB. Gipaneguro sa mga tig-organisar nga ang duha ka bersyon nag-reproduce sa parehas nga mga drowing, nagpasabut nga ang tibuuk nga bersyon sobra. Sa bisan unsang kaso, ang pagtipig sa 50 milyon nga mga imahe sa mga graphic file o sa porma sa mga arrays giisip dayon nga dili mapuslanon, ug nakahukom kami nga i-merge ang tanan nga mga file sa CSV gikan sa archive. train_simplified.zip ngadto sa database nga adunay sunod nga henerasyon sa mga hulagway sa gikinahanglan nga gidak-on "on the fly" alang sa matag batch.

Usa ka maayo nga napamatud-an nga sistema ang gipili isip DBMS MonetDB, nga mao ang pagpatuman alang sa R ​​isip usa ka pakete MonetDBLite. Ang package naglakip sa usa ka embedded nga bersyon sa database server ug nagtugot kanimo sa pagkuha sa server direkta gikan sa usa ka R session ug pagtrabaho uban niini didto. Ang paghimo og database ug pagkonektar niini gihimo gamit ang usa ka sugo:

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

Kinahanglan namon nga maghimo duha ka mga lamesa: usa alang sa tanan nga datos, ang lain alang sa kasayuran sa serbisyo bahin sa na-download nga mga file (mapuslanon kung adunay sayup ug ang proseso kinahanglan ipadayon pagkahuman sa pag-download sa daghang mga file):

Paghimo og mga lamesa

if (!DBI::dbExistsTable(con, "doodles")) {
  DBI::dbCreateTable(
    con = con,
    name = "doodles",
    fields = c(
      "countrycode" = "char(2)",
      "drawing" = "text",
      "key_id" = "bigint",
      "recognized" = "bool",
      "timestamp" = "timestamp",
      "word" = "text"
    )
  )
}

if (!DBI::dbExistsTable(con, "upload_log")) {
  DBI::dbCreateTable(
    con = con,
    name = "upload_log",
    fields = c(
      "id" = "serial",
      "file_name" = "text UNIQUE",
      "uploaded" = "bool DEFAULT false"
    )
  )
}

Ang pinakapaspas nga paagi sa pagkarga sa datos ngadto sa database mao ang direktang pagkopya sa mga file sa CSV gamit ang SQL - command COPY OFFSET 2 INTO tablename FROM path USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORTdiin tablename - ngalan sa lamesa ug path - ang dalan sa file. Samtang nagtrabaho uban sa archive, nadiskobrehan nga ang built-in nga pagpatuman unzip sa R dili molihok sa husto sa daghang mga file gikan sa archive, mao nga gigamit namon ang sistema unzip (gamit ang parameter getOption("unzip")).

Function alang sa pagsulat sa database

#' @title Извлечение и загрузка файлов
#'
#' @description
#' Извлечение CSV-файлов из ZIP-архива и загрузка их в базу данных
#'
#' @param con Объект подключения к базе данных (класс `MonetDBEmbeddedConnection`).
#' @param tablename Название таблицы в базе данных.
#' @oaram zipfile Путь к ZIP-архиву.
#' @oaram filename Имя файла внури ZIP-архива.
#' @param preprocess Функция предобработки, которая будет применена извлечённому файлу.
#'   Должна принимать один аргумент `data` (объект `data.table`).
#'
#' @return `TRUE`.
#'
upload_file <- function(con, tablename, zipfile, filename, preprocess = NULL) {
  # Проверка аргументов
  checkmate::assert_class(con, "MonetDBEmbeddedConnection")
  checkmate::assert_string(tablename)
  checkmate::assert_string(filename)
  checkmate::assert_true(DBI::dbExistsTable(con, tablename))
  checkmate::assert_file_exists(zipfile, access = "r", extension = "zip")
  checkmate::assert_function(preprocess, args = c("data"), null.ok = TRUE)

  # Извлечение файла
  path <- file.path(tempdir(), filename)
  unzip(zipfile, files = filename, exdir = tempdir(), 
        junkpaths = TRUE, unzip = getOption("unzip"))
  on.exit(unlink(file.path(path)))

  # Применяем функция предобработки
  if (!is.null(preprocess)) {
    .data <- data.table::fread(file = path)
    .data <- preprocess(data = .data)
    data.table::fwrite(x = .data, file = path, append = FALSE)
    rm(.data)
  }

  # Запрос к БД на импорт CSV
  sql <- sprintf(
    "COPY OFFSET 2 INTO %s FROM '%s' USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORT",
    tablename, path
  )
  # Выполнение запроса к БД
  DBI::dbExecute(con, sql)

  # Добавление записи об успешной загрузке в служебную таблицу
  DBI::dbExecute(con, sprintf("INSERT INTO upload_log(file_name, uploaded) VALUES('%s', true)",
                              filename))

  return(invisible(TRUE))
}

Kung kinahanglan nimo nga usbon ang lamesa sa dili pa isulat kini sa database, igo na nga ipasa ang argumento preprocess function nga magbag-o sa datos.

Kodigo alang sa sunodsunod nga pagkarga sa datos ngadto sa database:

Pagsulat sa datos sa database

# Список файлов для записи
files <- unzip(zipfile, list = TRUE)$Name

# Список исключений, если часть файлов уже была загружена
to_skip <- DBI::dbGetQuery(con, "SELECT file_name FROM upload_log")[[1L]]
files <- setdiff(files, to_skip)

if (length(files) > 0L) {
  # Запускаем таймер
  tictoc::tic()
  # Прогресс бар
  pb <- txtProgressBar(min = 0L, max = length(files), style = 3)
  for (i in seq_along(files)) {
    upload_file(con = con, tablename = "doodles", 
                zipfile = zipfile, filename = files[i])
    setTxtProgressBar(pb, i)
  }
  close(pb)
  # Останавливаем таймер
  tictoc::toc()
}

# 526.141 sec elapsed - копирование SSD->SSD
# 558.879 sec elapsed - копирование USB->SSD

Ang oras sa pagkarga sa datos mahimong magkalahi depende sa mga kinaiya sa gikusgon sa drive nga gigamit. Sa among kaso, ang pagbasa ug pagsulat sulod sa usa ka SSD o gikan sa flash drive (source file) ngadto sa SSD (DB) mokabat ug ubos sa 10 minutos.

Nagkinahanglan kini og pipila ka mga segundo aron makahimo og usa ka kolum nga adunay integer class label ug usa ka index column (ORDERED INDEX) nga adunay mga numero sa linya diin ang mga obserbasyon ma-sample kung maghimo mga batch:

Paghimo og Dugang nga mga Kolum ug Index

message("Generate lables")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD label_int int"))
invisible(DBI::dbExecute(con, "UPDATE doodles SET label_int = dense_rank() OVER (ORDER BY word) - 1"))

message("Generate row numbers")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD id serial"))
invisible(DBI::dbExecute(con, "CREATE ORDERED INDEX doodles_id_ord_idx ON doodles(id)"))

Aron masulbad ang problema sa paghimo sa usa ka batch sa langaw, kinahanglan namon nga makab-ot ang labing taas nga tulin sa pagkuha sa mga random nga laray gikan sa lamesa doodles. Alang niini gigamit namon ang 3 nga mga limbong. Ang una mao ang pagpakunhod sa dimensyon sa tipo nga nagtipig sa ID sa obserbasyon. Sa orihinal nga set sa datos, ang tipo nga gikinahanglan sa pagtipig sa ID mao bigint, apan ang gidaghanon sa mga obserbasyon nagpaposible sa pagpahiangay sa ilang mga identifier, nga katumbas sa ordinal nga numero, ngadto sa tipo int. Ang pagpangita labi ka paspas sa kini nga kaso. Ang ikaduha nga lansis mao ang paggamit ORDERED INDEX — midangat kami niini nga desisyon sa empirikal nga paagi, nga nakaagi na sa tanang anaa mga kapilian. Ang ikatulo mao ang paggamit sa parameterized nga mga pangutana. Ang diwa sa pamaagi mao ang pagpatuman sa mando sa makausa PREPARE uban ang sunod nga paggamit sa usa ka andam nga ekspresyon sa paghimo sa usa ka hugpong sa mga pangutana sa parehas nga tipo, apan sa tinuud adunay usa ka bentaha kung itandi sa usa ka yano. SELECT nahimo nga naa sa sulud sa sayup sa istatistika.

Ang proseso sa pag-upload sa datos naggamit dili molapas sa 450 MB sa RAM. Kana mao, ang gihulagway nga pamaagi nagtugot kanimo sa paglihok sa mga dataset nga may gibug-aton nga napulo ka gigabytes sa halos bisan unsang hardware sa badyet, lakip ang pipila ka mga single-board device, nga medyo cool.

Ang nahabilin mao ang pagsukod sa katulin sa pagkuha (random) nga datos ug pagtimbang-timbang sa pag-scale kung mag-sample sa mga batch nga lainlain ang gidak-on:

Benchmark sa database

library(ggplot2)

set.seed(0)
# Подключение к базе данных
con <- DBI::dbConnect(MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

# Функция для подготовки запроса на стороне сервера
prep_sql <- function(batch_size) {
  sql <- sprintf("PREPARE SELECT id FROM doodles WHERE id IN (%s)",
                 paste(rep("?", batch_size), collapse = ","))
  res <- DBI::dbSendQuery(con, sql)
  return(res)
}

# Функция для извлечения данных
fetch_data <- function(rs, batch_size) {
  ids <- sample(seq_len(n), batch_size)
  res <- DBI::dbFetch(DBI::dbBind(rs, as.list(ids)))
  return(res)
}

# Проведение замера
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    rs <- prep_sql(batch_size)
    bench::mark(
      fetch_data(rs, batch_size),
      min_iterations = 50L
    )
  }
)
# Параметры бенчмарка
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16   23.6ms  54.02ms  93.43ms     18.8        2.6s    49
# 2         32     38ms  84.83ms 151.55ms     11.4       4.29s    49
# 3         64   63.3ms 175.54ms 248.94ms     5.85       8.54s    50
# 4        128   83.2ms 341.52ms 496.24ms     3.00      16.69s    50
# 5        256  232.8ms 653.21ms 847.44ms     1.58      31.66s    50
# 6        512  784.6ms    1.41s    1.98s     0.740       1.1m    49
# 7       1024  681.7ms    2.72s    4.06s     0.377      2.16m    49

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Dali nga Draw Doodle Recognition: kung giunsa ang pagpakighigala sa R, C ++ ug neural network

2. Pag-andam sa mga batch

Ang tibuok proseso sa pag-andam sa batch naglangkob sa mosunod nga mga lakang:

  1. Pag-parse sa daghang mga JSON nga adunay mga vector sa mga kuwerdas nga adunay mga koordinasyon sa mga punto.
  2. Pagdrowing og kolor nga mga linya base sa mga koordinasyon sa mga punto sa usa ka imahe sa gikinahanglan nga gidak-on (pananglitan, 256 × 256 o 128 × 128).
  3. Pag-convert sa resulta nga mga hulagway ngadto sa usa ka tensor.

Isip kabahin sa kompetisyon sa Python kernels, ang problema nasulbad sa panguna gamit OpenCV. Ang usa sa pinakayano ug labing klaro nga analogue sa R ​​mahimong ingon niini:

Pagpatuman sa JSON ngadto sa Tensor Conversion sa R

r_process_json_str <- function(json, line.width = 3, 
                               color = TRUE, scale = 1) {
  # Парсинг JSON
  coords <- jsonlite::fromJSON(json, simplifyMatrix = FALSE)
  tmp <- tempfile()
  # Удаляем временный файл по завершению функции
  on.exit(unlink(tmp))
  png(filename = tmp, width = 256 * scale, height = 256 * scale, pointsize = 1)
  # Пустой график
  plot.new()
  # Размер окна графика
  plot.window(xlim = c(256 * scale, 0), ylim = c(256 * scale, 0))
  # Цвета линий
  cols <- if (color) rainbow(length(coords)) else "#000000"
  for (i in seq_along(coords)) {
    lines(x = coords[[i]][[1]] * scale, y = coords[[i]][[2]] * scale, 
          col = cols[i], lwd = line.width)
  }
  dev.off()
  # Преобразование изображения в 3-х мерный массив
  res <- png::readPNG(tmp)
  return(res)
}

r_process_json_vector <- function(x, ...) {
  res <- lapply(x, r_process_json_str, ...)
  # Объединение 3-х мерных массивов картинок в 4-х мерный в тензор
  res <- do.call(abind::abind, c(res, along = 0))
  return(res)
}

Ang pagdrowing gihimo gamit ang standard R nga mga himan ug gitipigan sa usa ka temporaryo nga PNG nga gitipigan sa RAM (sa Linux, ang temporaryo nga mga direktoryo sa R ​​nahimutang sa direktoryo /tmp, gitaod sa RAM). Kini nga payl kay basahon isip three-dimensional array nga adunay mga numero gikan sa 0 ngadto sa 1. Importante kini tungod kay ang mas naandan nga BMP basahon ngadto sa hilaw nga array nga adunay hex color codes.

Atong sulayan ang resulta:

zip_file <- file.path("data", "train_simplified.zip")
csv_file <- "cat.csv"
unzip(zip_file, files = csv_file, exdir = tempdir(), 
      junkpaths = TRUE, unzip = getOption("unzip"))
tmp_data <- data.table::fread(file.path(tempdir(), csv_file), sep = ",", 
                              select = "drawing", nrows = 10000)
arr <- r_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Dali nga Draw Doodle Recognition: kung giunsa ang pagpakighigala sa R, C ++ ug neural network

Ang batch mismo maporma sama sa mosunod:

res <- r_process_json_vector(tmp_data[1:4, drawing], scale = 0.5)
str(res)
 # num [1:4, 1:128, 1:128, 1:3] 1 1 1 1 1 1 1 1 1 1 ...
 # - attr(*, "dimnames")=List of 4
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL

Kini nga pagpatuman daw suboptimal alang kanamo, tungod kay ang pagporma sa dagkong mga batch nagkinahanglan og usa ka dili maayo nga panahon, ug kami nakahukom sa pagpahimulos sa kasinatian sa among mga kaubanan pinaagi sa paggamit sa usa ka gamhanan nga librarya OpenCV. Niadtong panahona wala'y andam nga pakete alang sa R ​​(wala na karon), mao nga ang usa ka gamay nga pagpatuman sa gikinahanglan nga pagpaandar gisulat sa C ++ nga adunay integrasyon sa R ​​code gamit ang Rcpp.

Aron masulbad ang problema, ang mosunod nga mga pakete ug mga librarya gigamit:

  1. OpenCV alang sa pagtrabaho sa mga imahe ug mga linya sa pagguhit. Gigamit ang pre-installed system libraries ug header files, ingon man ang dinamikong pagsumpay.

  2. xtensor alang sa pagtrabaho uban sa multidimensional arrays ug tensors. Gigamit namon ang mga file sa header nga gilakip sa R ​​nga pakete nga parehas nga ngalan. Gitugotan ka sa librarya nga magtrabaho uban ang multidimensional arrays, pareho sa row major ug column major order.

  3. ndjson alang sa pag-parse sa JSON. Kini nga librarya gigamit sa xtensor awtomatik kung naa kini sa proyekto.

  4. RcppThread alang sa pag-organisar sa multi-threaded nga pagproseso sa usa ka vector gikan sa JSON. Gigamit ang mga file sa header nga gihatag niini nga package. Gikan sa mas sikat RcppParallel Ang pakete, taliwala sa ubang mga butang, adunay usa ka built-in nga mekanismo sa pag-undang sa loop.

Kini angay nga isulat kana xtensor nahimo nga usa ka diyos: dugang pa sa kamatuoran nga kini adunay daghang pag-andar ug taas nga pasundayag, ang mga nag-develop niini nahimo’g dali nga pagtubag ug gitubag ang mga pangutana dayon ug detalyado. Uban sa ilang tabang, posible nga ipatuman ang mga pagbag-o sa OpenCV matrices ngadto sa xtensor tensors, ingon man usa ka paagi sa paghiusa sa 3-dimensional nga mga tensor sa imahe ngadto sa usa ka 4-dimensional nga tensor sa husto nga dimensyon (ang batch mismo).

Mga materyales para sa pagkat-on sa Rcpp, xtensor ug RcppThread

https://thecoatlessprofessor.com/programming/unofficial-rcpp-api-documentation

https://docs.opencv.org/4.0.1/d7/dbd/group__imgproc.html

https://xtensor.readthedocs.io/en/latest/

https://xtensor.readthedocs.io/en/latest/file_loading.html#loading-json-data-into-xtensor

https://cran.r-project.org/web/packages/RcppThread/vignettes/RcppThread-vignette.pdf

Aron makolekta ang mga file nga naggamit sa mga file sa system ug dinamikong pag-link sa mga librarya nga na-install sa sistema, gigamit namon ang mekanismo sa plugin nga gipatuman sa package Rcpp. Aron awtomatiko nga makit-an ang mga agianan ug mga bandila, gigamit namon ang usa ka sikat nga utility sa Linux pkg-config.

Pagpatuman sa Rcpp plugin para sa paggamit sa OpenCV library

Rcpp::registerPlugin("opencv", function() {
  # Возможные названия пакета
  pkg_config_name <- c("opencv", "opencv4")
  # Бинарный файл утилиты pkg-config
  pkg_config_bin <- Sys.which("pkg-config")
  # Проврека наличия утилиты в системе
  checkmate::assert_file_exists(pkg_config_bin, access = "x")
  # Проверка наличия файла настроек OpenCV для pkg-config
  check <- sapply(pkg_config_name, 
                  function(pkg) system(paste(pkg_config_bin, pkg)))
  if (all(check != 0)) {
    stop("OpenCV config for the pkg-config not found", call. = FALSE)
  }

  pkg_config_name <- pkg_config_name[check == 0]
  list(env = list(
    PKG_CXXFLAGS = system(paste(pkg_config_bin, "--cflags", pkg_config_name), 
                          intern = TRUE),
    PKG_LIBS = system(paste(pkg_config_bin, "--libs", pkg_config_name), 
                      intern = TRUE)
  ))
})

Ingon usa ka sangputanan sa operasyon sa plugin, ang mga mosunud nga kantidad ipuli sa proseso sa pag-compile:

Rcpp:::.plugins$opencv()$env

# $PKG_CXXFLAGS
# [1] "-I/usr/include/opencv"
#
# $PKG_LIBS
# [1] "-lopencv_shape -lopencv_stitching -lopencv_superres -lopencv_videostab -lopencv_aruco -lopencv_bgsegm -lopencv_bioinspired -lopencv_ccalib -lopencv_datasets -lopencv_dpm -lopencv_face -lopencv_freetype -lopencv_fuzzy -lopencv_hdf -lopencv_line_descriptor -lopencv_optflow -lopencv_video -lopencv_plot -lopencv_reg -lopencv_saliency -lopencv_stereo -lopencv_structured_light -lopencv_phase_unwrapping -lopencv_rgbd -lopencv_viz -lopencv_surface_matching -lopencv_text -lopencv_ximgproc -lopencv_calib3d -lopencv_features2d -lopencv_flann -lopencv_xobjdetect -lopencv_objdetect -lopencv_ml -lopencv_xphoto -lopencv_highgui -lopencv_videoio -lopencv_imgcodecs -lopencv_photo -lopencv_imgproc -lopencv_core"

Ang code sa pagpatuman alang sa pag-parse sa JSON ug pagmugna og usa ka batch alang sa transmission ngadto sa modelo gihatag ubos sa spoiler. Una, pagdugang usa ka lokal nga direktoryo sa proyekto aron pangitaon ang mga file sa header (gikinahanglan alang sa ndjson):

Sys.setenv("PKG_CXXFLAGS" = paste0("-I", normalizePath(file.path("src"))))

Pagpatuman sa JSON ngadto sa tensor conversion sa C++

// [[Rcpp::plugins(cpp14)]]
// [[Rcpp::plugins(opencv)]]
// [[Rcpp::depends(xtensor)]]
// [[Rcpp::depends(RcppThread)]]

#include <xtensor/xjson.hpp>
#include <xtensor/xadapt.hpp>
#include <xtensor/xview.hpp>
#include <xtensor-r/rtensor.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <Rcpp.h>
#include <RcppThread.h>

// Синонимы для типов
using RcppThread::parallelFor;
using json = nlohmann::json;
using points = xt::xtensor<double,2>;     // Извлечённые из JSON координаты точек
using strokes = std::vector<points>;      // Извлечённые из JSON координаты точек
using xtensor3d = xt::xtensor<double, 3>; // Тензор для хранения матрицы изоображения
using xtensor4d = xt::xtensor<double, 4>; // Тензор для хранения множества изображений
using rtensor3d = xt::rtensor<double, 3>; // Обёртка для экспорта в R
using rtensor4d = xt::rtensor<double, 4>; // Обёртка для экспорта в R

// Статические константы
// Размер изображения в пикселях
const static int SIZE = 256;
// Тип линии
// См. https://en.wikipedia.org/wiki/Pixel_connectivity#2-dimensional
const static int LINE_TYPE = cv::LINE_4;
// Толщина линии в пикселях
const static int LINE_WIDTH = 3;
// Алгоритм ресайза
// https://docs.opencv.org/3.1.0/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
const static int RESIZE_TYPE = cv::INTER_LINEAR;

// Шаблон для конвертирования OpenCV-матрицы в тензор
template <typename T, int NCH, typename XT=xt::xtensor<T,3,xt::layout_type::column_major>>
XT to_xt(const cv::Mat_<cv::Vec<T, NCH>>& src) {
  // Размерность целевого тензора
  std::vector<int> shape = {src.rows, src.cols, NCH};
  // Общее количество элементов в массиве
  size_t size = src.total() * NCH;
  // Преобразование cv::Mat в xt::xtensor
  XT res = xt::adapt((T*) src.data, size, xt::no_ownership(), shape);
  return res;
}

// Преобразование JSON в список координат точек
strokes parse_json(const std::string& x) {
  auto j = json::parse(x);
  // Результат парсинга должен быть массивом
  if (!j.is_array()) {
    throw std::runtime_error("'x' must be JSON array.");
  }
  strokes res;
  res.reserve(j.size());
  for (const auto& a: j) {
    // Каждый элемент массива должен быть 2-мерным массивом
    if (!a.is_array() || a.size() != 2) {
      throw std::runtime_error("'x' must include only 2d arrays.");
    }
    // Извлечение вектора точек
    auto p = a.get<points>();
    res.push_back(p);
  }
  return res;
}

// Отрисовка линий
// Цвета HSV
cv::Mat ocv_draw_lines(const strokes& x, bool color = true) {
  // Исходный тип матрицы
  auto stype = color ? CV_8UC3 : CV_8UC1;
  // Итоговый тип матрицы
  auto dtype = color ? CV_32FC3 : CV_32FC1;
  auto bg = color ? cv::Scalar(0, 0, 255) : cv::Scalar(255);
  auto col = color ? cv::Scalar(0, 255, 220) : cv::Scalar(0);
  cv::Mat img = cv::Mat(SIZE, SIZE, stype, bg);
  // Количество линий
  size_t n = x.size();
  for (const auto& s: x) {
    // Количество точек в линии
    size_t n_points = s.shape()[1];
    for (size_t i = 0; i < n_points - 1; ++i) {
      // Точка начала штриха
      cv::Point from(s(0, i), s(1, i));
      // Точка окончания штриха
      cv::Point to(s(0, i + 1), s(1, i + 1));
      // Отрисовка линии
      cv::line(img, from, to, col, LINE_WIDTH, LINE_TYPE);
    }
    if (color) {
      // Меняем цвет линии
      col[0] += 180 / n;
    }
  }
  if (color) {
    // Меняем цветовое представление на RGB
    cv::cvtColor(img, img, cv::COLOR_HSV2RGB);
  }
  // Меняем формат представления на float32 с диапазоном [0, 1]
  img.convertTo(img, dtype, 1 / 255.0);
  return img;
}

// Обработка JSON и получение тензора с данными изображения
xtensor3d process(const std::string& x, double scale = 1.0, bool color = true) {
  auto p = parse_json(x);
  auto img = ocv_draw_lines(p, color);
  if (scale != 1) {
    cv::Mat out;
    cv::resize(img, out, cv::Size(), scale, scale, RESIZE_TYPE);
    cv::swap(img, out);
    out.release();
  }
  xtensor3d arr = color ? to_xt<double,3>(img) : to_xt<double,1>(img);
  return arr;
}

// [[Rcpp::export]]
rtensor3d cpp_process_json_str(const std::string& x, 
                               double scale = 1.0, 
                               bool color = true) {
  xtensor3d res = process(x, scale, color);
  return res;
}

// [[Rcpp::export]]
rtensor4d cpp_process_json_vector(const std::vector<std::string>& x, 
                                  double scale = 1.0, 
                                  bool color = false) {
  size_t n = x.size();
  size_t dim = floor(SIZE * scale);
  size_t channels = color ? 3 : 1;
  xtensor4d res({n, dim, dim, channels});
  parallelFor(0, n, [&x, &res, scale, color](int i) {
    xtensor3d tmp = process(x[i], scale, color);
    auto view = xt::view(res, i, xt::all(), xt::all(), xt::all());
    view = tmp;
  });
  return res;
}

Kini nga code kinahanglan ibutang sa file src/cv_xt.cpp ug paghugpong uban sa sugo Rcpp::sourceCpp(file = "src/cv_xt.cpp", env = .GlobalEnv); gikinahanglan usab alang sa trabaho nlohmann/json.hpp gikan sa tipiganan. Ang code gibahin sa daghang mga gimbuhaton:

  • to_xt - usa ka templated function alang sa pagbag-o sa usa ka image matrix (cv::Mat) sa usa ka tensor xt::xtensor;

  • parse_json - ang function nag-parse sa usa ka JSON string, nagkuha sa mga koordinasyon sa mga punto, giputos kini sa usa ka vector;

  • ocv_draw_lines - gikan sa resulta nga vector sa mga punto, nagdrowing og multi-kolor nga mga linya;

  • process - gihiusa ang mga gimbuhaton sa ibabaw ug gidugang usab ang abilidad sa pag-scale sa sangputanan nga imahe;

  • cpp_process_json_str - wrapper sa ibabaw sa function process, nga nag-eksport sa resulta ngadto sa usa ka R-object (multidimensional array);

  • cpp_process_json_vector - wrapper sa ibabaw sa function cpp_process_json_str, nga nagtugot kanimo sa pagproseso sa usa ka string vector sa multi-threaded mode.

Aron magdrowing og daghang kolor nga mga linya, gigamit ang HSV color model, gisundan sa pagkakabig ngadto sa RGB. Atong sulayan ang resulta:

arr <- cpp_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Dali nga Draw Doodle Recognition: kung giunsa ang pagpakighigala sa R, C ++ ug neural network
Pagtandi sa katulin sa mga pagpatuman sa R ​​ug C ++

res_bench <- bench::mark(
  r_process_json_str(tmp_data[4, drawing], scale = 0.5),
  cpp_process_json_str(tmp_data[4, drawing], scale = 0.5),
  check = FALSE,
  min_iterations = 100
)
# Параметры бенчмарка
cols <- c("expression", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   expression                min     median       max `itr/sec` total_time  n_itr
#   <chr>                <bch:tm>   <bch:tm>  <bch:tm>     <dbl>   <bch:tm>  <int>
# 1 r_process_json_str     3.49ms     3.55ms    4.47ms      273.      490ms    134
# 2 cpp_process_json_str   1.94ms     2.02ms    5.32ms      489.      497ms    243

library(ggplot2)
# Проведение замера
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    .data <- tmp_data[sample(seq_len(.N), batch_size), drawing]
    bench::mark(
      r_process_json_vector(.data, scale = 0.5),
      cpp_process_json_vector(.data,  scale = 0.5),
      min_iterations = 50,
      check = FALSE
    )
  }
)

res_bench[, cols]

#    expression   batch_size      min   median      max `itr/sec` total_time n_itr
#    <chr>             <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
#  1 r                   16   50.61ms  53.34ms  54.82ms    19.1     471.13ms     9
#  2 cpp                 16    4.46ms   5.39ms   7.78ms   192.      474.09ms    91
#  3 r                   32   105.7ms 109.74ms 212.26ms     7.69        6.5s    50
#  4 cpp                 32    7.76ms  10.97ms  15.23ms    95.6     522.78ms    50
#  5 r                   64  211.41ms 226.18ms 332.65ms     3.85      12.99s    50
#  6 cpp                 64   25.09ms  27.34ms  32.04ms    36.0        1.39s    50
#  7 r                  128   534.5ms 627.92ms 659.08ms     1.61      31.03s    50
#  8 cpp                128   56.37ms  58.46ms  66.03ms    16.9        2.95s    50
#  9 r                  256     1.15s    1.18s    1.29s     0.851     58.78s    50
# 10 cpp                256  114.97ms 117.39ms 130.09ms     8.45       5.92s    50
# 11 r                  512     2.09s    2.15s    2.32s     0.463       1.8m    50
# 12 cpp                512  230.81ms  235.6ms 261.99ms     4.18      11.97s    50
# 13 r                 1024        4s    4.22s     4.4s     0.238       3.5m    50
# 14 cpp               1024  410.48ms 431.43ms 462.44ms     2.33      21.45s    50

ggplot(res_bench, aes(x = factor(batch_size), y = median, 
                      group =  expression, color = expression)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal() +
  scale_color_discrete(name = "", labels = c("cpp", "r")) +
  theme(legend.position = "bottom") 

Dali nga Draw Doodle Recognition: kung giunsa ang pagpakighigala sa R, C ++ ug neural network

Sama sa imong nakita, ang pagtaas sa tulin nahimo’g hinungdanon kaayo, ug dili posible nga maabut ang C ++ code pinaagi sa pagparis sa R ​​code.

3. Mga iterator para sa pagdiskarga sa mga batch gikan sa database

Ang R adunay usa ka maayo nga reputasyon alang sa pagproseso sa datos nga mohaum sa RAM, samtang ang Python mas gihulagway pinaagi sa iterative data processing, nga nagtugot kanimo sa dali ug natural nga pagpatuman sa out-of-core nga mga kalkulasyon (mga kalkulasyon gamit ang external memory). Ang usa ka klasiko ug may kalabutan nga panig-ingnan alang kanato sa konteksto sa gihulagway nga problema mao ang lawom nga neural network nga gibansay sa gradient descent method nga adunay gibanabana nga gradient sa matag lakang gamit ang gamay nga bahin sa mga obserbasyon, o mini-batch.

Ang lawom nga mga balangkas sa pagkat-on nga gisulat sa Python adunay mga espesyal nga klase nga nagpatuman sa mga iterator base sa datos: mga lamesa, mga litrato sa mga folder, binary nga mga format, ug uban pa. Sa R mahimo natong pahimuslan ang tanang bahin sa librarya sa Python kusgan uban ang lainlaing mga backend niini gamit ang pakete nga parehas nga ngalan, nga sa baylo molihok sa ibabaw sa pakete isulti usab. Ang naulahi takus sa usa ka lahi nga taas nga artikulo; dili lamang kini nagtugot kanimo sa pagpadagan sa Python code gikan sa R, apan usab nagtugot kanimo sa pagbalhin sa mga butang tali sa R ​​ug Python nga mga sesyon, nga awtomatiko nga nagpahigayon sa tanan nga gikinahanglan nga mga pagkakabig sa tipo.

Gikuha namon ang panginahanglan nga tipigan ang tanan nga datos sa RAM pinaagi sa paggamit sa MonetDBlite, ang tanan nga "neural network" nga trabaho himuon sa orihinal nga code sa Python, kinahanglan ra namon nga magsulat usa ka iterator sa datos, tungod kay wala’y andam. alang sa ingon nga sitwasyon sa R ​​o Python. Adunay duha ra ka kinahanglanon alang niini: kinahanglan nga ibalik ang mga batch sa usa ka wala’y katapusan nga loop ug i-save ang estado niini taliwala sa mga pag-uli (ang ulahi sa R ​​gipatuman sa pinakasimple nga paagi gamit ang mga pagsira). Kaniadto, gikinahanglan nga tin-aw nga i-convert ang R arrays ngadto sa numpy arrays sulod sa iterator, apan ang kasamtangan nga bersyon sa package kusgan nagabuhat niini sa iyang kaugalingon.

Ang iterator alang sa pagbansay ug data sa validation nahimo nga ingon sa mosunod:

Iterator alang sa pagbansay ug data sa validation

train_generator <- function(db_connection = con,
                            samples_index,
                            num_classes = 340,
                            batch_size = 32,
                            scale = 1,
                            color = FALSE,
                            imagenet_preproc = FALSE) {
  # Проверка аргументов
  checkmate::assert_class(con, "DBIConnection")
  checkmate::assert_integerish(samples_index)
  checkmate::assert_count(num_classes)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # Перемешиваем, чтобы брать и удалять использованные индексы батчей по порядку
  dt <- data.table::data.table(id = sample(samples_index))
  # Проставляем номера батчей
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  # Оставляем только полные батчи и индексируем
  dt <- dt[, if (.N == batch_size) .SD, keyby = batch]
  # Устанавливаем счётчик
  i <- 1
  # Количество батчей
  max_i <- dt[, max(batch)]

  # Подготовка выражения для выгрузки
  sql <- sprintf(
    "PREPARE SELECT drawing, label_int FROM doodles WHERE id IN (%s)",
    paste(rep("?", batch_size), collapse = ",")
  )
  res <- DBI::dbSendQuery(con, sql)

  # Аналог keras::to_categorical
  to_categorical <- function(x, num) {
    n <- length(x)
    m <- numeric(n * num)
    m[x * n + seq_len(n)] <- 1
    dim(m) <- c(n, num)
    return(m)
  }

  # Замыкание
  function() {
    # Начинаем новую эпоху
    if (i > max_i) {
      dt[, id := sample(id)]
      data.table::setkey(dt, batch)
      # Сбрасываем счётчик
      i <<- 1
      max_i <<- dt[, max(batch)]
    }

    # ID для выгрузки данных
    batch_ind <- dt[batch == i, id]
    # Выгрузка данных
    batch <- DBI::dbFetch(DBI::dbBind(res, as.list(batch_ind)), n = -1)

    # Увеличиваем счётчик
    i <<- i + 1

    # Парсинг JSON и подготовка массива
    batch_x <- cpp_process_json_vector(batch$drawing, scale = scale, color = color)
    if (imagenet_preproc) {
      # Шкалирование c интервала [0, 1] на интервал [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }

    batch_y <- to_categorical(batch$label_int, num_classes)
    result <- list(batch_x, batch_y)
    return(result)
  }
}

Ang function nagkinahanglan isip input sa usa ka variable nga adunay koneksyon sa database, ang gidaghanon sa mga linya nga gigamit, ang gidaghanon sa mga klase, ang gidak-on sa batch, ang sukdanan (scale = 1 katumbas sa paghubad sa mga hulagway sa 256x256 pixels, scale = 0.5 — 128x128 pixels), indikasyon sa kolor (color = FALSE nagtino sa paghubad sa grayscale kung gigamit color = TRUE matag stroke gidrowing sa usa ka bag-ong kolor) ug usa ka preprocessing indicator alang sa mga network nga pre-trained sa imagenet. Ang ulahi gikinahanglan aron masukod ang mga kantidad sa pixel gikan sa agwat [0, 1] hangtod sa agwat [-1, 1], nga gigamit sa pagbansay sa gihatag. kusgan mga modelo.

Ang eksternal nga function adunay sulud nga pagsusi sa tipo sa argumento, usa ka lamesa data.table nga adunay random nga sinagol nga mga numero sa linya gikan sa samples_index ug mga numero sa batch, counter ug maximum nga gidaghanon sa mga batch, ingon man usa ka ekspresyon sa SQL alang sa pagdiskarga sa datos gikan sa database. Dugang pa, gihubit namon ang usa ka paspas nga analogue sa function sa sulod keras::to_categorical(). Gigamit namon ang halos tanan nga datos alang sa pagbansay, nagbilin ug tunga sa porsyento alang sa pag-validate, mao nga ang gidak-on sa panahon limitado sa parameter steps_per_epoch sa dihang gitawag keras::fit_generator(), ug ang kondisyon if (i > max_i) nagtrabaho lamang alang sa validation iterator.

Sa internal nga function, ang row index makuha alang sa sunod nga batch, ang mga rekord gidiskarga gikan sa database uban ang batch counter nga nagdugang, JSON parsing (function cpp_process_json_vector(), gisulat sa C++) ug paghimo og mga array nga katumbas sa mga hulagway. Unya ang usa ka init nga mga vector nga adunay mga label sa klase gihimo, ang mga array nga adunay mga kantidad sa pixel ug mga label gihiusa sa usa ka lista, nga mao ang kantidad sa pagbalik. Aron mapadali ang trabaho, gigamit namon ang paghimo sa mga indeks sa mga lamesa data.table ug pagbag-o pinaagi sa link - kung wala kini nga mga pakete nga "chips" datos.tabla Lisud mahanduraw nga epektibo nga nagtrabaho sa bisan unsang hinungdanon nga kantidad sa datos sa R.

Ang mga resulta sa pagsukod sa tulin sa usa ka Core i5 nga laptop mao ang mosunod:

Ang benchmark sa iterator

library(Rcpp)
library(keras)
library(ggplot2)

source("utils/rcpp.R")
source("utils/keras_iterator.R")

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

ind <- seq_len(DBI::dbGetQuery(con, "SELECT count(*) FROM doodles")[[1L]])
num_classes <- DBI::dbGetQuery(con, "SELECT max(label_int) + 1 FROM doodles")[[1L]]

# Индексы для обучающей выборки
train_ind <- sample(ind, floor(length(ind) * 0.995))
# Индексы для проверочной выборки
val_ind <- ind[-train_ind]
rm(ind)
# Коэффициент масштаба
scale <- 0.5

# Проведение замера
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    it1 <- train_generator(
      db_connection = con,
      samples_index = train_ind,
      num_classes = num_classes,
      batch_size = batch_size,
      scale = scale
    )
    bench::mark(
      it1(),
      min_iterations = 50L
    )
  }
)
# Параметры бенчмарка
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16     25ms  64.36ms   92.2ms     15.9       3.09s    49
# 2         32   48.4ms 118.13ms 197.24ms     8.17       5.88s    48
# 3         64   69.3ms 117.93ms 181.14ms     8.57       5.83s    50
# 4        128  157.2ms 240.74ms 503.87ms     3.85      12.71s    49
# 5        256  359.3ms 613.52ms 988.73ms     1.54       30.5s    47
# 6        512  884.7ms    1.53s    2.07s     0.674      1.11m    45
# 7       1024     2.7s    3.83s    5.47s     0.261      2.81m    44

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
    geom_point() +
    geom_line() +
    ylab("median time, s") +
    theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Dali nga Draw Doodle Recognition: kung giunsa ang pagpakighigala sa R, C ++ ug neural network

Kung adunay ka igo nga kantidad sa RAM, mahimo nimo nga seryoso nga mapadali ang operasyon sa database pinaagi sa pagbalhin niini sa parehas nga RAM (32 GB igo na alang sa among buluhaton). Sa Linux, ang partisyon gi-mount pinaagi sa default /dev/shm, nga nag-okupar hangtod sa katunga sa kapasidad sa RAM. Mahimo nimong i-highlight ang dugang pinaagi sa pag-edit /etc/fstabpara makakuha ug record like tmpfs /dev/shm tmpfs defaults,size=25g 0 0. Siguruha nga i-reboot ug susihon ang resulta pinaagi sa pagpadagan sa mando df -h.

Ang iterator alang sa data sa pagsulay tan-awon nga mas simple, tungod kay ang pagsulay nga dataset hingpit nga mohaum sa RAM:

Iterator alang sa datos sa pagsulay

test_generator <- function(dt,
                           batch_size = 32,
                           scale = 1,
                           color = FALSE,
                           imagenet_preproc = FALSE) {

  # Проверка аргументов
  checkmate::assert_data_table(dt)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # Проставляем номера батчей
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  data.table::setkey(dt, batch)
  i <- 1
  max_i <- dt[, max(batch)]

  # Замыкание
  function() {
    batch_x <- cpp_process_json_vector(dt[batch == i, drawing], 
                                       scale = scale, color = color)
    if (imagenet_preproc) {
      # Шкалирование c интервала [0, 1] на интервал [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }
    result <- list(batch_x)
    i <<- i + 1
    return(result)
  }
}

4. Pagpili sa modelo nga arkitektura

Ang unang arkitektura nga gigamit mao ang mobilenet v1, ang mga bahin niini gihisgutan sa kini mensahe. Kini gilakip ingon nga sumbanan kusgan ug, sa ingon, anaa sa pakete nga parehas nga ngalan alang sa R. Apan kung gisulayan nga gamiton kini sa mga imahe nga single-channel, usa ka katingad-an nga butang ang nahimo: ang input tensor kinahanglan kanunay adunay dimensyon (batch, height, width, 3), nga mao, ang gidaghanon sa mga channel dili mausab. Walay ingon nga limitasyon sa Python, mao nga kami nagdali ug misulat sa among kaugalingong pagpatuman niini nga arkitektura, nga nagsunod sa orihinal nga artikulo (nga walay dropout nga anaa sa hard version):

Mobilenet v1 nga arkitektura

library(keras)

top_3_categorical_accuracy <- custom_metric(
    name = "top_3_categorical_accuracy",
    metric_fn = function(y_true, y_pred) {
         metric_top_k_categorical_accuracy(y_true, y_pred, k = 3)
    }
)

layer_sep_conv_bn <- function(object, 
                              filters,
                              alpha = 1,
                              depth_multiplier = 1,
                              strides = c(2, 2)) {

  # NB! depth_multiplier !=  resolution multiplier
  # https://github.com/keras-team/keras/issues/10349

  layer_depthwise_conv_2d(
    object = object,
    kernel_size = c(3, 3), 
    strides = strides,
    padding = "same",
    depth_multiplier = depth_multiplier
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() %>%
  layer_conv_2d(
    filters = filters * alpha,
    kernel_size = c(1, 1), 
    strides = c(1, 1)
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() 
}

get_mobilenet_v1 <- function(input_shape = c(224, 224, 1),
                             num_classes = 340,
                             alpha = 1,
                             depth_multiplier = 1,
                             optimizer = optimizer_adam(lr = 0.002),
                             loss = "categorical_crossentropy",
                             metrics = c("categorical_crossentropy",
                                         top_3_categorical_accuracy)) {

  inputs <- layer_input(shape = input_shape)

  outputs <- inputs %>%
    layer_conv_2d(filters = 32, kernel_size = c(3, 3), strides = c(2, 2), padding = "same") %>%
    layer_batch_normalization() %>% 
    layer_activation_relu() %>%
    layer_sep_conv_bn(filters = 64, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(1, 1)) %>%
    layer_global_average_pooling_2d() %>%
    layer_dense(units = num_classes) %>%
    layer_activation_softmax()

    model <- keras_model(
      inputs = inputs,
      outputs = outputs
    )

    model %>% compile(
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )

    return(model)
}

Ang mga disbentaha niini nga pamaagi klaro. Gusto nakong sulayan ang daghang mga modelo, apan sa kasukwahi, dili ko gusto nga isulat pag-usab ang matag arkitektura nga mano-mano. Gihikawan usab kami sa oportunidad nga magamit ang mga gibug-aton sa mga modelo nga nabansay nang daan sa imagenet. Sama sa naandan, ang pagtuon sa dokumentasyon nakatabang. Kalihokan get_config() nagtugot kanimo nga makakuha usa ka paghulagway sa modelo sa usa ka porma nga angay alang sa pag-edit (base_model_conf$layers - usa ka regular nga lista sa R), ug ang function from_config() naghimo sa reverse conversion sa usa ka modelo nga butang:

base_model_conf <- get_config(base_model)
base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
base_model <- from_config(base_model_conf)

Karon dili lisud ang pagsulat sa usa ka unibersal nga gimbuhaton aron makuha ang bisan unsang gihatag kusgan mga modelo nga adunay o walay mga gibug-aton nga gibansay sa imagenet:

Function alang sa pagkarga sa andam nga mga arkitektura

get_model <- function(name = "mobilenet_v2",
                      input_shape = NULL,
                      weights = "imagenet",
                      pooling = "avg",
                      num_classes = NULL,
                      optimizer = keras::optimizer_adam(lr = 0.002),
                      loss = "categorical_crossentropy",
                      metrics = NULL,
                      color = TRUE,
                      compile = FALSE) {
  # Проверка аргументов
  checkmate::assert_string(name)
  checkmate::assert_integerish(input_shape, lower = 1, upper = 256, len = 3)
  checkmate::assert_count(num_classes)
  checkmate::assert_flag(color)
  checkmate::assert_flag(compile)

  # Получаем объект из пакета keras
  model_fun <- get0(paste0("application_", name), envir = asNamespace("keras"))
  # Проверка наличия объекта в пакете
  if (is.null(model_fun)) {
    stop("Model ", shQuote(name), " not found.", call. = FALSE)
  }

  base_model <- model_fun(
    input_shape = input_shape,
    include_top = FALSE,
    weights = weights,
    pooling = pooling
  )

  # Если изображение не цветное, меняем размерность входа
  if (!color) {
    base_model_conf <- keras::get_config(base_model)
    base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
    base_model <- keras::from_config(base_model_conf)
  }

  predictions <- keras::get_layer(base_model, "global_average_pooling2d_1")$output
  predictions <- keras::layer_dense(predictions, units = num_classes, activation = "softmax")
  model <- keras::keras_model(
    inputs = base_model$input,
    outputs = predictions
  )

  if (compile) {
    keras::compile(
      object = model,
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )
  }

  return(model)
}

Kung mogamit og single-channel nga mga hulagway, walay gigamit nga pretrained nga mga gibug-aton. Mahimo kini nga ayo: gamit ang function get_weights() kuhaa ang mga gibug-aton sa modelo sa porma sa usa ka lista sa R ​​arrays, usba ang dimensyon sa una nga elemento niini nga lista (pinaagi sa pagkuha sa usa ka kolor nga channel o pag-aberids sa tanan nga tulo), ug dayon i-load ang mga gibug-aton balik sa modelo nga adunay function set_weights(). Wala gyud namo gidugang kini nga pag-andar, tungod kay sa kini nga yugto klaro na nga mas produktibo ang pagtrabaho sa mga kolor nga litrato.

Gihimo namo ang kadaghanan sa mga eksperimento gamit ang mobilenet nga bersyon 1 ug 2, ingon man ang resnet34. Ang mas modernong mga arkitektura sama sa SE-ResNeXt maayo nga nahimo sa kini nga kompetisyon. Ikasubo, wala kami'y andam nga mga implementasyon nga among magamit, ug wala kami nagsulat sa among kaugalingon (apan kami siguradong magsulat).

5. Parameterization sa mga script

Alang sa kasayon, ang tanan nga code alang sa pagsugod sa pagbansay gidisenyo isip usa ka script, gigamit sa parameter docopt ingon sa mosunod:

doc <- '
Usage:
  train_nn.R --help
  train_nn.R --list-models
  train_nn.R [options]

Options:
  -h --help                   Show this message.
  -l --list-models            List available models.
  -m --model=<model>          Neural network model name [default: mobilenet_v2].
  -b --batch-size=<size>      Batch size [default: 32].
  -s --scale-factor=<ratio>   Scale factor [default: 0.5].
  -c --color                  Use color lines [default: FALSE].
  -d --db-dir=<path>          Path to database directory [default: Sys.getenv("db_dir")].
  -r --validate-ratio=<ratio> Validate sample ratio [default: 0.995].
  -n --n-gpu=<number>         Number of GPUs [default: 1].
'
args <- docopt::docopt(doc)

Pakete docopt nagrepresentar sa pagpatuman http://docopt.org/ alang sa R. Uban sa tabang niini, ang mga script gilusad sa yano nga mga sugo sama sa Rscript bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db o ./bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db, kon file train_nn.R mao ang executable (kini nga sugo magsugod sa pagbansay sa modelo resnet50 sa tulo ka kolor nga mga hulagway nga nagsukod sa 128x128 pixels, ang database kinahanglang mahimutang sa folder /home/andrey/doodle_db). Mahimo nimong idugang ang katulin sa pagkat-on, tipo sa pag-optimize, ug bisan unsang uban pa nga napasibo nga mga parameter sa lista. Sa proseso sa pag-andam sa publikasyon, kini nahimo nga ang arkitektura mobilenet_v2 gikan sa kasamtangan nga bersyon kusgan sa paggamit sa R dili mahimo tungod sa mga pagbag-o nga wala gikonsiderar sa R ​​nga pakete, naghulat kami nga ayohon nila kini.

Kini nga pamaagi nagpaposible sa pagpadali pag-ayo sa mga eksperimento sa lain-laing mga modelo kon itandi sa mas tradisyonal nga paglansad sa mga script sa RStudio (atong namatikdan ang package isip posible nga alternatibo. mga tfrun). Apan ang panguna nga bentaha mao ang abilidad nga dali nga madumala ang paglansad sa mga script sa Docker o yano sa server, nga wala i-install ang RStudio alang niini.

6. Dockerization sa mga script

Gigamit namo ang Docker aron maseguro nga madala ang palibot alang sa mga modelo sa pagbansay tali sa mga miyembro sa team ug alang sa paspas nga pag-deploy sa panganod. Mahimo ka magsugod nga pamilyar sa kini nga himan, nga medyo talagsaon alang sa usa ka R programmer, nga adunay kini serye sa mga publikasyon o video nga kurso.

Gitugotan ka sa Docker nga maghimo sa imong kaugalingon nga mga imahe gikan sa wala ug mogamit sa ubang mga imahe ingon usa ka sukaranan sa paghimo sa imong kaugalingon. Kung gi-analisar ang magamit nga mga kapilian, nakahinapos kami nga ang pag-install sa NVIDIA, CUDA + cuDNN driver ug mga librarya sa Python usa ka labi ka daghan nga bahin sa imahe, ug nakahukom kami nga kuhaon ang opisyal nga imahe ingon usa ka sukaranan. tensorflow/tensorflow:1.12.0-gpu, pagdugang sa gikinahanglan nga R packages didto.

Ang katapusan nga docker file ingon niini:

Dockerfile

FROM tensorflow/tensorflow:1.12.0-gpu

MAINTAINER Artem Klevtsov <[email protected]>

SHELL ["/bin/bash", "-c"]

ARG LOCALE="en_US.UTF-8"
ARG APT_PKG="libopencv-dev r-base r-base-dev littler"
ARG R_BIN_PKG="futile.logger checkmate data.table rcpp rapidjsonr dbi keras jsonlite curl digest remotes"
ARG R_SRC_PKG="xtensor RcppThread docopt MonetDBLite"
ARG PY_PIP_PKG="keras"
ARG DIRS="/db /app /app/data /app/models /app/logs"

RUN source /etc/os-release && 
    echo "deb https://cloud.r-project.org/bin/linux/ubuntu ${UBUNTU_CODENAME}-cran35/" > /etc/apt/sources.list.d/cran35.list && 
    apt-key adv --keyserver keyserver.ubuntu.com --recv-keys E084DAB9 && 
    add-apt-repository -y ppa:marutter/c2d4u3.5 && 
    add-apt-repository -y ppa:timsc/opencv-3.4 && 
    apt-get update && 
    apt-get install -y locales && 
    locale-gen ${LOCALE} && 
    apt-get install -y --no-install-recommends ${APT_PKG} && 
    ln -s /usr/lib/R/site-library/littler/examples/install.r /usr/local/bin/install.r && 
    ln -s /usr/lib/R/site-library/littler/examples/install2.r /usr/local/bin/install2.r && 
    ln -s /usr/lib/R/site-library/littler/examples/installGithub.r /usr/local/bin/installGithub.r && 
    echo 'options(Ncpus = parallel::detectCores())' >> /etc/R/Rprofile.site && 
    echo 'options(repos = c(CRAN = "https://cloud.r-project.org"))' >> /etc/R/Rprofile.site && 
    apt-get install -y $(printf "r-cran-%s " ${R_BIN_PKG}) && 
    install.r ${R_SRC_PKG} && 
    pip install ${PY_PIP_PKG} && 
    mkdir -p ${DIRS} && 
    chmod 777 ${DIRS} && 
    rm -rf /tmp/downloaded_packages/ /tmp/*.rds && 
    rm -rf /var/lib/apt/lists/*

COPY utils /app/utils
COPY src /app/src
COPY tests /app/tests
COPY bin/*.R /app/

ENV DBDIR="/db"
ENV CUDA_HOME="/usr/local/cuda"
ENV PATH="/app:${PATH}"

WORKDIR /app

VOLUME /db
VOLUME /app

CMD bash

Alang sa kasayon, ang mga pakete nga gigamit gibutang sa mga variable; ang kinabag-an sa sinulat nga mga script gikopya sulod sa mga sudlanan panahon sa asembliya. Giusab usab namo ang command shell sa /bin/bash alang sa kasayon ​​sa paggamit sa sulod /etc/os-release. Gilikayan niini ang panginahanglan nga ipiho ang bersyon sa OS sa code.

Dugang pa, usa ka gamay nga script sa bash ang gisulat nga nagtugot kanimo sa paglansad sa usa ka sudlanan nga adunay lainlaing mga mando. Pananglitan, kini mahimong mga script alang sa pagbansay sa mga neural network nga kaniadto gibutang sa sulod sa sudlanan, o usa ka command shell alang sa pag-debug ug pagmonitor sa operasyon sa sudlanan:

Script aron ilunsad ang sudlanan

#!/bin/sh

DBDIR=${PWD}/db
LOGSDIR=${PWD}/logs
MODELDIR=${PWD}/models
DATADIR=${PWD}/data
ARGS="--runtime=nvidia --rm -v ${DBDIR}:/db -v ${LOGSDIR}:/app/logs -v ${MODELDIR}:/app/models -v ${DATADIR}:/app/data"

if [ -z "$1" ]; then
    CMD="Rscript /app/train_nn.R"
elif [ "$1" = "bash" ]; then
    ARGS="${ARGS} -ti"
else
    CMD="Rscript /app/train_nn.R $@"
fi

docker run ${ARGS} doodles-tf ${CMD}

Kung kini nga bash script gipadagan nga walay mga parameter, ang script tawgon sa sulod sa sudlanan train_nn.R nga adunay default nga mga kantidad; kung ang una nga positional nga argumento mao ang "bash", nan ang sudlanan magsugod nga interactive sa usa ka command shell. Sa tanan nga uban nga mga kaso, ang mga kantidad sa positional nga mga argumento gipulihan: CMD="Rscript /app/train_nn.R $@".

Angay nga hinumdoman nga ang mga direktoryo nga adunay gigikanan nga datos ug database, ingon man ang direktoryo alang sa pagtipig sa nabansay nga mga modelo, gi-mount sa sulud sa sulud gikan sa host system, nga nagtugot kanimo nga ma-access ang mga sangputanan sa mga script nga wala kinahanglana nga mga manipulasyon.

7. Paggamit sa daghang mga GPU sa Google Cloud

Usa sa mga bahin sa kompetisyon mao ang saba kaayo nga datos (tan-awa ang titulo nga hulagway, hinulaman gikan sa @Leigh.plt gikan sa ODS slack). Ang dagkong mga batch makatabang sa pagpakigbatok niini, ug human sa mga eksperimento sa usa ka PC nga adunay 1 GPU, nakahukom kami sa pag-master sa mga modelo sa pagbansay sa daghang mga GPU sa panganod. Gigamit ang GoogleCloud (maayong giya sa mga sukaranan) tungod sa daghang pagpili sa magamit nga mga pag-configure, makatarunganon nga mga presyo ug $300 nga bonus. Tungod sa kahakog, nag-order ako usa ka 4xV100 nga pananglitan nga adunay SSD ug usa ka tonelada nga RAM, ug kana usa ka dako nga sayup. Ang ingon nga makina mokaon dayon sa salapi; mahimo ka nga mag-eksperimento nga wala’y napamatud-an nga pipeline. Alang sa mga katuyoan sa edukasyon, mas maayo nga kuhaon ang K80. Apan ang daghang kantidad sa RAM magamit - ang cloud SSD wala nakadayeg sa pasundayag niini, mao nga ang database gibalhin sa dev/shm.

Ang labing dako nga interes mao ang tipik sa code nga responsable sa paggamit sa daghang mga GPU. Una, ang modelo gihimo sa CPU gamit ang usa ka tagdumala sa konteksto, sama sa Python:

with(tensorflow::tf$device("/cpu:0"), {
  model_cpu <- get_model(
    name = model_name,
    input_shape = input_shape,
    weights = weights,
    metrics =(top_3_categorical_accuracy,
    compile = FALSE
  )
})

Dayon ang uncompiled (kini importante) nga modelo gikopya ngadto sa usa ka gihatag nga gidaghanon sa anaa nga mga GPU, ug human lamang nga kini gihugpong:

model <- keras::multi_gpu_model(model_cpu, gpus = n_gpu)
keras::compile(
  object = model,
  optimizer = keras::optimizer_adam(lr = 0.0004),
  loss = "categorical_crossentropy",
  metrics = c(top_3_categorical_accuracy)
)

Ang klasiko nga teknik sa pagyelo sa tanan nga mga layer gawas sa katapusan nga usa, pagbansay sa katapusan nga layer, pag-unfreeze ug pag-retraining sa tibuuk nga modelo alang sa daghang mga GPU dili mapatuman.

Ang pagbansay gimonitor nga walay gamit. tensorboard, gilimitahan ang among kaugalingon sa pagrekord sa mga troso ug pag-save sa mga modelo nga adunay impormasyon nga mga ngalan pagkahuman sa matag panahon:

Mga callback

# Шаблон имени файла лога
log_file_tmpl <- file.path("logs", sprintf(
  "%s_%d_%dch_%s.csv",
  model_name,
  dim_size,
  channels,
  format(Sys.time(), "%Y%m%d%H%M%OS")
))
# Шаблон имени файла модели
model_file_tmpl <- file.path("models", sprintf(
  "%s_%d_%dch_{epoch:02d}_{val_loss:.2f}.h5",
  model_name,
  dim_size,
  channels
))

callbacks_list <- list(
  keras::callback_csv_logger(
    filename = log_file_tmpl
  ),
  keras::callback_early_stopping(
    monitor = "val_loss",
    min_delta = 1e-4,
    patience = 8,
    verbose = 1,
    mode = "min"
  ),
  keras::callback_reduce_lr_on_plateau(
    monitor = "val_loss",
    factor = 0.5, # уменьшаем lr в 2 раза
    patience = 4,
    verbose = 1,
    min_delta = 1e-4,
    mode = "min"
  ),
  keras::callback_model_checkpoint(
    filepath = model_file_tmpl,
    monitor = "val_loss",
    save_best_only = FALSE,
    save_weights_only = FALSE,
    mode = "min"
  )
)

8. Imbes nga konklusyon

Ubay-ubay nga mga problema nga among nasugatan wala pa mabuntog:

  • в kusgan wala'y andam nga function alang sa awtomatik nga pagpangita alang sa kamalaumon nga rate sa pagkat-on (analogue lr_finder sa librarya paspas.ai); Uban sa pipila ka paningkamot, posible nga i-port ang mga pagpatuman sa ikatulo nga partido sa R, pananglitan, kini;
  • isip sangputanan sa miaging punto, dili posible nga mapili ang husto nga katulin sa pagbansay kung mogamit daghang mga GPU;
  • adunay kakulang sa modernong mga arkitektura sa neural network, ilabi na kadtong nabansay na sa imagenet;
  • walay usa ka cycle nga palisiya ug diskriminatibo nga pagkat-on rate (cosine annealing kay sa among hangyo gipatuman, salamat skeydan).

Unsa nga mapuslanon nga mga butang ang nakat-unan gikan niini nga kompetisyon:

  • Sa medyo ubos nga gahum nga hardware, mahimo ka nga magtrabaho uban ang disente (daghang beses ang gidak-on sa RAM) nga mga volume sa datos nga walay kasakit. Plastic nga bag datos.tabla makaluwas sa memorya tungod sa in-place modification sa mga lamesa, nga maglikay sa pagkopya niini, ug kung gamiton sa husto, ang mga kapabilidad niini halos kanunay nga nagpakita sa pinakataas nga tulin sa tanan nga mga himan nga nahibal-an namo alang sa scripting nga mga pinulongan. Ang pagtipig sa datos sa usa ka database nagtugot kanimo, sa daghang mga kaso, nga dili maghunahuna sa tanan mahitungod sa panginahanglan sa pagpislit sa tibuok dataset ngadto sa RAM.
  • Ang hinay nga mga function sa R ​​mahimong mapulihan sa mga paspas sa C++ gamit ang package Rcpp. Kon dugang sa paggamit RcppThread o RcppParallel, nakuha namo ang cross-platform nga multi-threaded nga mga implementasyon, mao nga dili na kinahanglan nga iparehas ang code sa R ​​level.
  • Pakete Rcpp mahimong gamiton nga walay seryoso nga kahibalo sa C ++, ang gikinahanglan nga minimum gilatid dinhi. Mga file sa header alang sa daghang mga cool nga C-library sama sa xtensor anaa sa CRAN, sa ato pa, usa ka imprastraktura ang naporma para sa pagpatuman sa mga proyekto nga nag-integrate sa ready-made high-performance C++ code ngadto sa R. Ang dugang nga kasayon ​​mao ang pag-highlight sa syntax ug usa ka static nga C++ code analyzer sa RStudio.
  • docopt nagtugot kanimo sa pagpadagan sa kaugalingon nga mga script nga adunay mga parameter. Kini sayon ​​​​alang sa paggamit sa usa ka hilit nga server, lakip. ubos sa pantalan. Sa RStudio, dili kombenyente ang pagpahigayon og daghang oras sa mga eksperimento sa pagbansay sa mga neural network, ug ang pag-instalar sa IDE sa server mismo dili kanunay nga makatarunganon.
  • Gisiguro sa Docker ang code portability ug reproducibility sa mga resulta tali sa mga developers nga adunay lain-laing mga bersyon sa OS ug mga librarya, ingon man ang kasayon ​​sa pagpatuman sa mga server. Mahimo nimong ilunsad ang tibuok nga pipeline sa pagbansay gamit ang usa lang ka sugo.
  • Ang Google Cloud usa ka budget-friendly nga paagi sa pag-eksperimento sa mahal nga hardware, pero kinahanglan nimo nga pilion og maayo ang mga configuration.
  • Ang pagsukod sa katulin sa indibidwal nga mga tipik sa code mapuslanon kaayo, labi na kung gihiusa ang R ug C ++, ug uban ang pakete bench - sayon ​​usab kaayo.

Sa kinatibuk-an kini nga kasinatian magantihon kaayo ug nagpadayon kami sa pagtrabaho aron masulbad ang pipila nga mga isyu nga gipatungha.

Source: www.habr.com

Idugang sa usa ka comment