summaryrefslogtreecommitdiff
path: root/lib/supple/comms.lua
blob: efe737f0c2fb8fc3e31cdabb40ca56ad1ea7e116 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
-- lib/supple/comms.lua
--
-- Sandbox (for) Untrusted Procedure Partitioning (in) Lua Engine
--
-- Management of communications between host and sandbox
--
-- Copyright 2012 Daniel Silverstone <dsilvers@digital-scurf.org>
--
-- For licence terms, see COPYING
--

local luxio = require "luxio"

local capi = require "supple.capi"
local request = require "supple.request"
local objects = require "supple.objects"

local unpack = unpack
local tonumber = tonumber
local error = error

local fd = -1

local function set_fd(_fd)
   fd = _fd
end

local function send_msg(msg)
   if (#msg > 99999) then
      error("Message too long")
   end
   local msglen = ("%05d"):format(#msg)
   luxio.write(fd, msglen .. msg)
end

local function recv_msg()
   local len = luxio.read(fd, 5)
   if #len < 5 then
      error("Unable to read 5 byte length")
   end
   len = tonumber(len)
   if len == nil or len < 1 or len > 99999 then
      error("Odd, len didn't translate properly")
   end
   local str = luxio.read(fd, len)
   if type(str) ~= "string" or #str ~= len then
      error("Unable to read " .. tostring(len) .. " bytes of msg")
   end
   return str
end

local function wait_for_response()
   local back = request.deserialise(recv_msg())
   -- back could be three things
   -- an error (raise it)
   if back.error then
      error(back.message .. "\n" .. back.traceback)
   end
   -- A result, return it
   if back.error == false then
      return unpack(back.results)
   end
   -- A method call, call it

   local function safe_method(fn)
      local ok, res = pcall(fn)
      local resp
      if not ok then
	 resp = request.error(res, "")
      else
	 resp = request.response(unpack(res))
      end
      send_msg(resp)
   end
   if back.method == "__gc" then
      -- __gc is the garbage collect mechanism
      objects.forget_mine(back.object)
      send_msg(request.response())
   elseif back.method == "__call" then
      -- __call is the function call mechanism
      safe_method(function()
		     local obj = objects.receive { tag = back.object }
		     return {obj(unpack(back.args))}
		  end)
   elseif back.method == "__len" then
      safe_method(function()
		     local obj = objects.receive { tag = back.object }
		     return {#obj}
		  end)
   elseif back.method == "__index" then
      safe_method(function()
		     local obj = objects.receive { tag = back.object }
		     return {obj[back.args[1]]}
		  end)
   elseif back.method == "__newindex" then
      safe_method(function()
		     local obj = objects.receive { tag = back.object }
		     obj[back.args[1]] = back.args[2]
		     return {}
		  end)
   else
      safe_method(function()
		     local obj = objects.receive { tag = back.object }
		     local meth = capi.raw_getmm(obj, back.method)
		     if not meth then
			error("Unknown or disallowed method: " .. back.method, "")
		     end
		     return {meth(obj, unpack(back.args))}
		  end)
   end
   -- And try again
   return wait_for_response()
end

local function make_call(object, method, ...)
   local req = request.request(object, method, ...)
   send_msg(req)
   return wait_for_response()
end

return {
   call = make_call,
   _wait = wait_for_response,
   _set_fd = set_fd,
}