make third_party.utils.make_toot async

This commit is contained in:
io 2021-07-26 06:29:20 +00:00
parent 4e4619fbe0
commit d0965d437b
3 changed files with 7 additions and 31 deletions

2
gen.py
View file

@ -22,7 +22,7 @@ async def main():
args = parse_args() args = parse_args()
cfg = utils.load_config(args.cfg) cfg = utils.load_config(args.cfg)
toot = utils.make_toot(cfg, mode=utils.TextGenerationMode.__members__[args.mode]) toot = await utils.make_post(cfg, mode=utils.TextGenerationMode.__members__[args.mode])
if cfg['strip_paired_punctuation']: if cfg['strip_paired_punctuation']:
toot = re.sub(r"[\[\]\(\)\{\}\"“”«»„]", "", toot) toot = re.sub(r"[\[\]\(\)\{\}\"“”«»„]", "", toot)
if not args.simulate: if not args.simulate:

29
third_party/utils.py vendored
View file

@ -12,6 +12,7 @@ import argparse
import itertools import itertools
import json5 as json import json5 as json
import multiprocessing import multiprocessing
import anyio.to_process
from random import randint from random import randint
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
@ -61,37 +62,13 @@ def remove_mention(cfg, sentence):
return sentence return sentence
def _wrap_pipe(f): async def make_post(cfg, *, mode=TextGenerationMode.markov):
def g(pout, *args, **kwargs):
try:
pout.send(f(*args, **kwargs))
except ValueError as exc:
pout.send(exc.args[0])
return g
def make_toot(cfg, *, mode=TextGenerationMode.markov):
toot = None
pin, pout = multiprocessing.Pipe(False)
if mode is TextGenerationMode.markov: if mode is TextGenerationMode.markov:
from generators.markov import make_sentence from generators.markov import make_sentence
elif mode is TextGenerationMode.gpt_2: elif mode is TextGenerationMode.gpt_2:
from generators.gpt_2 import make_sentence from generators.gpt_2 import make_sentence
else:
raise ValueError('Invalid text generation mode')
p = multiprocessing.Process(target=_wrap_pipe(make_sentence), args=[pout, cfg]) return await anyio.to_process.run_sync(make_sentence, cfg)
p.start()
p.join(5) # wait 5 seconds to get something
if p.is_alive(): # if it's still trying to make a toot after 5 seconds
p.terminate()
p.join()
else:
toot = pin.recv()
if toot is None:
toot = 'Toot generation failed! Contact io@csdisaster.club for assistance.'
return toot
def extract_post_content(text): def extract_post_content(text):
soup = BeautifulSoup(text, "html.parser") soup = BeautifulSoup(text, "html.parser")

View file

@ -1,12 +1,11 @@
# SPDX-License-Identifier: AGPL-3.0-only # SPDX-License-Identifier: AGPL-3.0-only
import anyio import anyio
import functools from functools import wraps
from bs4 import BeautifulSoup
def shield(f): def shield(f):
@functools.wraps(f) @wraps(f)
async def shielded(*args, **kwargs): async def shielded(*args, **kwargs):
with anyio.CancelScope(shield=True) as cs: with anyio.CancelScope(shield=True):
return await f(*args, **kwargs) return await f(*args, **kwargs)
return shielded return shielded