Inference.lua
local nn = require('nn')
local network = require('arcades.network')
return function(args)
args = args or {}
args.input_size = args.input_size or {1, 256, 256}
args.conv_layers = args.conv_layers or {
{
n_filters = 16,
field_size = {width = 8, height = 8},
stride = {width = 4, height = 4},
zero_padding = {width = 2, height = 2}
},
{
n_filters = 32,
field_size = {width = 6, height = 6},
stride = {width = 3, height = 3},
zero_padding = {width = 1, height = 1}
},
{
n_filters = 64,
field_size = {width = 3, height = 3},
stride = {width = 2, height = 2},
zero_padding = {width = 0, height = 0}
},
{
n_filters = 64,
field_size = {width = 3, height = 3},
stride = {width = 1, height = 1},
zero_padding = {width = 0, height = 0}
}
}
args.fc_layers = args.fc_layers or {
512
}
args.nl_layer = args.nl_layer or "ReLU"
args.output_size = args.output_size or {1, 33}
return network.create_network(args)
end