Naskirina Quick Draw Doodle: Meriv çawa bi R, C++ û torên neuralî re hevaltiyê dike

Naskirina Quick Draw Doodle: Meriv çawa bi R, C++ û torên neuralî re hevaltiyê dike

Hey Habr!

Payîza çûyî, Kaggle pêşbaziyek ji bo dabeşkirina wêneyên bi destan hatine xêzkirin, Quick Draw Doodle Recognition, ku tê de, di nav yên din de, tîmek zanyarên R-ê beşdar bûn, kir: Artem Klevtsova, Rêveberê Philippa и Andrey Ogurtsov. Em ê pêşbaziyê bi hûrgulî rave nekin; ku ji berê ve hatî kirin weşana dawî.

Vê carê bi cotkariya madalyayê re bi ser neket, lê gelek ezmûnên hêja hatin bidestxistin, ji ber vê yekê ez dixwazim ji civakê re li ser çend tiştên herî balkêş û kêrhatî li ser Kagle û di xebata rojane de vebêjim. Di nav mijarên ku hatine nîqaş kirin de: jiyanek dijwar bêyî OpenCV, Parskirina JSON (van mînakan entegrasyona koda C++ di nav nivîsar an pakêtên di R de bi karanîna Rcpp), Parametrekirina nivîsan û dokerkirina çareseriya dawîn. Hemî kodên ji peyamê di formek ku ji bo darvekirinê guncan e tê de heye depoyên.

Contains:

  1. Daneyên ji CSV-ê bi bandor li MonetDB barkirin
  2. Amadekirina koman
  3. Iteratorên ji bo daxistina koman ji databasê
  4. Hilbijartina Mîmarek Model
  5. Parametrekirina skrîptê
  6. Dokerkirina senaryoyan
  7. Li ser Google Cloud-ê gelek GPU bikar tînin
  8. Şûna encamê

1. Daneyên ji CSV-ê bi bandor li databasa MonetDB barkirin

Daneyên di vê pêşbaziyê de ne di forma wêneyên amadekirî de, lê di forma 340 pelên CSV de (ji bo her polê yek pel) ku JSON-yên bi koordînatên xalê hene, têne peyda kirin. Bi girêdana van xalan bi xêzan re, em wêneyek paşîn a bi pîvana 256x256 pixel distînin. Di heman demê de ji bo her tomarek etîketek heye ku destnîşan dike ka wêne ji hêla dabeşkera ku di dema berhevkirina databasê de hatî bikar anîn rast rast hatîye nas kirin, kodek du tîpî ya welatê niştecîhiya nivîskarê wêneyê, nasnameyek yekta, mohra demjimêrek heye. û navek pola ku bi navê pelê re têkildar e. Guhertoyek hêsan a daneya orîjînal di arşîvê de 7.4 GB giran e û piştî vekêşanê bi qasî 20 GB giran e, daneyên tevahî piştî vekêşanê 240 GB digire. Organîzator piştrast kirin ku her du guhertoyan heman nexşeyan dubare dikin, tê vê wateyê ku guhertoya tevahî zêde bû. Di her rewşê de, hilanîna 50 mîlyon wêneyan di pelên grafîkî de an di forma rêzan de tavilê bêkêr hate hesibandin, û me biryar da ku em hemî pelên CSV ji arşîvê bikin yek. train_simplified.zip di nav databasê de bi nifşa paşîn a wêneyên bi mezinahiya pêwîst "li ser firînê" ji bo her komê.

Pergalek baş-îsbatkirî wekî DBMS hate hilbijartin MonetDB, ango pêkanînek ji bo R wekî pakêtek MonetDBLite. Di pakêtê de guhertoyek pêvekirî ya servera databasê vedihewîne û dihêle hûn serverê rasterast ji danişînek R hildin û li wir pê re bixebitin. Afirandina databasek û girêdana wê bi yek fermanê têne kirin:

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

Pêdivî ye ku em du tabloyan biafirînin: yek ji bo hemî daneyan, ya din ji bo agahdariya karûbarê li ser pelên dakêşandî (kêr e heke tiştek xelet derkeve û pêdivî ye ku pêvajo piştî dakêşana çend pelan ji nû ve were domandin):

Çêkirina tabloyan

if (!DBI::dbExistsTable(con, "doodles")) {
  DBI::dbCreateTable(
    con = con,
    name = "doodles",
    fields = c(
      "countrycode" = "char(2)",
      "drawing" = "text",
      "key_id" = "bigint",
      "recognized" = "bool",
      "timestamp" = "timestamp",
      "word" = "text"
    )
  )
}

if (!DBI::dbExistsTable(con, "upload_log")) {
  DBI::dbCreateTable(
    con = con,
    name = "upload_log",
    fields = c(
      "id" = "serial",
      "file_name" = "text UNIQUE",
      "uploaded" = "bool DEFAULT false"
    )
  )
}

Awayê bilez a barkirina daneyan li databasê ew bû ku rasterast pelên CSV bi karanîna fermana SQL-ê kopî bike COPY OFFSET 2 INTO tablename FROM path USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORTko tablename - navê sifrê û path - riya pelê. Dema ku bi arşîvê re xebitîn, hat dîtin ku pêkanîna çêkirî ye unzip di R de bi hejmarek pelên ji arşîvê re rast naxebite, ji ber vê yekê me pergalê bikar anî unzip (bikaranîna parametreyê getOption("unzip")).

Fonksiyon ji bo nivîsandina databasê

#' @title Извлечение и загрузка файлов
#'
#' @description
#' Извлечение CSV-файлов из ZIP-архива и загрузка их в базу данных
#'
#' @param con Объект подключения к базе данных (класс `MonetDBEmbeddedConnection`).
#' @param tablename Название таблицы в базе данных.
#' @oaram zipfile Путь к ZIP-архиву.
#' @oaram filename Имя файла внури ZIP-архива.
#' @param preprocess Функция предобработки, которая будет применена извлечённому файлу.
#'   Должна принимать один аргумент `data` (объект `data.table`).
#'
#' @return `TRUE`.
#'
upload_file <- function(con, tablename, zipfile, filename, preprocess = NULL) {
  # Проверка аргументов
  checkmate::assert_class(con, "MonetDBEmbeddedConnection")
  checkmate::assert_string(tablename)
  checkmate::assert_string(filename)
  checkmate::assert_true(DBI::dbExistsTable(con, tablename))
  checkmate::assert_file_exists(zipfile, access = "r", extension = "zip")
  checkmate::assert_function(preprocess, args = c("data"), null.ok = TRUE)

  # Извлечение файла
  path <- file.path(tempdir(), filename)
  unzip(zipfile, files = filename, exdir = tempdir(), 
        junkpaths = TRUE, unzip = getOption("unzip"))
  on.exit(unlink(file.path(path)))

  # Применяем функция предобработки
  if (!is.null(preprocess)) {
    .data <- data.table::fread(file = path)
    .data <- preprocess(data = .data)
    data.table::fwrite(x = .data, file = path, append = FALSE)
    rm(.data)
  }

  # Запрос к БД на импорт CSV
  sql <- sprintf(
    "COPY OFFSET 2 INTO %s FROM '%s' USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORT",
    tablename, path
  )
  # Выполнение запроса к БД
  DBI::dbExecute(con, sql)

  # Добавление записи об успешной загрузке в служебную таблицу
  DBI::dbExecute(con, sprintf("INSERT INTO upload_log(file_name, uploaded) VALUES('%s', true)",
                              filename))

  return(invisible(TRUE))
}

Heke hûn hewce ne ku tabloyê berî ku wê li databasê binivîsin veguherînin, bes e ku hûn di argumanê de derbas bibin preprocess fonksiyona ku dê daneyê veguherîne.

Koda ji bo barkirina daneyan bi rêzê li databasê:

Nivîsandina daneyan li ser databasê

# Список файлов для записи
files <- unzip(zipfile, list = TRUE)$Name

# Список исключений, если часть файлов уже была загружена
to_skip <- DBI::dbGetQuery(con, "SELECT file_name FROM upload_log")[[1L]]
files <- setdiff(files, to_skip)

