Gane Doodle Mai Saurin Zana: yadda ake yin abokai tare da R, C++ da cibiyoyin sadarwar jijiyoyi

Gane Doodle Mai Saurin Zana: yadda ake yin abokai tare da R, C++ da cibiyoyin sadarwar jijiyoyi

Hai Habr!

A faɗuwar da ta gabata, Kaggle ya shirya gasar don rarraba hotuna da aka zana da hannu, Ganewar Draw Doodle mai sauri, wanda a cikinsa, ƙungiyar masana kimiyya ta R suka shiga: Artem Klevtsova, Philippa Manager и Andrey Ogurtsov. Ba za mu yi bayanin gasar dalla dalla ba; an riga an yi a ciki buga kwanan nan.

A wannan karon bai yi aiki ba tare da noman lambobin yabo, amma an sami gogewa mai yawa mai mahimmanci, don haka ina so in gaya wa al'umma game da abubuwa da yawa masu ban sha'awa da amfani akan Kagle da kuma aikin yau da kullun. Daga cikin batutuwan da aka tattauna: rayuwa mai wahala ba tare da OpenCV, JSON parsing (waɗannan misalan suna nazarin haɗakar lambar C ++ cikin rubutun ko fakiti a cikin R ta amfani da Rcpp), ƙaddamar da rubutun rubutu da dockerization na mafita na ƙarshe. Duk lambar daga saƙon a cikin nau'i mai dacewa don aiwatarwa yana samuwa a ciki wuraren ajiya.

Abubuwan:

  1. Ingantacciyar loda bayanai daga CSV zuwa MonetDB
  2. Ana shirya batches
  3. Iterators don sauke batches daga ma'ajin bayanai
  4. Zabar Tsarin Gine-ginen Samfurin
  5. Daidaiton rubutun
  6. Dockerization na rubutun
  7. Amfani da GPUs da yawa akan Google Cloud
  8. Maimakon a ƙarshe

1. Yadda ya kamata a loda bayanai daga CSV zuwa cikin bayanan MonetDB

Bayanan da ke cikin wannan gasa ba a samar da su ta hanyar hotunan da aka shirya ba, amma a cikin nau'i na fayilolin CSV 340 (fayil ɗaya don kowane aji) mai ɗauke da JSONs tare da daidaitawa. Ta haɗa waɗannan maki tare da layi, muna samun hoton ƙarshe mai auna 256x256 pixels. Har ila yau, ga kowane rikodin akwai alamar da ke nuna ko an gane hoton daidai ta hanyar rarrabawa da aka yi amfani da shi a lokacin da aka tattara bayanan, lambar haruffa biyu na ƙasar mazaunin marubucin hoton, mai ganewa na musamman, tambarin lokaci. da sunan aji wanda yayi daidai da sunan fayil. Sauƙaƙan sigar ainihin bayanan tana auna 7.4 GB a cikin ma'ajiyar bayanai kuma kusan 20 GB bayan buɗewa, cikakkun bayanan bayan an kwashe suna ɗaukar 240 GB. Masu shiryawa sun tabbatar da cewa duka nau'ikan sun sake yin zane iri ɗaya, ma'ana cikakken sigar ba ta da yawa. A kowane hali, adana hotuna miliyan 50 a cikin fayilolin hoto ko a cikin tsarin tsararru an ɗauke shi nan da nan ba riba ba, kuma mun yanke shawarar haɗa duk fayilolin CSV daga ma'ajiyar. jirgin kasa_simplified.zip a cikin ma'ajin bayanai tare da tsararrun hotuna masu zuwa na girman da ake buƙata "a kan tashi" ga kowane tsari.

An zaɓi ingantaccen tsari azaman DBMS MonetDB, wato aiwatarwa don R a matsayin kunshin MonetDLite. Kunshin ya ƙunshi nau'in sabar bayanan bayanai kuma yana ba ku damar ɗaukar sabar kai tsaye daga zaman R kuma kuyi aiki tare da shi a can. Ƙirƙirar bayanan bayanai da haɗawa da ita ana yin su tare da umarni ɗaya:

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

Za mu buƙaci ƙirƙirar tebur guda biyu: ɗaya don duk bayanai, ɗayan don bayanin sabis game da fayilolin da aka sauke (da amfani idan wani abu ya ɓace kuma dole ne a ci gaba da aiwatarwa bayan zazzage fayiloli da yawa):

Ƙirƙirar teburi

if (!DBI::dbExistsTable(con, "doodles")) {
  DBI::dbCreateTable(
    con = con,
    name = "doodles",
    fields = c(
      "countrycode" = "char(2)",
      "drawing" = "text",
      "key_id" = "bigint",
      "recognized" = "bool",
      "timestamp" = "timestamp",
      "word" = "text"
    )
  )
}

if (!DBI::dbExistsTable(con, "upload_log")) {
  DBI::dbCreateTable(
    con = con,
    name = "upload_log",
    fields = c(
      "id" = "serial",
      "file_name" = "text UNIQUE",
      "uploaded" = "bool DEFAULT false"
    )
  )
}

Hanya mafi sauri don loda bayanai a cikin rumbun adana bayanai ita ce kwafin fayilolin CSV kai tsaye ta amfani da SQL - umarni COPY OFFSET 2 INTO tablename FROM path USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORTinda tablename - tebur sunan kuma path - hanyar zuwa fayil. Yayin aiki tare da tarihin, an gano cewa ginanniyar aiwatarwa unzip a R ba ya aiki daidai tare da adadin fayiloli daga ma'ajin, don haka muka yi amfani da tsarin unzip (ta amfani da parameter getOption("unzip")).

Ayyuka don rubutawa zuwa bayanan bayanai

#' @title Извлечение и загрузка файлов
#'
#' @description
#' Извлечение CSV-файлов из ZIP-архива и загрузка их в базу данных
#'
#' @param con Объект подключения к базе данных (класс `MonetDBEmbeddedConnection`).
#' @param tablename Название таблицы в базе данных.
#' @oaram zipfile Путь к ZIP-архиву.
#' @oaram filename Имя файла внури ZIP-архива.
#' @param preprocess Функция предобработки, которая будет применена извлечённому файлу.
#'   Должна принимать один аргумент `data` (объект `data.table`).
#'
#' @return `TRUE`.
#'
upload_file <- function(con, tablename, zipfile, filename, preprocess = NULL) {
  # Проверка аргументов
  checkmate::assert_class(con, "MonetDBEmbeddedConnection")
  checkmate::assert_string(tablename)
  checkmate::assert_string(filename)
  checkmate::assert_true(DBI::dbExistsTable(con, tablename))
  checkmate::assert_file_exists(zipfile, access = "r", extension = "zip")
  checkmate::assert_function(preprocess, args = c("data"), null.ok = TRUE)

  # Извлечение файла
  path <- file.path(tempdir(), filename)
  unzip(zipfile, files = filename, exdir = tempdir(), 
        junkpaths = TRUE, unzip = getOption("unzip"))
  on.exit(unlink(file.path(path)))

  # Применяем функция предобработки
  if (!is.null(preprocess)) {
    .data <- data.table::fread(file = path)
    .data <- preprocess(data = .data)
    data.table::fwrite(x = .data, file = path, append = FALSE)
    rm(.data)
  }

  # Запрос к БД на импорт CSV
  sql <- sprintf(
    "COPY OFFSET 2 INTO %s FROM '%s' USING DELIMITERS ',','n','"' NULL AS '' BEST EFFORT",
    tablename, path
  )
  # Выполнение запроса к БД
  DBI::dbExecute(con, sql)

  # Добавление записи об успешной загрузке в служебную таблицу
  DBI::dbExecute(con, sprintf("INSERT INTO upload_log(file_name, uploaded) VALUES('%s', true)",
                              filename))

  return(invisible(TRUE))
}

Idan kana buƙatar canza tebur kafin rubuta shi zuwa bayanan bayanai, ya isa ya wuce a cikin gardama preprocess aikin da zai canza bayanan.

Lambar don loda bayanai bi-da-bi a cikin ma'ajin bayanai:

Rubutun bayanai zuwa database

# Список файлов для записи
files <- unzip(zipfile, list = TRUE)$Name

# Список исключений, если часть файлов уже была загружена
to_skip <- DBI::dbGetQuery(con, "SELECT file_name FROM upload_log")[[1L]]
files <- setdiff(files, to_skip)

