1
2
3
4
5
6
7
8
9local socket = socket or require("socket")
10local ssl = ssl or nil
11
12local WATCH_DOG_TIMEOUT = 120
13local UDP_DATAGRAM_MAX = 8192
14
15local type, next, pcall, getmetatable, tostring = type, next, pcall, getmetatable, tostring
16local min, max, random = math.min, math.max, math.random
17local find = string.find
18local insert, remove = table.insert, table.remove
19
20local gettime = socket.gettime
21local selectsocket = socket.select
22
23local createcoroutine = coroutine.create
24local resumecoroutine = coroutine.resume
25local yieldcoroutine = coroutine.yield
26local runningcoroutine = coroutine.running
27
28
29
30
31
32
33local function report(fmt,first,...)
34 if logs then
35 report = logs and logs.reporter("copas")
36 report(fmt,first,...)
37 elseif fmt then
38 fmt = "copas: " .. fmt
39 if first then
40 print(format(fmt,first,...))
41 else
42 print(fmt)
43 end
44 end
45end
46
47local copas = {
48
49 _COPYRIGHT = "Copyright (C) 2005-2016 Kepler Project",
50 _DESCRIPTION = "Coroutine Oriented Portable Asynchronous Services",
51 _VERSION = "Copas 2.0.1",
52
53 autoclose = true,
54 running = false,
55
56 report = report,
57
58 trace = false,
59
60}
61
62local function statushandler(status, ...)
63 if status then
64 return ...
65 end
66 local err = (...)
67 if type(err) == "table" then
68 err = err[1]
69 end
70 if copas.trace then
71 report("error: %s",tostring(err))
72 end
73 return nil, err
74end
75
76function socket.protect(func)
77 return function(...)
78 return statushandler(pcall(func,...))
79 end
80end
81
82function socket.newtry(finalizer)
83 return function (...)
84 local status = (...)
85 if not status then
86 local detail = select(2,...)
87 pcall(finalizer,detail)
88 if copas.trace then
89 report("error: %s",tostring(detail))
90 end
91 return
92 end
93 return ...
94 end
95end
96
97
98
99
100local function newset()
101 local reverse = { }
102 local set = { }
103 local queue = { }
104 setmetatable(set, {
105 __index = {
106 insert =
107 function(set, value)
108 if not reverse[value] then
109 local n = #set +1
110 set[n] = value
111 reverse[value] = n
112 end
113 end,
114 remove =
115 function(set, value)
116 local index = reverse[value]
117 if index then
118 reverse[value] = nil
119 local n = #set
120 local top = set[n]
121 set[n] = nil
122 if top ~= value then
123 reverse[top] = index
124 set[index] = top
125 end
126 end
127 end,
128 push =
129 function (set, key, itm)
130 local entry = queue[key]
131 if entry == nil then
132 queue[key] = { itm }
133 else
134 entry[#entry + 1] = itm
135 end
136 end,
137 pop =
138 function (set, key)
139 local top = queue[key]
140 if top ~= nil then
141 local ret = remove(top,1)
142 if top[1] == nil then
143 queue[key] = nil
144 end
145 return ret
146 end
147 end
148 }
149 } )
150 return set
151end
152
153local _sleeping = {
154 times = { },
155 cos = { },
156 lethargy = { },
157
158 insert =
159 function()
160 end,
161 remove =
162 function()
163 end,
164 push =
165 function(self, sleeptime, co)
166 if not co then
167 return
168 end
169 if sleeptime < 0 then
170
171 self.lethargy[co] = true
172 return
173 else
174 sleeptime = gettime() + sleeptime
175 end
176 local t = self.times
177 local c = self.cos
178 local i = 1
179 local n = #t
180 while i <= n and t[i] <= sleeptime do
181 i = i + 1
182 end
183 insert(t,i,sleeptime)
184 insert(c,i,co)
185 end,
186 getnext =
187
188 function(self)
189 local t = self.times
190 local delay = t[1] and t[1] - gettime() or nil
191 return delay and max(delay, 0) or nil
192 end,
193 pop =
194
195 function(self, time)
196 local t = self.times
197 local c = self.cos
198 if #t == 0 or time < t[1] then
199 return
200 end
201 local co = c[1]
202 remove(t,1)
203 remove(c,1)
204 return co
205 end,
206 wakeup =
207 function(self, co)
208 local let = self.lethargy
209 if let[co] then
210 self:push(0, co)
211 let[co] = nil
212 else
213 local c = self.cos
214 local t = self.times
215 for i=1,#c do
216 if c[i] == co then
217 remove(c,i)
218 remove(t,i)
219 self:push(0, co)
220 return
221 end
222 end
223 end
224 end
225}
226
227local _servers = newset()
228local _reading = newset()
229local _writing = newset()
230
231local _reading_log = { }
232local _writing_log = { }
233
234local _is_timeout = {
235 timeout = true,
236 wantread = true,
237 wantwrite = true,
238}
239
240
241
242local function isTCP(socket)
243 return not find(tostring(socket),"^udp")
244end
245
246
247
248
249
250
251local function copasreceive(client, pattern, part)
252 if not pattern or pattern == "" then
253 pattern = "*l"
254 end
255 local current_log = _reading_log
256 local s, err
257 repeat
258 s, err, part = client:receive(pattern, part)
259 if s or (not _is_timeout[err]) then
260 current_log[client] = nil
261 return s, err, part
262 end
263 if err == "wantwrite" then
264 current_log = _writing_log
265 current_log[client] = gettime()
266 yieldcoroutine(client, _writing)
267 else
268 current_log = _reading_log
269 current_log[client] = gettime()
270 yieldcoroutine(client, _reading)
271 end
272 until false
273end
274
275
276
277
278local function copasreceivefrom(client, size)
279 local s, err, port
280 if not size or size == 0 then
281 size = UDP_DATAGRAM_MAX
282 end
283 repeat
284
285 s, err, port = client:receivefrom(size)
286 if s or err ~= "timeout" then
287 _reading_log[client] = nil
288 return s, err, port
289 end
290 _reading_log[client] = gettime()
291 yieldcoroutine(client, _reading)
292 until false
293end
294
295
296
297
298local function copasreceivepartial(client, pattern, part)
299 if not pattern or pattern == "" then
300 pattern = "*l"
301 end
302 local logger = _reading_log
303 local queue = _reading
304 local s, err
305 repeat
306 s, err, part = client:receive(pattern, part)
307 if s or (type(pattern) == "number" and part ~= "" and part) or not _is_timeout[err] then
308 logger[client] = nil
309 return s, err, part
310 end
311 if err == "wantwrite" then
312 logger = _writing_log
313 queue = _writing
314 else
315 logger = _reading_log
316 queue = _reading
317 end
318 logger[client] = gettime()
319 yieldcoroutine(client, queue)
320 until false
321end
322
323
324
325
326local function copassend(client, data, from, to)
327 if not from then
328 from = 1
329 end
330 local lastIndex = from - 1
331 local logger = _writing_log
332 local queue = _writing
333 local s, err
334 repeat
335 s, err, lastIndex = client:send(data, lastIndex + 1, to)
336
337
338 if random(100) > 90 then
339 logger[client] = gettime()
340 yieldcoroutine(client, queue)
341 end
342 if s or not _is_timeout[err] then
343 logger[client] = nil
344 return s, err,lastIndex
345 end
346 if err == "wantread" then
347 logger = _reading_log
348 queue = _reading
349 else
350 logger = _writing_log
351 queue = _writing
352 end
353 logger[client] = gettime()
354 yieldcoroutine(client, queue)
355 until false
356end
357
358
359
360
361local function copassendto(client, data, ip, port)
362 repeat
363 local s, err = client:sendto(data, ip, port)
364
365
366 if random(100) > 90 then
367 _writing_log[client] = gettime()
368 yieldcoroutine(client, _writing)
369 end
370 if s or err ~= "timeout" then
371 _writing_log[client] = nil
372 return s, err
373 end
374 _writing_log[client] = gettime()
375 yieldcoroutine(client, _writing)
376 until false
377end
378
379
380
381local function copasconnect(skt, host, port)
382 skt:settimeout(0)
383 local ret, err, tried_more_than_once
384 repeat
385 ret, err = skt:connect (host, port)
386
387
388
389 if ret or (err ~= "timeout" and err ~= "Operation already in progress") then
390
391
392
393
394 if not ret and err == "already connected" and tried_more_than_once then
395 ret = 1
396 err = nil
397 end
398 _writing_log[skt] = nil
399 return ret, err
400 end
401 tried_more_than_once = tried_more_than_once or true
402 _writing_log[skt] = gettime()
403 yieldcoroutine(skt, _writing)
404 until false
405end
406
407
408
409
410
411local function copasdohandshake(skt, sslt)
412 if not ssl then
413 ssl = require("ssl")
414 end
415 if not ssl then
416 report("error: no ssl library")
417 return
418 end
419 local nskt, err = ssl.wrap(skt, sslt)
420 if not nskt then
421 report("error: %s",tostring(err))
422 return
423 end
424 nskt:settimeout(0)
425 local queue
426 repeat
427 local success, err = nskt:dohandshake()
428 if success then
429 return nskt
430 elseif err == "wantwrite" then
431 queue = _writing
432 elseif err == "wantread" then
433 queue = _reading
434 else
435 report("error: %s",tostring(err))
436 return
437 end
438 yieldcoroutine(nskt, queue)
439 until false
440end
441
442
443
444local function copasflush(client)
445end
446
447
448
449copas.connect = copassconnect
450copas.send = copassend
451copas.sendto = copassendto
452copas.receive = copasreceive
453copas.receivefrom = copasreceivefrom
454copas.copasreceivepartial = copasreceivepartial
455copas.copasreceivePartial = copasreceivepartial
456copas.dohandshake = copasdohandshake
457copas.flush = copasflush
458
459
460
461local function _skt_mt_tostring(self)
462 return tostring(self.socket) .. " (copas wrapped)"
463end
464
465local _skt_mt_tcp_index = {
466 send =
467 function(self, data, from, to)
468 return copassend (self.socket, data, from, to)
469 end,
470 receive =
471 function (self, pattern, prefix)
472 if self.timeout == 0 then
473 return copasreceivePartial(self.socket, pattern, prefix)
474 else
475 return copasreceive(self.socket, pattern, prefix)
476 end
477 end,
478
479 flush =
480 function (self)
481 return copasflush(self.socket)
482 end,
483
484 settimeout =
485 function (self, time)
486 self.timeout = time
487 return true
488 end,
489
490
491 connect =
492 function(self, ...)
493 local res, err = copasconnect(self.socket, ...)
494 if res and self.ssl_params then
495 res, err = self:dohandshake()
496 end
497 return res, err
498 end,
499 close =
500 function(self, ...)
501 return self.socket:close(...)
502 end,
503
504 bind =
505 function(self, ...)
506 return self.socket:bind(...)
507 end,
508
509 getsockname =
510 function(self, ...)
511 return self.socket:getsockname(...)
512 end,
513 getstats =
514 function(self, ...)
515 return self.socket:getstats(...)
516 end,
517 setstats =
518 function(self, ...)
519 return self.socket:setstats(...)
520 end,
521 listen =
522 function(self, ...)
523 return self.socket:listen(...)
524 end,
525 accept =
526 function(self, ...)
527 return self.socket:accept(...)
528 end,
529 setoption =
530 function(self, ...)
531 return self.socket:setoption(...)
532 end,
533
534 getpeername =
535 function(self, ...)
536 return self.socket:getpeername(...)
537 end,
538 shutdown =
539 function(self, ...)
540 return self.socket:shutdown(...)
541 end,
542 dohandshake =
543 function(self, sslt)
544 self.ssl_params = sslt or self.ssl_params
545 local nskt, err = copasdohandshake(self.socket, self.ssl_params)
546 if not nskt then
547 return nskt, err
548 end
549 self.socket = nskt
550 return self
551 end,
552}
553
554local _skt_mt_tcp = {
555 __tostring = _skt_mt_tostring,
556 __index = _skt_mt_tcp_index,
557}
558
559
560
561local _skt_mt_udp_index = {
562
563
564 sendto =
565 function (self, ...)
566 return copassendto(self.socket,...)
567 end,
568 receive =
569 function (self, size)
570 return copasreceive(self.socket, size or UDP_DATAGRAM_MAX)
571 end,
572 receivefrom =
573 function (self, size)
574 return copasreceivefrom(self.socket, size or UDP_DATAGRAM_MAX)
575 end,
576
577 setpeername =
578 function(self, ...)
579 return self.socket:getpeername(...)
580 end,
581 setsockname =
582 function(self, ...)
583 return self.socket:setsockname(...)
584 end,
585
586 close =
587 function(self, ...)
588 return true
589 end
590}
591
592local _skt_mt_udp = {
593 __tostring = _skt_mt_tostring,
594 __index = _skt_mt_udp_index,
595}
596
597for k, v in next, _skt_mt_tcp_index do
598 if not _skt_mt_udp_index[k] then
599 _skt_mt_udp_index[k] = v
600 end
601end
602
603
604
605
606
607
608
609local function wrap(skt, sslt)
610 if getmetatable(skt) == _skt_mt_tcp or getmetatable(skt) == _skt_mt_udp then
611 return skt
612 end
613 skt:settimeout(0)
614 if isTCP(skt) then
615 return setmetatable ({ socket = skt, ssl_params = sslt }, _skt_mt_tcp)
616 else
617 return setmetatable ({ socket = skt }, _skt_mt_udp)
618 end
619end
620
621copas.wrap = wrap
622
623
624
625
626function copas.handler(handler, sslparams)
627 return function (skt,...)
628 skt = wrap(skt)
629 if sslparams then
630 skt:dohandshake(sslparams)
631 end
632 return handler(skt,...)
633 end
634end
635
636
637
638local _errhandlers = { }
639
640function copas.setErrorHandler(err)
641 local co = runningcoroutine()
642 if co then
643 _errhandlers[co] = err
644 end
645end
646
647local function _deferror (msg, co, skt)
648 report("%s (%s) (%s)", msg, tostring(co), tostring(skt))
649end
650
651
652
653local function _doTick (co, skt, ...)
654 if not co then
655 return
656 end
657
658 local ok, res, new_q = resumecoroutine(co, skt, ...)
659
660 if ok and res and new_q then
661 new_q:insert(res)
662 new_q:push(res, co)
663 else
664 if not ok then
665 pcall(_errhandlers[co] or _deferror, res, co, skt)
666 end
667
668 if skt and copas.autoclose and isTCP(skt) then
669 skt:close()
670 end
671 _errhandlers[co] = nil
672 end
673end
674
675
676
677local function _accept(input, handler)
678 local client = input:accept()
679 if client then
680 client:settimeout(0)
681 local co = createcoroutine(handler)
682 _doTick (co, client)
683
684 end
685 return client
686end
687
688
689
690local function _tickRead(skt)
691 _doTick(_reading:pop(skt), skt)
692end
693
694local function _tickWrite(skt)
695 _doTick(_writing:pop(skt), skt)
696end
697
698
699
700local function addTCPserver(server, handler, timeout)
701 server:settimeout(timeout or 0)
702 _servers[server] = handler
703 _reading:insert(server)
704end
705
706local function addUDPserver(server, handler, timeout)
707 server:settimeout(timeout or 0)
708 local co = createcoroutine(handler)
709 _reading:insert(server)
710 _doTick(co, server)
711end
712
713function copas.addserver(server, handler, timeout)
714 if isTCP(server) then
715 addTCPserver(server, handler, timeout)
716 else
717 addUDPserver(server, handler, timeout)
718 end
719end
720
721function copas.removeserver(server, keep_open)
722 local s = server
723 local mt = getmetatable(server)
724 if mt == _skt_mt_tcp or mt == _skt_mt_udp then
725 s = server.socket
726 end
727 _servers[s] = nil
728 _reading:remove(s)
729 if keep_open then
730 return true
731 end
732 return server:close()
733end
734
735
736
737
738
739function copas.addthread(handler, ...)
740 local thread = createcoroutine(function(_, ...) return handler(...) end)
741 _doTick(thread, nil, ...)
742 return thread
743end
744
745
746
747local _tasks = { }
748
749
750
751local function addtaskRead(task)
752 task.def_tick = _tickRead
753 _tasks[task] = true
754end
755
756
757
758local function addtaskWrite(task)
759 task.def_tick = _tickWrite
760 _tasks[task] = true
761end
762
763local function tasks()
764 return next, _tasks
765end
766
767
768
769local _readable_t = {
770 events =
771 function(self)
772 local i = 0
773 return function ()
774 i = i + 1
775 return self._evs[i]
776 end
777 end,
778 tick =
779 function(self, input)
780 local handler = _servers[input]
781 if handler then
782 input = _accept(input, handler)
783 else
784 _reading:remove(input)
785 self.def_tick(input)
786 end
787 end
788}
789
790addtaskRead(_readable_t)
791
792
793
794local _writable_t = {
795 events =
796 function(self)
797 local i = 0
798 return function()
799 i = i + 1
800 return self._evs[i]
801 end
802 end,
803 tick =
804 function(self, output)
805 _writing:remove(output)
806 self.def_tick(output)
807 end
808}
809
810addtaskWrite(_writable_t)
811
812
813
814local _sleeping_t = {
815 tick = function(self, time, ...)
816 _doTick(_sleeping:pop(time), ...)
817 end
818}
819
820
821
822function copas.sleep(sleeptime)
823 yieldcoroutine((sleeptime or 0), _sleeping)
824end
825
826
827
828function copas.wakeup(co)
829 _sleeping:wakeup(co)
830end
831
832
833
834local last_cleansing = 0
835
836local function _select(timeout)
837
838 local now = gettime()
839
840 local r_evs, w_evs, err = selectsocket(_reading, _writing, timeout)
841
842 _readable_t._evs = r_evs
843 _writable_t._evs = w_evs
844
845 if (last_cleansing - now) > WATCH_DOG_TIMEOUT then
846
847 last_cleansing = now
848
849
850
851
852 for skt, time in next, _reading_log do
853
854 if not r_evs[skt] and (time - now) > WATCH_DOG_TIMEOUT then
855
856
857
858
859
860 local n = #r_evs + 1
861 _reading_log[skt] = nil
862 r_evs[n] = skt
863 r_evs[skt] = n
864 end
865 end
866
867
868
869 for skt, time in next, _writing_log do
870 if not w_evs[skt] and (time - now) > WATCH_DOG_TIMEOUT then
871 local n = #w_evs + 1
872 _writing_log[skt] = nil
873 w_evs[n] = skt
874 w_evs[skt] = n
875 end
876 end
877
878 end
879
880 if err == "timeout" and #r_evs + #w_evs > 0 then
881 return nil
882 else
883 return err
884 end
885
886end
887
888
889
890
891local function copasfinished()
892 return not (next(_reading) or next(_writing) or _sleeping:getnext())
893end
894
895
896
897
898
899local function copasstep(timeout)
900 _sleeping_t:tick(gettime())
901
902 local nextwait = _sleeping:getnext()
903 if nextwait then
904 timeout = timeout and min(nextwait,timeout) or nextwait
905 elseif copasfinished() then
906 return false
907 end
908
909 local err = _select(timeout)
910 if err then
911 if err == "timeout" then
912 return false
913 end
914 return nil, err
915 end
916
917 for task in tasks() do
918 for event in task:events() do
919 task:tick(event)
920 end
921 end
922 return true
923end
924
925copas.finished = copasfinished
926copas.step = copasstep
927
928
929
930function copas.loop(timeout)
931 copas.running = true
932 while not copasfinished() do
933 copasstep(timeout)
934 end
935 copas.running = false
936end
937
938
939
940package.loaded["copas"] = copas
941
942return copas
943 |