mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 19:37:36 +08:00
style: introduce ruff as linter and formatter (#1356)
* style: remove necessary imports * style: fix F841 * style: fix F401 * style: fix F811 * style: fix E402 * style: fix E721 * style: fix E722 * style: fix E722 * style: fix F541 * style: ruff format * style: all passed * style: add ruff in deps * style: more ignores in ruff.toml * style: add pre-commit
This commit is contained in:
committed by
GitHub
parent
09e70d70e9
commit
209f16af76
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import typing
|
||||
import json
|
||||
@@ -10,12 +9,16 @@ import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
|
||||
import sqlalchemy
|
||||
|
||||
from . import database, migration
|
||||
from ..entity.persistence import base, user, model, pipeline, bot, plugin, metadata
|
||||
from ..entity.persistence import base, pipeline, metadata
|
||||
from ..entity import persistence
|
||||
from ..core import app
|
||||
from .databases import sqlite
|
||||
from ..utils import constants
|
||||
from .migrations import dbm001_migrate_v3_config
|
||||
from ..utils import constants, importutil
|
||||
from ..api.http.service import pipeline as pipeline_service
|
||||
from . import databases, migrations
|
||||
|
||||
importutil.import_modules_in_pkg(databases)
|
||||
importutil.import_modules_in_pkg(migrations)
|
||||
importutil.import_modules_in_pkg(persistence)
|
||||
|
||||
|
||||
class PersistenceManager:
|
||||
@@ -33,9 +36,8 @@ class PersistenceManager:
|
||||
self.meta = base.Base.metadata
|
||||
|
||||
async def initialize(self):
|
||||
self.ap.logger.info('Initializing database...')
|
||||
|
||||
self.ap.logger.info("Initializing database...")
|
||||
|
||||
for manager in database.preregistered_managers:
|
||||
self.db = manager(self.ap)
|
||||
await self.db.initialize()
|
||||
@@ -43,7 +45,6 @@ class PersistenceManager:
|
||||
await self.create_tables()
|
||||
|
||||
async def create_tables(self):
|
||||
|
||||
# create tables
|
||||
async with self.get_db_engine().connect() as conn:
|
||||
await conn.run_sync(self.meta.create_all)
|
||||
@@ -53,26 +54,28 @@ class PersistenceManager:
|
||||
# ======= write initial data =======
|
||||
|
||||
# write initial metadata
|
||||
self.ap.logger.info("Creating initial metadata...")
|
||||
self.ap.logger.info('Creating initial metadata...')
|
||||
for item in metadata.initial_metadata:
|
||||
# check if the item exists
|
||||
result = await self.execute_async(
|
||||
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == item['key'])
|
||||
sqlalchemy.select(metadata.Metadata).where(
|
||||
metadata.Metadata.key == item['key']
|
||||
)
|
||||
)
|
||||
row = result.first()
|
||||
if row is None:
|
||||
await self.execute_async(
|
||||
sqlalchemy.insert(metadata.Metadata).values(item)
|
||||
)
|
||||
|
||||
# write default pipeline
|
||||
result = await self.execute_async(
|
||||
sqlalchemy.select(pipeline.LegacyPipeline)
|
||||
)
|
||||
if result.first() is None:
|
||||
self.ap.logger.info("Creating default pipeline...")
|
||||
|
||||
pipeline_config = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8'))
|
||||
# write default pipeline
|
||||
result = await self.execute_async(sqlalchemy.select(pipeline.LegacyPipeline))
|
||||
if result.first() is None:
|
||||
self.ap.logger.info('Creating default pipeline...')
|
||||
|
||||
pipeline_config = json.load(
|
||||
open('templates/default-pipeline-config.json', 'r', encoding='utf-8')
|
||||
)
|
||||
|
||||
pipeline_data = {
|
||||
'uuid': str(uuid.uuid4()),
|
||||
@@ -91,7 +94,9 @@ class PersistenceManager:
|
||||
|
||||
# run migrations
|
||||
database_version = await self.execute_async(
|
||||
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
|
||||
sqlalchemy.select(metadata.Metadata).where(
|
||||
metadata.Metadata.key == 'database_version'
|
||||
)
|
||||
)
|
||||
|
||||
database_version = int(database_version.fetchone()[1])
|
||||
@@ -106,24 +111,27 @@ class PersistenceManager:
|
||||
for migration_cls in migrations:
|
||||
migration_instance = migration_cls(self.ap)
|
||||
|
||||
if migration_instance.number > database_version and migration_instance.number <= required_database_version:
|
||||
if (
|
||||
migration_instance.number > database_version
|
||||
and migration_instance.number <= required_database_version
|
||||
):
|
||||
await migration_instance.upgrade()
|
||||
await self.execute_async(
|
||||
sqlalchemy.update(metadata.Metadata).where(metadata.Metadata.key == 'database_version').values(
|
||||
{
|
||||
'value': str(migration_instance.number)
|
||||
}
|
||||
)
|
||||
sqlalchemy.update(metadata.Metadata)
|
||||
.where(metadata.Metadata.key == 'database_version')
|
||||
.values({'value': str(migration_instance.number)})
|
||||
)
|
||||
last_migration_number = migration_instance.number
|
||||
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
|
||||
|
||||
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
|
||||
self.ap.logger.info(
|
||||
f'Migration {migration_instance.number} completed.'
|
||||
)
|
||||
|
||||
self.ap.logger.info(
|
||||
f'Successfully upgraded database to version {last_migration_number}.'
|
||||
)
|
||||
|
||||
async def execute_async(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
self, *args, **kwargs
|
||||
) -> sqlalchemy.engine.cursor.CursorResult:
|
||||
async with self.get_db_engine().connect() as conn:
|
||||
result = await conn.execute(*args, **kwargs)
|
||||
@@ -132,9 +140,13 @@ class PersistenceManager:
|
||||
|
||||
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
|
||||
return self.db.get_engine()
|
||||
|
||||
def serialize_model(self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base) -> dict:
|
||||
|
||||
def serialize_model(
|
||||
self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base
|
||||
) -> dict:
|
||||
return {
|
||||
column.name: getattr(data, column.name) if not isinstance(getattr(data, column.name), (datetime.datetime)) else getattr(data, column.name).isoformat()
|
||||
column.name: getattr(data, column.name)
|
||||
if not isinstance(getattr(data, column.name), (datetime.datetime))
|
||||
else getattr(data, column.name).isoformat()
|
||||
for column in model.__table__.columns
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user