Kamohelo ea Draw Doodle: mokhoa oa ho etsa setsoalle le R, C++ le marang-rang a methapo ea kutlo

Kamohelo ea Draw Doodle: mokhoa oa ho etsa setsoalle le R, C++ le marang-rang a methapo ea kutlo

Hey Habr!

Hoetla e fetileng, Kaggle o ile a tšoara tlholisano ea ho hlophisa litšoantšo tse huloang ka letsoho, Tsebiso ea Draw Doodle, eo har'a tse ling, sehlopha sa bo-ramahlale ba R se ileng sa nka karolo: Artem Klevtsova, Motsamaisi oa Philippa и Andrey Ogurtsov. Re ke ke ra hlalosa tlholisano ka botlalo; seo se se se entsoe ka khatiso ea morao tjena.

Lekhetlong lena ha ea ka ea sebetsa ka temo ea likhau, empa ho ile ha fumanoa phihlelo e ngata ea bohlokoa, kahoo ke rata ho bolella sechaba ka lintho tse ngata tse thahasellisang le tse molemo ka Kagle le mosebetsing oa letsatsi le letsatsi. Har'a lihlooho tse tšohloang: bophelo bo thata ntle le OpenCV, JSON parsing (mehlala ena e hlahloba ho kopanngoa ha khoutu ea C++ ho scripts kapa liphutheloana ho R ho sebelisa Rcpp), parameterization ea mangolo le dockerization ea tharollo ea ho qetela. Khoutu eohle e tsoang ho molaetsa ka foromo e loketseng ho etsoa e fumaneha ho bobolokelo.

Tse ka Hare:

  1. Kenya data hantle ho tloha ho CSV ho ea ho MonetDB
  2. Ho lokisa lihlopha
  3. Li-Iterators bakeng sa ho laolla lihlopha ho tsoa ho database
  4. Ho Khetha Moetso oa Meetso
  5. Script parameterization
  6. Dockerization ea mangolo
  7. Ho sebelisa li-GPU tse ngata ho Google Cloud
  8. Ho e-na phetheha

1. Laola data hantle ho tsoa ho CSV ho database ea MonetDB

Lintlha tsa tlholisano ena ha li fanoe ka mokhoa oa litšoantšo tse seng li entsoe, empa ka mokhoa oa lifaele tsa 340 CSV (faele e le 'ngoe bakeng sa sehlopha ka seng) tse nang le li-JSON tse nang le likhokahano tsa lintlha. Ka ho hokahanya lintlha tsena le mela, re fumana setšoantšo sa ho qetela se lekanyang lipikselse tse 256x256. Hape bakeng sa rekoto e 'ngoe le e' ngoe ho na le lengolo le bontšang hore na setšoantšo se ile sa amoheloa ka nepo ke motho ea khethiloeng ka nako eo dataset e neng e bokelloa, khoutu ea litlhaku tse peli ea naha ea bolulo ea mongoli oa setšoantšo, sekhetho se ikhethileng, setempe sa nako. le lebitso la sehlopha le lumellanang le lebitso la faele. Phetolelo e nolofalitsoeng ea data ea mantlha e boima ba 7.4 GB polokelong ea polokelo le hoo e ka bang 20 GB ka mor'a hore e lokolloe, data e felletseng ka mor'a ho e notlolla e nka 240 GB. Bahlophisi ba ile ba etsa bonnete ba hore liphetolelo tseo ka bobeli li hlahisa litšoantšo tse tšoanang, ho bolelang hore phetolelo e feletseng ha e na thuso. Leha ho le joalo, ho boloka litšoantšo tse limilione tse 50 lifaeleng tse hlakileng kapa ka mokhoa oa li-arrays hang-hang ho ile ha nkoa e le ntho e se nang thuso, 'me re ile ra etsa qeto ea ho kopanya lifaele tsohle tsa CSV ho tloha polokelong ea khale. train_simplified.zip polokelongtshedimosetso ka tlhahiso e latelang ya ditshwantsho tsa boholo bo hlokehang "ka fofa" bakeng sa beche ka nngwe.

Sistimi e netefalitsoeng hantle e khethiloe e le DBMS MonetDB, e leng ts'ebetsong ea R joalo ka sephutheloana MonetDBLite. Sephutheloana se kenyelletsa mofuta o kentsoeng oa seva sa database mme se u lumella ho nka seva ka kotloloho ho tsoa ho R Seboka 'me u sebetse le eona moo. Ho theha database le ho hokela ho eona ho etsoa ka taelo e le 'ngoe:

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

Re tla hloka ho theha litafole tse peli: e 'ngoe bakeng sa data eohle, e' ngoe bakeng sa tlhaiso-leseling ea lits'ebeletso mabapi le lifaele tse jarollotsoeng (e thusa haeba ho na le ho sa tsamaeeng hantle mme ts'ebetso e tlameha ho qalelloa ka mor'a ho khoasolla lifaele tse 'maloa):

Ho theha litafole

if (!DBI::dbExistsTable(con, "doodles")) {
  DBI::dbCreateTable(
    con = con,
    name = "doodles",
    fields = c(
      "countrycode" = "char(2)",
      "drawing" = "text",
      "key_id" = "bigint",
      "recognized" = "bool",
      "timestamp" = "timestamp",
      "word" = "text"
    )
  )
}

if (!DBI::dbExistsTable(con, "upload_log")) {
  DBI::dbCreateTable(
    con = con,
    name = "upload_log",
    fields = c(
      "id" = "serial",
      "file_name" = "text UNIQUE",
      "uploaded" = "bool DEFAULT false"
    )
  )
}

Mokhoa o potlakileng oa ho kenya data ho database e ne e le ho kopitsa lifaele tsa CSV ka kotloloho u sebelisa SQL - command COPY OFFSET 2 INTO tablename FROM path USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORTkae tablename - lebitso la tafole le path - tsela ea faele. Ha a ntse a sebetsa le polokelo ea litlaleho, ho ile ha fumanoa hore ts'ebetsong e hahiloeng unzip ho R ha e sebetse ka nepo le lifaele tse 'maloa tse tsoang polokelong, kahoo re sebelisitse sistimi unzip (ho sebelisa parameter getOption("unzip")).

Mosebetsi oa ho ngolla database

#' @title Извлечение и загрузка файлов
#'
#' @description
#' Извлечение CSV-файлов из ZIP-архива и загрузка их в базу данных
#'
#' @param con Объект подключения к базе данных (класс `MonetDBEmbeddedConnection`).
#' @param tablename Название таблицы в базе данных.
#' @oaram zipfile Путь к ZIP-архиву.
#' @oaram filename Имя файла внури ZIP-архива.
#' @param preprocess Функция предобработки, которая будет применена извлечённому файлу.
#'   Должна принимать один аргумент `data` (объект `data.table`).
#'
#' @return `TRUE`.
#'
upload_file <- function(con, tablename, zipfile, filename, preprocess = NULL) {
  # Проверка аргументов
  checkmate::assert_class(con, "MonetDBEmbeddedConnection")
  checkmate::assert_string(tablename)
  checkmate::assert_string(filename)
  checkmate::assert_true(DBI::dbExistsTable(con, tablename))
  checkmate::assert_file_exists(zipfile, access = "r", extension = "zip")
  checkmate::assert_function(preprocess, args = c("data"), null.ok = TRUE)

  # Извлечение файла
  path <- file.path(tempdir(), filename)
  unzip(zipfile, files = filename, exdir = tempdir(), 
        junkpaths = TRUE, unzip = getOption("unzip"))
  on.exit(unlink(file.path(path)))

  # Применяем функция предобработки
  if (!is.null(preprocess)) {
    .data <- data.table::fread(file = path)
    .data <- preprocess(data = .data)
    data.table::fwrite(x = .data, file = path, append = FALSE)
    rm(.data)
  }

  # Запрос к БД на импорт CSV
  sql <- sprintf(
    "COPY OFFSET 2 INTO %s FROM '%s' USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORT",
    tablename, path
  )
  # Выполнение запроса к БД
  DBI::dbExecute(con, sql)

  # Добавление записи об успешной загрузке в служебную таблицу
  DBI::dbExecute(con, sprintf("INSERT INTO upload_log(file_name, uploaded) VALUES('%s', true)",
                              filename))

  return(invisible(TRUE))
}

Haeba o hloka ho fetola tafole pele o e ngolla ho database, ho lekane ho fetisa khang preprocess tshebetso e tla fetola data.

Khoutu ea ho kenya data ka tatellano ho database:

Ho ngolla data ho database

# Список файлов для записи
files <- unzip(zipfile, list = TRUE)$Name