if (length(files) > 0L) {
  # Запускаем таймер
  tictoc::tic()
  # Прогресс бар
  pb <- txtProgressBar(min = 0L, max = length(files), style = 3)
  for (i in seq_along(files)) {
    upload_file(con = con, tablename = "doodles", 
                zipfile = zipfile, filename = files[i])
    setTxtProgressBar(pb, i)
  }
  close(pb)
  # Останавливаем таймер
  tictoc::toc()
}

# 526.141 sec elapsed - копирование SSD->SSD
# 558.879 sec elapsed - копирование USB->SSD

Dibe ku dema barkirina daneyê li gorî taybetmendiyên leza ajokera ku tê bikar anîn ve girêdayî be. Di rewşa me de, xwendin û nivîsandin di nav yek SSD-ê de an ji ajokerek flash (pelê çavkanî) heya SSD (DB) kêmtirî 10 hûrdem digire.

Çend saniyeyên din hewce dike ku stûnek bi etîketa pola yekjimar û stûnek nîşanek were afirandin (ORDERED INDEX) bi jimareyên rêzê yên ku dema çêkirina koman dê çavdêrî werin nimûne:

Çêkirina Stûn û Indeksa Zêdeyî

message("Generate lables")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD label_int int"))
invisible(DBI::dbExecute(con, "UPDATE doodles SET label_int = dense_rank() OVER (ORDER BY word) - 1"))

message("Generate row numbers")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD id serial"))
invisible(DBI::dbExecute(con, "CREATE ORDERED INDEX doodles_id_ord_idx ON doodles(id)"))

Ji bo çareserkirina pirsgirêka afirandina komikek li ser masê, me hewce bû ku em bileziya herî zêde ya derxistina rêzên bêserûber ji tabloyê bi dest bixin. doodles. Ji bo vê me 3 hîle bikar anîn. Ya yekem ew bû ku mezinahiya celebê ku nasnameya çavdêriyê hilîne kêm bike. Di berhevoka daneya orîjînal de, celebê ku ji bo hilanîna nasnameyê hewce ye ev e bigint, lê hejmara çavdêriyan dihêle ku nasnameyên wan, yên ku bi hejmara rêzî re wekhev in, di celebê de bi cih bikin. int. Lêgerîn di vê rewşê de pir zûtir e. Tîpa duyemîn bi kar anîn bû ORDERED INDEX - Em bi awayekî ampîrîk gihîştin vê biryarê, ji ber ku hemî berdest derbas bûn vebijarkî. Ya sêyem jî bikaranîna pirsên parameterkirî bû. Esasê rêbazê ev e ku emrê carekê were bicîh kirin PREPARE bi karanîna dûv re bêjeyek amadekirî dema ku komek pirs ji heman celebê diafirîne, lê di rastiyê de li gorî yekek hêsan feydeyek heye. SELECT derket holê ku di nav rêza xeletiya îstatîstîkî de ye.

Pêvajoya barkirina daneyan ji 450 MB RAM bêtir naxwe. Ango, nêzîkatiya diyarkirî dihêle hûn li ser hema hema her amûrek budceyê, tevî hin cîhazên yek-board, ku pir xweş e, databasên ku bi dehan gigabaytan giran in biguhezînin.

Tiştê ku dimîne ev e ku meriv leza wergirtina daneyan (random) bipîve û pîvandinê binirxîne dema ku beşên bi mezinahiyên cûda têne nimûne kirin:

Pîvana databasê

library(ggplot2)

set.seed(0)
# Подключение к базе данных
con <- DBI::dbConnect(MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

# Функция для подготовки запроса на стороне сервера
prep_sql <- function(batch_size) {
  sql <- sprintf("PREPARE SELECT id FROM doodles WHERE id IN (%s)",
                 paste(rep("?", batch_size), collapse = ","))
  res <- DBI::dbSendQuery(con, sql)
  return(res)
}

# Функция для извлечения данных
fetch_data <- function(rs, batch_size) {
  ids <- sample(seq_len(n), batch_size)
  res <- DBI::dbFetch(DBI::dbBind(rs, as.list(ids)))
  return(res)
}

# Проведение замера
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    rs <- prep_sql(batch_size)
    bench::mark(
      fetch_data(rs, batch_size),
      min_iterations = 50L
    )
  }
)
# Параметры бенчмарка
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16   23.6ms  54.02ms  93.43ms     18.8        2.6s    49
# 2         32     38ms  84.83ms 151.55ms     11.4       4.29s    49
# 3         64   63.3ms 175.54ms 248.94ms     5.85       8.54s    50
# 4        128   83.2ms 341.52ms 496.24ms     3.00      16.69s    50
# 5        256  232.8ms 653.21ms 847.44ms     1.58      31.66s    50
# 6        512  784.6ms    1.41s    1.98s     0.740       1.1m    49
# 7       1024  681.7ms    2.72s    4.06s     0.377      2.16m    49

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Naskirina Quick Draw Doodle: Meriv çawa bi R, C++ û torên neuralî re hevaltiyê dike

2. Amadekirina koman

Tevahiya pêvajoya amadekirina bacê ji gavên jêrîn pêk tê:

  1. Parvekirina çend JSON-yên ku vektorên rêzan bi koordînatên xalan vedihewîne.
  2. Xêzkirina xêzên rengîn li ser bingeha hevrêzên xalan li ser wêneyek bi mezinahiya pêwîst (mînak, 256×256 an 128×128).
  3. Veguherandina wêneyên ku di encamê de di nav tensorekê de ye.

Wekî beşek pêşbaziya di nav kernelên Python de, pirsgirêk di serî de bi karanîna hate çareser kirin OpenCV. Yek ji sadetirîn û eşkere analogên di R de dê bi vî rengî xuya bike:

Bicîhkirina JSON bo Veguheztina Tensor li R

r_process_json_str <- function(json, line.width = 3, 
                               color = TRUE, scale = 1) {
  # Парсинг JSON
  coords <- jsonlite::fromJSON(json, simplifyMatrix = FALSE)
  tmp <- tempfile()
  # Удаляем временный файл по завершению функции
  on.exit(unlink(tmp))
  png(filename = tmp, width = 256 * scale, height = 256 * scale, pointsize = 1)
  # Пустой график
  plot.new()
  # Размер окна графика
  plot.window(xlim = c(256 * scale, 0), ylim = c(256 * scale, 0))
  # Цвета линий
  cols <- if (color) rainbow(length(coords)) else "#000000"
  for (i in seq_along(coords)) {
    lines(x = coords[[i]][[1]] * scale, y = coords[[i]][[2]] * scale, 
          col = cols[i], lwd = line.width)
  }
  dev.off()
  # Преобразование изображения в 3-х мерный массив
  res <- png::readPNG(tmp)
  return(res)
}

r_process_json_vector <- function(x, ...) {
  res <- lapply(x, r_process_json_str, ...)
  # Объединение 3-х мерных массивов картинок в 4-х мерный в тензор
  res <- do.call(abind::abind, c(res, along = 0))
  return(res)
}

Xêzkirin bi karanîna amûrên R-ya standard tête çêkirin û li PNG-ya demkî ya ku di RAM-ê de hatî hilanîn tê hilanîn (li Linux-ê, pelrêça R-ya demkî di pelrêçê de cih digire. /tmp, di RAM-ê de hatî çêkirin). Dûv re ev pel wekî rêzek sê-alî ya bi hejmarên ji 0 heya 1-ê tê xwendin. Ev girîng e ji ber ku BMP-ya kevneşoptir dê di nav rêzek xav bi kodên rengê hex-ê de were xwendin.

Ka em encamê biceribînin:

zip_file <- file.path("data", "train_simplified.zip")
csv_file <- "cat.csv"
unzip(zip_file, files = csv_file, exdir = tempdir(), 
      junkpaths = TRUE, unzip = getOption("unzip"))
tmp_data <- data.table::fread(file.path(tempdir(), csv_file), sep = ",", 
                              select = "drawing", nrows = 10000)
arr <- r_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Naskirina Quick Draw Doodle: Meriv çawa bi R, C++ û torên neuralî re hevaltiyê dike

Kom bi xwe dê wiha pêk were:

res <- r_process_json_vector(tmp_data[1:4, drawing], scale = 0.5)
str(res)
 # num [1:4, 1:128, 1:128, 1:3] 1 1 1 1 1 1 1 1 1 1 ...
 # - attr(*, "dimnames")=List of 4
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL

