curry tang commited on
Commit
1a92b4b
·
1 Parent(s): f3a6c77
Files changed (4) hide show
  1. .env +1 -0
  2. app.py +89 -41
  3. config.py +1 -0
  4. llm.py +67 -0
.env CHANGED
@@ -1,2 +1,3 @@
1
  DEEP_SEEK_API_KEY=
 
2
  DEBUG=False
 
1
  DEEP_SEEK_API_KEY=
2
+ OPEN_ROUTER_API_KEY=
3
  DEBUG=False
app.py CHANGED
@@ -1,21 +1,15 @@
1
  import gradio as gr
2
- from langchain_openai import ChatOpenAI
3
  from langchain_core.messages import HumanMessage, AIMessage
4
- from llm import DeepSeekLLM
5
  from config import settings
6
 
7
 
8
  deep_seek_llm = DeepSeekLLM(api_key=settings.deep_seek_api_key)
 
9
 
10
 
11
  def init_chat():
12
- return ChatOpenAI(
13
- model=deep_seek_llm.default_model,
14
- api_key=deep_seek_llm.api_key,
15
- base_url=deep_seek_llm.base_url,
16
- temperature=deep_seek_llm.default_temperature,
17
- max_tokens=deep_seek_llm.default_max_tokens,
18
- )
19
 
20
 
21
  def predict(message, history, chat):
@@ -33,14 +27,12 @@ def predict(message, history, chat):
33
  yield response_message
34
 
35
 
36
- def update_chat(_chat, _model: str, _temperature: float, _max_tokens: int):
37
- _chat = ChatOpenAI(
38
- model=_model,
39
- api_key=deep_seek_llm.api_key,
40
- base_url=deep_seek_llm.base_url,
41
- temperature=_temperature,
42
- max_tokens=_max_tokens,
43
- )
44
  return _chat
45
 
46
 
@@ -59,30 +51,86 @@ with gr.Blocks() as app:
59
  with gr.Column(scale=1, min_width=300):
60
  with gr.Accordion('Select Model', open=True):
61
  with gr.Column():
62
- model = gr.Dropdown(
63
- label='模型',
64
- choices=deep_seek_llm.support_models,
65
- value=deep_seek_llm.default_model
66
- )
67
- temperature = gr.Slider(
68
- minimum=0.0,
69
- maximum=1.0,
70
- step=0.1,
71
- value=deep_seek_llm.default_temperature,
72
- label="Temperature",
73
- key="temperature",
74
- )
75
- max_tokens = gr.Number(
76
- minimum=1024,
77
- maximum=1024 * 20,
78
- step=128,
79
- value=deep_seek_llm.default_max_tokens,
80
- label="Max Tokens",
81
- key="max_tokens",
82
- )
83
- model.change(fn=update_chat, inputs=[chat_engine, model, temperature, max_tokens], outputs=[chat_engine])
84
- temperature.change(fn=update_chat, inputs=[chat_engine, model, temperature, max_tokens], outputs=[chat_engine])
85
- max_tokens.change(fn=update_chat, inputs=[chat_engine, model, temperature, max_tokens], outputs=[chat_engine])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  with gr.Tab('画图'):
88
  with gr.Row():
 
1
  import gradio as gr
 
2
  from langchain_core.messages import HumanMessage, AIMessage
3
+ from llm import DeepSeekLLM, OpenRouterLLM
4
  from config import settings
5
 
6
 
7
  deep_seek_llm = DeepSeekLLM(api_key=settings.deep_seek_api_key)
8
+ open_router_llm = OpenRouterLLM(api_key=settings.open_router_api_key)
9
 
10
 
11
  def init_chat():
12
+ return deep_seek_llm.get_chat_engine()
 
 
 
 
 
 
13
 
14
 
15
  def predict(message, history, chat):
 
27
  yield response_message
28
 
29
 
30
+ def update_chat(_provider: str, _chat, _model: str, _temperature: float, _max_tokens: int):
31
+ print('?????', _provider, _chat, _model, _temperature, _max_tokens)
32
+ if _provider == 'DeepSeek':
33
+ _chat = deep_seek_llm.get_chat_engine(model=_model, temperature=_temperature, max_tokens=_max_tokens)
34
+ if _provider == 'OpenRouter':
35
+ _chat = open_router_llm.get_chat_engine(model=_model, temperature=_temperature, max_tokens=_max_tokens)
 
 
36
  return _chat
37
 
38
 
 
51
  with gr.Column(scale=1, min_width=300):
52
  with gr.Accordion('Select Model', open=True):
53
  with gr.Column():