# Список исключений, если часть файлов уже была загружена
to_skip <- DBI::dbGetQuery(con, "SELECT file_name FROM upload_log")[[1L]]
files <- setdiff(files, to_skip)

if (length(files) > 0L) {
  # Запускаем таймер
  tictoc::tic()
  # Прогресс бар
  pb <- txtProgressBar(min = 0L, max = length(files), style = 3)
  for (i in seq_along(files)) {
    upload_file(con = con, tablename = "doodles", 
                zipfile = zipfile, filename = files[i])
    setTxtProgressBar(pb, i)
  }
  close(pb)
  # Останавливаем таймер
  tictoc::toc()
}

# 526.141 sec elapsed - копирование SSD->SSD
# 558.879 sec elapsed - копирование USB->SSD

Nako ea ho kenya data e ka fapana ho latela litšobotsi tsa lebelo la koloi e sebelisitsoeng. Tabeng ea rona, ho bala le ho ngola ka har'a SSD e le 'ngoe kapa ho tloha ho flash drive (file ea mohloli) ho ea SSD (DB) ho nka nako e ka tlaase ho metsotso e 10.

Ho nka metsotsoana e seng mekae ho theha kholomo e nang le lengolo la sehlopha sa palo e felletseng le kholomo ea index (ORDERED INDEX) ka linomoro tsa mela tseo ho tla etsoa sampole ka tsona ha ho etsoa lihlopha:

Ho theha Likholomo tse Eketsehileng le Index

message("Generate lables")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD label_int int"))
invisible(DBI::dbExecute(con, "UPDATE doodles SET label_int = dense_rank() OVER (ORDER BY word) - 1"))

message("Generate row numbers")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD id serial"))
invisible(DBI::dbExecute(con, "CREATE ORDERED INDEX doodles_id_ord_idx ON doodles(id)"))

Ho rarolla bothata ba ho theha sehlopha ka fofa, re ne re hloka ho fihlela lebelo le phahameng la ho ntša mela e sa reroang tafoleng. doodles. Bakeng sa sena re sebelisitse maqheka a 3. Ea pele e ne e le ho fokotsa boholo ba mofuta o bolokang ID ea ho shebella. Sethalong sa data sa mantlha, mofuta o hlokahalang ho boloka ID ke bigint, empa palo ea litebello e etsa hore ho khonehe ho kopanya li-identifi tsa tsona, tse lekanang le nomoro ea ordinal, mofuteng. int. Patlo e potlakile haholo tabeng ena. Leqheka la bobeli e ne e le ho sebelisa ORDERED INDEX - re fihletse qeto ena ka matla, re fetile ho tsohle tse fumanehang likhetho. Ea boraro e ne e le ho sebelisa lipotso tsa parameterized. Moko oa mokhoa ke ho phethahatsa taelo hang PREPARE ka tšebeliso e latelang ea polelo e lokiselitsoeng ha u theha sehlopha sa lipotso tsa mofuta o le mong, empa ha e le hantle ho na le molemo ha o bapisoa le o bonolo. SELECT e ile ea hlaha e le ka har'a moeli oa phoso ea lipalo.

Ts'ebetso ea ho kenya data ha e sebelise ho feta 450 MB ea RAM. Ke hore, mokhoa o hlalositsoeng o u lumella ho tsamaisa li-dataset tse boima ba li-gigabyte tse mashome ho hoo e batlang e le lisebelisoa leha e le life tsa tekanyetso, ho kenyelletsa le lisebelisoa tse ling tsa boto e le 'ngoe, tse ntle haholo.

Ho setseng feela ke ho lekanya lebelo la ho khutlisa data (e sa reroang) le ho lekola sekala ha ho etsoa sampole lihlopha tsa boholo bo fapaneng:

Benchmark ea database

library(ggplot2)

set.seed(0)
# Подключение к базе данных
con <- DBI::dbConnect(MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

# Функция для подготовки запроса на стороне сервера
prep_sql <- function(batch_size) {
  sql <- sprintf("PREPARE SELECT id FROM doodles WHERE id IN (%s)",
                 paste(rep("?", batch_size), collapse = ","))
  res <- DBI::dbSendQuery(con, sql)
  return(res)
}

# Функция для извлечения данных
fetch_data <- function(rs, batch_size) {
  ids <- sample(seq_len(n), batch_size)
  res <- DBI::dbFetch(DBI::dbBind(rs, as.list(ids)))
  return(res)
}

# Проведение замера
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    rs <- prep_sql(batch_size)
    bench::mark(
      fetch_data(rs, batch_size),
      min_iterations = 50L
    )
  }
)
# Параметры бенчмарка
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16   23.6ms  54.02ms  93.43ms     18.8        2.6s    49
# 2         32     38ms  84.83ms 151.55ms     11.4       4.29s    49
# 3         64   63.3ms 175.54ms 248.94ms     5.85       8.54s    50
# 4        128   83.2ms 341.52ms 496.24ms     3.00      16.69s    50
# 5        256  232.8ms 653.21ms 847.44ms     1.58      31.66s    50
# 6        512  784.6ms    1.41s    1.98s     0.740       1.1m    49
# 7       1024  681.7ms    2.72s    4.06s     0.377      2.16m    49

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Kamohelo ea Draw Doodle: mokhoa oa ho etsa setsoalle le R, C++ le marang-rang a methapo ea kutlo

2. Ho lokisetsa lihlopha

Ts'ebetso eohle ea ho lokisa batch e na le mehato e latelang:

  1. Ho hlophisa li-JSON tse 'maloa tse nang le li-vector tsa likhoele tse nang le likhokahano tsa lintlha.
  2. Ho taka mela e mebala e thehiloeng ho lihokahanyo tsa lintlha tse setšoantšong sa boholo bo hlokahalang (mohlala, 256 × 256 kapa 128 × 128).
  3. Ho fetolela litšoantšo tse hlahisoang hore e be tensor.

E le karolo ea tlholisano har'a lithollo tsa Python, bothata bo ile ba rarolloa haholo ho sebelisoa OpenCV. E 'ngoe ea li-analogues tse bonolo le tse hlakileng ho R e ka shebahala tjena:

E kenya ts'ebetsong Phetoho ea JSON ho Tensor ho R

r_process_json_str <- function(json, line.width = 3, 
                               color = TRUE, scale = 1) {
  # Парсинг JSON
  coords <- jsonlite::fromJSON(json, simplifyMatrix = FALSE)
  tmp <- tempfile()
  # Удаляем временный файл по завершению функции
  on.exit(unlink(tmp))
  png(filename = tmp, width = 256 * scale, height = 256 * scale, pointsize = 1)
  # Пустой график
  plot.new()
  # Размер окна графика
  plot.window(xlim = c(256 * scale, 0), ylim = c(256 * scale, 0))
  # Цвета линий
  cols <- if (color) rainbow(length(coords)) else "#000000"
  for (i in seq_along(coords)) {
    lines(x = coords[[i]][[1]] * scale, y = coords[[i]][[2]] * scale, 
          col = cols[i], lwd = line.width)
  }
  dev.off()
  # Преобразование изображения в 3-х мерный массив
  res <- png::readPNG(tmp)
  return(res)
}

r_process_json_vector <- function(x, ...) {
  res <- lapply(x, r_process_json_str, ...)
  # Объединение 3-х мерных массивов картинок в 4-х мерный в тензор
  res <- do.call(abind::abind, c(res, along = 0))
  return(res)
}

Ho taka ho etsoa ho sebelisoa lisebelisoa tse tloaelehileng tsa R 'me ho bolokoa ho PNG ea nakoana e bolokiloeng ho RAM (ho Linux, li-directory tsa nakoana tsa R li fumaneha bukeng. /tmp, e kentsoeng ka RAM). Joale faele ena e baloa e le lethathamo la likarolo tse tharo tse nang le linomoro ho tloha ho 0 ho ea ho 1. Sena ke sa bohlokoa hobane BMP e tloaelehileng e ne e tla baloa ka mokhoa o tala o nang le likhoutu tsa mebala ea hex.

Ha re lekeng sephetho:

zip_file <- file.path("data", "train_simplified.zip")
csv_file <- "cat.csv"
unzip(zip_file, files = csv_file, exdir = tempdir(), 
      junkpaths = TRUE, unzip = getOption("unzip"))
tmp_data <- data.table::fread(file.path(tempdir(), csv_file), sep = ",", 
                              select = "drawing", nrows = 10000)
arr <- r_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Kamohelo ea Draw Doodle: mokhoa oa ho etsa setsoalle le R, C++ le marang-rang a methapo ea kutlo