Ev pêkanîn ji me re nebaş xuya bû, ji ber ku damezrandina komên mezin demek dirêj dirêj digire, û me biryar da ku bi karanîna pirtûkxaneyek hêzdar ji ezmûna hevkarên xwe sûd werbigirin. OpenCV. Wê demê ji bo R pakêtek amade tune bû (niha tune ye), ji ber vê yekê pêkanînek hindiktirîn a fonksiyona pêwîst di C++ de bi entegrasyona koda R-yê ve hatî nivîsandin. Rcpp.

Ji bo çareserkirina pirsgirêkê, pakêt û pirtûkxaneyên jêrîn hatin bikaranîn:

  1. OpenCV ji bo xebata bi wêneyan û xêzkirina xetên. Pirtûkxaneyên pergalê yên pêş-sazkirî û pelên sernavê, û her weha girêdana dînamîkî bikar anîn.

  2. xtensor ji bo xebitandina bi rêz û tensorên piralî. Me pelên sernavê yên ku di pakêta R ya bi heman navî de cih digirin bikar anîn. Pirtûkxane dihêle hûn bi rêzikên piralî, hem di rêza sereke û hem jî di rêza sereke de bixebitin.

  3. ndjson ji bo parkirina JSON. Ev pirtûkxane tê bikaranîn xtensor bixweber heke ew di projeyê de hebe.

  4. RcppThread ji bo organîzekirina pêvajoyek pir-mijalek vektorek ji JSON. Pelên sernavê yên ku ji hêla vê pakêtê ve hatî peyda kirin bikar anîn. Ji bêtir populer RcppParallel Di pakêtê de, di nav tiştên din de, mekanîzmayek qutkirina lûkê ya çêkirî heye.

Divê were zanîn ku xtensor derket holê ku xwedêgiravî ye: ji bilî vê yekê ku ew xwedan fonksiyonek berfireh û performansa bilind e, pêşdebirên wê pir bersivdar derketin û pirsan zû û bi hûrgulî bersivandin. Bi alîkariya wan, gengaz bû ku veguheztinên matrices OpenCV li tensorên xtensor, û her weha rêyek ku tensorên wêneya 3-dimensî li tensorek 4-alî ya pîvana rast (hevok bixwe) bi hev re bikin.

Materyalên ji bo fêrbûna Rcpp, xtensor û RcppThread

https://thecoatlessprofessor.com/programming/unofficial-rcpp-api-documentation

https://docs.opencv.org/4.0.1/d7/dbd/group__imgproc.html

https://xtensor.readthedocs.io/en/latest/

https://xtensor.readthedocs.io/en/latest/file_loading.html#loading-json-data-into-xtensor

https://cran.r-project.org/web/packages/RcppThread/vignettes/RcppThread-vignette.pdf

Ji bo berhevkirina pelên ku pelên pergalê û girêdana dînamîkî bi pirtûkxaneyên ku li ser pergalê hatine saz kirin re bikar tînin, me mekanîzmaya pêvekê ya ku di pakêtê de hatî bicîh kirin bikar anî. Rcpp. Ji bo ku bixweber rê û alayan bibînin, me amûrek populer a Linux bikar anî pkg-mîheng.

Pêkanîna pêveka Rcpp ji bo karanîna pirtûkxaneya OpenCV

Rcpp::registerPlugin("opencv", function() {
  # Возможные названия пакета
  pkg_config_name <- c("opencv", "opencv4")
  # Бинарный файл утилиты pkg-config
  pkg_config_bin <- Sys.which("pkg-config")
  # Проврека наличия утилиты в системе
  checkmate::assert_file_exists(pkg_config_bin, access = "x")
  # Проверка наличия файла настроек OpenCV для pkg-config
  check <- sapply(pkg_config_name, 
                  function(pkg) system(paste(pkg_config_bin, pkg)))
  if (all(check != 0)) {
    stop("OpenCV config for the pkg-config not found", call. = FALSE)
  }

  pkg_config_name <- pkg_config_name[check == 0]
  list(env = list(
    PKG_CXXFLAGS = system(paste(pkg_config_bin, "--cflags", pkg_config_name), 
                          intern = TRUE),
    PKG_LIBS = system(paste(pkg_config_bin, "--libs", pkg_config_name), 
                      intern = TRUE)
  ))
})

Di encama xebata pêvekê de, dê di pêvajoya berhevkirinê de nirxên jêrîn werin şûna:

Rcpp:::.plugins$opencv()$env

# $PKG_CXXFLAGS
# [1] "-I/usr/include/opencv"
#
# $PKG_LIBS
# [1] "-lopencv_shape -lopencv_stitching -lopencv_superres -lopencv_videostab -lopencv_aruco -lopencv_bgsegm -lopencv_bioinspired -lopencv_ccalib -lopencv_datasets -lopencv_dpm -lopencv_face -lopencv_freetype -lopencv_fuzzy -lopencv_hdf -lopencv_line_descriptor -lopencv_optflow -lopencv_video -lopencv_plot -lopencv_reg -lopencv_saliency -lopencv_stereo -lopencv_structured_light -lopencv_phase_unwrapping -lopencv_rgbd -lopencv_viz -lopencv_surface_matching -lopencv_text -lopencv_ximgproc -lopencv_calib3d -lopencv_features2d -lopencv_flann -lopencv_xobjdetect -lopencv_objdetect -lopencv_ml -lopencv_xphoto -lopencv_highgui -lopencv_videoio -lopencv_imgcodecs -lopencv_photo -lopencv_imgproc -lopencv_core"

Koda pêkanînê ji bo parskirina JSON û çêkirina komek ji bo veguheztina modelê di binê spoilerê de tê dayîn. Pêşîn, pelrêçek projeya herêmî lê zêde bike ku li pelên sernavê bigerin (ji bo ndjson hewce ye):

Sys.setenv("PKG_CXXFLAGS" = paste0("-I", normalizePath(file.path("src"))))

Pêkanîna JSON ji bo veguheztina tensor di C ++ de

// [[Rcpp::plugins(cpp14)]]
// [[Rcpp::plugins(opencv)]]
// [[Rcpp::depends(xtensor)]]
// [[Rcpp::depends(RcppThread)]]

#include <xtensor/xjson.hpp>
#include <xtensor/xadapt.hpp>
#include <xtensor/xview.hpp>
#include <xtensor-r/rtensor.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <Rcpp.h>
#include <RcppThread.h>

// Синонимы для типов
using RcppThread::parallelFor;
using json = nlohmann::json;
using points = xt::xtensor<double,2>;     // Извлечённые из JSON координаты точек
using strokes = std::vector<points>;      // Извлечённые из JSON координаты точек
using xtensor3d = xt::xtensor<double, 3>; // Тензор для хранения матрицы изоображения
using xtensor4d = xt::xtensor<double, 4>; // Тензор для хранения множества изображений
using rtensor3d = xt::rtensor<double, 3>; // Обёртка для экспорта в R
using rtensor4d = xt::rtensor<double, 4>; // Обёртка для экспорта в R

// Статические константы
// Размер изображения в пикселях
const static int SIZE = 256;
// Тип линии
// См. https://en.wikipedia.org/wiki/Pixel_connectivity#2-dimensional
const static int LINE_TYPE = cv::LINE_4;
// Толщина линии в пикселях
const static int LINE_WIDTH = 3;
// Алгоритм ресайза
// https://docs.opencv.org/3.1.0/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
const static int RESIZE_TYPE = cv::INTER_LINEAR;

// Шаблон для конвертирования OpenCV-матрицы в тензор
template <typename T, int NCH, typename XT=xt::xtensor<T,3,xt::layout_type::column_major>>
XT to_xt(const cv::Mat_<cv::Vec<T, NCH>>& src) {
  // Размерность целевого тензора
  std::vector<int> shape = {src.rows, src.cols, NCH};
  // Общее количество элементов в массиве
  size_t size = src.total() * NCH;
  // Преобразование cv::Mat в xt::xtensor
  XT res = xt::adapt((T*) src.data, size, xt::no_ownership(), shape);
  return res;
}

