mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 19:37:36 +08:00
feat: 完善session维护代码
This commit is contained in:
@@ -1,5 +1,9 @@
|
||||
import time
|
||||
|
||||
import pymysql
|
||||
|
||||
import config
|
||||
|
||||
inst = None
|
||||
|
||||
|
||||
@@ -26,9 +30,70 @@ class DatabaseManager:
|
||||
|
||||
def reconnect(self):
|
||||
self.conn = pymysql.connect(host=self.host, port=self.port, user=self.user, password=self.password,
|
||||
database=self.database)
|
||||
database=self.database, autocommit=True)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
def initialize_database(self):
|
||||
self.cursor.execute("""
|
||||
create table if not exists `sessions` (
|
||||
`id` bigint not null auto_increment primary key,
|
||||
`name` varchar(255) not null,
|
||||
`type` varchar(255) not null,
|
||||
`number` bigint not null,
|
||||
`create_timestamp` bigint not null,
|
||||
`last_interact_timestamp` bigint not null,
|
||||
`prompt` text not null
|
||||
)
|
||||
""")
|
||||
print('Database initialized.')
|
||||
|
||||
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
|
||||
last_interact_timestamp: int, prompt: str):
|
||||
# 检查是否已经有了此name和create_timestamp的session
|
||||
# 如果有,就更新prompt和last_interact_timestamp
|
||||
# 如果没有,就插入一条新的记录
|
||||
self.cursor.execute("""
|
||||
select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {}
|
||||
""".format(subject_type, subject_number, create_timestamp))
|
||||
count = self.cursor.fetchone()[0]
|
||||
if count == 0:
|
||||
self.cursor.execute("""
|
||||
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`)
|
||||
values ('{}', '{}', {}, {}, {}, '{}')
|
||||
""".format("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
|
||||
last_interact_timestamp, prompt))
|
||||
else:
|
||||
self.cursor.execute("""
|
||||
update `sessions` set `last_interact_timestamp` = {}, `prompt` = '{}'
|
||||
where `type` = '{}' and `number` = {} and `create_timestamp` = {}
|
||||
""".format(last_interact_timestamp, prompt, subject_type, subject_number, create_timestamp))
|
||||
|
||||
# 记载还没过期的session数据
|
||||
def load_valid_sessions(self) -> dict:
|
||||
# 从数据库中加载所有还没过期的session
|
||||
self.cursor.execute("""
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`
|
||||
from `sessions` where `last_interact_timestamp` > {}
|
||||
""".format(int(time.time()) - config.session_expire_time))
|
||||
results = self.cursor.fetchall()
|
||||
sessions = {}
|
||||
for result in results:
|
||||
session_name = result[0]
|
||||
subject_type = result[1]
|
||||
subject_number = result[2]
|
||||
create_timestamp = result[3]
|
||||
last_interact_timestamp = result[4]
|
||||
prompt = result[5]
|
||||
|
||||
sessions[session_name] = {
|
||||
'subject_type': subject_type,
|
||||
'subject_number': subject_number,
|
||||
'create_timestamp': create_timestamp,
|
||||
'last_interact_timestamp': last_interact_timestamp,
|
||||
'prompt': prompt
|
||||
}
|
||||
return sessions
|
||||
|
||||
|
||||
def get_inst() -> DatabaseManager:
|
||||
global inst
|
||||
|
||||
Reference in New Issue
Block a user