Sehlopha ka boeona se tla etsoa ka tsela e latelang:

res <- r_process_json_vector(tmp_data[1:4, drawing], scale = 0.5)
str(res)
 # num [1:4, 1:128, 1:128, 1:3] 1 1 1 1 1 1 1 1 1 1 ...
 # - attr(*, "dimnames")=List of 4
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL

Ts'ebetsong ena e ne e bonahala e le ea bohlokoa ho rona, kaha ho etsoa ha lihlopha tse kholo ho nka nako e telele, 'me re nkile qeto ea ho nka monyetla ka boiphihlelo ba basebetsi-'moho le rona ka ho sebelisa laeborari e matla. OpenCV. Ka nako eo ho ne ho se na sephutheloana se lokiselitsoeng bakeng sa R ​​(ha ho na hona joale), kahoo ts'ebetsong e fokolang ea ts'ebetso e hlokahalang e ne e ngotsoe ka C ++ ka ho kopanngoa le khoutu ea R ho sebelisoa. Rcpp.

Ho rarolla bothata, ho ile ha sebelisoa liphutheloana le lilaebrari tse latelang:

  1. OpenCV bakeng sa ho sebetsa ka litšoantšo le mela ea ho taka. E sebelisitsoe lilaebrari tsa sistimi e kentsoeng pele le lifaele tsa lihlooho, hammoho le khokahano e matla.

  2. xtensor bakeng sa ho sebetsa ka li-multidimensional arrays le tensor. Re sebelisitse lifaele tsa hlooho tse kenyellelitsoeng ka har'a sephutheloana sa R ​​sa lebitso le le leng. Laebrari e u lumella ho sebetsa ka mefuta e mengata ea li-multidimensional, ka bobeli ka tatellano e kholo le ea kholomo.

  3. ndjson bakeng sa ho hlalosa JSON. Laebrari ena e sebelisoa ho xtensor ka tsela e iketsang haeba e le teng morerong.

  4. RcppThread bakeng sa ho hlophisa ts'ebetso ea likhoele tse ngata tsa vector ho tsoa ho JSON. U sebelisitse lifaele tsa lihlooho tse fanoeng ke sephutheloana sena. Ho tsoa ho tse tsebahalang RcppParallel Sephutheloana, har'a lintho tse ling, se na le mochine o hahelletsoeng ka har'a loop.

Ke habohlokoa ho hlokomela seo xtensor e ile ea fetoha molimo: ntle le taba ea hore e na le ts'ebetso e pharalletseng le ts'ebetso e phahameng, baetsi ba eona ba ile ba fetoha ba arabelang haholo mme ba araba lipotso hang-hang le ka botlalo. Ka thuso ea bona, ho ile ha khonahala ho kenya ts'ebetsong liphetoho tsa matrices a OpenCV ho xtensor tensor, hammoho le mokhoa oa ho kopanya li-tensor tsa setšoantšo sa 3-dimensional ho tensor ea 4-dimensional ea tekanyo e nepahetseng (batch ka boeona).

Lisebelisoa tsa ho ithuta Rcpp, xtensor le RcppThread

https://thecoatlessprofessor.com/programming/unofficial-rcpp-api-documentation

https://docs.opencv.org/4.0.1/d7/dbd/group__imgproc.html

https://xtensor.readthedocs.io/en/latest/

https://xtensor.readthedocs.io/en/latest/file_loading.html#loading-json-data-into-xtensor

https://cran.r-project.org/web/packages/RcppThread/vignettes/RcppThread-vignette.pdf

Ho bokella lifaele tse sebelisang lifaele tsa sistimi le khokahano e matla le lilaebrari tse kentsoeng tsamaisong, re sebelisitse mochini oa plugin o kentsoeng sephuthelong. Rcpp. Ho fumana litsela le lifolakha ka bo eona, re sebelisitse sesebelisoa se tsebahalang sa Linux pkg-config.

Ho kenngwa tshebetsong ha plugin ya Rcpp bakeng sa ho sebedisa laeborari ya OpenCV

Rcpp::registerPlugin("opencv", function() {
  # Возможные названия пакета
  pkg_config_name <- c("opencv", "opencv4")
  # Бинарный файл утилиты pkg-config
  pkg_config_bin <- Sys.which("pkg-config")
  # Проврека наличия утилиты в системе
  checkmate::assert_file_exists(pkg_config_bin, access = "x")
  # Проверка наличия файла настроек OpenCV для pkg-config
  check <- sapply(pkg_config_name, 
                  function(pkg) system(paste(pkg_config_bin, pkg)))
  if (all(check != 0)) {
    stop("OpenCV config for the pkg-config not found", call. = FALSE)
  }

  pkg_config_name <- pkg_config_name[check == 0]
  list(env = list(
    PKG_CXXFLAGS = system(paste(pkg_config_bin, "--cflags", pkg_config_name), 
                          intern = TRUE),
    PKG_LIBS = system(paste(pkg_config_bin, "--libs", pkg_config_name), 
                      intern = TRUE)
  ))
})

Ka lebaka la ts'ebetso ea plugin, litekanyetso tse latelang li tla nkeloa sebaka nakong ea ho bokella:

Rcpp:::.plugins$opencv()$env

# $PKG_CXXFLAGS
# [1] "-I/usr/include/opencv"
#
# $PKG_LIBS
# [1] "-lopencv_shape -lopencv_stitching -lopencv_superres -lopencv_videostab -lopencv_aruco -lopencv_bgsegm -lopencv_bioinspired -lopencv_ccalib -lopencv_datasets -lopencv_dpm -lopencv_face -lopencv_freetype -lopencv_fuzzy -lopencv_hdf -lopencv_line_descriptor -lopencv_optflow -lopencv_video -lopencv_plot -lopencv_reg -lopencv_saliency -lopencv_stereo -lopencv_structured_light -lopencv_phase_unwrapping -lopencv_rgbd -lopencv_viz -lopencv_surface_matching -lopencv_text -lopencv_ximgproc -lopencv_calib3d -lopencv_features2d -lopencv_flann -lopencv_xobjdetect -lopencv_objdetect -lopencv_ml -lopencv_xphoto -lopencv_highgui -lopencv_videoio -lopencv_imgcodecs -lopencv_photo -lopencv_imgproc -lopencv_core"

Khoutu ea ts'ebetsong bakeng sa ho arola JSON le ho hlahisa sehlopha bakeng sa phetisetso ho mohlala e fanoa tlas'a spoiler. Taba ea mantlha, eketsa bukana ea morero oa lehae ho batla lifaele tsa hlooho (tse hlokahalang bakeng sa ndjson):

Sys.setenv("PKG_CXXFLAGS" = paste0("-I", normalizePath(file.path("src"))))

Phethahatso ea JSON ho phetoho ea tensor ho C++

// [[Rcpp::plugins(cpp14)]]
// [[Rcpp::plugins(opencv)]]
// [[Rcpp::depends(xtensor)]]
// [[Rcpp::depends(RcppThread)]]

#include <xtensor/xjson.hpp>
#include <xtensor/xadapt.hpp>
#include <xtensor/xview.hpp>
#include <xtensor-r/rtensor.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <Rcpp.h>
#include <RcppThread.h>

// Синонимы для типов
using RcppThread::parallelFor;
using json = nlohmann::json;
using points = xt::xtensor<double,2>;     // Извлечённые из JSON координаты точек
using strokes = std::vector<points>;      // Извлечённые из JSON координаты точек
using xtensor3d = xt::xtensor<double, 3>; // Тензор для хранения матрицы изоображения
using xtensor4d = xt::xtensor<double, 4>; // Тензор для хранения множества изображений
using rtensor3d = xt::rtensor<double, 3>; // Обёртка для экспорта в R
using rtensor4d = xt::rtensor<double, 4>; // Обёртка для экспорта в R

// Статические константы
// Размер изображения в пикселях
const static int SIZE = 256;
// Тип линии
// См. https://en.wikipedia.org/wiki/Pixel_connectivity#2-dimensional
const static int LINE_TYPE = cv::LINE_4;
// Толщина линии в пикселях
const static int LINE_WIDTH = 3;
// Алгоритм ресайза
// https://docs.opencv.org/3.1.0/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
const static int RESIZE_TYPE = cv::INTER_LINEAR;