// Преобразование JSON в список координат точек
strokes parse_json(const std::string& x) {
  auto j = json::parse(x);
  // Результат парсинга должен быть массивом
  if (!j.is_array()) {
    throw std::runtime_error("'x' must be JSON array.");
  }
  strokes res;
  res.reserve(j.size());
  for (const auto& a: j) {
    // Каждый элемент массива должен быть 2-мерным массивом
    if (!a.is_array() || a.size() != 2) {
      throw std::runtime_error("'x' must include only 2d arrays.");
    }
    // Извлечение вектора точек
    auto p = a.get<points>();
    res.push_back(p);
  }
  return res;
}

// Отрисовка линий
// Цвета HSV
cv::Mat ocv_draw_lines(const strokes& x, bool color = true) {
  // Исходный тип матрицы
  auto stype = color ? CV_8UC3 : CV_8UC1;
  // Итоговый тип матрицы
  auto dtype = color ? CV_32FC3 : CV_32FC1;
  auto bg = color ? cv::Scalar(0, 0, 255) : cv::Scalar(255);
  auto col = color ? cv::Scalar(0, 255, 220) : cv::Scalar(0);
  cv::Mat img = cv::Mat(SIZE, SIZE, stype, bg);
  // Количество линий
  size_t n = x.size();
  for (const auto& s: x) {
    // Количество точек в линии
    size_t n_points = s.shape()[1];
    for (size_t i = 0; i < n_points - 1; ++i) {
      // Точка начала штриха
      cv::Point from(s(0, i), s(1, i));
      // Точка окончания штриха
      cv::Point to(s(0, i + 1), s(1, i + 1));
      // Отрисовка линии
      cv::line(img, from, to, col, LINE_WIDTH, LINE_TYPE);
    }
    if (color) {
      // Меняем цвет линии
      col[0] += 180 / n;
    }
  }
  if (color) {
    // Меняем цветовое представление на RGB
    cv::cvtColor(img, img, cv::COLOR_HSV2RGB);
  }
  // Меняем формат представления на float32 с диапазоном [0, 1]
  img.convertTo(img, dtype, 1 / 255.0);
  return img;
}

// Обработка JSON и получение тензора с данными изображения
xtensor3d process(const std::string& x, double scale = 1.0, bool color = true) {
  auto p = parse_json(x);
  auto img = ocv_draw_lines(p, color);
  if (scale != 1) {
    cv::Mat out;
    cv::resize(img, out, cv::Size(), scale, scale, RESIZE_TYPE);
    cv::swap(img, out);
    out.release();
  }
  xtensor3d arr = color ? to_xt<double,3>(img) : to_xt<double,1>(img);
  return arr;
}

// [[Rcpp::export]]
rtensor3d cpp_process_json_str(const std::string& x, 
                               double scale = 1.0, 
                               bool color = true) {
  xtensor3d res = process(x, scale, color);
  return res;
}

// [[Rcpp::export]]
rtensor4d cpp_process_json_vector(const std::vector<std::string>& x, 
                                  double scale = 1.0, 
                                  bool color = false) {
  size_t n = x.size();
  size_t dim = floor(SIZE * scale);
  size_t channels = color ? 3 : 1;
  xtensor4d res({n, dim, dim, channels});
  parallelFor(0, n, [&x, &res, scale, color](int i) {
    xtensor3d tmp = process(x[i], scale, color);
    auto view = xt::view(res, i, xt::all(), xt::all(), xt::all());
    view = tmp;
  });
  return res;
}

Divê ev kod di pelê de were danîn src/cv_xt.cpp û bi fermanê berhev bikin Rcpp::sourceCpp(file = "src/cv_xt.cpp", env = .GlobalEnv); ji bo xebatê jî pêwîst e nlohmann/json.hpp ji depo. Kod di çend fonksiyonan de dabeş dibe:

  • to_xt - fonksiyonek şablonek ji bo veguheztina matrixek wêneyê (cv::Mat) ji tensorekê re xt::xtensor;

  • parse_json - Fonksiyon rêzek JSON par dike, koordînatên xalan derdixe, wan di vektorekê de pak dike;

  • ocv_draw_lines - ji vektora encam a xalan, xêzên pir-reng xêz dike;

  • process - fonksiyonên jorîn tevlihev dike û di heman demê de şiyana pîvandina wêneya encam jî zêde dike;

  • cpp_process_json_str - pêça li ser fonksiyonê process, ku encam ji bo R-objekt (array piralî);

  • cpp_process_json_vector - pêça li ser fonksiyonê cpp_process_json_str, ku destûrê dide te ku hûn vektorek rêzikan di moda pir-mijarî de pêvajo bikin.

Ji bo xêzkirina xêzên pirreng, modela rengê HSV hate bikar anîn, li dûv veguhertina RGB. Ka em encamê biceribînin:

arr <- cpp_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Naskirina Quick Draw Doodle: Meriv çawa bi R, C++ û torên neuralî re hevaltiyê dike
Berawirdkirina leza pêkanînan di R û C++ de

res_bench <- bench::mark(
  r_process_json_str(tmp_data[4, drawing], scale = 0.5),
  cpp_process_json_str(tmp_data[4, drawing], scale = 0.5),
  check = FALSE,
  min_iterations = 100
)
# Параметры бенчмарка
cols <- c("expression", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   expression                min     median       max `itr/sec` total_time  n_itr
#   <chr>                <bch:tm>   <bch:tm>  <bch:tm>     <dbl>   <bch:tm>  <int>
# 1 r_process_json_str     3.49ms     3.55ms    4.47ms      273.      490ms    134
# 2 cpp_process_json_str   1.94ms     2.02ms    5.32ms      489.      497ms    243

library(ggplot2)
# Проведение замера
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    .data <- tmp_data[sample(seq_len(.N), batch_size), drawing]
    bench::mark(
      r_process_json_vector(.data, scale = 0.5),
      cpp_process_json_vector(.data,  scale = 0.5),
      min_iterations = 50,
      check = FALSE
    )
  }
)

res_bench[, cols]

#    expression   batch_size      min   median      max `itr/sec` total_time n_itr
#    <chr>             <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
#  1 r                   16   50.61ms  53.34ms  54.82ms    19.1     471.13ms     9
#  2 cpp                 16    4.46ms   5.39ms   7.78ms   192.      474.09ms    91
#  3 r                   32   105.7ms 109.74ms 212.26ms     7.69        6.5s    50
#  4 cpp                 32    7.76ms  10.97ms  15.23ms    95.6     522.78ms    50
#  5 r                   64  211.41ms 226.18ms 332.65ms     3.85      12.99s    50
#  6 cpp                 64   25.09ms  27.34ms  32.04ms    36.0        1.39s    50
#  7 r                  128   534.5ms 627.92ms 659.08ms     1.61      31.03s    50
#  8 cpp                128   56.37ms  58.46ms  66.03ms    16.9        2.95s    50
#  9 r                  256     1.15s    1.18s    1.29s     0.851     58.78s    50
# 10 cpp                256  114.97ms 117.39ms 130.09ms     8.45       5.92s    50
# 11 r                  512     2.09s    2.15s    2.32s     0.463       1.8m    50
# 12 cpp                512  230.81ms  235.6ms 261.99ms     4.18      11.97s    50
# 13 r                 1024        4s    4.22s     4.4s     0.238       3.5m    50
# 14 cpp               1024  410.48ms 431.43ms 462.44ms     2.33      21.45s    50

