add basic migration support

This commit is contained in:
io 2021-09-17 06:34:44 +00:00
parent 191214dbd6
commit b906abe2b1
3 changed files with 41 additions and 1 deletions

View file

@ -9,10 +9,10 @@ import pendulum
import operator import operator
import aiosqlite import aiosqlite
import contextlib import contextlib
from utils import shield
from pleroma import Pleroma from pleroma import Pleroma
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from functools import partial from functools import partial
from utils import shield, suppress
from typing import Iterable, NewType from typing import Iterable, NewType
from third_party.utils import extract_post_content from third_party.utils import extract_post_content
@ -26,6 +26,8 @@ UTC = pendulum.timezone('UTC')
JSON_CONTENT_TYPE = 'application/json' JSON_CONTENT_TYPE = 'application/json'
ACTIVITYPUB_CONTENT_TYPE = 'application/activity+json' ACTIVITYPUB_CONTENT_TYPE = 'application/activity+json'
MIGRATION_VERSION = 1
class PostFetcher: class PostFetcher:
def __init__(self, *, config): def __init__(self, *, config):
self.config = config self.config = config
@ -47,10 +49,26 @@ class PostFetcher:
), ),
) )
self._db = await stack.enter_async_context(aiosqlite.connect(self.config['db_path'])) self._db = await stack.enter_async_context(aiosqlite.connect(self.config['db_path']))
await self._maybe_run_migrations()
self._db.row_factory = aiosqlite.Row self._db.row_factory = aiosqlite.Row
self._ctx_stack = stack self._ctx_stack = stack
return self return self
async def _maybe_run_migrations(self):
async with self._db.cursor() as cur, suppress(aiosqlite.OperationalError):
if await (await cur.execute('SELECT migration_version FROM migrations')).fetchone(): return
await self._run_migrations()
async def _run_migrations(self):
# TODO proper migrations, not just "has the schema ever been run" migrations
async with await (anyio.Path(__file__).parent/'schema.sql').open() as f:
schema = await f.read()
async with self._db.cursor() as cur:
await cur.executescript(schema)
await cur.execute('INSERT INTO migrations (migration_version) VALUES (?)', (MIGRATION_VERSION,))
async def __aexit__(self, *excinfo): async def __aexit__(self, *excinfo):
return await self._ctx_stack.__aexit__(*excinfo) return await self._ctx_stack.__aexit__(*excinfo)

View file

@ -6,3 +6,7 @@ CREATE TABLE posts (
-- UTC Unix timestamp in seconds -- UTC Unix timestamp in seconds
published_at REAL NOT NULL published_at REAL NOT NULL
); );
CREATE TABLE migrations (
migration_version INTEGER NOT NULL
);

View file

@ -1,7 +1,25 @@
# SPDX-License-Identifier: AGPL-3.0-only # SPDX-License-Identifier: AGPL-3.0-only
import anyio import anyio
import contextlib
from functools import wraps from functools import wraps
from datetime import datetime, timezone
def as_corofunc(f):
@wraps(f)
async def wrapped(*args, **kwargs):
# can't decide if i want an `anyio.sleep(0)` here.
return f(*args, **kwargs)
return wrapped
def as_async_cm(cls):
@wraps(cls, updated=()) # cls.__dict__ doesn't support .update()
class wrapped(cls, contextlib.AbstractAsyncContextManager):
__aenter__ = as_corofunc(cls.__enter__)
__aexit__ = as_corofunc(cls.__exit__)
return wrapped
suppress = as_async_cm(contextlib.suppress)
def shield(f): def shield(f):
@wraps(f) @wraps(f)