// Шаблон для конвертирования OpenCV-матрицы в тензор
template <typename T, int NCH, typename XT=xt::xtensor<T,3,xt::layout_type::column_major>>
XT to_xt(const cv::Mat_<cv::Vec<T, NCH>>& src) {
  // Размерность целевого тензора
  std::vector<int> shape = {src.rows, src.cols, NCH};
  // Общее количество элементов в массиве
  size_t size = src.total() * NCH;
  // Преобразование cv::Mat в xt::xtensor
  XT res = xt::adapt((T*) src.data, size, xt::no_ownership(), shape);
  return res;
}

// Преобразование JSON в список координат точек
strokes parse_json(const std::string& x) {
  auto j = json::parse(x);
  // Результат парсинга должен быть массивом
  if (!j.is_array()) {
    throw std::runtime_error("'x' must be JSON array.");
  }
  strokes res;
  res.reserve(j.size());
  for (const auto& a: j) {
    // Каждый элемент массива должен быть 2-мерным массивом
    if (!a.is_array() || a.size() != 2) {
      throw std::runtime_error("'x' must include only 2d arrays.");
    }
    // Извлечение вектора точек
    auto p = a.get<points>();
    res.push_back(p);
  }
  return res;
}

// Отрисовка линий
// Цвета HSV
cv::Mat ocv_draw_lines(const strokes& x, bool color = true) {
  // Исходный тип матрицы
  auto stype = color ? CV_8UC3 : CV_8UC1;
  // Итоговый тип матрицы
  auto dtype = color ? CV_32FC3 : CV_32FC1;
  auto bg = color ? cv::Scalar(0, 0, 255) : cv::Scalar(255);
  auto col = color ? cv::Scalar(0, 255, 220) : cv::Scalar(0);
  cv::Mat img = cv::Mat(SIZE, SIZE, stype, bg);
  // Количество линий
  size_t n = x.size();
  for (const auto& s: x) {
    // Количество точек в линии
    size_t n_points = s.shape()[1];
    for (size_t i = 0; i < n_points - 1; ++i) {
      // Точка начала штриха
      cv::Point from(s(0, i), s(1, i));
      // Точка окончания штриха
      cv::Point to(s(0, i + 1), s(1, i + 1));
      // Отрисовка линии
      cv::line(img, from, to, col, LINE_WIDTH, LINE_TYPE);
    }
    if (color) {
      // Меняем цвет линии
      col[0] += 180 / n;
    }
  }
  if (color) {
    // Меняем цветовое представление на RGB
    cv::cvtColor(img, img, cv::COLOR_HSV2RGB);
  }
  // Меняем формат представления на float32 с диапазоном [0, 1]
  img.convertTo(img, dtype, 1 / 255.0);
  return img;
}

// Обработка JSON и получение тензора с данными изображения
xtensor3d process(const std::string& x, double scale = 1.0, bool color = true) {
  auto p = parse_json(x);
  auto img = ocv_draw_lines(p, color);
  if (scale != 1) {
    cv::Mat out;
    cv::resize(img, out, cv::Size(), scale, scale, RESIZE_TYPE);
    cv::swap(img, out);
    out.release();
  }
  xtensor3d arr = color ? to_xt<double,3>(img) : to_xt<double,1>(img);
  return arr;
}

// [[Rcpp::export]]
rtensor3d cpp_process_json_str(const std::string& x, 
                               double scale = 1.0, 
                               bool color = true) {
  xtensor3d res = process(x, scale, color);
  return res;
}

// [[Rcpp::export]]
rtensor4d cpp_process_json_vector(const std::vector<std::string>& x, 
                                  double scale = 1.0, 
                                  bool color = false) {
  size_t n = x.size();
  size_t dim = floor(SIZE * scale);
  size_t channels = color ? 3 : 1;
  xtensor4d res({n, dim, dim, channels});
  parallelFor(0, n, [&x, &res, scale, color](int i) {
    xtensor3d tmp = process(x[i], scale, color);
    auto view = xt::view(res, i, xt::all(), xt::all(), xt::all());
    view = tmp;
  });
  return res;
}

Khoutu ena e lokela ho kenngoa faeleng src/cv_xt.cpp le ho bokella ka taelo Rcpp::sourceCpp(file = "src/cv_xt.cpp", env = .GlobalEnv); hape e hlokahalang bakeng sa mosebetsi nlohmann/json.hpp ho tswa polokelo. Code e arotsoe ka likarolo tse 'maloa:

  • to_xt - ts'ebetso e etselitsoeng ho fetola matrix ea setšoantšo (cv::Mat) ho tenor xt::xtensor;

  • parse_json - ts'ebetso e arola khoele ea JSON, e ntša likhokahano tsa lintlha, e li kenya ka har'a vector;

  • ocv_draw_lines - ho tloha ho vector e hlahisoang ea lintlha, e hula mela e mebala-bala;

  • process - e kopanya mesebetsi e ka holimo mme e boetse e eketsa bokhoni ba ho lekanya setšoantšo se hlahisoang;

  • cpp_process_json_str - sekoahelo holim'a mosebetsi process, e romelang sephetho ho R-object (multidimensional array);

  • cpp_process_json_vector - sekoahelo holim'a mosebetsi cpp_process_json_str, e leng se u lumellang hore u sebetse vector ea likhoele ka mokhoa o nang le mefuta e mengata.

Ho taka mela e mebala-bala, ho ile ha sebelisoa mofuta oa 'mala oa HSV, o lateloa ke ho fetolela ho RGB. Ha re hlahlobe sephetho:

arr <- cpp_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Kamohelo ea Draw Doodle: mokhoa oa ho etsa setsoalle le R, C++ le marang-rang a methapo ea kutlo
Papiso ea lebelo la ts'ebetsong ho R le C ++

res_bench <- bench::mark(
  r_process_json_str(tmp_data[4, drawing], scale = 0.5),
  cpp_process_json_str(tmp_data[4, drawing], scale = 0.5),
  check = FALSE,
  min_iterations = 100
)
# Параметры бенчмарка
cols <- c("expression", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   expression                min     median       max `itr/sec` total_time  n_itr
#   <chr>                <bch:tm>   <bch:tm>  <bch:tm>     <dbl>   <bch:tm>  <int>
# 1 r_process_json_str     3.49ms     3.55ms    4.47ms      273.      490ms    134
# 2 cpp_process_json_str   1.94ms     2.02ms    5.32ms      489.      497ms    243

library(ggplot2)
# Проведение замера
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    .data <- tmp_data[sample(seq_len(.N), batch_size), drawing]
    bench::mark(
      r_process_json_vector(.data, scale = 0.5),
      cpp_process_json_vector(.data,  scale = 0.5),
      min_iterations = 50,
      check = FALSE
    )
  }
)

res_bench[, cols]

#    expression   batch_size      min   median      max `itr/sec` total_time n_itr
#    <chr>             <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
#  1 r                   16   50.61ms  53.34ms  54.82ms    19.1     471.13ms     9
#  2 cpp                 16    4.46ms   5.39ms   7.78ms   192.      474.09ms    91
#  3 r                   32   105.7ms 109.74ms 212.26ms     7.69        6.5s    50
#  4 cpp                 32    7.76ms  10.97ms  15.23ms    95.6     522.78ms    50
#  5 r                   64  211.41ms 226.18ms 332.65ms     3.85      12.99s    50
#  6 cpp                 64   25.09ms  27.34ms  32.04ms    36.0        1.39s    50
#  7 r                  128   534.5ms 627.92ms 659.08ms     1.61      31.03s    50
#  8 cpp                128   56.37ms  58.46ms  66.03ms    16.9        2.95s    50
#  9 r                  256     1.15s    1.18s    1.29s     0.851     58.78s    50
# 10 cpp                256  114.97ms 117.39ms 130.09ms     8.45       5.92s    50
# 11 r                  512     2.09s    2.15s    2.32s     0.463       1.8m    50
# 12 cpp                512  230.81ms  235.6ms 261.99ms     4.18      11.97s    50
# 13 r                 1024        4s    4.22s     4.4s     0.238       3.5m    50
# 14 cpp               1024  410.48ms 431.43ms 462.44ms     2.33      21.45s    50

