From f6173de8d0c463a252aa420dbb85994e0b4eb598 Mon Sep 17 00:00:00 2001 From: int3l Date: Wed, 6 May 2020 01:55:46 +0300 Subject: [PATCH] Add graceful shutdown handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Don't ask me how I know this 🍡 Check https://github.com/python/cpython/blob/master/Lib/asyncio/runners.py for details ... --- bot.py | 47 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/bot.py b/bot.py index e633ec2..0412de2 100644 --- a/bot.py +++ b/bot.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import argparse import asyncio.subprocess +import contextlib import datetime import functools import hashlib @@ -7,6 +10,7 @@ import json import os.path import random import re +import signal import struct import sys import tempfile @@ -17,7 +21,6 @@ from typing import Dict from typing import List from typing import Match from typing import NamedTuple -from typing import NoReturn from typing import Optional from typing import Pattern from typing import Tuple @@ -117,8 +120,6 @@ async def recv( quiet: bool = False, ) -> bytes: data = await reader.readline() - if not data: - raise SystemExit('unexpected EOF') if not quiet: sys.stderr.buffer.write(b'> ') sys.stderr.buffer.write(data) @@ -561,16 +562,48 @@ def dt_str() -> str: return f'[{dt_now.hour:02}:{dt_now.minute:02}]' -async def amain(config: Config, *, quiet: bool) -> NoReturn: +def _shutdown( + writer: asyncio.StreamWriter, + loop: asyncio.AbstractEventLoop, + shutdown_task: Optional[asyncio.Task[Any]] = None, +) -> None: + print('bye!') + ignored_tasks = set() + if shutdown_task is not None: + ignored_tasks.add(shutdown_task) + + if writer: + writer.close() + closing_task = loop.create_task(writer.wait_closed()) + + def cancel_tasks(fut: asyncio.Future[Any]) -> None: + tasks = [t for t in asyncio.all_tasks() if t not in ignored_tasks] + for task in tasks: + task.cancel() + + closing_task.add_done_callback(cancel_tasks) + + +async def amain(config: Config, *, quiet: bool) -> None: reader, writer = await asyncio.open_connection(HOST, PORT, ssl=True) + loop = asyncio.get_event_loop() + shutdown_cb = functools.partial(_shutdown, writer, loop) + try: + loop.add_signal_handler(signal.SIGINT, shutdown_cb) + except NotImplementedError: + # Doh... Windows... + signal.signal(signal.SIGINT, lambda *_: shutdown_cb()) + await send(writer, f'PASS {config.oauth_token}\r\n', quiet=True) await send(writer, f'NICK {config.username}\r\n', quiet=quiet) await send(writer, f'JOIN #{config.channel}\r\n', quiet=quiet) await send(writer, 'CAP REQ :twitch.tv/tags\r\n', quiet=quiet) - while True: + while not writer.is_closing(): data = await recv(reader, quiet=quiet) + if not data: + return msg = data.decode('UTF-8', errors='backslashreplace') msg_match = MSG_RE.match(msg) @@ -634,7 +667,9 @@ def main() -> int: with open(args.config) as f: config = Config(**json.load(f)) - asyncio.run(amain(config, quiet=not args.verbose)) + with contextlib.suppress(KeyboardInterrupt): + asyncio.run(amain(config, quiet=not args.verbose)) + return 0