if (length(files) > 0L) {
  # Запускаем таймер
  tictoc::tic()
  # Прогресс бар
  pb <- txtProgressBar(min = 0L, max = length(files), style = 3)
  for (i in seq_along(files)) {
    upload_file(con = con, tablename = "doodles", 
                zipfile = zipfile, filename = files[i])
    setTxtProgressBar(pb, i)
  }
  close(pb)
  # Останавливаем таймер
  tictoc::toc()
}

# 526.141 sec elapsed - копирование SSD->SSD
# 558.879 sec elapsed - копирование USB->SSD

Lokacin loda bayanai na iya bambanta dangane da yanayin saurin abin da ake amfani da shi. A cikin yanayinmu, karantawa da rubutu a cikin SSD ɗaya ko daga filasha (fayil ɗin tushen) zuwa SSD (DB) yana ɗaukar ƙasa da mintuna 10.

Yana ɗaukar ƴan daƙiƙa kaɗan don ƙirƙirar ginshiƙi tare da alamar ajin lamba da ginshiƙin fihirisa (ORDERED INDEX) tare da lambobin layi waɗanda za a ƙirƙira abubuwan lura yayin ƙirƙirar batches:

Ƙirƙirar Ƙarin Rukunoni da Fihirisa

message("Generate lables")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD label_int int"))
invisible(DBI::dbExecute(con, "UPDATE doodles SET label_int = dense_rank() OVER (ORDER BY word) - 1"))

message("Generate row numbers")
invisible(DBI::dbExecute(con, "ALTER TABLE doodles ADD id serial"))
invisible(DBI::dbExecute(con, "CREATE ORDERED INDEX doodles_id_ord_idx ON doodles(id)"))

Don magance matsalar ƙirƙirar tsari akan tashi, muna buƙatar cimma matsakaicin saurin cire layuka na bazuwar daga tebur. doodles. Don wannan mun yi amfani da dabaru 3. Na farko shine don rage girman nau'in da ke adana ID na lura. A cikin saitin bayanan asali, nau'in da ake buƙata don adana ID shine bigint, amma adadin abubuwan lura yana ba da damar dacewa da masu gano su, daidai da lambar ƙididdiga, cikin nau'in. int. Binciken ya fi sauri a wannan yanayin. Dabarar ta biyu ita ce amfani ORDERED INDEX - Mun zo wannan shawarar da gaske, bayan mun bi duk abin da ke akwai zaɓuɓɓuka. Na uku shine a yi amfani da tambayoyin da aka daidaita. Ma'anar hanyar ita ce aiwatar da umarnin sau ɗaya PREPARE tare da yin amfani da maganganun da aka shirya a gaba lokacin ƙirƙirar tarin tambayoyi iri ɗaya, amma a zahiri akwai fa'ida idan aka kwatanta da mai sauƙi. SELECT ya zama cikin kewayon kuskuren ƙididdiga.

Tsarin loda bayanai yana cinye fiye da 450 MB na RAM. Wato, hanyar da aka kwatanta tana ba ku damar motsa bayanan da ke auna dubun gigabytes akan kusan kowane kayan aikin kasafin kuɗi, gami da wasu na'urorin allo guda ɗaya, waɗanda ke da kyau sosai.

Abin da ya rage shi ne auna saurin dawo da bayanai (bazuwar) da kimanta sikelin lokacin da ake yin samfurin batches masu girma dabam:

Database benchmark

library(ggplot2)

set.seed(0)
# Подключение к базе данных
con <- DBI::dbConnect(MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

# Функция для подготовки запроса на стороне сервера
prep_sql <- function(batch_size) {
  sql <- sprintf("PREPARE SELECT id FROM doodles WHERE id IN (%s)",
                 paste(rep("?", batch_size), collapse = ","))
  res <- DBI::dbSendQuery(con, sql)
  return(res)
}

# Функция для извлечения данных
fetch_data <- function(rs, batch_size) {
  ids <- sample(seq_len(n), batch_size)
  res <- DBI::dbFetch(DBI::dbBind(rs, as.list(ids)))
  return(res)
}

# Проведение замера
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    rs <- prep_sql(batch_size)
    bench::mark(
      fetch_data(rs, batch_size),
      min_iterations = 50L
    )
  }
)
# Параметры бенчмарка
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16   23.6ms  54.02ms  93.43ms     18.8        2.6s    49
# 2         32     38ms  84.83ms 151.55ms     11.4       4.29s    49
# 3         64   63.3ms 175.54ms 248.94ms     5.85       8.54s    50
# 4        128   83.2ms 341.52ms 496.24ms     3.00      16.69s    50
# 5        256  232.8ms 653.21ms 847.44ms     1.58      31.66s    50
# 6        512  784.6ms    1.41s    1.98s     0.740       1.1m    49
# 7       1024  681.7ms    2.72s    4.06s     0.377      2.16m    49

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Gane Doodle Mai Saurin Zana: yadda ake yin abokai tare da R, C++ da cibiyoyin sadarwar jijiyoyi

2. Ana shirya batches

Gabaɗayan tsarin shirya tsari ya ƙunshi matakai masu zuwa:

  1. Yin nazarin JSONs da yawa masu ɗauke da vectors na kirtani tare da daidaita maki.
  2. Zana layuka masu launi dangane da daidaita maki akan hoton girman da ake buƙata (misali, 256×256 ko 128×128).
  3. Mayar da sakamakon hotuna zuwa tensor.

A matsayin wani ɓangare na gasar tsakanin kernels Python, an magance matsalar da farko ta amfani da OpenCV. Ɗaya daga cikin mafi sauƙi kuma mafi bayyane analogues a cikin R zai yi kama da wannan:

Aiwatar da JSON zuwa Canjin Tensor a cikin R

r_process_json_str <- function(json, line.width = 3, 
                               color = TRUE, scale = 1) {
  # Парсинг JSON
  coords <- jsonlite::fromJSON(json, simplifyMatrix = FALSE)
  tmp <- tempfile()
  # Удаляем временный файл по завершению функции
  on.exit(unlink(tmp))
  png(filename = tmp, width = 256 * scale, height = 256 * scale, pointsize = 1)
  # Пустой график
  plot.new()
  # Размер окна графика
  plot.window(xlim = c(256 * scale, 0), ylim = c(256 * scale, 0))
  # Цвета линий
  cols <- if (color) rainbow(length(coords)) else "#000000"
  for (i in seq_along(coords)) {
    lines(x = coords[[i]][[1]] * scale, y = coords[[i]][[2]] * scale, 
          col = cols[i], lwd = line.width)
  }
  dev.off()
  # Преобразование изображения в 3-х мерный массив
  res <- png::readPNG(tmp)
  return(res)
}

r_process_json_vector <- function(x, ...) {
  res <- lapply(x, r_process_json_str, ...)
  # Объединение 3-х мерных массивов картинок в 4-х мерный в тензор
  res <- do.call(abind::abind, c(res, along = 0))
  return(res)
}

Ana yin zane ta amfani da daidaitattun kayan aikin R kuma an adana su zuwa PNG na wucin gadi da aka adana a cikin RAM (a kan Linux, kundayen adireshi na R na wucin gadi suna cikin kundin adireshi). /tmp, saka a cikin RAM). Ana karanta wannan fayil ɗin azaman tsararru mai girma uku tare da lambobi daga 0 zuwa 1. Wannan yana da mahimmanci saboda za'a karanta BMP na al'ada a cikin ɗanyen tsararru tare da lambobin launi hex.

Mu gwada sakamakon:

zip_file <- file.path("data", "train_simplified.zip")
csv_file <- "cat.csv"
unzip(zip_file, files = csv_file, exdir = tempdir(), 
      junkpaths = TRUE, unzip = getOption("unzip"))
tmp_data <- data.table::fread(file.path(tempdir(), csv_file), sep = ",", 
                              select = "drawing", nrows = 10000)
arr <- r_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Gane Doodle Mai Saurin Zana: yadda ake yin abokai tare da R, C++ da cibiyoyin sadarwar jijiyoyi

Za a samar da rukunin kanta kamar haka:

res <- r_process_json_vector(tmp_data[1:4, drawing], scale = 0.5)
str(res)
 # num [1:4, 1:128, 1:128, 1:3] 1 1 1 1 1 1 1 1 1 1 ...
 # - attr(*, "dimnames")=List of 4
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL
 #  ..$ : NULL