ggplot(res_bench, aes(x = factor(batch_size), y = median, 
                      group =  expression, color = expression)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal() +
  scale_color_discrete(name = "", labels = c("cpp", "r")) +
  theme(legend.position = "bottom") 

Kamohelo ea Draw Doodle: mokhoa oa ho etsa setsoalle le R, C++ le marang-rang a methapo ea kutlo

Joalokaha u bona, keketseho ea lebelo e bile ea bohlokoa haholo, 'me ha ho khonehe ho fumana khoutu ea C ++ ka ho bapisa khoutu ea R.

3. Li-Iterators bakeng sa ho laolla lihlopha ho tsoa ho database

R e na le botumo bo loketseng bakeng sa ho sebetsana le data e lumellanang le RAM, ha Python e khetholloa haholo ka ts'ebetso ea data e pheta-phetoang, e leng se u lumellang hore u sebelise habonolo le ka tlhaho lipalo tse tsoileng matsoho (lipalo tse sebelisang mohopolo o ka ntle). Mohlala oa khale le o loketseng bakeng sa rona moelelong oa bothata bo hlalositsoeng ke marang-rang a tebileng a methapo ea kutlo a koetlisitsoeng ke mokhoa oa ho theoha ha sekhahla ka khakanyo ea sekhahla mohatong o mong le o mong ho sebelisoa karolo e nyane ea litebello, kapa batch e nyane.

Merero e tebileng ea ho ithuta e ngotsoeng ka Python e na le lihlopha tse khethehileng tse kenyang ts'ebetsong li-iterators tse thehiloeng ho data: litafole, litšoantšo tse ka har'a li-folders, liforomo tsa binary, joalo-joalo U ka sebelisa likhetho tse lokiselitsoeng kapa ua ngola tsa hau bakeng sa mesebetsi e itseng. Ho R re ka nka monyetla ka likarolo tsohle tsa laeborari ea Python kerata ka li-backend tsa eona tse fapaneng ho sebelisa sephutheloana sa lebitso le le leng, leo le lona le sebetsang ka holim'a sephutheloana pheta-pheta. Ea ho qetela e tšoaneloa ke sehlooho se selelele se arohaneng; ha e u lumelle feela hore u tsamaise khoutu ea Python ho tloha ho R, empa e boetse e u lumella ho fetisetsa lintho pakeng tsa R le Python sessions, ka ho iketsetsa mefuta eohle ea liphetoho tse hlokahalang.

Re lahlile tlhoko ea ho boloka lintlha tsohle ho RAM ka ho sebelisa MonetDBLite, mosebetsi oohle oa "neural network" o tla etsoa ke khoutu ea mantlha ho Python, re tlameha feela ho ngola sengoloa holim'a data, kaha ha ho letho le lokiselitsoeng. bakeng sa boemo bo joalo ho R kapa Python. Ha e le hantle ho na le litlhokahalo tse peli feela bakeng sa eona: e tlameha ho khutlisa li-batches ka loop e sa feleng ebe e boloka boemo ba eona lipakeng tsa ho pheta-pheta (ea ho qetela ho R e kengoa ts'ebetsong ka tsela e bonolo ka ho fetisisa e sebelisa ho koaloa). Pejana, ho ne ho hlokahala hore ho fetoleloe ka ho hlaka li-arrays hore e be numpy arrays ka hare ho iterator, empa mofuta oa hajoale oa sephutheloana. kerata o etsa ka boyena.

Iterator bakeng sa lintlha tsa koetliso le netefatso e bile ka tsela e latelang:

Iterator bakeng sa lithupelo le lintlha tsa netefatso

train_generator <- function(db_connection = con,
                            samples_index,
                            num_classes = 340,
                            batch_size = 32,
                            scale = 1,
                            color = FALSE,
                            imagenet_preproc = FALSE) {
  # Проверка аргументов
  checkmate::assert_class(con, "DBIConnection")
  checkmate::assert_integerish(samples_index)
  checkmate::assert_count(num_classes)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # Перемешиваем, чтобы брать и удалять использованные индексы батчей по порядку
  dt <- data.table::data.table(id = sample(samples_index))
  # Проставляем номера батчей
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  # Оставляем только полные батчи и индексируем
  dt <- dt[, if (.N == batch_size) .SD, keyby = batch]
  # Устанавливаем счётчик
  i <- 1
  # Количество батчей
  max_i <- dt[, max(batch)]

  # Подготовка выражения для выгрузки
  sql <- sprintf(
    "PREPARE SELECT drawing, label_int FROM doodles WHERE id IN (%s)",
    paste(rep("?", batch_size), collapse = ",")
  )
  res <- DBI::dbSendQuery(con, sql)

  # Аналог keras::to_categorical
  to_categorical <- function(x, num) {
    n <- length(x)
    m <- numeric(n * num)
    m[x * n + seq_len(n)] <- 1
    dim(m) <- c(n, num)
    return(m)
  }

  # Замыкание
  function() {
    # Начинаем новую эпоху
    if (i > max_i) {
      dt[, id := sample(id)]
      data.table::setkey(dt, batch)
      # Сбрасываем счётчик
      i <<- 1
      max_i <<- dt[, max(batch)]
    }

    # ID для выгрузки данных
    batch_ind <- dt[batch == i, id]
    # Выгрузка данных
    batch <- DBI::dbFetch(DBI::dbBind(res, as.list(batch_ind)), n = -1)

    # Увеличиваем счётчик
    i <<- i + 1

    # Парсинг JSON и подготовка массива
    batch_x <- cpp_process_json_vector(batch$drawing, scale = scale, color = color)
    if (imagenet_preproc) {
      # Шкалирование c интервала [0, 1] на интервал [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }

    batch_y <- to_categorical(batch$label_int, num_classes)
    result <- list(batch_x, batch_y)
    return(result)
  }
}

Ts'ebetso e nka e le ho kenya phetoho e nang le khokahano ho database, lipalo tsa mela e sebelisitsoeng, palo ea lihlopha, boholo ba batch, sekala (scale = 1 e lumellana le ho fana ka litšoantšo tsa 256x256 pixels, scale = 0.5 - 128x128 pixels), pontšo ea mebala (color = FALSE e totobatsa ho fana ka boputsoa ha bo sebedisoa color = TRUE setorouku se seng le se seng se huloa ka 'mala o mocha) le sesupo sa preprocessing bakeng sa marang-rang a koetlisitsoeng pele ho imagenet. Ea morao-rao e ea hlokahala bakeng sa ho lekanya boleng ba pixel ho tloha karohano [0, 1] ho ea ho nako [-1, 1], e neng e sebelisoa ha ho koetlisoa ba fanoeng. kerata mehlala.

Ts'ebetso ea kantle e na le tlhahlobo ea mofuta oa likhang, tafole data.table ka linomoro tsa mela e tsoakiloeng ka mokhoa o sa reroang ho tloha samples_index le linomoro tsa batch, li-counter le palo e kholo ea lihlopha, hammoho le polelo ea SQL bakeng sa ho laolla data ho tswa ho database. Ho feta moo, re hlalositse analogue e potlakileng ea ts'ebetso kahare keras::to_categorical(). Re sebelisitse hoo e ka bang data eohle bakeng sa koetliso, re siea halofo ea liperesente bakeng sa netefatso, kahoo boholo ba nako bo ne bo lekantsoe ke paramethara. steps_per_epoch ha a bitsoa keras::fit_generator(), le boemo if (i > max_i) e sebetsa feela bakeng sa netefatso ea netefatso.

Ts'ebetsong ea kahare, li-index tsa mela li khutlisoa bakeng sa betch e latelang, lirekoto li laolloa ho tsoa polokelong ea litaba ka palo ea batch e ntseng e eketseha, JSON parsing (function). cpp_process_json_vector(), e ngotsoeng ka C ++) le ho theha lihlopha tse lumellanang le litšoantšo. Ebe ho etsoa li-vector tse chesang tse nang le mabitso a sehlopha, lihlopha tse nang le boleng ba pixel le li-label li kopantsoe lethathamong, e leng boleng ba ho khutla. Ho potlakisa mosebetsi, re sebelisitse ho theha li-index litafoleng data.table le phetoho ka sehokelo - ntle le "li-chips" tsena tsa sephutheloana data.tafole Ho thata ho nahana ho sebetsa ka katleho le palo efe kapa efe ea bohlokoa ea data ho R.

Liphetho tsa litekanyo tsa lebelo ho laptop ea Core i5 ke tse latelang:

Iterator benchmark

library(Rcpp)
library(keras)
library(ggplot2)

source("utils/rcpp.R")
source("utils/keras_iterator.R")

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

ind <- seq_len(DBI::dbGetQuery(con, "SELECT count(*) FROM doodles")[[1L]])
num_classes <- DBI::dbGetQuery(con, "SELECT max(label_int) + 1 FROM doodles")[[1L]]

# Индексы для обучающей выборки
train_ind <- sample(ind, floor(length(ind) * 0.995))
# Индексы для проверочной выборки
val_ind <- ind[-train_ind]
rm(ind)
# Коэффициент масштаба
scale <- 0.5

