make third_party.utils.make_toot async
This commit is contained in:
parent
4e4619fbe0
commit
d0965d437b
3 changed files with 7 additions and 31 deletions
2
gen.py
2
gen.py
|
@ -22,7 +22,7 @@ async def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
cfg = utils.load_config(args.cfg)
|
cfg = utils.load_config(args.cfg)
|
||||||
|
|
||||||
toot = utils.make_toot(cfg, mode=utils.TextGenerationMode.__members__[args.mode])
|
toot = await utils.make_post(cfg, mode=utils.TextGenerationMode.__members__[args.mode])
|
||||||
if cfg['strip_paired_punctuation']:
|
if cfg['strip_paired_punctuation']:
|
||||||
toot = re.sub(r"[\[\]\(\)\{\}\"“”«»„]", "", toot)
|
toot = re.sub(r"[\[\]\(\)\{\}\"“”«»„]", "", toot)
|
||||||
if not args.simulate:
|
if not args.simulate:
|
||||||
|
|
29
third_party/utils.py
vendored
29
third_party/utils.py
vendored
|
@ -12,6 +12,7 @@ import argparse
|
||||||
import itertools
|
import itertools
|
||||||
import json5 as json
|
import json5 as json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import anyio.to_process
|
||||||
from random import randint
|
from random import randint
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
@ -61,37 +62,13 @@ def remove_mention(cfg, sentence):
|
||||||
|
|
||||||
return sentence
|
return sentence
|
||||||
|
|
||||||
def _wrap_pipe(f):
|
async def make_post(cfg, *, mode=TextGenerationMode.markov):
|
||||||
def g(pout, *args, **kwargs):
|
|
||||||
try:
|
|
||||||
pout.send(f(*args, **kwargs))
|
|
||||||
except ValueError as exc:
|
|
||||||
pout.send(exc.args[0])
|
|
||||||
return g
|
|
||||||
|
|
||||||
def make_toot(cfg, *, mode=TextGenerationMode.markov):
|
|
||||||
toot = None
|
|
||||||
pin, pout = multiprocessing.Pipe(False)
|
|
||||||
|
|
||||||
if mode is TextGenerationMode.markov:
|
if mode is TextGenerationMode.markov:
|
||||||
from generators.markov import make_sentence
|
from generators.markov import make_sentence
|
||||||
elif mode is TextGenerationMode.gpt_2:
|
elif mode is TextGenerationMode.gpt_2:
|
||||||
from generators.gpt_2 import make_sentence
|
from generators.gpt_2 import make_sentence
|
||||||
else:
|
|
||||||
raise ValueError('Invalid text generation mode')
|
|
||||||
|
|
||||||
p = multiprocessing.Process(target=_wrap_pipe(make_sentence), args=[pout, cfg])
|
return await anyio.to_process.run_sync(make_sentence, cfg)
|
||||||
p.start()
|
|
||||||
p.join(5) # wait 5 seconds to get something
|
|
||||||
if p.is_alive(): # if it's still trying to make a toot after 5 seconds
|
|
||||||
p.terminate()
|
|
||||||
p.join()
|
|
||||||
else:
|
|
||||||
toot = pin.recv()
|
|
||||||
|
|
||||||
if toot is None:
|
|
||||||
toot = 'Toot generation failed! Contact io@csdisaster.club for assistance.'
|
|
||||||
return toot
|
|
||||||
|
|
||||||
def extract_post_content(text):
|
def extract_post_content(text):
|
||||||
soup = BeautifulSoup(text, "html.parser")
|
soup = BeautifulSoup(text, "html.parser")
|
||||||
|
|
7
utils.py
7
utils.py
|
@ -1,12 +1,11 @@
|
||||||
# SPDX-License-Identifier: AGPL-3.0-only
|
# SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
import functools
|
from functools import wraps
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
|
|
||||||
def shield(f):
|
def shield(f):
|
||||||
@functools.wraps(f)
|
@wraps(f)
|
||||||
async def shielded(*args, **kwargs):
|
async def shielded(*args, **kwargs):
|
||||||
with anyio.CancelScope(shield=True) as cs:
|
with anyio.CancelScope(shield=True):
|
||||||
return await f(*args, **kwargs)
|
return await f(*args, **kwargs)
|
||||||
return shielded
|
return shielded
|
||||||
|
|
Loading…
Reference in a new issue