maxschulz-COL commited on
Commit
9fa94f5
·
1 Parent(s): 1c98fd4

Update based on latest release

Browse files
Files changed (4) hide show
  1. actions.py +10 -6
  2. app.py +0 -2
  3. requirements.in +1 -1
  4. requirements.txt +2 -19
actions.py CHANGED
@@ -34,9 +34,7 @@ SUPPORTED_MODELS = {
34
  "OpenAI": [
35
  "gpt-4o-mini",
36
  "gpt-4o",
37
- "gpt-4",
38
  "gpt-4-turbo",
39
- "gpt-3.5-turbo",
40
  ],
41
  "Anthropic": [
42
  "claude-3-opus-latest",
@@ -46,6 +44,8 @@ SUPPORTED_MODELS = {
46
  ],
47
  "Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
48
  }
 
 
49
 
50
 
51
  def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input):
@@ -53,14 +53,18 @@ def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input):
53
  vendor = SUPPORTED_VENDORS[vendor_input]
54
 
55
  if vendor_input == "OpenAI":
56
- llm = vendor(model_name=model, openai_api_key=api_key, openai_api_base=api_base)
 
 
57
  if vendor_input == "Anthropic":
58
- llm = vendor(model=model, anthropic_api_key=api_key, anthropic_api_url=api_base)
 
 
59
  if vendor_input == "Mistral":
60
- llm = vendor(model=model, mistral_api_key=api_key, mistral_api_url=api_base)
61
 
62
  vizro_ai = VizroAI(model=llm)
63
- ai_outputs = vizro_ai.plot(df, user_prompt, return_elements=True)
64
 
65
  return ai_outputs
66
 
 
34
  "OpenAI": [
35
  "gpt-4o-mini",
36
  "gpt-4o",
 
37
  "gpt-4-turbo",
 
38
  ],
39
  "Anthropic": [
40
  "claude-3-opus-latest",
 
44
  ],
45
  "Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"],
46
  }
47
+ DEFAULT_TEMPERATURE = 0.1
48
+ DEFAULT_RETRY = 3
49
 
50
 
51
  def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input):
 
53
  vendor = SUPPORTED_VENDORS[vendor_input]
54
 
55
  if vendor_input == "OpenAI":
56
+ llm = vendor(
57
+ model_name=model, openai_api_key=api_key, openai_api_base=api_base, temperature=DEFAULT_TEMPERATURE
58
+ )
59
  if vendor_input == "Anthropic":
60
+ llm = vendor(
61
+ model=model, anthropic_api_key=api_key, anthropic_api_url=api_base, temperature=DEFAULT_TEMPERATURE
62
+ )
63
  if vendor_input == "Mistral":
64
+ llm = vendor(model=model, mistral_api_key=api_key, mistral_api_url=api_base, temperature=DEFAULT_TEMPERATURE)
65
 
66
  vizro_ai = VizroAI(model=llm)
67
+ ai_outputs = vizro_ai.plot(df, user_prompt, max_debug_retry=DEFAULT_RETRY, return_elements=True)
68
 
69
  return ai_outputs
70
 
app.py CHANGED
@@ -61,9 +61,7 @@ SUPPORTED_MODELS = {
61
  "OpenAI": [
62
  "gpt-4o-mini",
63
  "gpt-4o",
64
- "gpt-4",
65
  "gpt-4-turbo",
66
- "gpt-3.5-turbo",
67
  ],
68
  "Anthropic": [
69
  "claude-3-opus-latest",
 
61
  "OpenAI": [
62
  "gpt-4o-mini",
63
  "gpt-4o",
 
64
  "gpt-4-turbo",
 
65
  ],
66
  "Anthropic": [
67
  "claude-3-opus-latest",
requirements.in CHANGED
@@ -1,5 +1,5 @@
1
  gunicorn
2
- vizro-ai>=0.3.0
3
  black
4
  openpyxl
5
  langchain_anthropic
 
1
  gunicorn
2
+ vizro-ai>=0.3.1
3
  black
4
  openpyxl
5
  langchain_anthropic
requirements.txt CHANGED
@@ -15,10 +15,6 @@ anyio==4.4.0
15
  # anthropic
16
  # httpx
17
  # openai
18
- async-timeout==4.0.3
19
- # via
20
- # aiohttp
21
- # langchain
22
  attrs==24.2.0
23
  # via aiohttp
24
  autoflake==2.3.1
@@ -67,8 +63,6 @@ distro==1.9.0
67
  # openai
68
  et-xmlfile==1.1.0
69
  # via openpyxl
70
- exceptiongroup==1.2.2
71
- # via anyio
72
  filelock==3.16.1
73
  # via huggingface-hub
74
  flask==3.0.3
@@ -83,8 +77,6 @@ frozenlist==1.4.1
83
  # aiosignal
84
  fsspec==2024.10.0
85
  # via huggingface-hub
86
- greenlet==3.1.0
87
- # via sqlalchemy
88
  gunicorn==23.0.0
89
  # via -r requirements.in
90
  h11==0.14.0
@@ -108,9 +100,7 @@ idna==3.8
108
  # requests
109
  # yarl
110
  importlib-metadata==8.5.0
111
- # via
112
- # dash
113
- # flask
114
  itsdangerous==2.2.0
115
  # via flask
116
  jinja2==3.1.4
@@ -256,10 +246,6 @@ tokenizers==0.20.1
256
  # via
257
  # anthropic
258
  # langchain-mistralai
259
- tomli==2.0.1
260
- # via
261
- # autoflake
262
- # black
263
  tqdm==4.66.5
264
  # via
265
  # huggingface-hub
@@ -267,12 +253,9 @@ tqdm==4.66.5
267
  typing-extensions==4.12.2
268
  # via
269
  # anthropic
270
- # anyio
271
- # black
272
  # dash
273
  # huggingface-hub
274
  # langchain-core
275
- # multidict
276
  # openai
277
  # pydantic
278
  # pydantic-core
@@ -283,7 +266,7 @@ urllib3==2.2.3
283
  # via requests
284
  vizro==0.1.23
285
  # via vizro-ai
286
- vizro-ai==0.3.0
287
  # via -r requirements.in
288
  werkzeug==3.0.4
289
  # via
 
15
  # anthropic
16
  # httpx
17
  # openai
 
 
 
 
18
  attrs==24.2.0
19
  # via aiohttp
20
  autoflake==2.3.1
 
63
  # openai
64
  et-xmlfile==1.1.0
65
  # via openpyxl
 
 
66
  filelock==3.16.1
67
  # via huggingface-hub
68
  flask==3.0.3
 
77
  # aiosignal
78
  fsspec==2024.10.0
79
  # via huggingface-hub
 
 
80
  gunicorn==23.0.0
81
  # via -r requirements.in
82
  h11==0.14.0
 
100
  # requests
101
  # yarl
102
  importlib-metadata==8.5.0
103
+ # via dash
 
 
104
  itsdangerous==2.2.0
105
  # via flask
106
  jinja2==3.1.4
 
246
  # via
247
  # anthropic
248
  # langchain-mistralai
 
 
 
 
249
  tqdm==4.66.5
250
  # via
251
  # huggingface-hub
 
253
  typing-extensions==4.12.2
254
  # via
255
  # anthropic
 
 
256
  # dash
257
  # huggingface-hub
258
  # langchain-core
 
259
  # openai
260
  # pydantic
261
  # pydantic-core
 
266
  # via requests
267
  vizro==0.1.23
268
  # via vizro-ai
269
+ vizro-ai==0.3.1
270
  # via -r requirements.in
271
  werkzeug==3.0.4
272
  # via