# Проведение замера
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    it1 <- train_generator(
      db_connection = con,
      samples_index = train_ind,
      num_classes = num_classes,
      batch_size = batch_size,
      scale = scale
    )
    bench::mark(
      it1(),
      min_iterations = 50L
    )
  }
)
# Параметры бенчмарка
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16     25ms  64.36ms   92.2ms     15.9       3.09s    49
# 2         32   48.4ms 118.13ms 197.24ms     8.17       5.88s    48
# 3         64   69.3ms 117.93ms 181.14ms     8.57       5.83s    50
# 4        128  157.2ms 240.74ms 503.87ms     3.85      12.71s    49
# 5        256  359.3ms 613.52ms 988.73ms     1.54       30.5s    47
# 6        512  884.7ms    1.53s    2.07s     0.674      1.11m    45
# 7       1024     2.7s    3.83s    5.47s     0.261      2.81m    44

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
    geom_point() +
    geom_line() +
    ylab("median time, s") +
    theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Kamohelo ea Draw Doodle: mokhoa oa ho etsa setsoalle le R, C++ le marang-rang a methapo ea kutlo

Haeba u na le palo e lekaneng ea RAM, u ka potlakisa ts'ebetso ea database ka ho e fetisetsa ho RAM eona ena (32 GB e lekane bakeng sa mosebetsi oa rona). Ho Linux, karohano e behiloe ka ho sa feleng /dev/shm, e nka hoo e ka bang halofo ea matla a RAM. U ka totobatsa tse ling ka ho hlophisa /etc/fstabho fumana rekoto e kang tmpfs /dev/shm tmpfs defaults,size=25g 0 0. Etsa bonnete ba hore u qala hape 'me u hlahlobe sephetho ka ho tsamaisa taelo df -h.

Iterator bakeng sa data ea liteko e shebahala e le bonolo haholoanyane, kaha dataset ea liteko e lumellana ka botlalo le RAM:

Iterator bakeng sa lintlha tsa tlhahlobo