ggplot(res_bench, aes(x = factor(batch_size), y = median, 
                      group =  expression, color = expression)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal() +
  scale_color_discrete(name = "", labels = c("cpp", "r")) +
  theme(legend.position = "bottom") 

Naskirina Quick Draw Doodle: Meriv çawa bi R, C++ û torên neuralî re hevaltiyê dike

Wekî ku hûn dikarin bibînin, zêdebûna lezê pir girîng derket holê, û ne gengaz e ku meriv bi koda C ++ re bi paralelkirina koda R-yê bigire.

3. Iteratorên ji bo daxistina heviyên ji databasê

R ji bo hilanîna daneyên ku di RAM-ê de cih digire navûdengek hêja ye, di heman demê de Python ji hêla hilberandina daneya dubare ve tête taybetmend kirin, ku dihêle hûn bi hêsanî û xwezayî hesabên derveyî-bingehîn pêk bînin (hesabên ku bîranîna derveyî bikar tînin). Di çarçoweya pirsgirêka diyarkirî de ji bo me mînakek klasîk û têkildar toreyên neuralî yên kûr e ku bi rêbaza daketina gradientê bi nêzikbûna gradientê di her gavê de bi karanîna beşek piçûk a çavdêriyan, an mini-hevokê ve hatî perwerde kirin.

Çarçoveyên fêrbûna kûr ên ku di Python de hatine nivîsandin xwedan dersên taybetî ne ku îteratoran li ser bingeha daneyan pêk tînin: tablo, wêneyên di peldankan de, formatên binary hwd. Hûn dikarin vebijarkên amade bikar bînin an jî ji bo karên taybetî yên xwe binivîsin. Di R de em dikarin ji hemî taybetmendiyên pirtûkxaneya Python sûd werbigirin kera bi paşnavên xwe yên cihêreng pakêta bi heman navî bikar tîne, ku di encamê de li ser pakêtê dixebite reticulate. Ya dawî gotareke dirêj a cuda heq dike; ew ne tenê dihêle hûn koda Python-ê ji R-yê bimeşînin, lê di heman demê de dihêle hûn tiştan di navbera danişînên R û Python de veguhezînin, bixweber hemî veguheztinên cûrbecûr yên pêwîst pêk bînin.

Me ji hewcedariya hilanîna hemî daneyan di RAM-ê de bi karanîna MonetDBLite xilas kir, hemî xebata "tora neuralî" dê ji hêla koda orîjînal a li Python ve were kirin, divê em tenê li ser daneyan îteratorek binivîsin, ji ber ku tiştek amade tune. ji bo rewşek weha di R an Python de. Di eslê xwe de tenê du hewcedarî ji bo wê hene: pêdivî ye ku ew koman di çerxek bêdawî de vegerîne û rewşa xwe di navbera dubareyan de xilas bike (ya paşîn di R de bi awayê herî hêsan bi karanîna girtinan ve tête bicîh kirin). Berê, pêdivî bû ku bi eşkereyî rêzikên R-yê di hundurê iteratorê de veguhezînin rêzikên numpy, lê guhertoya heyî ya pakêtê kera xwe dike.

Vegere ji bo daneyên perwerde û pejirandinê wiha derket holê:

Iterator ji bo daneyên perwerde û pejirandinê

train_generator <- function(db_connection = con,
                            samples_index,
                            num_classes = 340,
                            batch_size = 32,
                            scale = 1,
                            color = FALSE,
                            imagenet_preproc = FALSE) {
  # Проверка аргументов
  checkmate::assert_class(con, "DBIConnection")
  checkmate::assert_integerish(samples_index)
  checkmate::assert_count(num_classes)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # Перемешиваем, чтобы брать и удалять использованные индексы батчей по порядку
  dt <- data.table::data.table(id = sample(samples_index))
  # Проставляем номера батчей
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  # Оставляем только полные батчи и индексируем
  dt <- dt[, if (.N == batch_size) .SD, keyby = batch]
  # Устанавливаем счётчик
  i <- 1
  # Количество батчей
  max_i <- dt[, max(batch)]

  # Подготовка выражения для выгрузки
  sql <- sprintf(
    "PREPARE SELECT drawing, label_int FROM doodles WHERE id IN (%s)",
    paste(rep("?", batch_size), collapse = ",")
  )
  res <- DBI::dbSendQuery(con, sql)

  # Аналог keras::to_categorical
  to_categorical <- function(x, num) {
    n <- length(x)
    m <- numeric(n * num)
    m[x * n + seq_len(n)] <- 1
    dim(m) <- c(n, num)
    return(m)
  }

  # Замыкание
  function() {
    # Начинаем новую эпоху
    if (i > max_i) {
      dt[, id := sample(id)]
      data.table::setkey(dt, batch)
      # Сбрасываем счётчик
      i <<- 1
      max_i <<- dt[, max(batch)]
    }

    # ID для выгрузки данных
    batch_ind <- dt[batch == i, id]
    # Выгрузка данных
    batch <- DBI::dbFetch(DBI::dbBind(res, as.list(batch_ind)), n = -1)

    # Увеличиваем счётчик
    i <<- i + 1

    # Парсинг JSON и подготовка массива
    batch_x <- cpp_process_json_vector(batch$drawing, scale = scale, color = color)
    if (imagenet_preproc) {
      # Шкалирование c интервала [0, 1] на интервал [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }

    batch_y <- to_categorical(batch$label_int, num_classes)
    result <- list(batch_x, batch_y)
    return(result)
  }
}

Fonksiyon wekî têketinê guhêrbarek bi girêdana databasê re, hejmarên rêzikên hatine bikar anîn, hejmara çînan, mezinahiya hevîrê, pîvan (scale = 1 bi vegotina wêneyên 256x256 pixel re têkildar e, scale = 0.5 - 128x128 pixel), nîşana rengîn (color = FALSE dema ku tê bikar anîn renderkirina bi rengê gewr diyar dike color = TRUE her lêdan bi rengek nû tê kişandin) û nîşanek pêş-processing ji bo torên ku li ser imagenet-ê pêş-perwerde kirine. Ya paşîn ji bo pîvandina nirxên pixelê ji navbera [0, 1] heya navbera [-1, 1], ya ku dema perwerdehiya peydakirî hatî bikar anîn hewce ye. kera modelên.

Fonksiyona derveyî kontrolkirina celebê arguman, tabloyek heye data.table bi bi korfelaqî hejmarên xeta tevlihev ji samples_index û hejmarên hevîrê, hejmar û hejmareke herî zêde ya koman, û her weha vegotinek SQL ji bo rakirina daneyan ji databasê. Wekî din, me analogek bilez a fonksiyonê li hundur diyar kir keras::to_categorical(). Me hema hema hemî daneyên ji bo perwerdehiyê bikar anîn, ji sedî nîv ji bo pejirandinê hiştin, ji ber vê yekê mezinahiya serdemê ji hêla pîvanê ve sînorkirî bû steps_per_epoch dema gazî kirin keras::fit_generator(), û şert if (i > max_i) tenê ji bo îteratorê pejirandinê xebitî.

Di fonksiyona hundurîn de, îndeksên rêzan ji bo koma paşîn têne hilanîn, tomar ji databasê têne barkirin digel ku jimareya hevîrê zêde dibe, parskirina JSON (fonksiyonê cpp_process_json_vector(), bi C++ hatiye nivîsandin û rêzikên li gorî wêneyan diafirîne. Dûv re vektorên yek-germ ên bi etîketên polê têne afirandin, rêzikên bi nirxên pixel û etîketan di navnîşek de têne hev kirin, ku nirxa vegerê ye. Ji bo lezkirina xebatê, me di tabloyan de çêkirina îndeksan bikar anî data.table û guheztin bi rêya girêdanê - bêyî van pakêtê "çîp" data.sifre Zehmet e ku meriv bi rengek girîng bi daneyên R re bi bandor bixebite.

Encamên pîvandinên bilez ên li ser laptopek Core i5 wiha ne:

Pîvana Iterator

library(Rcpp)
library(keras)
library(ggplot2)

source("utils/rcpp.R")
source("utils/keras_iterator.R")

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

ind <- seq_len(DBI::dbGetQuery(con, "SELECT count(*) FROM doodles")[[1L]])
num_classes <- DBI::dbGetQuery(con, "SELECT max(label_int) + 1 FROM doodles")[[1L]]

# Индексы для обучающей выборки
train_ind <- sample(ind, floor(length(ind) * 0.995))
# Индексы для проверочной выборки
val_ind <- ind[-train_ind]
rm(ind)
# Коэффициент масштаба
scale <- 0.5

# Проведение замера
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    it1 <- train_generator(
      db_connection = con,
      samples_index = train_ind,
      num_classes = num_classes,
      batch_size = batch_size,
      scale = scale
    )
    bench::mark(
      it1(),
      min_iterations = 50L
    )
  }
)
# Параметры бенчмарка
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16     25ms  64.36ms   92.2ms     15.9       3.09s    49
# 2         32   48.4ms 118.13ms 197.24ms     8.17       5.88s    48
# 3         64   69.3ms 117.93ms 181.14ms     8.57       5.83s    50
# 4        128  157.2ms 240.74ms 503.87ms     3.85      12.71s    49
# 5        256  359.3ms 613.52ms 988.73ms     1.54       30.5s    47
# 6        512  884.7ms    1.53s    2.07s     0.674      1.11m    45
# 7       1024     2.7s    3.83s    5.47s     0.261      2.81m    44

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
    geom_point() +
    geom_line() +
    ylab("median time, s") +
    theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Naskirina Quick Draw Doodle: Meriv çawa bi R, C++ û torên neuralî re hevaltiyê dike

Ger têra we RAM hebe, hûn dikarin bi ciddî xebata databasê bi veguheztina wê li heman RAM-ê bileztir bikin (32 GB ji bo karê me bes e). Di Linuxê de, dabeşkirin ji hêla xwerû ve hatî çêkirin /dev/shm, heta nîvê kapasîteya RAM-ê digire. Hûn dikarin bi guherandinê bêtir ronî bikin /etc/fstabji bo ku qeydek mîna tmpfs /dev/shm tmpfs defaults,size=25g 0 0. Bawer bikin ku ji nû ve dest pê bikin û encamê bi xebitandina fermanê kontrol bikin df -h.

Iterator ji bo daneyên ceribandinê pir hêsan xuya dike, ji ber ku daneyên testê bi tevahî di RAM-ê de cih digire:

Iterator ji bo daneyên testê

test_generator <- function(dt,
                           batch_size = 32,
                           scale = 1,
                           color = FALSE,
                           imagenet_preproc = FALSE) {

  # Проверка аргументов
  checkmate::assert_data_table(dt)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # Проставляем номера батчей
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  data.table::setkey(dt, batch)
  i <- 1
  max_i <- dt[, max(batch)]

  # Замыкание
  function() {
    batch_x <- cpp_process_json_vector(dt[batch == i, drawing], 
                                       scale = scale, color = color)
    if (imagenet_preproc) {
      # Шкалирование c интервала [0, 1] на интервал [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }
    result <- list(batch_x)
    i <<- i + 1
    return(result)
  }
}

4. Hilbijartina mîmariya model

Mîmariya yekem a ku hatî bikar anîn bû mobilenet v1, taybetmendiyên ku di nav de têne nîqaş kirin ev agah. Ew wekî standard tê de ye kera û, li gorî vê yekê, di pakêta bi heman navî de ji bo R heye. Lê gava ku hûn hewl didin ku wê bi wêneyên yek-kanal re bikar bînin, tiştek ecêb derket holê: divê tensora têketinê her dem xwedî pîvan be. (batch, height, width, 3), ango hejmara kanalan nayê guhertin. Di Python-ê de tixûbek wusa tune, ji ber vê yekê me lezand û me li pey gotara orîjînal (bêyî dakêşana ku di guhertoya kerasê de ye) pêkanîna xwe ya vê mîmariyê nivîsand:

Mîmariya Mobilenet v1

library(keras)

top_3_categorical_accuracy <- custom_metric(
    name = "top_3_categorical_accuracy",
    metric_fn = function(y_true, y_pred) {
         metric_top_k_categorical_accuracy(y_true, y_pred, k = 3)
    }
)

layer_sep_conv_bn <- function(object, 
                              filters,
                              alpha = 1,
                              depth_multiplier = 1,
                              strides = c(2, 2)) {

  # NB! depth_multiplier !=  resolution multiplier
  # https://github.com/keras-team/keras/issues/10349

  layer_depthwise_conv_2d(
    object = object,
    kernel_size = c(3, 3), 
    strides = strides,
    padding = "same",
    depth_multiplier = depth_multiplier
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() %>%
  layer_conv_2d(
    filters = filters * alpha,
    kernel_size = c(1, 1), 
    strides = c(1, 1)
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() 
}

get_mobilenet_v1 <- function(input_shape = c(224, 224, 1),
                             num_classes = 340,
                             alpha = 1,
                             depth_multiplier = 1,
                             optimizer = optimizer_adam(lr = 0.002),
                             loss = "categorical_crossentropy",
                             metrics = c("categorical_crossentropy",
                                         top_3_categorical_accuracy)) {

  inputs <- layer_input(shape = input_shape)

  outputs <- inputs %>%
    layer_conv_2d(filters = 32, kernel_size = c(3, 3), strides = c(2, 2), padding = "same") %>%
    layer_batch_normalization() %>% 
    layer_activation_relu() %>%
    layer_sep_conv_bn(filters = 64, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(1, 1)) %>%
    layer_global_average_pooling_2d() %>%
    layer_dense(units = num_classes) %>%
    layer_activation_softmax()

    model <- keras_model(
      inputs = inputs,
      outputs = outputs
    )

    model %>% compile(
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )

    return(model)
}

Dezawantajên vê nêzîkbûnê diyar in. Ez dixwazim gelek modelan biceribînim, lê berevajî vê, ez naxwazim her mîmarî bi destan ji nû ve binivîsim. Em jî ji derfeta bikaranîna giraniya modelên ku li ser imagenet-ê berê hatine perwerdekirin bêpar man. Wekî her car, xwendina belgeyan alîkarî kir. Karkirin get_config() destûrê dide te ku hûn ravekirina modelê bi formek ku ji bo sererastkirinê guncan e (base_model_conf$layers - navnîşek R-ya birêkûpêk), û fonksiyonê from_config() veguhertina berevajî ya li objektek modelê pêk tîne:

base_model_conf <- get_config(base_model)
base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
base_model <- from_config(base_model_conf)

Naha ne dijwar e ku meriv fonksiyonek gerdûnî binivîsîne da ku yek ji wan peyda bike kera modelên bi giranî an bê wan ên ku li ser imagenet hatine perwerde kirin:

Fonksiyon ji bo barkirina mîmariyên amade

get_model <- function(name = "mobilenet_v2",
                      input_shape = NULL,
                      weights = "imagenet",
                      pooling = "avg",
                      num_classes = NULL,
                      optimizer = keras::optimizer_adam(lr = 0.002),
                      loss = "categorical_crossentropy",
                      metrics = NULL,
                      color = TRUE,
                      compile = FALSE) {
  # Проверка аргументов
  checkmate::assert_string(name)
  checkmate::assert_integerish(input_shape, lower = 1, upper = 256, len = 3)
  checkmate::assert_count(num_classes)
  checkmate::assert_flag(color)
  checkmate::assert_flag(compile)

  # Получаем объект из пакета keras
  model_fun <- get0(paste0("application_", name), envir = asNamespace("keras"))
  # Проверка наличия объекта в пакете
  if (is.null(model_fun)) {
    stop("Model ", shQuote(name), " not found.", call. = FALSE)
  }

  base_model <- model_fun(
    input_shape = input_shape,
    include_top = FALSE,
    weights = weights,
    pooling = pooling
  )

  # Если изображение не цветное, меняем размерность входа
  if (!color) {
    base_model_conf <- keras::get_config(base_model)
    base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
    base_model <- keras::from_config(base_model_conf)
  }

  predictions <- keras::get_layer(base_model, "global_average_pooling2d_1")$output
  predictions <- keras::layer_dense(predictions, units = num_classes, activation = "softmax")
  model <- keras::keras_model(
    inputs = base_model$input,
    outputs = predictions
  )

  if (compile) {
    keras::compile(
      object = model,
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )
  }

  return(model)
}

Dema ku wêneyên yek-kanal bikar tînin, tu giraniyên pêşdibistanê nayên bikar anîn. Ev dikare were rast kirin: fonksiyonê bikar bînin get_weights() giraniyên modelê di forma navnîşek rêzikên R-yê de bistînin, pîvana hêmana yekem a vê navnîşê biguhezînin (bi girtina kanalek rengîn an navînîkirina her sêyan), û dûv re bi fonksiyonê giraniyan dîsa li modelê bar bikin. set_weights(). Me tu carî vê fonksiyonê lê zêde nekir, ji ber ku di vê qonaxê de jixwe diyar bû ku xebata bi wêneyên rengîn re hilbertir bû.

Me piraniya ceribandinan bi karanîna guhertoyên mobilenet 1 û 2, û hem jî resnet34 pêk anî. Mîmarên nûjen ên wekî SE-ResNeXt di vê pêşbaziyê de baş derketin. Mixabin, pêkanînên hazir di destê me de tunebûn û me ya xwe nenivîsand (lê teqez em ê binivîsin).

5. Parametrekirina nivîsan

Ji bo rehetiyê, hemî kodên ji bo destpêkirina perwerdehiyê wekî skrîptek yekane hate sêwirandin, bi karanîna parameterkirî docopt wiha ye:

doc <- '
Usage:
  train_nn.R --help
  train_nn.R --list-models
  train_nn.R [options]

Options:
  -h --help                   Show this message.
  -l --list-models            List available models.
  -m --model=<model>          Neural network model name [default: mobilenet_v2].
  -b --batch-size=<size>      Batch size [default: 32].
  -s --scale-factor=<ratio>   Scale factor [default: 0.5].
  -c --color                  Use color lines [default: FALSE].
  -d --db-dir=<path>          Path to database directory [default: Sys.getenv("db_dir")].
  -r --validate-ratio=<ratio> Validate sample ratio [default: 0.995].
  -n --n-gpu=<number>         Number of GPUs [default: 1].
'
args <- docopt::docopt(doc)

Pakêt docopt pêkanînê temsîl dike http://docopt.org/ ji bo R. Bi alîkariya wê, skrîpt bi fermanên sade yên mîna Rscript bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db an ./bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db, heke pel train_nn.R pêkan e (ev ferman dê dest bi perwerdekirina modelê bike resnet50 li ser wêneyên sê-reng ên bi pîvana 128x128 pixel, pêdivî ye ku databas di peldankê de be /home/andrey/doodle_db). Hûn dikarin leza fêrbûnê, celebê xweşbînker, û her pîvanên xwerû yên din li navnîşê zêde bikin. Di pêvajoya amadekirina weşanê de derket holê ku mîmarî mobilenet_v2 ji guhertoya heyî kera di bikaranîna R nikare ji ber guhertinên ku di pakêta R de nehatine hesibandin, em li bendê ne ku ew sererast bikin.

Vê nêzîkatiya hanê gengaz kir ku ceribandinên bi modelên cûda re li gorî destpêkirina kevneşopî ya nivîsarên li RStudio-yê bi girîngî bileztir bikin (em pakêtê wekî alternatîfek mimkun destnîşan dikin tfruns). Lê avantajê sereke ew e ku meriv bi hêsanî destpêkirina nivîsarên li Docker an bi tenê li ser serverê birêve bibe, bêyî ku ji bo vê yekê RStudio saz bike.

6. Dockerkirina senaryoyan

Me Docker bikar anî da ku veguheztina jîngehê ji bo modelên perwerdehiyê di navbera endamên tîmê de û ji bo bicîhkirina bilez di ewr de misoger bike. Hûn dikarin bi vê amûrê re, ku ji bo bernamenûsek R-yê bi nisbet neasayî ye, dest pê bikin ev rêze weşanên an kursa vîdyoyê.

Docker dihêle hûn hem wêneyên xwe ji nû ve biafirînin hem jî wêneyên din wekî bingehek ji bo afirandina xweya xwe bikar bînin. Dema ku vebijarkên berdest analîz kirin, em gihîştin wê encamê ku sazkirina ajokarên NVIDIA, CUDA + cuDNN û pirtûkxaneyên Python beşek pir mezin a wêneyê ye, û me biryar da ku em wêneya fermî wekî bingeh bigirin. tensorflow/tensorflow:1.12.0-gpu, pakêtên R yên pêwîst li wir zêde dikin.

Pelê dawîn docker wiha xuya bû:

dockerfile

FROM tensorflow/tensorflow:1.12.0-gpu

MAINTAINER Artem Klevtsov <[email protected]>

SHELL ["/bin/bash", "-c"]

ARG LOCALE="en_US.UTF-8"
ARG APT_PKG="libopencv-dev r-base r-base-dev littler"
ARG R_BIN_PKG="futile.logger checkmate data.table rcpp rapidjsonr dbi keras jsonlite curl digest remotes"
ARG R_SRC_PKG="xtensor RcppThread docopt MonetDBLite"
ARG PY_PIP_PKG="keras"
ARG DIRS="/db /app /app/data /app/models /app/logs"

RUN source /etc/os-release && 
    echo "deb https://cloud.r-project.org/bin/linux/ubuntu ${UBUNTU_CODENAME}-cran35/" > /etc/apt/sources.list.d/cran35.list && 
    apt-key adv --keyserver keyserver.ubuntu.com --recv-keys E084DAB9 && 
    add-apt-repository -y ppa:marutter/c2d4u3.5 && 
    add-apt-repository -y ppa:timsc/opencv-3.4 && 
    apt-get update && 
    apt-get install -y locales && 
    locale-gen ${LOCALE} && 
    apt-get install -y --no-install-recommends ${APT_PKG} && 
    ln -s /usr/lib/R/site-library/littler/examples/install.r /usr/local/bin/install.r && 
    ln -s /usr/lib/R/site-library/littler/examples/install2.r /usr/local/bin/install2.r && 
    ln -s /usr/lib/R/site-library/littler/examples/installGithub.r /usr/local/bin/installGithub.r && 
    echo 'options(Ncpus = parallel::detectCores())' >> /etc/R/Rprofile.site && 
    echo 'options(repos = c(CRAN = "https://cloud.r-project.org"))' >> /etc/R/Rprofile.site && 
    apt-get install -y $(printf "r-cran-%s " ${R_BIN_PKG}) && 
    install.r ${R_SRC_PKG} && 
    pip install ${PY_PIP_PKG} && 
    mkdir -p ${DIRS} && 
    chmod 777 ${DIRS} && 
    rm -rf /tmp/downloaded_packages/ /tmp/*.rds && 
    rm -rf /var/lib/apt/lists/*

COPY utils /app/utils
COPY src /app/src
COPY tests /app/tests
COPY bin/*.R /app/

ENV DBDIR="/db"
ENV CUDA_HOME="/usr/local/cuda"
ENV PATH="/app:${PATH}"

WORKDIR /app

VOLUME /db
VOLUME /app

CMD bash

Ji bo rehetiyê, pakêtên ku hatine bikar anîn di nav guhêrbaran de hatin danîn; piraniya nivîsarên nivîskî di dema kombûnê de di hundurê konteyneran de têne kopî kirin. Me şêla fermanê jî guhert /bin/bash ji bo hêsaniya karanîna naverokê /etc/os-release. Vê yekê ji hewcedariya danasîna guhertoya OS-ê di kodê de dûr xist.

Wekî din, skrîptek bashek piçûk hate nivîsandin ku dihêle hûn konteynirek bi fermanên cihêreng bidin destpêkirin. Mînakî, ev dikarin skrîptên ji bo perwerdekirina torên neuralî yên ku berê di hundurê konteynerê de hatine danîn, an şêlek fermanê ji bo verastkirin û şopandina xebata konteynerê bin:

Skrîpta destpêkirina konteynerê

#!/bin/sh

DBDIR=${PWD}/db
LOGSDIR=${PWD}/logs
MODELDIR=${PWD}/models
DATADIR=${PWD}/data
ARGS="--runtime=nvidia --rm -v ${DBDIR}:/db -v ${LOGSDIR}:/app/logs -v ${MODELDIR}:/app/models -v ${DATADIR}:/app/data"

if [ -z "$1" ]; then
    CMD="Rscript /app/train_nn.R"
elif [ "$1" = "bash" ]; then
    ARGS="${ARGS} -ti"
else
    CMD="Rscript /app/train_nn.R $@"
fi

docker run ${ARGS} doodles-tf ${CMD}

Ger ev skrîpta bash bêyî pîvanan were xebitandin, dê skrîpt di hundurê konteynerê de were gazî kirin train_nn.R bi nirxên xwerû; heke argumana pozîsyonê ya yekem "bash" be, wê hingê konteynir dê bi şepêlek fermanê bi înteraktîf dest pê bike. Di hemî rewşên din de, nirxên argumanên pozîsyonê têne veguheztin: CMD="Rscript /app/train_nn.R $@".

Hêjayî gotinê ye ku pelrêçên bi daneya çavkaniyê û databasê, û her weha pelrêça ji bo hilanîna modelên perwerdekirî, di hundurê konteynerê de ji pergala mêvandar têne siwar kirin, ku dihêle hûn bigihîjin encamên nivîsan bêyî manipulasyonên nehewce.

7. Bikaranîna gelek GPU li ser Google Cloud

Yek ji taybetmendiyên pêşbaziyê daneyên pir bi deng bû (li wêneya sernavê binêre, ji @Leigh.plt ji ODS slack hatî deyn kirin). Parçeyên mezin di şerkirina vê yekê de dibin alîkar, û piştî ceribandinên li ser PC-ya bi 1 GPU, me biryar da ku em modelên perwerdehiyê li ser çend GPU-yên di ewrê de master bikin. GoogleCloud bikar anîn (rêberê baş ji bo bingehîn) ji ber bijartina mezin a veavakirinên berdest, bihayên maqûl û 300 $ bonus. Ji çavnebariyê, min mînakek 4xV100 bi SSD û tonek RAM ferman da, û ew xeletiyek mezin bû. Makîneyek wusa zû drav dixwe; hûn dikarin bêyî boriyek pejirandî ceribandinek têk bibin. Ji bo armancên perwerdehiyê, çêtir e ku hûn K80 bistînin. Lê mîqdara mezin a RAM bi kêr hat - cloud SSD bi performansa xwe bandor nekir, ji ber vê yekê databas hate veguheztin dev/shm.

Balkêşiya herî mezin parçeya kodê ye ku berpirsiyarê karanîna gelek GPU-yan e. Pêşîn, modela li ser CPU-yê bi karanîna rêveberek kontekstê ve hatî çêkirin, mîna Python:

with(tensorflow::tf$device("/cpu:0"), {
  model_cpu <- get_model(
    name = model_name,
    input_shape = input_shape,
    weights = weights,
    metrics =(top_3_categorical_accuracy,
    compile = FALSE
  )
})

Dûv re modela nehevkirî (ev girîng e) li hejmarek diyarkirî ya GPU-yên berdest tê kopî kirin, û tenê piştî wê tê berhev kirin:

model <- keras::multi_gpu_model(model_cpu, gpus = n_gpu)
keras::compile(
  object = model,
  optimizer = keras::optimizer_adam(lr = 0.0004),
  loss = "categorical_crossentropy",
  metrics = c(top_3_categorical_accuracy)
)

Teknîka klasîk a cemidandina hemî qatan ji bilî ya paşîn, perwerdekirina qata paşîn, venekirin û ji nû ve perwerdekirina modela tevahî ji bo çend GPU-yan nekarî were bicîh kirin.

Perwerde bê bikaranîn dihat şopandin. tensorboard, xwe bi tomarkirina têketin û hilanîna modelên bi navên agahdar piştî her serdemê sînordar dikin:

Callbacks

# Шаблон имени файла лога
log_file_tmpl <- file.path("logs", sprintf(
  "%s_%d_%dch_%s.csv",
  model_name,
  dim_size,
  channels,
  format(Sys.time(), "%Y%m%d%H%M%OS")
))
# Шаблон имени файла модели
model_file_tmpl <- file.path("models", sprintf(
  "%s_%d_%dch_{epoch:02d}_{val_loss:.2f}.h5",
  model_name,
  dim_size,
  channels
))

callbacks_list <- list(
  keras::callback_csv_logger(
    filename = log_file_tmpl
  ),
  keras::callback_early_stopping(
    monitor = "val_loss",
    min_delta = 1e-4,
    patience = 8,
    verbose = 1,
    mode = "min"
  ),
  keras::callback_reduce_lr_on_plateau(
    monitor = "val_loss",
    factor = 0.5, # уменьшаем lr в 2 раза
    patience = 4,
    verbose = 1,
    min_delta = 1e-4,
    mode = "min"
  ),
  keras::callback_model_checkpoint(
    filepath = model_file_tmpl,
    monitor = "val_loss",
    save_best_only = FALSE,
    save_weights_only = FALSE,
    mode = "min"
  )
)

8. Li şûna encamekê

Çend kêşeyên ku em pê re rû bi rû mane hê jî nehatine derbaskirin:

  • в kera fonksiyonek amade tune ku bixweber li rêjeya fêrbûna çêtirîn (analog lr_finder li pirtûkxanê zû.ai); Bi hin hewildanan re, gengaz e ku meriv pêkanînên partiya sêyemîn li R-yê veguhezîne, mînakî, ev;
  • di encama xala berê de, dema ku çend GPU bikar tînin ne gengaz bû ku leza perwerdehiya rast hilbijêrin;
  • kêmasiya mîmariya tora neuralî ya nûjen heye, nemaze yên ku li ser imagenet-ê berê hatine perwerdekirin;
  • tu kes polîtîka û rêjeyên fêrbûnê yên cihêkar naxebitîne (li gorî daxwaza me vekirina kozînê bû pêkanîn, Spas dikim skeydan).

Çi tiştên kêrhatî ji vê pêşbaziyê hîn bûn:

  • Li ser hardware-a nisbeten kêm-hêza, hûn dikarin bi cildên daneya hêja (gelek caran mezinahiya RAM-ê) bêyî êş bixebitin. Çenteyê plastîk data.sifre ji ber guheztina tabloyên di cîh de, ku ji kopîkirina wan dûr dikeve, bîranînê diparêze, û gava ku rast were bikar anîn, kapasîteyên wê hema hema her gav leza herî bilind di nav hemî amûrên ku ji me re têne zanîn ji bo zimanên nivîsandinê nîşan dide. Tomarkirina daneyan di danegehekê de dihêle hûn, di pir rewşan de, qet li ser hewcedariya ku hûn tevahiya databasê di RAM-ê de biqelişînin nefikirin.
  • Fonksiyonên hêdî di R-ê de dikarin bi yên bilez ên di C++ de bi karanîna pakêtê werin guheztin Rcpp. Ger ji bilî bikaranîna RcppThread an RcppParallel, em pêkanînên pir-mijal ên cross-platformê digirin, ji ber vê yekê ne hewce ye ku kodê di asta R-yê de paralel bikin.
  • Pakêt Rcpp dikare bêyî zanîna ciddî ya C++-ê were bikar anîn, hindiktirîn hewce tête diyar kirin vir. Pelên sernavê ji bo hejmarek pirtûkxaneyên C-ya xweş ên mîna xtensor li ser CRAN-ê peyda dibe, ango, binesaziyek ji bo pêkanîna projeyên ku koda amade-performansa bilind a amadekirî ya C++-ê di R-yê de yek dike, tê damezrandin. Rehetiya pêvek ronîkirina hevoksaziyê û analîzkerek koda C++ ya statîk di RStudio de ye.
  • docopt destûrê dide te ku hûn bi pîvanan nivîsarên xweser bimeşînin. Ev ji bo karanîna li ser serverek dûr hêsan e, di nav de. di bin dokerê de. Di RStudio de, nerehet e ku meriv gelek demjimêran ceribandinan bi perwerdekirina torên neuralî re bike, û sazkirina IDE-yê li ser serverê bixwe her gav ne rastdar e.
  • Docker di navbera pêşdebiran de bi guhertoyên cihêreng ên OS û pirtûkxaneyan, û hem jî hêsankirina darvekirinê li ser pêşkêşkeran veguheztina kodê û dubarekirina encaman misoger dike. Hûn dikarin bi tenê yek fermanê tevahiya boriyê perwerdehiyê bidin destpêkirin.
  • Google Cloud rêgezek budceyê ye ku meriv li ser hardware biha biceribîne, lê hûn hewce ne ku bi baldarî mîhengan hilbijêrin.
  • Pîvandina leza perçeyên kodê yên takekesî pir bikêr e, nemaze dema ku R û C++, û bi pakêtê re tê hev kirin. dika - di heman demê de pir hêsan.

Bi tevayî ev ezmûn pir bikêrhatî bû û em berdewam dikin ji bo çareserkirina hin pirsgirêkên ku hatine raber kirin.

Source: www.habr.com

Add a comment