util-soc-imp-copas.lua /size: 25 Kb    last modification: 2020-07-01 14:35
1-- original file : copas.lua
2-- for more info : see util-soc.lua
3-- copyright     : see below
4-- comment       : this version is a it cleaned up and adapted
5
6-- there is an official update but i'll wait till it is stable before i check
7-- it out (after all what we have now seems to work ok)
8
9local socket = socket or require("socket")
10local ssl    = ssl or nil -- only loaded upon demand
11
12local WATCH_DOG_TIMEOUT =  120
13local UDP_DATAGRAM_MAX  = 8192
14
15local type, next, pcall, getmetatable, tostring = type, next, pcall, getmetatable, tostring
16local min, max, random = math.min, math.max, math.random
17local find = string.find
18local insert, remove = table.insert, table.remove
19
20local gettime          = socket.gettime
21local selectsocket     = socket.select
22
23local createcoroutine  = coroutine.create
24local resumecoroutine  = coroutine.resume
25local yieldcoroutine   = coroutine.yield
26local runningcoroutine = coroutine.running
27
28-- Redefines LuaSocket functions with coroutine safe versions (this allows the use
29-- of socket.http from within copas).
30
31-- Meta information is public even if beginning with an "_"
32
33local function report(fmt,first,...)
34    if logs then
35        report = logs and logs.reporter("copas")
36        report(fmt,first,...)
37    elseif fmt then
38        fmt = "copas: " .. fmt
39        if first then
40            print(format(fmt,first,...))
41        else
42            print(fmt)
43        end
44    end
45end
46
47local copas = {
48
49    _COPYRIGHT   = "Copyright (C) 2005-2016 Kepler Project",
50    _DESCRIPTION = "Coroutine Oriented Portable Asynchronous Services",
51    _VERSION     = "Copas 2.0.1",
52
53    autoclose    = true,
54    running      = false,
55
56    report       = report,
57
58    trace        = false,
59
60}
61
62local function statushandler(status, ...)
63    if status then
64        return ...
65    end
66    local err = (...)
67    if type(err) == "table" then
68        err = err[1]
69    end
70    if copas.trace then
71        report("error: %s",tostring(err))
72    end
73    return nil, err
74end
75
76function socket.protect(func)
77    return function(...)
78        return statushandler(pcall(func,...))
79    end
80end
81
82function socket.newtry(finalizer)
83    return function (...)
84        local status = (...)
85        if not status then
86            local detail = select(2,...)
87            pcall(finalizer,detail)
88            if copas.trace then
89                report("error: %s",tostring(detail))
90            end
91            return
92        end
93        return ...
94    end
95end
96
97-- Simple set implementation based on LuaSocket's tinyirc.lua example
98-- adds a FIFO queue for each value in the set
99
100local function newset()
101    local reverse = { }
102    local set     = { }
103    local queue   = { }
104    setmetatable(set, {
105        __index = {
106            insert =
107                function(set, value)
108                    if not reverse[value] then
109                        local n = #set +1
110                        set[n] = value
111                        reverse[value] = n
112                    end
113                end,
114            remove =
115                function(set, value)
116                    local index = reverse[value]
117                    if index then
118                        reverse[value] = nil
119                        local n  = #set
120                        local top = set[n]
121                        set[n] = nil
122                        if top ~= value then
123                            reverse[top] = index
124                            set[index]   = top
125                        end
126                    end
127                end,
128            push =
129                function (set, key, itm)
130                    local entry = queue[key]
131                    if entry == nil then -- hm can it be false then?
132                        queue[key] = { itm }
133                    else
134                        entry[#entry + 1] = itm
135                    end
136                end,
137            pop =
138                function (set, key)
139                    local top = queue[key]
140                    if top ~= nil then
141                        local ret = remove(top,1)
142                        if top[1] == nil then
143                            queue[key] = nil
144                        end
145                        return ret
146                    end
147                end
148        }
149    } )
150    return set
151end
152
153local _sleeping = {
154    times    = { }, -- list with wake-up times
155    cos      = { }, -- list with coroutines, index matches the 'times' list
156    lethargy = { }, -- list of coroutines sleeping without a wakeup time
157
158    insert =
159        function()
160        end,
161    remove =
162        function()
163        end,
164    push =
165        function(self, sleeptime, co)
166            if not co then
167                return
168            end
169            if sleeptime < 0 then
170                --sleep until explicit wakeup through copas.wakeup
171                self.lethargy[co] = true
172                return
173            else
174                sleeptime = gettime() + sleeptime
175            end
176            local t = self.times
177            local c = self.cos
178            local i = 1
179            local n = #t
180            while i <= n and t[i] <= sleeptime do
181                i = i + 1
182            end
183            insert(t,i,sleeptime)
184            insert(c,i,co)
185        end,
186    getnext =
187        -- returns delay until next sleep expires, or nil if there is none
188        function(self)
189            local t = self.times
190            local delay = t[1] and t[1] - gettime() or nil
191            return delay and max(delay, 0) or nil
192        end,
193    pop =
194        -- find the thread that should wake up to the time
195        function(self, time)
196            local t = self.times
197            local c = self.cos
198            if #t == 0 or time < t[1] then
199                return
200            end
201            local co = c[1]
202            remove(t,1)
203            remove(c,1)
204            return co
205        end,
206        wakeup =
207            function(self, co)
208                local let = self.lethargy
209                if let[co] then
210                    self:push(0, co)
211                    let[co] = nil
212                else
213                    local c = self.cos
214                    local t = self.times
215                    for i=1,#c do
216                        if c[i] == co then
217                            remove(c,i)
218                            remove(t,i)
219                            self:push(0, co)
220                            return
221                        end
222                    end
223                end
224            end
225}
226
227local _servers     = newset() -- servers being handled
228local _reading     = newset() -- sockets currently being read
229local _writing     = newset() -- sockets currently being written
230
231local _reading_log = { }
232local _writing_log = { }
233
234local _is_timeout  = {        -- set of errors indicating a timeout
235    timeout   = true,         -- default LuaSocket timeout
236    wantread  = true,         -- LuaSec specific timeout
237    wantwrite = true,         -- LuaSec specific timeout
238}
239
240-- Coroutine based socket I/O functions.
241
242local function isTCP(socket)
243    return not find(tostring(socket),"^udp")
244end
245
246-- Reads a pattern from a client and yields to the reading set on timeouts UDP: a
247-- UDP socket expects a second argument to be a number, so it MUST be provided as
248-- the 'pattern' below defaults to a string. Will throw a 'bad argument' error if
249-- omitted.
250
251local function copasreceive(client, pattern, part)
252    if not pattern or pattern == "" then
253        pattern = "*l"
254    end
255    local current_log = _reading_log
256    local s, err
257    repeat
258        s, err, part = client:receive(pattern, part)
259        if s or (not _is_timeout[err]) then
260            current_log[client] = nil
261            return s, err, part
262        end
263        if err == "wantwrite" then
264            current_log         = _writing_log
265            current_log[client] = gettime()
266            yieldcoroutine(client, _writing)
267        else
268            current_log         = _reading_log
269            current_log[client] = gettime()
270            yieldcoroutine(client, _reading)
271        end
272    until false
273end
274
275-- Receives data from a client over UDP. Not available for TCP. (this is a copy of
276-- receive() method, adapted for receivefrom() use).
277
278local function copasreceivefrom(client, size)
279    local s, err, port
280    if not size or size == 0 then
281        size = UDP_DATAGRAM_MAX
282    end
283    repeat
284        -- upon success err holds ip address
285        s, err, port = client:receivefrom(size)
286        if s or err ~= "timeout" then
287            _reading_log[client] = nil
288            return s, err, port
289        end
290        _reading_log[client] = gettime()
291        yieldcoroutine(client, _reading)
292    until false
293end
294
295-- Same as above but with special treatment when reading chunks, unblocks on any
296-- data received.
297
298local function copasreceivepartial(client, pattern, part)
299    if not pattern or pattern == "" then
300        pattern = "*l"
301    end
302    local logger = _reading_log
303    local queue  = _reading
304    local s, err
305    repeat
306        s, err, part = client:receive(pattern, part)
307        if s or (type(pattern) == "number" and part ~= "" and part) or not _is_timeout[err] then
308          logger[client] = nil
309          return s, err, part
310        end
311        if err == "wantwrite" then
312            logger = _writing_log
313            queue  = _writing
314        else
315            logger = _reading_log
316            queue  = _reading
317        end
318        logger[client] = gettime()
319        yieldcoroutine(client, queue)
320    until false
321end
322
323-- Sends data to a client. The operation is buffered and yields to the writing set
324-- on timeouts Note: from and to parameters will be ignored by/for UDP sockets
325
326local function copassend(client, data, from, to)
327    if not from then
328        from = 1
329    end
330    local lastIndex = from - 1
331    local logger = _writing_log
332    local queue  = _writing
333    local s, err
334    repeat
335        s, err, lastIndex = client:send(data, lastIndex + 1, to)
336        -- Adds extra coroutine swap and garantees that high throughput doesn't take
337        -- other threads to starvation.
338        if random(100) > 90 then
339            logger[client] = gettime()
340            yieldcoroutine(client, queue)
341        end
342        if s or not _is_timeout[err] then
343            logger[client] = nil
344            return s, err,lastIndex
345        end
346        if err == "wantread" then
347            logger = _reading_log
348            queue  = _reading
349        else
350            logger = _writing_log
351            queue  = _writing
352        end
353        logger[client] = gettime()
354        yieldcoroutine(client, queue)
355    until false
356end
357
358-- Sends data to a client over UDP. Not available for TCP. (this is a copy of send()
359-- method, adapted for sendto() use).
360
361local function copassendto(client, data, ip, port)
362    repeat
363        local s, err = client:sendto(data, ip, port)
364        -- Adds extra coroutine swap and garantees that high throughput doesn't
365        -- take other threads to starvation.
366        if random(100) > 90 then
367            _writing_log[client] = gettime()
368            yieldcoroutine(client, _writing)
369        end
370        if s or err ~= "timeout" then
371            _writing_log[client] = nil
372            return s, err
373        end
374        _writing_log[client] = gettime()
375        yieldcoroutine(client, _writing)
376    until false
377end
378
379-- Waits until connection is completed.
380
381local function copasconnect(skt, host, port)
382    skt:settimeout(0)
383    local ret, err, tried_more_than_once
384    repeat
385        ret, err = skt:connect (host, port)
386        -- A non-blocking connect on Windows results in error "Operation already in
387        -- progress" to indicate that it is completing the request async. So
388        -- essentially it is the same as "timeout".
389        if ret or (err ~= "timeout" and err ~= "Operation already in progress") then
390            -- Once the async connect completes, Windows returns the error "already
391            -- connected" to indicate it is done, so that error should be ignored.
392            -- Except when it is the first call to connect, then it was already
393            -- connected to something else and the error should be returned.
394            if not ret and err == "already connected" and tried_more_than_once then
395                ret = 1
396                err = nil
397            end
398            _writing_log[skt] = nil
399            return ret, err
400        end
401        tried_more_than_once = tried_more_than_once or true
402        _writing_log[skt]    = gettime()
403        yieldcoroutine(skt, _writing)
404    until false
405end
406
407-- Peforms an (async) ssl handshake on a connected TCP client socket. Replacec all
408-- previous socket references, with the returned new ssl wrapped socket Throws error
409-- and does not return nil+error, as that might silently fail in code like this.
410
411local function copasdohandshake(skt, sslt) -- extra ssl parameters
412    if not ssl then
413        ssl = require("ssl")
414    end
415    if not ssl then
416        report("error: no ssl library")
417        return
418    end
419    local nskt, err = ssl.wrap(skt, sslt)
420    if not nskt then
421        report("error: %s",tostring(err))
422        return
423    end
424    nskt:settimeout(0)
425    local queue
426    repeat
427        local success, err = nskt:dohandshake()
428        if success then
429            return nskt
430        elseif err == "wantwrite" then
431            queue = _writing
432        elseif err == "wantread" then
433            queue = _reading
434        else
435            report("error: %s",tostring(err))
436            return
437        end
438        yieldcoroutine(nskt, queue)
439    until false
440end
441
442-- Flushes a client write buffer.
443
444local function copasflush(client)
445end
446
447-- Public.
448
449copas.connect             = copassconnect
450copas.send                = copassend
451copas.sendto              = copassendto
452copas.receive             = copasreceive
453copas.receivefrom         = copasreceivefrom
454copas.copasreceivepartial = copasreceivepartial
455copas.copasreceivePartial = copasreceivepartial
456copas.dohandshake         = copasdohandshake
457copas.flush               = copasflush
458
459-- Wraps a TCP socket to use Copas methods (send, receive, flush and settimeout).
460
461local function _skt_mt_tostring(self)
462    return tostring(self.socket) .. " (copas wrapped)"
463end
464
465local _skt_mt_tcp_index = {
466    send =
467        function(self, data, from, to)
468            return copassend (self.socket, data, from, to)
469        end,
470    receive =
471        function (self, pattern, prefix)
472            if self.timeout == 0 then
473                return copasreceivePartial(self.socket, pattern, prefix)
474            else
475                return copasreceive(self.socket, pattern, prefix)
476            end
477        end,
478
479    flush =
480        function (self)
481            return copasflush(self.socket)
482        end,
483
484    settimeout =
485        function (self, time)
486            self.timeout = time
487            return true
488        end,
489    -- TODO: socket.connect is a shortcut, and must be provided with an alternative
490    -- if ssl parameters are available, it will also include a handshake
491    connect =
492        function(self, ...)
493            local res, err = copasconnect(self.socket, ...)
494            if res and self.ssl_params then
495                res, err = self:dohandshake()
496            end
497            return res, err
498        end,
499    close =
500        function(self, ...)
501            return self.socket:close(...)
502        end,
503    -- TODO: socket.bind is a shortcut, and must be provided with an alternative
504    bind =
505        function(self, ...)
506            return self.socket:bind(...)
507        end,
508    -- TODO: is this DNS related? hence blocking?
509    getsockname =
510        function(self, ...)
511            return self.socket:getsockname(...)
512        end,
513    getstats =
514        function(self, ...)
515            return self.socket:getstats(...)
516        end,
517    setstats =
518        function(self, ...)
519            return self.socket:setstats(...)
520        end,
521    listen =
522        function(self, ...)
523            return self.socket:listen(...)
524        end,
525    accept =
526        function(self, ...)
527            return self.socket:accept(...)
528        end,
529    setoption =
530        function(self, ...)
531            return self.socket:setoption(...)
532        end,
533    -- TODO: is this DNS related? hence blocking?
534    getpeername =
535        function(self, ...)
536            return self.socket:getpeername(...)
537        end,
538    shutdown =
539        function(self, ...)
540            return self.socket:shutdown(...)
541        end,
542    dohandshake =
543        function(self, sslt)
544            self.ssl_params = sslt or self.ssl_params
545            local nskt, err = copasdohandshake(self.socket, self.ssl_params)
546            if not nskt then
547                return nskt, err
548            end
549            self.socket = nskt
550            return self
551        end,
552}
553
554local _skt_mt_tcp = {
555    __tostring = _skt_mt_tostring,
556    __index    = _skt_mt_tcp_index,
557}
558
559-- wraps a UDP socket, copy of TCP one adapted for UDP.
560
561local _skt_mt_udp_index = {
562    -- UDP sending is non-blocking, but we provide starvation prevention, so replace
563    -- anyway.
564    sendto =
565        function (self, ...)
566            return copassendto(self.socket,...)
567        end,
568    receive =
569        function (self, size)
570            return copasreceive(self.socket, size or UDP_DATAGRAM_MAX)
571        end,
572    receivefrom =
573        function (self, size)
574            return copasreceivefrom(self.socket, size or UDP_DATAGRAM_MAX)
575        end,
576    -- TODO: is this DNS related? hence blocking?
577    setpeername =
578        function(self, ...)
579            return self.socket:getpeername(...)
580        end,
581    setsockname =
582        function(self, ...)
583            return self.socket:setsockname(...)
584        end,
585    -- do not close client, as it is also the server for udp.
586    close =
587        function(self, ...)
588            return true
589        end
590}
591
592local _skt_mt_udp = {
593    __tostring = _skt_mt_tostring,
594    __index    = _skt_mt_udp_index,
595}
596
597for k, v in next, _skt_mt_tcp_index do
598    if not _skt_mt_udp_index[k] then
599        _skt_mt_udp_index[k] = v
600    end
601end
602
603-- Wraps a LuaSocket socket object in an async Copas based socket object.
604
605-- @param skt  the socket to wrap
606-- @sslt       (optional) Table with ssl parameters, use an empty table to use ssl with defaults
607-- @return     wrapped socket object
608
609local function wrap(skt, sslt)
610    if getmetatable(skt) == _skt_mt_tcp or getmetatable(skt) == _skt_mt_udp then
611        return skt -- already wrapped
612    end
613    skt:settimeout(0)
614    if isTCP(skt) then
615        return setmetatable ({ socket = skt, ssl_params = sslt }, _skt_mt_tcp)
616    else
617        return setmetatable ({ socket = skt }, _skt_mt_udp)
618    end
619end
620
621copas.wrap = wrap
622
623-- Wraps a handler in a function that deals with wrapping the socket and doing
624-- the optional ssl handshake.
625
626function copas.handler(handler, sslparams)
627    return function (skt,...)
628        skt = wrap(skt)
629        if sslparams then
630            skt:dohandshake(sslparams)
631        end
632        return handler(skt,...)
633    end
634end
635
636-- Error handling (a handler per coroutine).
637
638local _errhandlers = { }
639
640function copas.setErrorHandler(err)
641    local co = runningcoroutine()
642    if co then
643        _errhandlers[co] = err
644    end
645end
646
647local function _deferror (msg, co, skt)
648    report("%s (%s) (%s)", msg, tostring(co), tostring(skt))
649end
650
651-- Thread handling
652
653local function _doTick (co, skt, ...)
654    if not co then
655        return
656    end
657
658    local ok, res, new_q = resumecoroutine(co, skt, ...)
659
660    if ok and res and new_q then
661        new_q:insert(res)
662        new_q:push(res, co)
663    else
664        if not ok then
665            pcall(_errhandlers[co] or _deferror, res, co, skt)
666        end
667        -- Do not auto-close UDP sockets, as the handler socket is also the server socket.
668        if skt and copas.autoclose and isTCP(skt) then
669            skt:close()
670        end
671        _errhandlers[co] = nil
672    end
673end
674
675-- Accepts a connection on socket input.
676
677local function _accept(input, handler)
678    local client = input:accept()
679    if client then
680        client:settimeout(0)
681        local co = createcoroutine(handler)
682        _doTick (co, client)
683    -- _reading:insert(client)
684    end
685    return client
686end
687
688-- Handle threads on a queue.
689
690local function _tickRead(skt)
691    _doTick(_reading:pop(skt), skt)
692end
693
694local function _tickWrite(skt)
695    _doTick(_writing:pop(skt), skt)
696end
697
698-- Adds a server/handler pair to Copas dispatcher.
699
700local function addTCPserver(server, handler, timeout)
701    server:settimeout(timeout or 0)
702    _servers[server] = handler
703    _reading:insert(server)
704end
705
706local function addUDPserver(server, handler, timeout)
707    server:settimeout(timeout or 0)
708    local co = createcoroutine(handler)
709    _reading:insert(server)
710    _doTick(co, server)
711end
712
713function copas.addserver(server, handler, timeout)
714    if isTCP(server) then
715        addTCPserver(server, handler, timeout)
716    else
717        addUDPserver(server, handler, timeout)
718    end
719end
720
721function copas.removeserver(server, keep_open)
722    local s  = server
723    local mt = getmetatable(server)
724    if mt == _skt_mt_tcp or mt == _skt_mt_udp then
725        s = server.socket
726    end
727    _servers[s] = nil
728    _reading:remove(s)
729    if keep_open then
730        return true
731    end
732    return server:close()
733end
734
735-- Adds an new coroutine thread to Copas dispatcher. Create a coroutine that skips
736-- the first argument, which is always the socket passed by the scheduler, but `nil`
737-- in case of a task/thread
738
739function copas.addthread(handler, ...)
740    local thread = createcoroutine(function(_, ...) return handler(...) end)
741    _doTick(thread, nil, ...)
742    return thread
743end
744
745-- tasks registering
746
747local _tasks = { }
748
749-- Lets tasks call the default _tick().
750
751local function addtaskRead(task)
752    task.def_tick = _tickRead
753    _tasks[task] = true
754end
755
756-- Lets tasks call the default _tick().
757
758local function addtaskWrite(task)
759    task.def_tick = _tickWrite
760    _tasks[task] = true
761end
762
763local function tasks()
764    return next, _tasks
765end
766
767-- A task to check ready to read events.
768
769local _readable_t = {
770    events =
771        function(self)
772            local i = 0
773            return function ()
774                i = i + 1
775                return self._evs[i]
776            end
777        end,
778    tick =
779        function(self, input)
780            local handler = _servers[input]
781            if handler then
782                input = _accept(input, handler)
783            else
784                _reading:remove(input)
785                self.def_tick(input)
786            end
787        end
788}
789
790addtaskRead(_readable_t)
791
792-- A task to check ready to write events.
793
794local _writable_t = {
795    events =
796        function(self)
797            local i = 0
798            return function()
799                i = i + 1
800                return self._evs[i]
801            end
802        end,
803    tick =
804        function(self, output)
805            _writing:remove(output)
806            self.def_tick(output)
807        end
808}
809
810addtaskWrite(_writable_t)
811
812--sleeping threads task
813
814local _sleeping_t = {
815    tick = function(self, time, ...)
816        _doTick(_sleeping:pop(time), ...)
817    end
818}
819
820-- yields the current coroutine and wakes it after 'sleeptime' seconds.
821-- If sleeptime<0 then it sleeps until explicitly woken up using 'wakeup'
822function copas.sleep(sleeptime)
823    yieldcoroutine((sleeptime or 0), _sleeping)
824end
825
826-- Wakes up a sleeping coroutine 'co'.
827
828function copas.wakeup(co)
829    _sleeping:wakeup(co)
830end
831
832-- Checks for reads and writes on sockets
833
834local last_cleansing = 0
835
836local function _select(timeout)
837
838    local now = gettime()
839
840    local r_evs, w_evs, err = selectsocket(_reading, _writing, timeout)
841
842    _readable_t._evs = r_evs
843    _writable_t._evs = w_evs
844
845    if (last_cleansing - now) > WATCH_DOG_TIMEOUT then
846
847        last_cleansing = now
848
849        -- Check all sockets selected for reading, and check how long they have been
850        -- waiting for data already, without select returning them as readable.
851
852        for skt, time in next, _reading_log do
853
854            if not r_evs[skt] and (time - now) > WATCH_DOG_TIMEOUT then
855
856                -- This one timedout while waiting to become readable, so move it in
857                -- the readable list and try and read anyway, despite not having
858                -- been returned by select.
859
860                local n = #r_evs + 1
861                _reading_log[skt] = nil
862                r_evs[n]   = skt
863                r_evs[skt] = n
864            end
865        end
866
867        -- Do the same for writing.
868
869        for skt, time in next, _writing_log do
870            if not w_evs[skt] and (time - now) > WATCH_DOG_TIMEOUT then
871                local n = #w_evs + 1
872                _writing_log[skt] = nil
873                w_evs[n]   = skt
874                w_evs[skt] = n
875            end
876        end
877
878    end
879
880    if err == "timeout" and #r_evs + #w_evs > 0 then
881        return nil
882    else
883        return err
884    end
885
886end
887
888-- Check whether there is something to do. It returns false if there are no sockets
889-- for read/write nor tasks scheduled (which means Copas is in an empty spin).
890
891local function copasfinished()
892    return not (next(_reading) or next(_writing) or _sleeping:getnext())
893end
894
895-- Dispatcher loop step. It listens to client requests and handles them and returns
896-- false if no data was handled (timeout), or true if there was data handled (or nil
897-- + error message).
898
899local function copasstep(timeout)
900    _sleeping_t:tick(gettime())
901
902    local nextwait = _sleeping:getnext()
903    if nextwait then
904        timeout = timeout and min(nextwait,timeout) or nextwait
905    elseif copasfinished() then
906        return false
907    end
908
909    local err = _select(timeout)
910    if err then
911        if err == "timeout" then
912            return false
913        end
914        return nil, err
915    end
916
917    for task in tasks() do
918        for event in task:events() do
919            task:tick(event)
920        end
921    end
922    return true
923end
924
925copas.finished = copasfinished
926copas.step     = copasstep
927
928-- Dispatcher endless loop. It listens to client requests and handles them forever.
929
930function copas.loop(timeout)
931    copas.running = true
932    while not copasfinished() do
933        copasstep(timeout)
934    end
935    copas.running = false
936end
937
938-- _G.copas = copas
939
940package.loaded["copas"] = copas
941
942return copas
943