Arcades

Inference.lua

--- A convolutional neural network for image processing.
-- @module network.Inference
-- @see Inference.lua
-- @see network
-- @usage local network = require("network.Inference")(args)
-- @author Alexis BRENON <alexis.brenon@imag.fr>

local nn = require('nn')

local network = require('arcades.network')

return function(args)
  args = args or {}
  args.input_size = args.input_size or {1, 256, 256}
  args.conv_layers = args.conv_layers or {
    {
      n_filters = 16,
      field_size = {width = 8, height = 8},
      stride = {width = 4, height = 4},
      zero_padding = {width = 2, height = 2}
    },
    {
      n_filters = 32,
      field_size = {width = 6, height = 6},
      stride = {width = 3, height = 3},
      zero_padding = {width = 1, height = 1}
    },
    {
      n_filters = 64,
      field_size = {width = 3, height = 3},
      stride = {width = 2, height = 2},
      zero_padding = {width = 0, height = 0}
    },
    {
      n_filters = 64,
      field_size = {width = 3, height = 3},
      stride = {width = 1, height = 1},
      zero_padding = {width = 0, height = 0}
    }
  }
  args.fc_layers = args.fc_layers or {
    512
  }
  args.nl_layer = args.nl_layer or "ReLU"
  args.output_size = args.output_size or {1, 33}

  return network.create_network(args)
end