make third_party.utils.make_toot async

This commit is contained in:
io 2021-07-26 06:29:20 +00:00
parent 4e4619fbe0
commit d0965d437b
3 changed files with 7 additions and 31 deletions

2
gen.py
View file

@ -22,7 +22,7 @@ async def main():
args = parse_args()
cfg = utils.load_config(args.cfg)
toot = utils.make_toot(cfg, mode=utils.TextGenerationMode.__members__[args.mode])
toot = await utils.make_post(cfg, mode=utils.TextGenerationMode.__members__[args.mode])
if cfg['strip_paired_punctuation']:
toot = re.sub(r"[\[\]\(\)\{\}\"“”«»„]", "", toot)
if not args.simulate:

29
third_party/utils.py vendored
View file

@ -12,6 +12,7 @@ import argparse
import itertools
import json5 as json
import multiprocessing
import anyio.to_process
from random import randint
from bs4 import BeautifulSoup
@ -61,37 +62,13 @@ def remove_mention(cfg, sentence):
return sentence
def _wrap_pipe(f):
def g(pout, *args, **kwargs):
try:
pout.send(f(*args, **kwargs))
except ValueError as exc:
pout.send(exc.args[0])
return g
def make_toot(cfg, *, mode=TextGenerationMode.markov):
toot = None
pin, pout = multiprocessing.Pipe(False)
async def make_post(cfg, *, mode=TextGenerationMode.markov):
if mode is TextGenerationMode.markov:
from generators.markov import make_sentence
elif mode is TextGenerationMode.gpt_2:
from generators.gpt_2 import make_sentence
else:
raise ValueError('Invalid text generation mode')
p = multiprocessing.Process(target=_wrap_pipe(make_sentence), args=[pout, cfg])
p.start()
p.join(5) # wait 5 seconds to get something
if p.is_alive(): # if it's still trying to make a toot after 5 seconds
p.terminate()
p.join()
else:
toot = pin.recv()
if toot is None:
toot = 'Toot generation failed! Contact io@csdisaster.club for assistance.'
return toot
return await anyio.to_process.run_sync(make_sentence, cfg)
def extract_post_content(text):
soup = BeautifulSoup(text, "html.parser")

View file

@ -1,12 +1,11 @@
# SPDX-License-Identifier: AGPL-3.0-only
import anyio
import functools
from bs4 import BeautifulSoup
from functools import wraps
def shield(f):
@functools.wraps(f)
@wraps(f)
async def shielded(*args, **kwargs):
with anyio.CancelScope(shield=True) as cs:
with anyio.CancelScope(shield=True):
return await f(*args, **kwargs)
return shielded