feat: 完善session维护代码

This commit is contained in:
Rock Chin
2022-12-08 00:41:35 +08:00
parent da3b964a8c
commit 9ca641a2e5
3 changed files with 132 additions and 9 deletions

View File

@@ -1,5 +1,9 @@
import time
import pymysql
import config
inst = None
@@ -26,9 +30,70 @@ class DatabaseManager:
def reconnect(self):
self.conn = pymysql.connect(host=self.host, port=self.port, user=self.user, password=self.password,
database=self.database)
database=self.database, autocommit=True)
self.cursor = self.conn.cursor()
def initialize_database(self):
self.cursor.execute("""
create table if not exists `sessions` (
`id` bigint not null auto_increment primary key,
`name` varchar(255) not null,
`type` varchar(255) not null,
`number` bigint not null,
`create_timestamp` bigint not null,
`last_interact_timestamp` bigint not null,
`prompt` text not null
)
""")
print('Database initialized.')
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
last_interact_timestamp: int, prompt: str):
# 检查是否已经有了此name和create_timestamp的session
# 如果有就更新prompt和last_interact_timestamp
# 如果没有,就插入一条新的记录
self.cursor.execute("""
select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {}
""".format(subject_type, subject_number, create_timestamp))
count = self.cursor.fetchone()[0]
if count == 0:
self.cursor.execute("""
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`)
values ('{}', '{}', {}, {}, {}, '{}')
""".format("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
last_interact_timestamp, prompt))
else:
self.cursor.execute("""
update `sessions` set `last_interact_timestamp` = {}, `prompt` = '{}'
where `type` = '{}' and `number` = {} and `create_timestamp` = {}
""".format(last_interact_timestamp, prompt, subject_type, subject_number, create_timestamp))
# 记载还没过期的session数据
def load_valid_sessions(self) -> dict:
# 从数据库中加载所有还没过期的session
self.cursor.execute("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`
from `sessions` where `last_interact_timestamp` > {}
""".format(int(time.time()) - config.session_expire_time))
results = self.cursor.fetchall()
sessions = {}
for result in results:
session_name = result[0]
subject_type = result[1]
subject_number = result[2]
create_timestamp = result[3]
last_interact_timestamp = result[4]
prompt = result[5]
sessions[session_name] = {
'subject_type': subject_type,
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt
}
return sessions
def get_inst() -> DatabaseManager:
global inst