Merge pull request #238 from highkay/main

增加Cloudflare workers ai作为llm后端
This commit is contained in:
Harry
2024-04-11 23:11:17 +08:00
committed by GitHub
5 changed files with 33 additions and 2 deletions

4
.gitignore vendored
View File

@@ -9,4 +9,6 @@
/app/utils/__pycache__/ /app/utils/__pycache__/
/*/__pycache__/* /*/__pycache__/*
.vscode .vscode
/**/.streamlit /**/.streamlit
__pycache__
logs/

View File

@@ -59,6 +59,11 @@ def _generate_response(prompt: str) -> str:
api_key = config.app.get("qwen_api_key") api_key = config.app.get("qwen_api_key")
model_name = config.app.get("qwen_model_name") model_name = config.app.get("qwen_model_name")
base_url = "***" base_url = "***"
elif llm_provider == "cloudflare":
api_key = config.app.get("cloudflare_api_key")
model_name = config.app.get("cloudflare_model_name")
account_id = config.app.get("cloudflare_account_id")
base_url = "***"
else: else:
raise ValueError("llm_provider is not set, please set it in the config.toml file.") raise ValueError("llm_provider is not set, please set it in the config.toml file.")
@@ -117,6 +122,22 @@ def _generate_response(prompt: str) -> str:
convo.send_message(prompt) convo.send_message(prompt)
return convo.last.text return convo.last.text
if llm_provider == "cloudflare":
import requests
response = requests.post(
f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}",
headers={"Authorization": f"Bearer {api_key}"},
json={
"messages": [
{"role": "system", "content": "You are a friendly assistant"},
{"role": "user", "content": prompt}
]
}
)
result = response.json()
logger.info(result)
return result["result"]["response"]
if llm_provider == "azure": if llm_provider == "azure":
client = AzureOpenAI( client = AzureOpenAI(

View File

@@ -175,7 +175,7 @@ with st.expander(tr("Basic Settings"), expanded=False):
# qwen (通义千问) # qwen (通义千问)
# gemini # gemini
# ollama # ollama
llm_providers = ['OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Gemini', 'Ollama', 'G4f', 'OneAPI'] llm_providers = ['OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Gemini', 'Ollama', 'G4f', 'OneAPI', "Cloudflare"]
saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower() saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower()
saved_llm_provider_index = 0 saved_llm_provider_index = 0
for i, provider in enumerate(llm_providers): for i, provider in enumerate(llm_providers):
@@ -190,6 +190,7 @@ with st.expander(tr("Basic Settings"), expanded=False):
llm_api_key = config.app.get(f"{llm_provider}_api_key", "") llm_api_key = config.app.get(f"{llm_provider}_api_key", "")
llm_base_url = config.app.get(f"{llm_provider}_base_url", "") llm_base_url = config.app.get(f"{llm_provider}_base_url", "")
llm_model_name = config.app.get(f"{llm_provider}_model_name", "") llm_model_name = config.app.get(f"{llm_provider}_model_name", "")
llm_account_id = config.app.get(f"{llm_provider}_account_id", "")
st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password") st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password")
st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url) st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url)
st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name) st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name)
@@ -200,6 +201,11 @@ with st.expander(tr("Basic Settings"), expanded=False):
if st_llm_model_name: if st_llm_model_name:
config.app[f"{llm_provider}_model_name"] = st_llm_model_name config.app[f"{llm_provider}_model_name"] = st_llm_model_name
if llm_provider == 'cloudflare':
st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id)
if st_llm_account_id:
config.app[f"{llm_provider}_account_id"] = st_llm_account_id
config.save_config() config.save_config()
with right_config_panel: with right_config_panel:

View File

@@ -55,6 +55,7 @@
"LLM Provider": "LLM Provider", "LLM Provider": "LLM Provider",
"API Key": "API Key (:red[Required])", "API Key": "API Key (:red[Required])",
"Base Url": "Base Url", "Base Url": "Base Url",
"Account ID": "Account ID (Get from Cloudflare dashboard)",
"Model Name": "Model Name", "Model Name": "Model Name",
"Please Enter the LLM API Key": "Please Enter the **LLM API Key**", "Please Enter the LLM API Key": "Please Enter the **LLM API Key**",
"Please Enter the Pexels API Key": "Please Enter the **Pexels API Key**", "Please Enter the Pexels API Key": "Please Enter the **Pexels API Key**",

View File

@@ -55,6 +55,7 @@
"LLM Provider": "大模型提供商", "LLM Provider": "大模型提供商",
"API Key": "API Key (:red[必填,需要到大模型提供商的后台申请])", "API Key": "API Key (:red[必填,需要到大模型提供商的后台申请])",
"Base Url": "Base Url (可选)", "Base Url": "Base Url (可选)",
"Account ID": "账户ID (Cloudflare的dash面板url中获取)",
"Model Name": "模型名称 (:blue[需要到大模型提供商的后台确认被授权的模型名称])", "Model Name": "模型名称 (:blue[需要到大模型提供商的后台确认被授权的模型名称])",
"Please Enter the LLM API Key": "请先填写大模型 **API Key**", "Please Enter the LLM API Key": "请先填写大模型 **API Key**",
"Please Enter the Pexels API Key": "请先填写 **Pexels API Key**", "Please Enter the Pexels API Key": "请先填写 **Pexels API Key**",