cboettig commited on
Commit
1ec6b81
·
1 Parent(s): 179db2b
Files changed (1) hide show
  1. test.R +11 -13
test.R CHANGED
@@ -15,22 +15,27 @@ schema <- read_file("schema.yml")
15
  system_prompt <- glue::glue(readr::read_file("system-prompt.md"),
16
  .open = "<", .close = ">")
17
 
 
 
 
18
  chat <- ellmer::chat_vllm(
19
- base_url = "https://llm.nrp-nautilus.io/",
20
- model = "llama3",
21
- api_key = Sys.getenv("NRP_API_KEY"),
22
  system_prompt = system_prompt,
23
  api_args = list(temperature = 0)
24
  )
25
 
 
26
  chat <- ellmer::chat_vllm(
27
- base_url = "https://llm.cirrus.carlboettiger.info/v1/",
28
- model = "kosbu/Llama-3.3-70B-Instruct-AWQ",
29
- api_key = Sys.getenv("CIRRUS_LLM_KEY"),
30
  system_prompt = system_prompt,
31
  api_args = list(temperature = 0)
32
  )
33
 
 
34
  # Test a chat-based response
35
  chat$chat("Which columns describes racial components of social vulnerability?")
36
  ## A query-based response
@@ -39,13 +44,6 @@ response <- jsonlite::fromJSON(stream)
39
 
40
  con <- duckdbfs::cached_connection()
41
  filtered_data <- DBI::dbGetQuery(con, response$query)
42
- full_data <- svi
43
- response_query <- "
44
- SELECT COUNTY, AVG(RPL_THEME1) as avg_soc_vuln FROM
45
- svi WHERE STATE = 'California' GROUP BY COUNTY ORDER BY
46
- avg_soc_vuln DESC LIMIT 10;
47
- "
48
-
49
 
50
  filter_column <- function(full_data, filtered_data, id_col) {
51
  if (nrow(filtered_data) < 1) return(NULL)
 
15
  system_prompt <- glue::glue(readr::read_file("system-prompt.md"),
16
  .open = "<", .close = ">")
17
 
18
+
19
+
20
+ # Or optionally test with cirrus
21
  chat <- ellmer::chat_vllm(
22
+ base_url = "https://llm.cirrus.carlboettiger.info/v1/",
23
+ model = "kosbu/Llama-3.3-70B-Instruct-AWQ",
24
+ api_key = Sys.getenv("CIRRUS_LLM_KEY"),
25
  system_prompt = system_prompt,
26
  api_args = list(temperature = 0)
27
  )
28
 
29
+ # or use the NRP model
30
  chat <- ellmer::chat_vllm(
31
+ base_url = "https://llm.nrp-nautilus.io/",
32
+ model = "llama3",
33
+ api_key = Sys.getenv("NRP_API_KEY"),
34
  system_prompt = system_prompt,
35
  api_args = list(temperature = 0)
36
  )
37
 
38
+
39
  # Test a chat-based response
40
  chat$chat("Which columns describes racial components of social vulnerability?")
41
  ## A query-based response
 
44
 
45
  con <- duckdbfs::cached_connection()
46
  filtered_data <- DBI::dbGetQuery(con, response$query)
 
 
 
 
 
 
 
47
 
48
  filter_column <- function(full_data, filtered_data, id_col) {
49
  if (nrow(filtered_data) < 1) return(NULL)