Wannan aiwatarwa ya yi kama da mafi kyau a gare mu, tun da kafa manyan batches yana ɗaukar lokaci mai tsawo da ba daidai ba, kuma mun yanke shawarar yin amfani da ƙwarewar abokan aikinmu ta amfani da ɗakin karatu mai ƙarfi. OpenCV. A wannan lokacin babu wani kunshin da aka shirya don R (babu ɗaya yanzu), don haka an rubuta ƙaramin aiwatar da aikin da ake buƙata a cikin C ++ tare da haɗawa cikin R code ta amfani da Rcpp.

Don magance matsalar, an yi amfani da fakiti da ɗakunan karatu masu zuwa:

  1. OpenCV don aiki tare da hotuna da layin zane. An yi amfani da ɗakunan karatu na tsarin da aka riga aka shigar da su da fayilolin kan kai, da kuma haɗin kai mai ƙarfi.

  2. xtensor don aiki tare da multidimensional arrays da tenors. Mun yi amfani da fayilolin rubutun da aka haɗa a cikin fakitin R na wannan suna. Laburaren yana ba ku damar yin aiki tare da tsararraki masu girma dabam, duka a cikin manyan jere da babban tsari.

  3. ndjson Farashin JSON. Ana amfani da wannan ɗakin karatu a ciki xtensor ta atomatik idan yana cikin aikin.

  4. RcppThread don tsara Multi-threaded sarrafa na'urar vector daga JSON. An yi amfani da fayilolin taken da wannan fakitin ya bayar. Daga mafi shahara RcppParallel Kunshin, a tsakanin sauran abubuwa, yana da ginanniyar hanyar katse madauki.

Yana da daraja daraja wannan xtensor ya zama abin godiya: ban da gaskiyar cewa yana da ayyuka masu yawa da babban aiki, masu haɓakawa sun kasance masu amsawa sosai kuma sun amsa tambayoyi cikin sauri da dalla-dalla. Tare da taimakonsu, yana yiwuwa a aiwatar da sauye-sauye na matrices na OpenCV a cikin xtensor tenors, da kuma hanyar da za a haɗa nau'i-nau'i na hotuna 3 a cikin nau'i na 4-dimensional tensor na madaidaicin girman (batch kanta).

Kayayyakin don koyan Rcpp, xtensor da RcppThread

https://thecoatlessprofessor.com/programming/unofficial-rcpp-api-documentation

https://docs.opencv.org/4.0.1/d7/dbd/group__imgproc.html

https://xtensor.readthedocs.io/en/latest/

https://xtensor.readthedocs.io/en/latest/file_loading.html#loading-json-data-into-xtensor

https://cran.r-project.org/web/packages/RcppThread/vignettes/RcppThread-vignette.pdf

Don tattara fayilolin da ke amfani da fayilolin tsarin da haɗin kai mai ƙarfi tare da ɗakunan karatu da aka shigar a kan tsarin, mun yi amfani da tsarin plugin ɗin da aka aiwatar a cikin kunshin. Rcpp. Don nemo hanyoyi da tutoci ta atomatik, mun yi amfani da mashahurin mai amfani na Linux pkg-jeri.

Aiwatar da kayan aikin Rcpp don amfani da ɗakin karatu na OpenCV

Rcpp::registerPlugin("opencv", function() {
  # Возможные названия пакета
  pkg_config_name <- c("opencv", "opencv4")
  # Бинарный файл утилиты pkg-config
  pkg_config_bin <- Sys.which("pkg-config")
  # Проврека наличия утилиты в системе
  checkmate::assert_file_exists(pkg_config_bin, access = "x")
  # Проверка наличия файла настроек OpenCV для pkg-config
  check <- sapply(pkg_config_name, 
                  function(pkg) system(paste(pkg_config_bin, pkg)))
  if (all(check != 0)) {
    stop("OpenCV config for the pkg-config not found", call. = FALSE)
  }

  pkg_config_name <- pkg_config_name[check == 0]
  list(env = list(
    PKG_CXXFLAGS = system(paste(pkg_config_bin, "--cflags", pkg_config_name), 
                          intern = TRUE),
    PKG_LIBS = system(paste(pkg_config_bin, "--libs", pkg_config_name), 
                      intern = TRUE)
  ))
})

A sakamakon aikin plugin ɗin, za a maye gurbin dabi'u masu zuwa yayin tsarin tattarawa:

Rcpp:::.plugins$opencv()$env

# $PKG_CXXFLAGS
# [1] "-I/usr/include/opencv"
#
# $PKG_LIBS
# [1] "-lopencv_shape -lopencv_stitching -lopencv_superres -lopencv_videostab -lopencv_aruco -lopencv_bgsegm -lopencv_bioinspired -lopencv_ccalib -lopencv_datasets -lopencv_dpm -lopencv_face -lopencv_freetype -lopencv_fuzzy -lopencv_hdf -lopencv_line_descriptor -lopencv_optflow -lopencv_video -lopencv_plot -lopencv_reg -lopencv_saliency -lopencv_stereo -lopencv_structured_light -lopencv_phase_unwrapping -lopencv_rgbd -lopencv_viz -lopencv_surface_matching -lopencv_text -lopencv_ximgproc -lopencv_calib3d -lopencv_features2d -lopencv_flann -lopencv_xobjdetect -lopencv_objdetect -lopencv_ml -lopencv_xphoto -lopencv_highgui -lopencv_videoio -lopencv_imgcodecs -lopencv_photo -lopencv_imgproc -lopencv_core"

An ba da lambar aiwatarwa don rarraba JSON da samar da tsari don watsawa ga samfurin a ƙarƙashin mai ɓarna. Da farko, ƙara kundin tsarin aikin gida don bincika fayilolin kan kai (ana buƙatar ndjson):

Sys.setenv("PKG_CXXFLAGS" = paste0("-I", normalizePath(file.path("src"))))

Aiwatar da JSON zuwa canjin tensor a C++

// [[Rcpp::plugins(cpp14)]]
// [[Rcpp::plugins(opencv)]]
// [[Rcpp::depends(xtensor)]]
// [[Rcpp::depends(RcppThread)]]

#include <xtensor/xjson.hpp>
#include <xtensor/xadapt.hpp>
#include <xtensor/xview.hpp>
#include <xtensor-r/rtensor.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <Rcpp.h>
#include <RcppThread.h>

// Синонимы для типов
using RcppThread::parallelFor;
using json = nlohmann::json;
using points = xt::xtensor<double,2>;     // Извлечённые из JSON координаты точек
using strokes = std::vector<points>;      // Извлечённые из JSON координаты точек
using xtensor3d = xt::xtensor<double, 3>; // Тензор для хранения матрицы изоображения
using xtensor4d = xt::xtensor<double, 4>; // Тензор для хранения множества изображений
using rtensor3d = xt::rtensor<double, 3>; // Обёртка для экспорта в R
using rtensor4d = xt::rtensor<double, 4>; // Обёртка для экспорта в R

// Статические константы
// Размер изображения в пикселях
const static int SIZE = 256;
// Тип линии
// См. https://en.wikipedia.org/wiki/Pixel_connectivity#2-dimensional
const static int LINE_TYPE = cv::LINE_4;
// Толщина линии в пикселях
const static int LINE_WIDTH = 3;
// Алгоритм ресайза
// https://docs.opencv.org/3.1.0/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
const static int RESIZE_TYPE = cv::INTER_LINEAR;

// Шаблон для конвертирования OpenCV-матрицы в тензор
template <typename T, int NCH, typename XT=xt::xtensor<T,3,xt::layout_type::column_major>>
XT to_xt(const cv::Mat_<cv::Vec<T, NCH>>& src) {
  // Размерность целевого тензора
  std::vector<int> shape = {src.rows, src.cols, NCH};
  // Общее количество элементов в массиве
  size_t size = src.total() * NCH;
  // Преобразование cv::Mat в xt::xtensor
  XT res = xt::adapt((T*) src.data, size, xt::no_ownership(), shape);
  return res;
}

// Преобразование JSON в список координат точек
strokes parse_json(const std::string& x) {
  auto j = json::parse(x);
  // Результат парсинга должен быть массивом
  if (!j.is_array()) {
    throw std::runtime_error("'x' must be JSON array.");
  }
  strokes res;
  res.reserve(j.size());
  for (const auto& a: j) {
    // Каждый элемент массива должен быть 2-мерным массивом
    if (!a.is_array() || a.size() != 2) {
      throw std::runtime_error("'x' must include only 2d arrays.");
    }
    // Извлечение вектора точек
    auto p = a.get<points>();
    res.push_back(p);
  }
  return res;
}

