From d0965d437bc0fb042596641368c85b431214829b Mon Sep 17 00:00:00 2001 From: io Date: Mon, 26 Jul 2021 06:29:20 +0000 Subject: [PATCH] make third_party.utils.make_toot async --- gen.py | 2 +- third_party/utils.py | 29 +++-------------------------- utils.py | 7 +++---- 3 files changed, 7 insertions(+), 31 deletions(-) diff --git a/gen.py b/gen.py index 5577d1d..a46c292 100755 --- a/gen.py +++ b/gen.py @@ -22,7 +22,7 @@ async def main(): args = parse_args() cfg = utils.load_config(args.cfg) - toot = utils.make_toot(cfg, mode=utils.TextGenerationMode.__members__[args.mode]) + toot = await utils.make_post(cfg, mode=utils.TextGenerationMode.__members__[args.mode]) if cfg['strip_paired_punctuation']: toot = re.sub(r"[\[\]\(\)\{\}\"“”«»„]", "", toot) if not args.simulate: diff --git a/third_party/utils.py b/third_party/utils.py index ee8d62f..f36afb3 100644 --- a/third_party/utils.py +++ b/third_party/utils.py @@ -12,6 +12,7 @@ import argparse import itertools import json5 as json import multiprocessing +import anyio.to_process from random import randint from bs4 import BeautifulSoup @@ -61,37 +62,13 @@ def remove_mention(cfg, sentence): return sentence -def _wrap_pipe(f): - def g(pout, *args, **kwargs): - try: - pout.send(f(*args, **kwargs)) - except ValueError as exc: - pout.send(exc.args[0]) - return g - -def make_toot(cfg, *, mode=TextGenerationMode.markov): - toot = None - pin, pout = multiprocessing.Pipe(False) - +async def make_post(cfg, *, mode=TextGenerationMode.markov): if mode is TextGenerationMode.markov: from generators.markov import make_sentence elif mode is TextGenerationMode.gpt_2: from generators.gpt_2 import make_sentence - else: - raise ValueError('Invalid text generation mode') - p = multiprocessing.Process(target=_wrap_pipe(make_sentence), args=[pout, cfg]) - p.start() - p.join(5) # wait 5 seconds to get something - if p.is_alive(): # if it's still trying to make a toot after 5 seconds - p.terminate() - p.join() - else: - toot = pin.recv() - - if toot is None: - toot = 'Toot generation failed! Contact io@csdisaster.club for assistance.' - return toot + return await anyio.to_process.run_sync(make_sentence, cfg) def extract_post_content(text): soup = BeautifulSoup(text, "html.parser") diff --git a/utils.py b/utils.py index 2dcbb8b..b552598 100644 --- a/utils.py +++ b/utils.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: AGPL-3.0-only import anyio -import functools -from bs4 import BeautifulSoup +from functools import wraps def shield(f): - @functools.wraps(f) + @wraps(f) async def shielded(*args, **kwargs): - with anyio.CancelScope(shield=True) as cs: + with anyio.CancelScope(shield=True): return await f(*args, **kwargs) return shielded