54
+ provider = gr.Dropdown(label='Provider', choices=['DeepSeek', 'OpenRouter'], value='DeepSeek')
55
+
56
+ @gr.render(inputs=provider)
57
+ def show_model_config_panel(_provider):
58
+ if _provider == 'DeepSeek':
59
+ with gr.Column():
60
+ model = gr.Dropdown(
61
+ label='模型',
62
+ choices=deep_seek_llm.support_models,
63
+ value=deep_seek_llm.default_model
64
+ )
65
+ temperature = gr.Slider(
66
+ minimum=0.0,
67
+ maximum=1.0,
68
+ step=0.1,
69
+ value=deep_seek_llm.default_temperature,
70
+ label="Temperature",
71
+ key="temperature",
72
+ )
73
+ max_tokens = gr.Number(
74
+ minimum=1024,
75
+ maximum=1024 * 20,
76
+ step=128,
77
+ value=deep_seek_llm.default_max_tokens,
78
+ label="Max Tokens",
79
+ key="max_tokens",
80
+ )
81
+ model.change(
82
+ fn=update_chat,
83
+ inputs=[provider, chat_engine, model, temperature, max_tokens],
84
+ outputs=[chat_engine],
85
+ )
86
+ temperature.change(
87
+ fn=update_chat,
88
+ inputs=[provider, chat_engine, model, temperature, max_tokens],
89
+ outputs=[chat_engine],
90
+ )
91
+ max_tokens.change(
92
+ fn=update_chat,
93
+ inputs=[provider, chat_engine, model, temperature, max_tokens],
94
+ outputs=[chat_engine],
95
+ )
96
+ if _provider == 'OpenRouter':
97
+ with gr.Column():
98
+ model = gr.Dropdown(
99
+ label='模型',
100
+ choices=open_router_llm.support_models,
101
+ value=open_router_llm.default_model
102
+ )
103
+ temperature = gr.Slider(
104
+ minimum=0.0,
105
+ maximum=1.0,
106
+ step=0.1,
107
+ value=open_router_llm.default_temperature,
108
+ label="Temperature",
109
+ key="temperature",
110
+ )
111
+ max_tokens = gr.Number(
112
+ minimum=1024,
113
+ maximum=1024 * 20,
114
+ step=128,
115
+ value=open_router_llm.default_max_tokens,
116
+ label="Max Tokens",
117
+ key="max_tokens",
118
+ )
119
+ model.change(
120
+ fn=update_chat,
121
+ inputs=[provider, chat_engine, model, temperature, max_tokens],
122
+ outputs=[chat_engine],
123
+ )
124
+ temperature.change(
125
+ fn=update_chat,
126
+ inputs=[provider, chat_engine, model, temperature, max_tokens],
127
+ outputs=[chat_engine],
128
+ )
129
+ max_tokens.change(
130
+ fn=update_chat,
131
+ inputs=[provider, chat_engine, model, temperature, max_tokens],
132
+ outputs=[chat_engine],
133
+ )
134
 
135
  with gr.Tab('画图'):
136
  with gr.Row():
config.py CHANGED
@@ -3,6 +3,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
3
 
4
  class Settings(BaseSettings):
5
  deep_seek_api_key: str
 
6
  debug: bool
7
 
8
  model_config = SettingsConfigDict(env_file=('.env', '.env.local'), env_file_encoding='utf-8')
 
3
 
4
  class Settings(BaseSettings):
5
  deep_seek_api_key: str
6
+ open_router_api_key: str
7
  debug: bool
8
 
9
  model_config = SettingsConfigDict(env_file=('.env', '.env.local'), env_file_encoding='utf-8')
llm.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import List
2
  from abc import ABC
 
3
 
4
 
5
  class DeepSeekLLM(ABC):
@@ -37,3 +38,69 @@ class DeepSeekLLM(ABC):
37
  def default_max_tokens(self) -> int:
38
  return self._default_max_tokens
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import List
2
  from abc import ABC
3
+ from langchain_openai import ChatOpenAI
4
 
5
 
6
  class DeepSeekLLM(ABC):
 
38
  def default_max_tokens(self) -> int:
39
  return self._default_max_tokens
40
 
41
+ def get_chat_engine(self, *, model: str = None, temperature: float = None, max_tokens: int = None):
42
+ model = model or self.default_model
43
+ temperature = temperature or self.default_temperature
44
+ max_tokens = max_tokens or self.default_max_tokens
45
+ return ChatOpenAI(
46
+ model=model,
47
+ api_key=self.api_key,
48
+ base_url=self.base_url,
49
+ temperature=temperature,
50
+ max_tokens=max_tokens,
51
+ )
52
+
53
+
54
+ class OpenRouterLLM(ABC):
55
+ _support_models = [
56
+ 'anthropic/claude-3.5-sonnet', 'openai/gpt-4o',
57
+ 'nvidia/nemotron-4-340b-instruct', 'deepseek/deepseek-coder',
58
+ 'google/gemini-flash-1.5', 'deepseek/deepseek-chat',
59
+ 'liuhaotian/llava-yi-34b', 'qwen/qwen-110b-chat',
60
+ 'qwen/qwen-72b-chat', 'google/gemini-pro-1.5',
61
+ 'cohere/command-r-plus', 'anthropic/claude-3-haiku',
62
+ ]
63
+ _base_url = 'https://openrouter.ai/api/v1'
64
+ _default_model = 'anthropic/claude-3.5-sonnet'
65
+ _api_key: str
66
+ _default_temperature: float = 0.5
67
+ _default_max_tokens: int = 4096
68
+
69
+ def __init__(self, *, api_key: str):
70
+ self._api_key = api_key
71
+
72
+ @property
73
+ def support_models(self) -> List[str]:
74
+ return self._support_models
75
+
76
+ @property
77
+ def default_model(self) -> str:
78
+ return self._default_model
79
+
80
+ @property
81
+ def base_url(self) -> str:
82
+ return self._base_url
83
+
84
+ @property
85
+ def api_key(self) -> str:
86
+ return self._api_key
87
+
88
+ @property
89
+ def default_temperature(self) -> float:
90
+ return self._default_temperature
91
+
92
+ @property
93
+ def default_max_tokens(self) -> int:
94
+ return self._default_max_tokens
95
+
96
+ def get_chat_engine(self, *, model: str = None, temperature: float = None, max_tokens: int = None):
97
+ model = model or self.default_model
98
+ temperature = temperature or self.default_temperature
99
+ max_tokens = max_tokens or self.default_max_tokens
100
+ return ChatOpenAI(
101
+ model=model,
102
+ api_key=self.api_key,
103
+ base_url=self.base_url,
104
+ temperature=temperature,
105
+ max_tokens=max_tokens,
106
+ )