test_generator <- function(dt,
                           batch_size = 32,
                           scale = 1,
                           color = FALSE,
                           imagenet_preproc = FALSE) {

  # Проверка аргументов
  checkmate::assert_data_table(dt)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # Проставляем номера батчей
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  data.table::setkey(dt, batch)
  i <- 1
  max_i <- dt[, max(batch)]

  # Замыкание
  function() {
    batch_x <- cpp_process_json_vector(dt[batch == i, drawing], 
                                       scale = scale, color = color)
    if (imagenet_preproc) {
      # Шкалирование c интервала [0, 1] на интервал [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }
    result <- list(batch_x)
    i <<- i + 1
    return(result)
  }
}

4. Khetho ea mohaho oa mohlala

Mohaho oa pele o sebelisitsoeng e ne e le mobilenet v1, litšobotsi tseo ho buisanoang ka tsona ho sena molaetsa. E kenyelelitsoe e le maemo kerata 'me, ka hona, e fumaneha ka har'a sephutheloana sa lebitso le le leng bakeng sa R. Empa ha u leka ho e sebelisa ka litšoantšo tsa mocha o le mong, ho ile ha hlaha ntho e makatsang: "input tensor" e tlameha ho lula e na le tekanyo. (batch, height, width, 3), ke hore, palo ea likanale e ke ke ea fetoloa. Ha ho na moeli o joalo ho Python, kahoo re ile ra potlakela ho ngola ts'ebetsong ea rona ea meralo ea kaho, re latela sengoloa sa mantlha (ntle le ho tlohela ho mofuta oa keras):

Mohaho oa Mobilenet v1

library(keras)

top_3_categorical_accuracy <- custom_metric(
    name = "top_3_categorical_accuracy",
    metric_fn = function(y_true, y_pred) {
         metric_top_k_categorical_accuracy(y_true, y_pred, k = 3)
    }
)

layer_sep_conv_bn <- function(object, 
                              filters,
                              alpha = 1,
                              depth_multiplier = 1,
                              strides = c(2, 2)) {

  # NB! depth_multiplier !=  resolution multiplier
  # https://github.com/keras-team/keras/issues/10349

  layer_depthwise_conv_2d(
    object = object,
    kernel_size = c(3, 3), 
    strides = strides,
    padding = "same",
    depth_multiplier = depth_multiplier
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() %>%
  layer_conv_2d(
    filters = filters * alpha,
    kernel_size = c(1, 1), 
    strides = c(1, 1)
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() 
}

get_mobilenet_v1 <- function(input_shape = c(224, 224, 1),
                             num_classes = 340,
                             alpha = 1,
                             depth_multiplier = 1,
                             optimizer = optimizer_adam(lr = 0.002),
                             loss = "categorical_crossentropy",
                             metrics = c("categorical_crossentropy",
                                         top_3_categorical_accuracy)) {

  inputs <- layer_input(shape = input_shape)

  outputs <- inputs %>%
    layer_conv_2d(filters = 32, kernel_size = c(3, 3), strides = c(2, 2), padding = "same") %>%
    layer_batch_normalization() %>% 
    layer_activation_relu() %>%
    layer_sep_conv_bn(filters = 64, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(1, 1)) %>%
    layer_global_average_pooling_2d() %>%
    layer_dense(units = num_classes) %>%
    layer_activation_softmax()

    model <- keras_model(
      inputs = inputs,
      outputs = outputs
    )

    model %>% compile(
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )

    return(model)
}

Mefokolo ea mokhoa ona e totobetse. Ke batla ho leka mehlala e mengata, empa ho fapana le hoo, ha ke batle ho ngola bocha mohaho o mong le o mong ka letsoho. Hape re ile ra amohuoa monyetla oa ho sebelisa litekanyo tsa mefuta e koetlisitsoeng pele ho imagenet. Joalo ka tloaelo, ho ithuta litokomane ho ile ha thusa. Mosebetsi get_config() e o lumella ho fumana tlhaloso ea mohlala ka foromo e loketseng ho hlophisoa (base_model_conf$layers - lenane la kamehla la R), le mosebetsi from_config() e etsa phetoho e ka morao ho ntho ea mohlala:

base_model_conf <- get_config(base_model)
base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
base_model <- from_config(base_model_conf)

Hona joale ha ho thata ho ngola mosebetsi oa bokahohleng ho fumana leha e le efe ea tse fanoeng kerata mefuta e nang le boima ba 'mele kapa ntle le eona e koetlisitsoeng ho imagenet:

Mosebetsi oa ho kenya meralo e seng e entsoe

get_model <- function(name = "mobilenet_v2",
                      input_shape = NULL,
                      weights = "imagenet",
                      pooling = "avg",
                      num_classes = NULL,
                      optimizer = keras::optimizer_adam(lr = 0.002),
                      loss = "categorical_crossentropy",
                      metrics = NULL,
                      color = TRUE,
                      compile = FALSE) {
  # Проверка аргументов
  checkmate::assert_string(name)
  checkmate::assert_integerish(input_shape, lower = 1, upper = 256, len = 3)
  checkmate::assert_count(num_classes)
  checkmate::assert_flag(color)
  checkmate::assert_flag(compile)

  # Получаем объект из пакета keras
  model_fun <- get0(paste0("application_", name), envir = asNamespace("keras"))
  # Проверка наличия объекта в пакете
  if (is.null(model_fun)) {
    stop("Model ", shQuote(name), " not found.", call. = FALSE)
  }

  base_model <- model_fun(
    input_shape = input_shape,
    include_top = FALSE,
    weights = weights,
    pooling = pooling
  )

  # Если изображение не цветное, меняем размерность входа
  if (!color) {
    base_model_conf <- keras::get_config(base_model)
    base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
    base_model <- keras::from_config(base_model_conf)
  }

  predictions <- keras::get_layer(base_model, "global_average_pooling2d_1")$output
  predictions <- keras::layer_dense(predictions, units = num_classes, activation = "softmax")
  model <- keras::keras_model(
    inputs = base_model$input,
    outputs = predictions
  )

  if (compile) {
    keras::compile(
      object = model,
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )
  }

  return(model)
}

Ha u sebelisa litšoantšo tsa kanane e le 'ngoe, ha ho na litekanyo tse koetlisitsoeng esale pele tse sebelisoang. Sena se ka lokisoa: ho sebelisa ts'ebetso get_weights() fumana litekanyo tsa mohlala ka mokhoa oa lethathamo la lihlopha tsa R, fetola boholo ba karolo ea pele ea lenane lena (ka ho nka mocha o le mong oa mebala kapa ka karolelano e meraro), ebe o kenya litekanyo hape ho mohlala ka mosebetsi. set_weights(). Ha ho mohla re kileng ra eketsa ts'ebetso ena, hobane sethaleng sena ho ne ho se ho hlakile hore ho sebetsa ka litšoantšo tsa mebala ho ne ho atleha haholoanyane.

Re entse liteko tse ngata re sebelisa mofuta oa mobilenet 1 le 2, hammoho le resnet34. Mehaho e meng ea sejoale-joale e kang SE-ResNeXt e sebelitse hantle tlholisanong ena. Ka bomalimabe, ha rea ​​​​ka ra ba le ts'ebetsong e entsoeng esale pele, 'me ha rea ​​ka ra ngola tsa rona (empa re tla ngola ka sebele).

5. Parameterization ea mangolo

Bakeng sa boiketlo, khoutu eohle ea ho qala koetliso e ne e entsoe e le script e le 'ngoe, e sebelisoang ka parametered docopt ka tsela e latelang:

doc <- '
Usage:
  train_nn.R --help
  train_nn.R --list-models
  train_nn.R [options]

Options:
  -h --help                   Show this message.
  -l --list-models            List available models.
  -m --model=<model>          Neural network model name [default: mobilenet_v2].
  -b --batch-size=<size>      Batch size [default: 32].
  -s --scale-factor=<ratio>   Scale factor [default: 0.5].
  -c --color                  Use color lines [default: FALSE].
  -d --db-dir=<path>          Path to database directory [default: Sys.getenv("db_dir")].
  -r --validate-ratio=<ratio> Validate sample ratio [default: 0.995].
  -n --n-gpu=<number>         Number of GPUs [default: 1].
'
args <- docopt::docopt(doc)

Package docopt e emela phethahatso http://docopt.org/ bakeng sa R. Ka thuso ea eona, lingoloa li qalisoa ka litaelo tse bonolo joalo ka Rscript bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db kapa ./bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db, haeba faele train_nn.R e ea sebetsa (taelo ena e tla qala ho koetlisa mohlala resnet50 litšoantšong tsa mebala e meraro tse lekanyang lipikselse tse 128x128, database e tlameha ho beoa foldareng. /home/andrey/doodle_db). O ka eketsa lebelo la ho ithuta, mofuta oa optimizer, le li-parameter life kapa life tse ka khonehang lenaneng. Ha a ntse a lokisetsa khatiso, ho ile ha fumaneha hore mohaho oa mobilenet_v2 ho tsoa ho mofuta oa hajoale kerata ka R tshebediso ha e khone ka lebaka la liphetoho tse sa nkoang ka har'a sephutheloana sa R, re emetse hore ba e lokise.

Mokhoa ona o entse hore ho khonehe ho potlakisa liteko ka mefuta e fapaneng ha ho bapisoa le ho qalisoa ha lingoliloeng ho RStudio (re hlokomela sephutheloana joalo ka mokhoa o mong o ka khonehang. tfrun). Empa molemo o ka sehloohong ke bokhoni ba ho laola habonolo ho qalisoa ha mangolo ho Docker kapa feela ho seva, ntle le ho kenya RStudio bakeng sa sena.

6. Dockerization ea mangolo

Re sebelisitse Docker ho etsa bonnete ba hore tikoloho e sebetsa hantle bakeng sa mehlala ea koetliso pakeng tsa litho tsa sehlopha le ho romelloa ka potlako marung. U ka qala ho tloaelana le sesebelisoa sena, se batlang se sa tloaeleha bakeng sa moqapi oa R, ka sena letoto la likhatiso kapa thupelo ea video.

Docker e u lumella ho iketsetsa litšoantšo ho tloha qalong 'me u sebelise litšoantšo tse ling e le motheo oa ho iketsetsa tsa hau. Ha re sekaseka likhetho tse fumanehang, re fihletse qeto ea hore ho kenya li-driver tsa NVIDIA, CUDA + cuDNN le lilaebrari tsa Python ke karolo e hlakileng ea setšoantšo, mme re nkile qeto ea ho nka setšoantšo sa semmuso e le motheo. tensorflow/tensorflow:1.12.0-gpu, ho eketsa liphutheloana tsa R tse hlokahalang moo.

Faele ea ho qetela ea docker e ne e shebahala tjena:

dockerfile

FROM tensorflow/tensorflow:1.12.0-gpu

MAINTAINER Artem Klevtsov <[email protected]>

SHELL ["/bin/bash", "-c"]

ARG LOCALE="en_US.UTF-8"
ARG APT_PKG="libopencv-dev r-base r-base-dev littler"
ARG R_BIN_PKG="futile.logger checkmate data.table rcpp rapidjsonr dbi keras jsonlite curl digest remotes"
ARG R_SRC_PKG="xtensor RcppThread docopt MonetDBLite"
ARG PY_PIP_PKG="keras"
ARG DIRS="/db /app /app/data /app/models /app/logs"

RUN source /etc/os-release && 
    echo "deb https://cloud.r-project.org/bin/linux/ubuntu ${UBUNTU_CODENAME}-cran35/" > /etc/apt/sources.list.d/cran35.list && 
    apt-key adv --keyserver keyserver.ubuntu.com --recv-keys E084DAB9 && 
    add-apt-repository -y ppa:marutter/c2d4u3.5 && 
    add-apt-repository -y ppa:timsc/opencv-3.4 && 
    apt-get update && 
    apt-get install -y locales && 
    locale-gen ${LOCALE} && 
    apt-get install -y --no-install-recommends ${APT_PKG} && 
    ln -s /usr/lib/R/site-library/littler/examples/install.r /usr/local/bin/install.r && 
    ln -s /usr/lib/R/site-library/littler/examples/install2.r /usr/local/bin/install2.r && 
    ln -s /usr/lib/R/site-library/littler/examples/installGithub.r /usr/local/bin/installGithub.r && 
    echo 'options(Ncpus = parallel::detectCores())' >> /etc/R/Rprofile.site && 
    echo 'options(repos = c(CRAN = "https://cloud.r-project.org"))' >> /etc/R/Rprofile.site && 
    apt-get install -y $(printf "r-cran-%s " ${R_BIN_PKG}) && 
    install.r ${R_SRC_PKG} && 
    pip install ${PY_PIP_PKG} && 
    mkdir -p ${DIRS} && 
    chmod 777 ${DIRS} && 
    rm -rf /tmp/downloaded_packages/ /tmp/*.rds && 
    rm -rf /var/lib/apt/lists/*

COPY utils /app/utils
COPY src /app/src
COPY tests /app/tests
COPY bin/*.R /app/

ENV DBDIR="/db"
ENV CUDA_HOME="/usr/local/cuda"
ENV PATH="/app:${PATH}"

WORKDIR /app

VOLUME /db
VOLUME /app

CMD bash

Bakeng sa boiketlo, liphutheloana tse sebelisitsoeng li ne li kenngoa ka mefuta-futa; bongata ba lingoloa tse ngotsoeng li kopitsoa ka har'a lijana nakong ea kopano. Re boetse re fetotse khetla ea taelo ho /bin/bash molemong wa tshebediso ya dikahare /etc/os-release. Sena se ile sa qoba tlhoko ea ho hlakisa mofuta oa OS khoutu.

Ho feta moo, ho ngotsoe lengolo le lenyenyane la bash le u lumellang ho qala setshelo se nang le litaelo tse fapaneng. Mohlala, tsena e ka ba lingoloa tsa ho koetlisa marang-rang a neural a neng a kentsoe ka har'a sets'oants'o, kapa khetla ea taelo bakeng sa ho lokisa liphoso le ho lekola ts'ebetso ea setshelo:

Script ho qala setshelo

#!/bin/sh

DBDIR=${PWD}/db
LOGSDIR=${PWD}/logs
MODELDIR=${PWD}/models
DATADIR=${PWD}/data
ARGS="--runtime=nvidia --rm -v ${DBDIR}:/db -v ${LOGSDIR}:/app/logs -v ${MODELDIR}:/app/models -v ${DATADIR}:/app/data"

if [ -z "$1" ]; then
    CMD="Rscript /app/train_nn.R"
elif [ "$1" = "bash" ]; then
    ARGS="${ARGS} -ti"
else
    CMD="Rscript /app/train_nn.R $@"
fi

docker run ${ARGS} doodles-tf ${CMD}

Haeba script ena ea bash e tsamaisoa ntle le litekanyetso, script e tla bitsoa ka har'a setshelo train_nn.R ka boleng ba kamehla; haeba khang ea pele e le "bash", setshelo se tla qala ka ho sebelisana le khetla ea taelo. Maemong a mang kaofela, boleng ba likhang tsa maemo bo nkeloa sebaka: CMD="Rscript /app/train_nn.R $@".

Ho bohlokoa ho hlokomela hore li-directory tse nang le data ea mohloli le database, hammoho le bukana ea ho boloka mehlala e koetlisitsoeng, li kentsoe ka har'a sets'oants'o ho tsoa ho sistimi e amohelang, e u lumellang ho fumana liphetho tsa mangolo ntle le ho qhekella ho sa hlokahaleng.

7. Ho sebelisa li-GPU tse ngata ho Google Cloud

E 'ngoe ea likarolo tsa tlhōlisano e ne e le data e lerata haholo (bona setšoantšo sa sehlooho, se alimiloeng ho @Leigh.plt ho tloha ODS slack). Lihlopha tse kholo li thusa ho loantša sena, 'me ka mor'a liteko tsa PC e nang le 1 GPU, re ile ra etsa qeto ea ho tseba mekhoa ea koetliso ho li-GPU tse' maloa marung. GoogleCloud e sebelisitsoeng (tataiso e ntle ea lintho tsa motheo) ka lebaka la khetho e kholo ea litlhophiso tse fumanehang, litheko tse loketseng le bonase ea $ 300. Ka lebaka la meharo, ke ile ka laela mohlala oa 4xV100 ka SSD le tone ea RAM, 'me eo e ne e le phoso e kholo. Mochini o joalo o ja chelete kapele; o ka khona ho etsa liteko ntle le pompo e netefalitsoeng. Bakeng sa merero ea thuto, ho molemo ho nka K80. Empa palo e kholo ea RAM e ile ea sebetsa hantle - SSD ea leru ha ea ka ea khahlisa ts'ebetso ea eona, kahoo database e fetiselitsoe ho eona. dev/shm.

Se khahlang haholo ke sekhechana sa khoutu se ikarabellang bakeng sa ho sebelisa li-GPU tse ngata. Taba ea pele, mohlala o entsoe ho CPU o sebelisa mookameli oa litaba, joalo ka ho Python:

with(tensorflow::tf$device("/cpu:0"), {
  model_cpu <- get_model(
    name = model_name,
    input_shape = input_shape,
    weights = weights,
    metrics =(top_3_categorical_accuracy,
    compile = FALSE
  )
})

Ebe mohlala o sa bokelloang (sena ke oa bohlokoa) o kopitsoa ho palo e fanoeng ea li-GPU tse fumanehang, 'me feela ka mor'a moo oa bokelloa:

model <- keras::multi_gpu_model(model_cpu, gpus = n_gpu)
keras::compile(
  object = model,
  optimizer = keras::optimizer_adam(lr = 0.0004),
  loss = "categorical_crossentropy",
  metrics = c(top_3_categorical_accuracy)
)

Mokhoa oa khale oa ho hatsetsa likarolo tsohle ntle le oa ho qetela, ho koetlisa lera la ho qetela, ho se pholile le ho koetlisa mofuta oohle bakeng sa li-GPU tse 'maloa o ne o ke ke oa sebelisoa.

Koetliso e ne e behiloe leihlo ntle le tšebeliso. tensorboard, re ipehela meeli ea ho rekota lits'oants'o le ho boloka mefuta e nang le mabitso a rutang ka mor'a nako e 'ngoe le e 'ngoe:

Li-callback

# Шаблон имени файла лога
log_file_tmpl <- file.path("logs", sprintf(
  "%s_%d_%dch_%s.csv",
  model_name,
  dim_size,
  channels,
  format(Sys.time(), "%Y%m%d%H%M%OS")
))
# Шаблон имени файла модели
model_file_tmpl <- file.path("models", sprintf(
  "%s_%d_%dch_{epoch:02d}_{val_loss:.2f}.h5",
  model_name,
  dim_size,
  channels
))

callbacks_list <- list(
  keras::callback_csv_logger(
    filename = log_file_tmpl
  ),
  keras::callback_early_stopping(
    monitor = "val_loss",
    min_delta = 1e-4,
    patience = 8,
    verbose = 1,
    mode = "min"
  ),
  keras::callback_reduce_lr_on_plateau(
    monitor = "val_loss",
    factor = 0.5, # уменьшаем lr в 2 раза
    patience = 4,
    verbose = 1,
    min_delta = 1e-4,
    mode = "min"
  ),
  keras::callback_model_checkpoint(
    filepath = model_file_tmpl,
    monitor = "val_loss",
    save_best_only = FALSE,
    save_weights_only = FALSE,
    mode = "min"
  )
)

8. Ho e-na le qeto

Mathata a mangata ao re kopaneng le ona ha a so ka a hlola:

  • в kerata ha ho na ts'ebetso e lokiselitsoeng ho batla ka bo eona sekhahla se nepahetseng sa ho ithuta (analogue lr_finder ka laeboraring ka potlako.ai); Ka boiteko bo itseng, hoa khoneha ho kenya ts'ebetsong ea motho oa boraro ho R, mohlala, sena;
  • ka lebaka la ntlha e fetileng, ho ne ho sa khonehe ho khetha lebelo le nepahetseng la koetliso ha u sebelisa li-GPU tse 'maloa;
  • ho na le khaello ea meralo ea sejoale-joale ea neural network, haholo-holo e koetlisitsoeng pele ho imagenet;
  • ha ho na leano la baesekele le litekanyetso tsa ho ithuta tse khethollang (cosine annealing e ne e le kopo ea rona kenngwa tshebetsong, Kea leboha skydan).

Ke lintho life tse molemo tse ithutiloeng tlholisanong ena:

  • Ho hardware e batlang e le tlase, o ka sebetsa ka mekhoa e metle (hangata boholo ba RAM) ea data ntle le bohloko. Mokotla oa polasetiki data.tafole e boloka mohopolo ka lebaka la ho feto-fetoha ha litafole, e qobang ho li kopitsa, 'me ha e sebelisoa ka nepo, bokhoni ba eona hoo e ka bang kamehla bo bontša lebelo le phahameng ka ho fetisisa har'a lisebelisoa tsohle tseo re li tsebang bakeng sa lipuo tsa ho ngola. Ho boloka data sebakeng sa polokelo ea litaba ho u lumella, maemong a mangata, hore u se ke ua nahana ho hang ka tlhoko ea ho pepeta dataset kaofela ho RAM.
  • Mesebetsi e liehang ho R e ka nkeloa sebaka ke e potlakileng ho C ++ ho sebelisa sephutheloana Rcpp. Haeba ho phaella ho sebelisoa RcppThread kapa RcppParallel, re fumana ts'ebetsong ea mefuta e mengata ea li-cross-platform, kahoo ha ho hlokahale ho bapisa khoutu boemong ba R.
  • Sephutheloana Rcpp e ka sebelisoa ntle le tsebo e tebileng ea C ++, bonyane bo hlokahalang bo hlalositsoe mona. Lifaele tsa hlooho bakeng sa lilaebrari tse ngata tse pholileng tsa C tse kang xtensor e fumanehang ho CRAN, ke hore, ho ntse ho etsoa moralo oa ts'ebetsong oa merero e kopanyang khoutu ea C++ e seng e lokisitsoe ho R. Bonolo bo eketsehileng ke ho totobatsa li-syntax le analyzer ea khoutu ea C++ e tsitsitseng ho RStudio.
  • docopt e o lumella ho tsamaisa lingoloa tse ikemetseng ka li-parameter. Sena se loketse ho sebelisoa ho seva se hole, ho kenyeletsoa. tlas'a docker. Ho RStudio, ha ho bonolo ho etsa liteko tsa lihora tse ngata ka ho koetlisa marang-rang a neural, 'me ho kenya IDE ho seva ka boeona ha ho na lebaka kamehla.
  • Docker e netefatsa ho nkeha ha khoutu le ho hlahisa liphetho lipakeng tsa bahlahisi ba nang le mefuta e fapaneng ea OS le lilaeborari, hammoho le ts'ebetso e bonolo ho li-server. U ka qala lipeipi tsohle tsa koetliso ka taelo e le 'ngoe feela.
  • Google Cloud ke mokhoa o bonolo oa lichelete oa ho leka lisebelisoa tse turang, empa o hloka ho khetha litlhophiso ka hloko.
  • Ho lekanya lebelo la likhechana tsa khoutu ka bomong ho molemo haholo, haholo ha o kopanya R le C ++, le sephutheloana. benche - hape ho bonolo haholo.

Ka kakaretso phihlelo ena e bile e putsang haholo 'me re tsoela pele ho sebetsa ho rarolla tse ling tsa lintlha tse hlahisitsoeng.

Source: www.habr.com

Eketsa ka tlhaloso