Add graceful shutdown handling

Don't ask me how I know this 🍡
Check https://github.com/python/cpython/blob/master/Lib/asyncio/runners.py for details ...
This commit is contained in:
int3l 2020-05-06 01:55:46 +03:00 committed by Anthony Sottile
parent 0378fc80ef
commit f6173de8d0
1 changed files with 41 additions and 6 deletions

47
bot.py
View File

@ -1,5 +1,8 @@
from __future__ import annotations
import argparse import argparse
import asyncio.subprocess import asyncio.subprocess
import contextlib
import datetime import datetime
import functools import functools
import hashlib import hashlib
@ -7,6 +10,7 @@ import json
import os.path import os.path
import random import random
import re import re
import signal
import struct import struct
import sys import sys
import tempfile import tempfile
@ -17,7 +21,6 @@ from typing import Dict
from typing import List from typing import List
from typing import Match from typing import Match
from typing import NamedTuple from typing import NamedTuple
from typing import NoReturn
from typing import Optional from typing import Optional
from typing import Pattern from typing import Pattern
from typing import Tuple from typing import Tuple
@ -117,8 +120,6 @@ async def recv(
quiet: bool = False, quiet: bool = False,
) -> bytes: ) -> bytes:
data = await reader.readline() data = await reader.readline()
if not data:
raise SystemExit('unexpected EOF')
if not quiet: if not quiet:
sys.stderr.buffer.write(b'> ') sys.stderr.buffer.write(b'> ')
sys.stderr.buffer.write(data) sys.stderr.buffer.write(data)
@ -561,16 +562,48 @@ def dt_str() -> str:
return f'[{dt_now.hour:02}:{dt_now.minute:02}]' return f'[{dt_now.hour:02}:{dt_now.minute:02}]'
async def amain(config: Config, *, quiet: bool) -> NoReturn: def _shutdown(
writer: asyncio.StreamWriter,
loop: asyncio.AbstractEventLoop,
shutdown_task: Optional[asyncio.Task[Any]] = None,
) -> None:
print('bye!')
ignored_tasks = set()
if shutdown_task is not None:
ignored_tasks.add(shutdown_task)
if writer:
writer.close()
closing_task = loop.create_task(writer.wait_closed())
def cancel_tasks(fut: asyncio.Future[Any]) -> None:
tasks = [t for t in asyncio.all_tasks() if t not in ignored_tasks]
for task in tasks:
task.cancel()
closing_task.add_done_callback(cancel_tasks)
async def amain(config: Config, *, quiet: bool) -> None:
reader, writer = await asyncio.open_connection(HOST, PORT, ssl=True) reader, writer = await asyncio.open_connection(HOST, PORT, ssl=True)
loop = asyncio.get_event_loop()
shutdown_cb = functools.partial(_shutdown, writer, loop)
try:
loop.add_signal_handler(signal.SIGINT, shutdown_cb)
except NotImplementedError:
# Doh... Windows...
signal.signal(signal.SIGINT, lambda *_: shutdown_cb())
await send(writer, f'PASS {config.oauth_token}\r\n', quiet=True) await send(writer, f'PASS {config.oauth_token}\r\n', quiet=True)
await send(writer, f'NICK {config.username}\r\n', quiet=quiet) await send(writer, f'NICK {config.username}\r\n', quiet=quiet)
await send(writer, f'JOIN #{config.channel}\r\n', quiet=quiet) await send(writer, f'JOIN #{config.channel}\r\n', quiet=quiet)
await send(writer, 'CAP REQ :twitch.tv/tags\r\n', quiet=quiet) await send(writer, 'CAP REQ :twitch.tv/tags\r\n', quiet=quiet)
while True: while not writer.is_closing():
data = await recv(reader, quiet=quiet) data = await recv(reader, quiet=quiet)
if not data:
return
msg = data.decode('UTF-8', errors='backslashreplace') msg = data.decode('UTF-8', errors='backslashreplace')
msg_match = MSG_RE.match(msg) msg_match = MSG_RE.match(msg)
@ -634,7 +667,9 @@ def main() -> int:
with open(args.config) as f: with open(args.config) as f:
config = Config(**json.load(f)) config = Config(**json.load(f))
asyncio.run(amain(config, quiet=not args.verbose)) with contextlib.suppress(KeyboardInterrupt):
asyncio.run(amain(config, quiet=not args.verbose))
return 0 return 0