// Отрисовка линий
// Цвета HSV
cv::Mat ocv_draw_lines(const strokes& x, bool color = true) {
  // Исходный тип матрицы
  auto stype = color ? CV_8UC3 : CV_8UC1;
  // Итоговый тип матрицы
  auto dtype = color ? CV_32FC3 : CV_32FC1;
  auto bg = color ? cv::Scalar(0, 0, 255) : cv::Scalar(255);
  auto col = color ? cv::Scalar(0, 255, 220) : cv::Scalar(0);
  cv::Mat img = cv::Mat(SIZE, SIZE, stype, bg);
  // Количество линий
  size_t n = x.size();
  for (const auto& s: x) {
    // Количество точек в линии
    size_t n_points = s.shape()[1];
    for (size_t i = 0; i < n_points - 1; ++i) {
      // Точка начала штриха
      cv::Point from(s(0, i), s(1, i));
      // Точка окончания штриха
      cv::Point to(s(0, i + 1), s(1, i + 1));
      // Отрисовка линии
      cv::line(img, from, to, col, LINE_WIDTH, LINE_TYPE);
    }
    if (color) {
      // Меняем цвет линии
      col[0] += 180 / n;
    }
  }
  if (color) {
    // Меняем цветовое представление на RGB
    cv::cvtColor(img, img, cv::COLOR_HSV2RGB);
  }
  // Меняем формат представления на float32 с диапазоном [0, 1]
  img.convertTo(img, dtype, 1 / 255.0);
  return img;
}

// Обработка JSON и получение тензора с данными изображения
xtensor3d process(const std::string& x, double scale = 1.0, bool color = true) {
  auto p = parse_json(x);
  auto img = ocv_draw_lines(p, color);
  if (scale != 1) {
    cv::Mat out;
    cv::resize(img, out, cv::Size(), scale, scale, RESIZE_TYPE);
    cv::swap(img, out);
    out.release();
  }
  xtensor3d arr = color ? to_xt<double,3>(img) : to_xt<double,1>(img);
  return arr;
}

// [[Rcpp::export]]
rtensor3d cpp_process_json_str(const std::string& x, 
                               double scale = 1.0, 
                               bool color = true) {
  xtensor3d res = process(x, scale, color);
  return res;
}

// [[Rcpp::export]]
rtensor4d cpp_process_json_vector(const std::vector<std::string>& x, 
                                  double scale = 1.0, 
                                  bool color = false) {
  size_t n = x.size();
  size_t dim = floor(SIZE * scale);
  size_t channels = color ? 3 : 1;
  xtensor4d res({n, dim, dim, channels});
  parallelFor(0, n, [&x, &res, scale, color](int i) {
    xtensor3d tmp = process(x[i], scale, color);
    auto view = xt::view(res, i, xt::all(), xt::all(), xt::all());
    view = tmp;
  });
  return res;
}

Ya kamata a sanya wannan lambar a cikin fayil ɗin src/cv_xt.cpp kuma tara tare da umarnin Rcpp::sourceCpp(file = "src/cv_xt.cpp", env = .GlobalEnv); kuma ana buƙata don aiki nlohmann/json.hpp daga wurin ajiya. An raba lambar zuwa ayyuka da yawa:

  • to_xt - aikin da aka tsara don canza matrix na hoto (cv::Mat) zuwa tensor xt::xtensor;

  • parse_json - aikin yana rarraba kirtani na JSON, yana fitar da daidaitawar maki, tattara su cikin vector;

  • ocv_draw_lines - daga sakamakon vector na maki, zana layukan launuka masu yawa;

  • process - ya haɗa ayyukan da ke sama kuma yana ƙara ikon sikelin hoton da aka samu;

  • cpp_process_json_str - wrapper akan aikin process, wanda ke fitar da sakamakon zuwa R-abu (multidimensional array);

  • cpp_process_json_vector - wrapper akan aikin cpp_process_json_str, wanda ke ba ku damar aiwatar da vector na kirtani a cikin yanayin zaren da yawa.

Don zana layukan launuka masu yawa, an yi amfani da samfurin launi na HSV, sannan juyawa zuwa RGB. Mu gwada sakamakon:

arr <- cpp_process_json_str(tmp_data[4, drawing])
dim(arr)
# [1] 256 256   3
plot(magick::image_read(arr))

Gane Doodle Mai Saurin Zana: yadda ake yin abokai tare da R, C++ da cibiyoyin sadarwar jijiyoyi
Kwatanta saurin aiwatarwa a cikin R da C++

res_bench <- bench::mark(
  r_process_json_str(tmp_data[4, drawing], scale = 0.5),
  cpp_process_json_str(tmp_data[4, drawing], scale = 0.5),
  check = FALSE,
  min_iterations = 100
)
# Параметры бенчмарка
cols <- c("expression", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   expression                min     median       max `itr/sec` total_time  n_itr
#   <chr>                <bch:tm>   <bch:tm>  <bch:tm>     <dbl>   <bch:tm>  <int>
# 1 r_process_json_str     3.49ms     3.55ms    4.47ms      273.      490ms    134
# 2 cpp_process_json_str   1.94ms     2.02ms    5.32ms      489.      497ms    243

library(ggplot2)
# Проведение замера
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    .data <- tmp_data[sample(seq_len(.N), batch_size), drawing]
    bench::mark(
      r_process_json_vector(.data, scale = 0.5),
      cpp_process_json_vector(.data,  scale = 0.5),
      min_iterations = 50,
      check = FALSE
    )
  }
)

res_bench[, cols]

#    expression   batch_size      min   median      max `itr/sec` total_time n_itr
#    <chr>             <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
#  1 r                   16   50.61ms  53.34ms  54.82ms    19.1     471.13ms     9
#  2 cpp                 16    4.46ms   5.39ms   7.78ms   192.      474.09ms    91
#  3 r                   32   105.7ms 109.74ms 212.26ms     7.69        6.5s    50
#  4 cpp                 32    7.76ms  10.97ms  15.23ms    95.6     522.78ms    50
#  5 r                   64  211.41ms 226.18ms 332.65ms     3.85      12.99s    50
#  6 cpp                 64   25.09ms  27.34ms  32.04ms    36.0        1.39s    50
#  7 r                  128   534.5ms 627.92ms 659.08ms     1.61      31.03s    50
#  8 cpp                128   56.37ms  58.46ms  66.03ms    16.9        2.95s    50
#  9 r                  256     1.15s    1.18s    1.29s     0.851     58.78s    50
# 10 cpp                256  114.97ms 117.39ms 130.09ms     8.45       5.92s    50
# 11 r                  512     2.09s    2.15s    2.32s     0.463       1.8m    50
# 12 cpp                512  230.81ms  235.6ms 261.99ms     4.18      11.97s    50
# 13 r                 1024        4s    4.22s     4.4s     0.238       3.5m    50
# 14 cpp               1024  410.48ms 431.43ms 462.44ms     2.33      21.45s    50

ggplot(res_bench, aes(x = factor(batch_size), y = median, 
                      group =  expression, color = expression)) +
  geom_point() +
  geom_line() +
  ylab("median time, s") +
  theme_minimal() +
  scale_color_discrete(name = "", labels = c("cpp", "r")) +
  theme(legend.position = "bottom") 

Gane Doodle Mai Saurin Zana: yadda ake yin abokai tare da R, C++ da cibiyoyin sadarwar jijiyoyi

Kamar yadda kuke gani, haɓakar saurin ya zama mahimmanci sosai, kuma ba zai yiwu a cim ma lambar C++ ta hanyar daidaita lambar R ba.

3. Iterators don sauke batches daga database

R yana da kyakkyawan suna don sarrafa bayanan da suka dace da RAM, yayin da Python ya fi saninsa ta hanyar sarrafa bayanai, yana ba ku damar aiwatar da ƙididdige ƙididdiga cikin sauƙi da ta halitta (ƙididdigar ta amfani da ƙwaƙwalwar waje). Misali na yau da kullun kuma mai dacewa a gare mu a cikin mahallin matsalar da aka kwatanta shine hanyoyin sadarwa masu zurfi waɗanda aka horar da su ta hanyar zuriyar gradient tare da ƙima na gradient a kowane mataki ta amfani da ƙaramin yanki na lura, ko ƙaramin tsari.

