Merge pull request #20 from int3l/master

Add graceful shutdown handling
This commit is contained in:
Anthony Sottile 2020-05-11 11:46:06 -07:00 committed by GitHub
commit 30b423fd0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 41 additions and 6 deletions

45
bot.py
View File

@ -1,5 +1,8 @@
from __future__ import annotations
import argparse
import asyncio.subprocess
import contextlib
import datetime
import functools
import hashlib
@ -7,6 +10,7 @@ import json
import os.path
import random
import re
import signal
import struct
import sys
import tempfile
@ -17,7 +21,6 @@ from typing import Dict
from typing import List
from typing import Match
from typing import NamedTuple
from typing import NoReturn
from typing import Optional
from typing import Pattern
from typing import Tuple
@ -117,8 +120,6 @@ async def recv(
quiet: bool = False,
) -> bytes:
data = await reader.readline()
if not data:
raise SystemExit('unexpected EOF')
if not quiet:
sys.stderr.buffer.write(b'> ')
sys.stderr.buffer.write(data)
@ -561,16 +562,48 @@ def dt_str() -> str:
return f'[{dt_now.hour:02}:{dt_now.minute:02}]'
async def amain(config: Config, *, quiet: bool) -> NoReturn:
def _shutdown(
writer: asyncio.StreamWriter,
loop: asyncio.AbstractEventLoop,
shutdown_task: Optional[asyncio.Task[Any]] = None,
) -> None:
print('bye!')
ignored_tasks = set()
if shutdown_task is not None:
ignored_tasks.add(shutdown_task)
if writer:
writer.close()
closing_task = loop.create_task(writer.wait_closed())
def cancel_tasks(fut: asyncio.Future[Any]) -> None:
tasks = [t for t in asyncio.all_tasks() if t not in ignored_tasks]
for task in tasks:
task.cancel()
closing_task.add_done_callback(cancel_tasks)
async def amain(config: Config, *, quiet: bool) -> None:
reader, writer = await asyncio.open_connection(HOST, PORT, ssl=True)
loop = asyncio.get_event_loop()
shutdown_cb = functools.partial(_shutdown, writer, loop)
try:
loop.add_signal_handler(signal.SIGINT, shutdown_cb)
except NotImplementedError:
# Doh... Windows...
signal.signal(signal.SIGINT, lambda *_: shutdown_cb())
await send(writer, f'PASS {config.oauth_token}\r\n', quiet=True)
await send(writer, f'NICK {config.username}\r\n', quiet=quiet)
await send(writer, f'JOIN #{config.channel}\r\n', quiet=quiet)
await send(writer, 'CAP REQ :twitch.tv/tags\r\n', quiet=quiet)
while True:
while not writer.is_closing():
data = await recv(reader, quiet=quiet)
if not data:
return
msg = data.decode('UTF-8', errors='backslashreplace')
msg_match = MSG_RE.match(msg)
@ -634,7 +667,9 @@ def main() -> int:
with open(args.config) as f:
config = Config(**json.load(f))
with contextlib.suppress(KeyboardInterrupt):
asyncio.run(amain(config, quiet=not args.verbose))
return 0