Merge pull request #20 from int3l/master

Add graceful shutdown handling
This commit is contained in:
Anthony Sottile 2020-05-11 11:46:06 -07:00 committed by GitHub
commit 30b423fd0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 41 additions and 6 deletions

45
bot.py
View File

@ -1,5 +1,8 @@
from __future__ import annotations
import argparse import argparse
import asyncio.subprocess import asyncio.subprocess
import contextlib
import datetime import datetime
import functools import functools
import hashlib import hashlib
@ -7,6 +10,7 @@ import json
import os.path import os.path
import random import random
import re import re
import signal
import struct import struct
import sys import sys
import tempfile import tempfile
@ -17,7 +21,6 @@ from typing import Dict
from typing import List from typing import List
from typing import Match from typing import Match
from typing import NamedTuple from typing import NamedTuple
from typing import NoReturn
from typing import Optional from typing import Optional
from typing import Pattern from typing import Pattern
from typing import Tuple from typing import Tuple
@ -117,8 +120,6 @@ async def recv(
quiet: bool = False, quiet: bool = False,
) -> bytes: ) -> bytes:
data = await reader.readline() data = await reader.readline()
if not data:
raise SystemExit('unexpected EOF')
if not quiet: if not quiet:
sys.stderr.buffer.write(b'> ') sys.stderr.buffer.write(b'> ')
sys.stderr.buffer.write(data) sys.stderr.buffer.write(data)
@ -561,16 +562,48 @@ def dt_str() -> str:
return f'[{dt_now.hour:02}:{dt_now.minute:02}]' return f'[{dt_now.hour:02}:{dt_now.minute:02}]'
async def amain(config: Config, *, quiet: bool) -> NoReturn: def _shutdown(
writer: asyncio.StreamWriter,
loop: asyncio.AbstractEventLoop,
shutdown_task: Optional[asyncio.Task[Any]] = None,
) -> None:
print('bye!')
ignored_tasks = set()
if shutdown_task is not None:
ignored_tasks.add(shutdown_task)
if writer:
writer.close()
closing_task = loop.create_task(writer.wait_closed())
def cancel_tasks(fut: asyncio.Future[Any]) -> None:
tasks = [t for t in asyncio.all_tasks() if t not in ignored_tasks]
for task in tasks:
task.cancel()
closing_task.add_done_callback(cancel_tasks)
async def amain(config: Config, *, quiet: bool) -> None:
reader, writer = await asyncio.open_connection(HOST, PORT, ssl=True) reader, writer = await asyncio.open_connection(HOST, PORT, ssl=True)
loop = asyncio.get_event_loop()
shutdown_cb = functools.partial(_shutdown, writer, loop)
try:
loop.add_signal_handler(signal.SIGINT, shutdown_cb)
except NotImplementedError:
# Doh... Windows...
signal.signal(signal.SIGINT, lambda *_: shutdown_cb())
await send(writer, f'PASS {config.oauth_token}\r\n', quiet=True) await send(writer, f'PASS {config.oauth_token}\r\n', quiet=True)
await send(writer, f'NICK {config.username}\r\n', quiet=quiet) await send(writer, f'NICK {config.username}\r\n', quiet=quiet)
await send(writer, f'JOIN #{config.channel}\r\n', quiet=quiet) await send(writer, f'JOIN #{config.channel}\r\n', quiet=quiet)
await send(writer, 'CAP REQ :twitch.tv/tags\r\n', quiet=quiet) await send(writer, 'CAP REQ :twitch.tv/tags\r\n', quiet=quiet)
while True: while not writer.is_closing():
data = await recv(reader, quiet=quiet) data = await recv(reader, quiet=quiet)
if not data:
return
msg = data.decode('UTF-8', errors='backslashreplace') msg = data.decode('UTF-8', errors='backslashreplace')
msg_match = MSG_RE.match(msg) msg_match = MSG_RE.match(msg)
@ -634,7 +667,9 @@ def main() -> int:
with open(args.config) as f: with open(args.config) as f:
config = Config(**json.load(f)) config = Config(**json.load(f))
with contextlib.suppress(KeyboardInterrupt):
asyncio.run(amain(config, quiet=not args.verbose)) asyncio.run(amain(config, quiet=not args.verbose))
return 0 return 0