Tsarin ilmantarwa mai zurfi da aka rubuta a cikin Python yana da azuzuwan na musamman waɗanda ke aiwatar da masu ƙira bisa bayanai: tebur, hotuna a manyan fayiloli, tsarin binary, da sauransu. Kuna iya amfani da zaɓin da aka shirya ko rubuta naku don takamaiman ayyuka. A cikin R za mu iya amfani da duk fasalulluka na ɗakin karatu na Python keras tare da nau'ikan bayansa daban-daban ta amfani da kunshin suna iri ɗaya, wanda kuma yana aiki a saman kunshin reticlant. Na karshen ya cancanci wani dogon labari dabam; ba wai kawai yana ba ku damar gudanar da lambar Python daga R ba, har ma yana ba ku damar canja wurin abubuwa tsakanin zaman R da Python, yin duk abubuwan da suka dace ta atomatik.

Mun kawar da buƙatar adana duk bayanan a cikin RAM ta amfani da MonetDLite, duk aikin "cibiyar sadarwa" za a yi ta hanyar lambar asali a cikin Python, kawai dole ne mu rubuta maimaitawa akan bayanan, tun da babu wani abu da aka shirya. don irin wannan yanayin a cikin R ko Python. Akwai ainihin buƙatu guda biyu kawai don shi: dole ne ya dawo da batches a cikin madauki mara iyaka kuma ya adana yanayinsa tsakanin abubuwan da ke faruwa (an aiwatar da na ƙarshe a cikin R ta hanya mafi sauƙi ta amfani da rufewa). A baya can, ana buƙatar a fito fili a canza tsarin R zuwa tsararraki masu ƙima a cikin na'ura, amma sigar fakitin na yanzu. keras tana yi da kanta.

Mai ƙididdigewa don horarwa da bayanan tabbatarwa sun kasance kamar haka:

Iterator don horo da ingantaccen bayanan

train_generator <- function(db_connection = con,
                            samples_index,
                            num_classes = 340,
                            batch_size = 32,
                            scale = 1,
                            color = FALSE,
                            imagenet_preproc = FALSE) {
  # Проверка аргументов
  checkmate::assert_class(con, "DBIConnection")
  checkmate::assert_integerish(samples_index)
  checkmate::assert_count(num_classes)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # Перемешиваем, чтобы брать и удалять использованные индексы батчей по порядку
  dt <- data.table::data.table(id = sample(samples_index))
  # Проставляем номера батчей
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  # Оставляем только полные батчи и индексируем
  dt <- dt[, if (.N == batch_size) .SD, keyby = batch]
  # Устанавливаем счётчик
  i <- 1
  # Количество батчей
  max_i <- dt[, max(batch)]

  # Подготовка выражения для выгрузки
  sql <- sprintf(
    "PREPARE SELECT drawing, label_int FROM doodles WHERE id IN (%s)",
    paste(rep("?", batch_size), collapse = ",")
  )
  res <- DBI::dbSendQuery(con, sql)

  # Аналог keras::to_categorical
  to_categorical <- function(x, num) {
    n <- length(x)
    m <- numeric(n * num)
    m[x * n + seq_len(n)] <- 1
    dim(m) <- c(n, num)
    return(m)
  }

  # Замыкание
  function() {
    # Начинаем новую эпоху
    if (i > max_i) {
      dt[, id := sample(id)]
      data.table::setkey(dt, batch)
      # Сбрасываем счётчик
      i <<- 1
      max_i <<- dt[, max(batch)]
    }

    # ID для выгрузки данных
    batch_ind <- dt[batch == i, id]
    # Выгрузка данных
    batch <- DBI::dbFetch(DBI::dbBind(res, as.list(batch_ind)), n = -1)

    # Увеличиваем счётчик
    i <<- i + 1

    # Парсинг JSON и подготовка массива
    batch_x <- cpp_process_json_vector(batch$drawing, scale = scale, color = color)
    if (imagenet_preproc) {
      # Шкалирование c интервала [0, 1] на интервал [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }

    batch_y <- to_categorical(batch$label_int, num_classes)
    result <- list(batch_x, batch_y)
    return(result)
  }
}

Ayyukan yana ɗauka azaman shigar da maɓalli tare da haɗi zuwa bayanan bayanai, lambobin layin da aka yi amfani da su, adadin azuzuwan, girman tsari, ma'auni (scale = 1 yayi daidai da ɗaukar hotuna na 256x256 pixels, scale = 0.5 - 128x128 pixels), mai nuna launi (color = FALSE yana ƙayyade ma'ana a cikin launin toka lokacin amfani color = TRUE kowane bugun jini ana zana shi a cikin sabon launi) da kuma nuna alama na farko don cibiyoyin sadarwar da aka riga aka horar akan imagenet. Ana buƙatar ƙarshen don auna ƙimar pixel daga tazara [0, 1] zuwa tazara [-1, 1], wanda aka yi amfani dashi lokacin horar da abubuwan da aka kawo. keras samfura.

Ayyukan waje ya ƙunshi nau'in gardama, tebur data.table tare da gauraye lambobi daga bazuwar samples_index da lambobin batch, counter da matsakaicin adadin batches, da kuma kalmar SQL don sauke bayanai daga ma'ajin bayanai. Bugu da ƙari, mun ayyana saurin analog na aikin a ciki keras::to_categorical(). Mun yi amfani da kusan dukkanin bayanan don horarwa, muna barin rabin kashi don ingantawa, don haka girman zamanin ya iyakance ta hanyar siga. steps_per_epoch lokacin da ake kira keras::fit_generator(), da yanayin if (i > max_i) ya yi aiki ne kawai don ingantaccen maimaitawa.

A cikin aikin cikin gida, ana dawo da fihirisar jeri don tsari na gaba, ana sauke bayanai daga ma'ajin bayanai tare da ma'aunin batch yana ƙaruwa, JSON parsing (aiki). cpp_process_json_vector(), an rubuta a cikin C++) da ƙirƙirar tsararru masu dacewa da hotuna. Sa'an nan kuma an ƙirƙiri vectors mai zafi guda ɗaya tare da lakabin aji, arrays tare da ƙimar pixel da alamomi ana haɗa su cikin jeri, wanda shine ƙimar dawowa. Don hanzarta aiki, mun yi amfani da ƙirƙirar fihirisa a cikin tebur data.table da gyare-gyare ta hanyar haɗin gwiwar - ba tare da waɗannan fakitin "kwakwalwa" bayanai Yana da matukar wahala a yi tunanin yin aiki yadda ya kamata tare da kowane muhimmin adadin bayanai a cikin R.

Sakamakon ma'aunin gudu akan kwamfutar tafi-da-gidanka Core i5 sune kamar haka:

Iterator benchmark

library(Rcpp)
library(keras)
library(ggplot2)

source("utils/rcpp.R")
source("utils/keras_iterator.R")

con <- DBI::dbConnect(drv = MonetDBLite::MonetDBLite(), Sys.getenv("DBDIR"))

ind <- seq_len(DBI::dbGetQuery(con, "SELECT count(*) FROM doodles")[[1L]])
num_classes <- DBI::dbGetQuery(con, "SELECT max(label_int) + 1 FROM doodles")[[1L]]

# Индексы для обучающей выборки
train_ind <- sample(ind, floor(length(ind) * 0.995))
# Индексы для проверочной выборки
val_ind <- ind[-train_ind]
rm(ind)
# Коэффициент масштаба
scale <- 0.5

# Проведение замера
res_bench <- bench::press(
  batch_size = 2^(4:10),
  {
    it1 <- train_generator(
      db_connection = con,
      samples_index = train_ind,
      num_classes = num_classes,
      batch_size = batch_size,
      scale = scale
    )
    bench::mark(
      it1(),
      min_iterations = 50L
    )
  }
)
# Параметры бенчмарка
cols <- c("batch_size", "min", "median", "max", "itr/sec", "total_time", "n_itr")
res_bench[, cols]

