Arcades

Downsample.lua

--- A preprocessing network to downsample a given image.
-- This preprocessing network will scale the image and convert it to Y color space
-- @module network.Downsample
-- @see Downsample.lua
-- @see network
-- @usage local network = require("network.Downsample")(args)
-- @author Alexis BRENON <alexis.brenon@imag.fr>

local nn = require('nn')
local image = require('image')
local torch = require('torch')

assert(nn.Module)
local Scale, __module = torch.class('nn.Scale', 'nn.Module')
local Rgb2Y = torch.class('nn.Rgb2Y', 'nn.Module')

function Scale:__init(height, width)
  __module.__init(self)
  self.height = height
  self.width = width
end

function Scale:updateOutput(input)
  return image.scale(input, self.width, self.height, 'bilinear')
end

-- luacheck: push no self
function Rgb2Y:updateOutput(input)
  return image.rgb2y(input)
end
-- luacheck: pop

return function(args)
  args = args or {}
  args.scale_size = args.scale_size or {256, 256}
  return nn.Sequential():add(nn.Rgb2Y()):add(nn.Scale(unpack(args.scale_size)))
end