#   batch_size      min   median      max `itr/sec` total_time n_itr
#        <dbl> <bch:tm> <bch:tm> <bch:tm>     <dbl>   <bch:tm> <int>
# 1         16     25ms  64.36ms   92.2ms     15.9       3.09s    49
# 2         32   48.4ms 118.13ms 197.24ms     8.17       5.88s    48
# 3         64   69.3ms 117.93ms 181.14ms     8.57       5.83s    50
# 4        128  157.2ms 240.74ms 503.87ms     3.85      12.71s    49
# 5        256  359.3ms 613.52ms 988.73ms     1.54       30.5s    47
# 6        512  884.7ms    1.53s    2.07s     0.674      1.11m    45
# 7       1024     2.7s    3.83s    5.47s     0.261      2.81m    44

ggplot(res_bench, aes(x = factor(batch_size), y = median, group = 1)) +
    geom_point() +
    geom_line() +
    ylab("median time, s") +
    theme_minimal()

DBI::dbDisconnect(con, shutdown = TRUE)

Gane Doodle Mai Saurin Zana: yadda ake yin abokai tare da R, C++ da cibiyoyin sadarwar jijiyoyi

Idan kuna da isasshen adadin RAM, zaku iya hanzarta aiwatar da bayanan ta hanyar canja wurin zuwa wannan RAM ɗin (32 GB ya isa ga aikinmu). A cikin Linux, an saka ɓangaren ta tsohuwa /dev/shm, mamaye har zuwa rabin ƙarfin RAM. Kuna iya haskaka ƙarin ta hanyar gyarawa /etc/fstabdon samun rikodin kamar tmpfs /dev/shm tmpfs defaults,size=25g 0 0. Tabbatar sake kunnawa kuma duba sakamakon ta hanyar gudanar da umarni df -h.

Mai ƙididdigewa don bayanan gwaji ya fi sauƙi, tunda bayanan gwajin ya dace gabaɗaya cikin RAM:

Iterator don bayanan gwaji

test_generator <- function(dt,
                           batch_size = 32,
                           scale = 1,
                           color = FALSE,
                           imagenet_preproc = FALSE) {

  # Проверка аргументов
  checkmate::assert_data_table(dt)
  checkmate::assert_count(batch_size)
  checkmate::assert_number(scale, lower = 0.001, upper = 5)
  checkmate::assert_flag(color)
  checkmate::assert_flag(imagenet_preproc)

  # Проставляем номера батчей
  dt[, batch := (.I - 1L) %/% batch_size + 1L]
  data.table::setkey(dt, batch)
  i <- 1
  max_i <- dt[, max(batch)]

  # Замыкание
  function() {
    batch_x <- cpp_process_json_vector(dt[batch == i, drawing], 
                                       scale = scale, color = color)
    if (imagenet_preproc) {
      # Шкалирование c интервала [0, 1] на интервал [-1, 1]
      batch_x <- (batch_x - 0.5) * 2
    }
    result <- list(batch_x)
    i <<- i + 1
    return(result)
  }
}

4. Zaɓin ƙirar ƙirar ƙira

Farkon gine-ginen da aka yi amfani da shi shine wayar hannu v1, abubuwan da aka tattauna a cikin su wannan sako. An haɗa shi azaman ma'auni keras kuma, saboda haka, yana samuwa a cikin kunshin suna iri ɗaya don R. Amma lokacin ƙoƙarin yin amfani da shi tare da hotunan tashoshi guda ɗaya, wani abu mai ban mamaki ya juya: mai shigarwa dole ne ya kasance yana da girma. (batch, height, width, 3), wato, ba za a iya canza adadin tashoshi ba. Babu irin wannan iyakancewa a Python, don haka muka yi gaggawar rubuta namu aiwatar da wannan gine-gine, muna bin labarin na asali (ba tare da ficewa da ke cikin sigar keras ba):

Mobilenet v1 gine

library(keras)

top_3_categorical_accuracy <- custom_metric(
    name = "top_3_categorical_accuracy",
    metric_fn = function(y_true, y_pred) {
         metric_top_k_categorical_accuracy(y_true, y_pred, k = 3)
    }
)

layer_sep_conv_bn <- function(object, 
                              filters,
                              alpha = 1,
                              depth_multiplier = 1,
                              strides = c(2, 2)) {

  # NB! depth_multiplier !=  resolution multiplier
  # https://github.com/keras-team/keras/issues/10349

  layer_depthwise_conv_2d(
    object = object,
    kernel_size = c(3, 3), 
    strides = strides,
    padding = "same",
    depth_multiplier = depth_multiplier
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() %>%
  layer_conv_2d(
    filters = filters * alpha,
    kernel_size = c(1, 1), 
    strides = c(1, 1)
  ) %>%
  layer_batch_normalization() %>% 
  layer_activation_relu() 
}

get_mobilenet_v1 <- function(input_shape = c(224, 224, 1),
                             num_classes = 340,
                             alpha = 1,
                             depth_multiplier = 1,
                             optimizer = optimizer_adam(lr = 0.002),
                             loss = "categorical_crossentropy",
                             metrics = c("categorical_crossentropy",
                                         top_3_categorical_accuracy)) {

  inputs <- layer_input(shape = input_shape)

  outputs <- inputs %>%
    layer_conv_2d(filters = 32, kernel_size = c(3, 3), strides = c(2, 2), padding = "same") %>%
    layer_batch_normalization() %>% 
    layer_activation_relu() %>%
    layer_sep_conv_bn(filters = 64, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 128, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 256, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 512, strides = c(1, 1)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(2, 2)) %>%
    layer_sep_conv_bn(filters = 1024, strides = c(1, 1)) %>%
    layer_global_average_pooling_2d() %>%
    layer_dense(units = num_classes) %>%
    layer_activation_softmax()

    model <- keras_model(
      inputs = inputs,
      outputs = outputs
    )

    model %>% compile(
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )

    return(model)
}

Rashin amfanin wannan hanya a bayyane yake. Ina so in gwada samfurori da yawa, amma akasin haka, ba na so in sake rubuta kowane gine-gine da hannu. An kuma hana mu damar yin amfani da ma'aunin nauyi na samfuran da aka riga aka horar akan imagenet. Kamar yadda ya saba, nazarin takardun ya taimaka. Aiki get_config() yana ba ku damar samun bayanin samfurin a cikin nau'i mai dacewa don gyarawa (base_model_conf$layers - jerin R na yau da kullun), da aikin from_config() yana yin jujjuya juzu'i zuwa abin ƙira:

base_model_conf <- get_config(base_model)
base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
base_model <- from_config(base_model_conf)

Yanzu ba shi da wahala a rubuta aikin duniya don samun kowane ɗayan da aka kawo keras samfura tare da ko ba tare da ma'aunin nauyi da aka horar akan imagenet:

Aiki don loda shirye-shiryen gine-gine

get_model <- function(name = "mobilenet_v2",
                      input_shape = NULL,
                      weights = "imagenet",
                      pooling = "avg",
                      num_classes = NULL,
                      optimizer = keras::optimizer_adam(lr = 0.002),
                      loss = "categorical_crossentropy",
                      metrics = NULL,
                      color = TRUE,
                      compile = FALSE) {
  # Проверка аргументов
  checkmate::assert_string(name)
  checkmate::assert_integerish(input_shape, lower = 1, upper = 256, len = 3)
  checkmate::assert_count(num_classes)
  checkmate::assert_flag(color)
  checkmate::assert_flag(compile)

  # Получаем объект из пакета keras
  model_fun <- get0(paste0("application_", name), envir = asNamespace("keras"))
  # Проверка наличия объекта в пакете
  if (is.null(model_fun)) {
    stop("Model ", shQuote(name), " not found.", call. = FALSE)
  }

  base_model <- model_fun(
    input_shape = input_shape,
    include_top = FALSE,
    weights = weights,
    pooling = pooling
  )

  # Если изображение не цветное, меняем размерность входа
  if (!color) {
    base_model_conf <- keras::get_config(base_model)
    base_model_conf$layers[[1]]$config$batch_input_shape[[4]] <- 1L
    base_model <- keras::from_config(base_model_conf)
  }

  predictions <- keras::get_layer(base_model, "global_average_pooling2d_1")$output
  predictions <- keras::layer_dense(predictions, units = num_classes, activation = "softmax")
  model <- keras::keras_model(
    inputs = base_model$input,
    outputs = predictions
  )

  if (compile) {
    keras::compile(
      object = model,
      optimizer = optimizer,
      loss = loss,
      metrics = metrics
    )
  }

  return(model)
}

Lokacin amfani da hotunan tashoshi ɗaya, ba a yi amfani da ma'aunin nauyi da aka riga aka yi amfani da shi ba. Ana iya gyara wannan: ta amfani da aikin get_weights() sami ma'aunin ƙira a cikin nau'i na jerin jerin R, canza girman kashi na farko na wannan jerin (ta hanyar ɗaukar tashar launi ɗaya ko matsakaicin duka uku), sa'an nan kuma mayar da ma'aunin nauyi a cikin samfurin tare da aikin. set_weights(). Ba mu taɓa ƙara wannan aikin ba, saboda a wannan matakin ya riga ya bayyana cewa ya fi dacewa don yin aiki tare da hotuna masu launi.

Mun gudanar da yawancin gwaje-gwajen ta amfani da nau'ikan wayar hannu 1 da 2, da kuma resnet34. Ƙarin gine-gine na zamani kamar SE-ResNeXt sun yi kyau a wannan gasar. Abin takaici, ba mu da shirye-shiryen aiwatarwa a hannunmu, kuma ba mu rubuta namu ba (amma tabbas za mu rubuta).

5. Daidaita rubutun

Don dacewa, duk lambar don farawa horo an tsara su azaman rubutun guda ɗaya, wanda aka daidaita ta amfani da shi doka kamar haka:

doc <- '
Usage:
  train_nn.R --help
  train_nn.R --list-models
  train_nn.R [options]

Options:
  -h --help                   Show this message.
  -l --list-models            List available models.
  -m --model=<model>          Neural network model name [default: mobilenet_v2].
  -b --batch-size=<size>      Batch size [default: 32].
  -s --scale-factor=<ratio>   Scale factor [default: 0.5].
  -c --color                  Use color lines [default: FALSE].
  -d --db-dir=<path>          Path to database directory [default: Sys.getenv("db_dir")].
  -r --validate-ratio=<ratio> Validate sample ratio [default: 0.995].
  -n --n-gpu=<number>         Number of GPUs [default: 1].
'
args <- docopt::docopt(doc)

Kunshin doka wakiltar aiwatarwa http://docopt.org/ don R. Tare da taimakonsa, ana ƙaddamar da rubutun tare da umarni masu sauƙi kamar Rscript bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db ko ./bin/train_nn.R -m resnet50 -c -d /home/andrey/doodle_db, idan fayil train_nn.R ana iya aiwatarwa (wannan umarni zai fara horar da ƙirar resnet50 akan hotuna masu launi uku masu auna pixels 128x128, dole ne a kasance tushen bayanan a cikin babban fayil ɗin /home/andrey/doodle_db). Kuna iya ƙara saurin koyo, nau'in ingantawa, da duk wasu sigogin da za'a iya daidaita su zuwa lissafin. A cikin shirye-shiryen da aka buga, ya nuna cewa gine-ginen mobilenet_v2 daga sigar yanzu keras a cikin R amfani ba zai iya ba saboda canje-canjen da ba a la'akari da su a cikin kunshin R, muna jiran su gyara shi.

Wannan tsarin ya ba da damar hanzarta gwaje-gwaje tare da samfura daban-daban idan aka kwatanta da mafi yawan ƙaddamar da rubutun al'ada a cikin RStudio (mun lura da kunshin azaman madadin mai yiwuwa. tfruns). Amma babban fa'ida shine ikon iya sarrafa ƙaddamar da rubutun a cikin Docker ko kawai akan sabar, ba tare da shigar da RStudio ba don wannan.

6. Dockerization na rubutun

Mun yi amfani da Docker don tabbatar da ɗaukar yanayi don samfuran horo tsakanin membobin ƙungiyar da kuma saurin turawa cikin gajimare. Kuna iya fara sanin wannan kayan aiki, wanda ba a saba gani ba ga mai shirye-shiryen R, tare da wannan jerin wallafe-wallafe ko bidiyo shakka.

Docker yana ba ku damar ƙirƙirar hotunan ku daga karce kuma amfani da wasu hotuna azaman tushen ƙirƙirar naku. Lokacin nazarin zaɓuɓɓukan da ake da su, mun kai ga ƙarshe cewa shigar da direbobin NVIDIA, CUDA+cuDNN da dakunan karatu na Python wani yanki ne mai girman gaske na hoton, kuma mun yanke shawarar ɗaukar hoton hukuma azaman tushe. tensorflow/tensorflow:1.12.0-gpu, ƙara abubuwan da ake buƙata R a can.

Fayil ɗin docker na ƙarshe yayi kama da haka:

Dockerfile

FROM tensorflow/tensorflow:1.12.0-gpu

MAINTAINER Artem Klevtsov <[email protected]>

SHELL ["/bin/bash", "-c"]

ARG LOCALE="en_US.UTF-8"
ARG APT_PKG="libopencv-dev r-base r-base-dev littler"
ARG R_BIN_PKG="futile.logger checkmate data.table rcpp rapidjsonr dbi keras jsonlite curl digest remotes"
ARG R_SRC_PKG="xtensor RcppThread docopt MonetDBLite"
ARG PY_PIP_PKG="keras"
ARG DIRS="/db /app /app/data /app/models /app/logs"

RUN source /etc/os-release && 
    echo "deb https://cloud.r-project.org/bin/linux/ubuntu ${UBUNTU_CODENAME}-cran35/" > /etc/apt/sources.list.d/cran35.list && 
    apt-key adv --keyserver keyserver.ubuntu.com --recv-keys E084DAB9 && 
    add-apt-repository -y ppa:marutter/c2d4u3.5 && 
    add-apt-repository -y ppa:timsc/opencv-3.4 && 
    apt-get update && 
    apt-get install -y locales && 
    locale-gen ${LOCALE} && 
    apt-get install -y --no-install-recommends ${APT_PKG} && 
    ln -s /usr/lib/R/site-library/littler/examples/install.r /usr/local/bin/install.r && 
    ln -s /usr/lib/R/site-library/littler/examples/install2.r /usr/local/bin/install2.r && 
    ln -s /usr/lib/R/site-library/littler/examples/installGithub.r /usr/local/bin/installGithub.r && 
    echo 'options(Ncpus = parallel::detectCores())' >> /etc/R/Rprofile.site && 
    echo 'options(repos = c(CRAN = "https://cloud.r-project.org"))' >> /etc/R/Rprofile.site && 
    apt-get install -y $(printf "r-cran-%s " ${R_BIN_PKG}) && 
    install.r ${R_SRC_PKG} && 
    pip install ${PY_PIP_PKG} && 
    mkdir -p ${DIRS} && 
    chmod 777 ${DIRS} && 
    rm -rf /tmp/downloaded_packages/ /tmp/*.rds && 
    rm -rf /var/lib/apt/lists/*

COPY utils /app/utils
COPY src /app/src
COPY tests /app/tests
COPY bin/*.R /app/

ENV DBDIR="/db"
ENV CUDA_HOME="/usr/local/cuda"
ENV PATH="/app:${PATH}"

WORKDIR /app

VOLUME /db
VOLUME /app

CMD bash

Don dacewa, fakitin da aka yi amfani da su an sanya su cikin masu canji; Yawancin rubutun rubuce-rubuce ana kwafi su a cikin kwantena yayin taro. Mun kuma canza harsashin umarni zuwa /bin/bash don sauƙin amfani da abun ciki /etc/os-release. Wannan ya kauce wa buƙatar tantance sigar OS a cikin lambar.

Bugu da ƙari, an rubuta ƙaramin rubutun bash wanda ke ba ku damar ƙaddamar da akwati tare da umarni daban-daban. Misali, waɗannan na iya zama rubutun don horar da hanyoyin sadarwa na jijiyoyi waɗanda a baya aka sanya su a cikin akwati, ko harsashi na umarni don gyarawa da sa ido kan aikin gandun:

Rubutun don ƙaddamar da akwati

#!/bin/sh

DBDIR=${PWD}/db
LOGSDIR=${PWD}/logs
MODELDIR=${PWD}/models
DATADIR=${PWD}/data
ARGS="--runtime=nvidia --rm -v ${DBDIR}:/db -v ${LOGSDIR}:/app/logs -v ${MODELDIR}:/app/models -v ${DATADIR}:/app/data"

if [ -z "$1" ]; then
    CMD="Rscript /app/train_nn.R"
elif [ "$1" = "bash" ]; then
    ARGS="${ARGS} -ti"
else
    CMD="Rscript /app/train_nn.R $@"
fi

docker run ${ARGS} doodles-tf ${CMD}

Idan an gudanar da wannan rubutun bash ba tare da sigogi ba, za a kira rubutun a cikin akwati train_nn.R tare da dabi'u na asali; idan hujjar matsayi na farko shine "bash", to kwandon zai fara hulɗa tare da harsashi na umarni. A duk sauran lokuta, ana musanya dabi'u na muhawarar matsayi: CMD="Rscript /app/train_nn.R $@".

Yana da mahimmanci a lura cewa kundayen adireshi tare da bayanan tushen da bayanan bayanai, da kuma kundin adireshi don adana samfuran horarwa, an ɗora su a cikin akwati daga tsarin runduna, wanda ke ba ku damar samun damar sakamakon rubutun ba tare da magudi ba.

7. Amfani da GPUs da yawa akan Google Cloud

Ɗaya daga cikin fasalulluka na gasar shine bayanai masu yawan hayaniya (duba hoton take, aro daga @Leigh.plt daga ODS slack). Manyan batches suna taimakawa wajen yaƙar wannan, kuma bayan gwaje-gwaje akan PC mai 1 GPU, mun yanke shawarar ƙware samfuran horo akan GPUs da yawa a cikin gajimare. Amfani da GoogleCloud (mai kyau jagora ga asali) saboda babban zaɓi na gyare-gyaren da aka samu, farashi masu dacewa da $ 300 bonus. Saboda kwadayi, na ba da umarnin misali 4xV100 tare da SSD da tan na RAM, kuma wannan kuskure ne babba. Irin wannan injin yana cinye kuɗi da sauri; za ku iya tafiya karya gwaji ba tare da tabbataccen bututun mai ba. Don dalilai na ilimi, yana da kyau a ɗauki K80. Amma babban adadin RAM ya zo da amfani - Cloud SSD bai burge aikinsa ba, don haka an canja wurin bayanan zuwa ga bayanai. dev/shm.

Babban abin sha'awa shine guntun lambar da ke da alhakin amfani da GPUs da yawa. Na farko, an ƙirƙiri samfurin akan CPU ta amfani da mai sarrafa mahallin, kamar a cikin Python:

with(tensorflow::tf$device("/cpu:0"), {
  model_cpu <- get_model(
    name = model_name,
    input_shape = input_shape,
    weights = weights,
    metrics =(top_3_categorical_accuracy,
    compile = FALSE
  )
})

Sa'an nan samfurin da ba a tattara ba (wannan yana da mahimmanci) ana kwafi zuwa adadin da aka bayar na GPUs, kuma bayan haka an haɗa shi:

model <- keras::multi_gpu_model(model_cpu, gpus = n_gpu)
keras::compile(
  object = model,
  optimizer = keras::optimizer_adam(lr = 0.0004),
  loss = "categorical_crossentropy",
  metrics = c(top_3_categorical_accuracy)
)

Ba za a iya aiwatar da dabarar daskare duk yadudduka ba sai ta ƙarshe, horar da Layer na ƙarshe, cire daskarewa da sake horar da dukkan ƙirar don GPUs da yawa.

An kula da horo ba tare da amfani ba. tensorboard, iyakance kanmu don yin rikodin rajista da adana samfura tare da sunaye masu ba da labari bayan kowane zamani:

Maimaitawa

# Шаблон имени файла лога
log_file_tmpl <- file.path("logs", sprintf(
  "%s_%d_%dch_%s.csv",
  model_name,
  dim_size,
  channels,
  format(Sys.time(), "%Y%m%d%H%M%OS")
))
# Шаблон имени файла модели
model_file_tmpl <- file.path("models", sprintf(
  "%s_%d_%dch_{epoch:02d}_{val_loss:.2f}.h5",
  model_name,
  dim_size,
  channels
))

callbacks_list <- list(
  keras::callback_csv_logger(
    filename = log_file_tmpl
  ),
  keras::callback_early_stopping(
    monitor = "val_loss",
    min_delta = 1e-4,
    patience = 8,
    verbose = 1,
    mode = "min"
  ),
  keras::callback_reduce_lr_on_plateau(
    monitor = "val_loss",
    factor = 0.5, # уменьшаем lr в 2 раза
    patience = 4,
    verbose = 1,
    min_delta = 1e-4,
    mode = "min"
  ),
  keras::callback_model_checkpoint(
    filepath = model_file_tmpl,
    monitor = "val_loss",
    save_best_only = FALSE,
    save_weights_only = FALSE,
    mode = "min"
  )
)

8. Maimakon gamawa

Matsaloli da dama da muka fuskanta har yanzu ba a shawo kansu ba:

  • в keras babu wani aikin da aka shirya don bincika mafi kyawun ƙimar koyo ta atomatik (analogue lr_finder a cikin ɗakin karatu saurin.ai); Tare da ɗan ƙoƙari, yana yiwuwa a aika da aiwatarwa na ɓangare na uku zuwa R, misali, wannan;
  • sakamakon abin da ya gabata, ba zai yiwu a zaɓi saurin horo daidai lokacin amfani da GPUs da yawa ba;
  • akwai rashin tsarin gine-ginen jijiyoyi na zamani, musamman waɗanda aka riga aka horar da su akan imagenet;
  • babu wani tsarin sake zagayowar tsarin da ƙimar koyo na wariya (ɓarɓarwar cosine ya kasance bisa buƙatarmu aiwatar, godiya skydan).

Wadanne abubuwa masu amfani aka koya daga wannan gasa:

  • A kan ingantacciyar kayan aiki mara ƙarfi, zaku iya aiki tare da ƙima (sau da yawa girman girman RAM) kundin bayanai ba tare da jin zafi ba. Jakar filastik bayanai yana adana ƙwaƙwalwar ajiya saboda gyare-gyaren tebur a wuri, wanda ke guje wa kwafin su, kuma idan aka yi amfani da shi daidai, ƙarfinsa kusan koyaushe yana nuna mafi girman gudu tsakanin duk kayan aikin da aka sani da mu don rubutun harsuna. Ajiye bayanai a cikin rumbun adana bayanai yana ba ku damar, a yawancin lokuta, kada ku yi tunani kwata-kwata game da buƙatun matse dukkan bayanan cikin RAM.
  • Za a iya maye gurbin ayyukan jinkiri a cikin R tare da masu sauri a cikin C++ ta amfani da kunshin Rcpp. Idan ban da amfani RcppThread ko RcppParallel, Muna samun aiwatar da aiwatar da zaren da yawa, don haka babu buƙatar daidaita lambar a matakin R.
  • Kunshin Rcpp ana iya amfani da shi ba tare da sanin mahimmancin C ++ ba, an bayyana mafi ƙarancin da ake buƙata a nan. Fayilolin rubutu don adadin ɗakunan karatu na C masu sanyi kamar xtensor akwai akan CRAN, wato, ana samar da ababen more rayuwa don aiwatar da ayyukan da ke haɗa lambar C++ da aka ƙera a cikin R. Ƙarin dacewa shine haɓakawa na syntax da kuma madaidaicin lambar C++ a cikin RStudio.
  • doka yana ba ku damar gudanar da rubutun kai tsaye tare da sigogi. Wannan ya dace don amfani akan uwar garken nesa, gami da. karkashin docker. A cikin RStudio, yana da wahala a gudanar da gwaje-gwaje na sa'o'i da yawa tare da hanyoyin sadarwar jijiyoyi, kuma shigar da IDE akan sabar kanta ba koyaushe bane barata.
  • Docker yana tabbatar da ɗaukar hoto da sake fasalin sakamako tsakanin masu haɓakawa tare da nau'ikan OS da ɗakunan karatu daban-daban, da kuma sauƙin aiwatarwa akan sabobin. Kuna iya ƙaddamar da duka bututun horo tare da umarni ɗaya kawai.
  • Google Cloud hanya ce mai dacewa da kasafin kuɗi don gwaji akan kayan masarufi masu tsada, amma kuna buƙatar zaɓar saiti a hankali.
  • Auna saurin gutsuttsyoyin lambar guda ɗaya yana da amfani sosai, musamman lokacin haɗa R da C++, tare da fakitin. benci - kuma mai sauqi qwarai.

Gabaɗaya wannan ƙwarewar tana da lada sosai kuma muna ci gaba da yin aiki don warware wasu batutuwan da aka taso.

source: www.habr.com

Add a comment