FIRE / src /serve /api_provider.py
zhangbofei
fix: src
2238fe2
raw
history blame
32.7 kB
"""Call API providers."""
import json
import os
import random
import re
from typing import Optional
import time
import requests
from src.utils import build_logger
logger = build_logger("gradio_web_server", "gradio_web_server.log")
def get_api_provider_stream_iter(
conv,
model_name,
model_api_dict,
temperature,
top_p,
max_new_tokens,
state,
):
if model_api_dict["api_type"] == "openai":
if model_api_dict["vision-arena"]:
prompt = conv.to_openai_vision_api_messages()
else:
prompt = conv.to_openai_api_messages()
stream_iter = openai_api_stream_iter(
model_api_dict["model_name"],
prompt,
temperature,
top_p,
max_new_tokens,
api_base=model_api_dict["api_base"],
api_key=model_api_dict["api_key"],
)
elif model_api_dict["api_type"] == "openai_assistant":
last_prompt = conv.messages[-2][1]
stream_iter = openai_assistant_api_stream_iter(
state,
last_prompt,
assistant_id=model_api_dict["assistant_id"],
api_key=model_api_dict["api_key"],
)
elif model_api_dict["api_type"] == "anthropic":
if model_api_dict["vision-arena"]:
prompt = conv.to_anthropic_vision_api_messages()
else:
prompt = conv.to_openai_api_messages()
stream_iter = anthropic_api_stream_iter(
model_name, prompt, temperature, top_p, max_new_tokens
)
elif model_api_dict["api_type"] == "anthropic_message":
if model_api_dict["vision-arena"]:
prompt = conv.to_anthropic_vision_api_messages()
else:
prompt = conv.to_openai_api_messages()
stream_iter = anthropic_message_api_stream_iter(
model_name, prompt, temperature, top_p, max_new_tokens
)
elif model_api_dict["api_type"] == "anthropic_message_vertex":
if model_api_dict["vision-arena"]:
prompt = conv.to_anthropic_vision_api_messages()
else:
prompt = conv.to_openai_api_messages()
stream_iter = anthropic_message_api_stream_iter(
model_api_dict["model_name"],
prompt,
temperature,
top_p,
max_new_tokens,
vertex_ai=True,
)
elif model_api_dict["api_type"] == "gemini":
prompt = conv.to_gemini_api_messages()
stream_iter = gemini_api_stream_iter(
model_api_dict["model_name"],
prompt,
temperature,
top_p,
max_new_tokens,
api_key=model_api_dict["api_key"],
)
elif model_api_dict["api_type"] == "bard":
prompt = conv.to_openai_api_messages()
stream_iter = bard_api_stream_iter(
model_api_dict["model_name"],
prompt,
temperature,
top_p,
api_key=model_api_dict["api_key"],
)
elif model_api_dict["api_type"] == "mistral":
prompt = conv.to_openai_api_messages()
stream_iter = mistral_api_stream_iter(
model_name, prompt, temperature, top_p, max_new_tokens
)
elif model_api_dict["api_type"] == "nvidia":
prompt = conv.to_openai_api_messages()
stream_iter = nvidia_api_stream_iter(
model_name,
prompt,
temperature,
top_p,
max_new_tokens,
model_api_dict["api_base"],
)
elif model_api_dict["api_type"] == "ai2":
prompt = conv.to_openai_api_messages()
stream_iter = ai2_api_stream_iter(
model_name,
model_api_dict["model_name"],
prompt,
temperature,
top_p,
max_new_tokens,
api_base=model_api_dict["api_base"],
api_key=model_api_dict["api_key"],
)
elif model_api_dict["api_type"] == "vertex":
prompt = conv.to_vertex_api_messages()
stream_iter = vertex_api_stream_iter(
model_name, prompt, temperature, top_p, max_new_tokens
)
elif model_api_dict["api_type"] == "yandexgpt":
# note: top_p parameter is unused by yandexgpt
messages = []
if conv.system_message:
messages.append({"role": "system", "text": conv.system_message})
messages += [
{"role": role, "text": text}
for role, text in conv.messages
if text is not None
]
fixed_temperature = model_api_dict.get("fixed_temperature")
if fixed_temperature is not None:
temperature = fixed_temperature
stream_iter = yandexgpt_api_stream_iter(
model_name=model_api_dict["model_name"],
messages=messages,
temperature=temperature,
max_tokens=max_new_tokens,
api_base=model_api_dict["api_base"],
api_key=model_api_dict.get("api_key"),
folder_id=model_api_dict.get("folder_id"),
)
elif model_api_dict["api_type"] == "cohere":
messages = conv.to_openai_api_messages()
stream_iter = cohere_api_stream_iter(
client_name=model_api_dict.get("client_name", "FastChat"),
model_id=model_api_dict["model_name"],
messages=messages,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
api_base=model_api_dict["api_base"],
api_key=model_api_dict["api_key"],
)
elif model_api_dict["api_type"] == "reka":
messages = conv.to_reka_api_messages()
stream_iter = reka_api_stream_iter(
model_name=model_api_dict["model_name"],
messages=messages,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
api_base=model_api_dict["api_base"],
api_key=model_api_dict["api_key"],
)
else:
raise NotImplementedError()
return stream_iter
def openai_api_stream_iter(
model_name,
messages,
temperature,
top_p,
max_new_tokens,
api_base=None,
api_key=None,
):
import openai
api_key = api_key or os.environ["OPENAI_API_KEY"]
if "azure" in model_name:
client = openai.AzureOpenAI(
api_version="2023-07-01-preview",
azure_endpoint=api_base or "https://api.openai.com/v1",
api_key=api_key,
)
else:
client = openai.OpenAI(
base_url=api_base or "https://api.openai.com/v1",
api_key=api_key,
timeout=180,
)
# Make requests for logging
text_messages = []
for message in messages:
if type(message["content"]) == str: # text-only model
text_messages.append(message)
else: # vision model
filtered_content_list = [
content for content in message["content"] if content["type"] == "text"
]
text_messages.append(
{"role": message["role"], "content": filtered_content_list}
)
gen_params = {
"model": model_name,
"prompt": text_messages,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
res = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=temperature,
max_tokens=max_new_tokens,
stream=True,
)
text = ""
for chunk in res:
if len(chunk.choices) > 0:
text += chunk.choices[0].delta.content or ""
data = {
"text": text,
"error_code": 0,
}
yield data
def upload_openai_file_to_gcs(file_id):
import openai
from google.cloud import storage
storage_client = storage.Client()
file = openai.files.content(file_id)
# upload file to GCS
bucket = storage_client.get_bucket("arena_user_content")
blob = bucket.blob(f"{file_id}")
blob.upload_from_string(file.read())
blob.make_public()
return blob.public_url
def openai_assistant_api_stream_iter(
state,
prompt,
assistant_id,
api_key=None,
):
import openai
import base64
api_key = api_key or os.environ["OPENAI_API_KEY"]
client = openai.OpenAI(base_url="https://api.openai.com/v1", api_key=api_key)
if state.oai_thread_id is None:
logger.info("==== create thread ====")
thread = client.beta.threads.create()
state.oai_thread_id = thread.id
logger.info(f"==== thread_id ====\n{state.oai_thread_id}")
thread_message = client.beta.threads.messages.with_raw_response.create(
state.oai_thread_id,
role="user",
content=prompt,
timeout=3,
)
# logger.info(f"header {thread_message.headers}")
thread_message = thread_message.parse()
# Make requests
gen_params = {
"assistant_id": assistant_id,
"thread_id": state.oai_thread_id,
"message": prompt,
}
logger.info(f"==== request ====\n{gen_params}")
res = requests.post(
f"https://api.openai.com/v1/threads/{state.oai_thread_id}/runs",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"OpenAI-Beta": "assistants=v1",
},
json={"assistant_id": assistant_id, "stream": True},
timeout=30,
stream=True,
)
list_of_text = []
list_of_raw_text = []
offset_idx = 0
full_ret_text = ""
idx_mapping = {}
for line in res.iter_lines():
if not line:
continue
data = line.decode("utf-8")
# logger.info("data:", data)
if data.endswith("[DONE]"):
break
if data.startswith("event"):
event = data.split(":")[1].strip()
if event == "thread.message.completed":
offset_idx += len(list_of_text)
continue
data = json.loads(data[6:])
if data.get("status") == "failed":
yield {
"text": f"**API REQUEST ERROR** Reason: {data['last_error']['message']}",
"error_code": 1,
}
return
if data.get("status") == "completed":
logger.info(f"[debug]: {data}")
if data["object"] != "thread.message.delta":
continue
for delta in data["delta"]["content"]:
text_index = delta["index"] + offset_idx
if len(list_of_text) <= text_index:
list_of_text.append("")
list_of_raw_text.append("")
text = list_of_text[text_index]
raw_text = list_of_raw_text[text_index]
if delta["type"] == "text":
# text, url_citation or file_path
content = delta["text"]
if "annotations" in content and len(content["annotations"]) > 0:
annotations = content["annotations"]
cur_offset = 0
raw_text_copy = raw_text
for anno in annotations:
if anno["type"] == "url_citation":
anno_text = anno["text"]
if anno_text not in idx_mapping:
continue
citation_number = idx_mapping[anno_text]
start_idx = anno["start_index"] + cur_offset
end_idx = anno["end_index"] + cur_offset
url = anno["url_citation"]["url"]
citation = f" [[{citation_number}]]({url})"
raw_text_copy = (
raw_text_copy[:start_idx]
+ citation
+ raw_text_copy[end_idx:]
)
cur_offset += len(citation) - (end_idx - start_idx)
elif anno["type"] == "file_path":
file_public_url = upload_openai_file_to_gcs(
anno["file_path"]["file_id"]
)
raw_text_copy = raw_text_copy.replace(
anno["text"], f"{file_public_url}"
)
text = raw_text_copy
else:
text_content = content["value"]
raw_text += text_content
# re-index citation number
pattern = r"【\d+】"
matches = re.findall(pattern, content["value"])
if len(matches) > 0:
for match in matches:
if match not in idx_mapping:
idx_mapping[match] = len(idx_mapping) + 1
citation_number = idx_mapping[match]
text_content = text_content.replace(
match, f" [{citation_number}]"
)
text += text_content
# yield {"text": text, "error_code": 0}
elif delta["type"] == "image_file":
image_public_url = upload_openai_file_to_gcs(
delta["image_file"]["file_id"]
)
# raw_text += f"![image]({image_public_url})"
text += f"![image]({image_public_url})"
list_of_text[text_index] = text
list_of_raw_text[text_index] = raw_text
full_ret_text = "\n".join(list_of_text)
yield {"text": full_ret_text, "error_code": 0}
def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens):
import anthropic
c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
# Make requests
gen_params = {
"model": model_name,
"prompt": prompt,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
res = c.completions.create(
prompt=prompt,
stop_sequences=[anthropic.HUMAN_PROMPT],
max_tokens_to_sample=max_new_tokens,
temperature=temperature,
top_p=top_p,
model=model_name,
stream=True,
)
text = ""
for chunk in res:
text += chunk.completion
data = {
"text": text,
"error_code": 0,
}
yield data
def anthropic_message_api_stream_iter(
model_name,
messages,
temperature,
top_p,
max_new_tokens,
vertex_ai=False,
):
import anthropic
if vertex_ai:
client = anthropic.AnthropicVertex(
region=os.environ["GCP_LOCATION"],
project_id=os.environ["GCP_PROJECT_ID"],
max_retries=5,
)
else:
client = anthropic.Anthropic(
api_key=os.environ["ANTHROPIC_API_KEY"],
max_retries=5,
)
text_messages = []
for message in messages:
if type(message["content"]) == str: # text-only model
text_messages.append(message)
else: # vision model
filtered_content_list = [
content for content in message["content"] if content["type"] == "text"
]
text_messages.append(
{"role": message["role"], "content": filtered_content_list}
)
# Make requests for logging
gen_params = {
"model": model_name,
"prompt": text_messages,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
system_prompt = ""
if messages[0]["role"] == "system":
if type(messages[0]["content"]) == dict:
system_prompt = messages[0]["content"]["text"]
elif type(messages[0]["content"]) == str:
system_prompt = messages[0]["content"]
# remove system prompt
messages = messages[1:]
text = ""
with client.messages.stream(
temperature=temperature,
top_p=top_p,
max_tokens=max_new_tokens,
messages=messages,
model=model_name,
system=system_prompt,
) as stream:
for chunk in stream.text_stream:
text += chunk
data = {
"text": text,
"error_code": 0,
}
yield data
def gemini_api_stream_iter(
model_name, messages, temperature, top_p, max_new_tokens, api_key=None
):
import google.generativeai as genai # pip install google-generativeai
if api_key is None:
api_key = os.environ["GEMINI_API_KEY"]
genai.configure(api_key=api_key)
generation_config = {
"temperature": temperature,
"max_output_tokens": max_new_tokens,
"top_p": top_p,
}
params = {
"model": model_name,
"prompt": messages,
}
params.update(generation_config)
logger.info(f"==== request ====\n{params}")
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
]
history = []
system_prompt = None
for message in messages[:-1]:
if message["role"] == "system":
system_prompt = message["content"]
continue
history.append({"role": message["role"], "parts": message["content"]})
model = genai.GenerativeModel(
model_name=model_name,
system_instruction=system_prompt,
generation_config=generation_config,
safety_settings=safety_settings,
)
convo = model.start_chat(history=history)
response = convo.send_message(messages[-1]["content"], stream=True)
try:
text = ""
for chunk in response:
text += chunk.candidates[0].content.parts[0].text
data = {
"text": text,
"error_code": 0,
}
yield data
except Exception as e:
logger.error(f"==== error ====\n{e}")
reason = chunk.candidates
yield {
"text": f"**API REQUEST ERROR** Reason: {reason}.",
"error_code": 1,
}
def bard_api_stream_iter(model_name, conv, temperature, top_p, api_key=None):
del top_p # not supported
del temperature # not supported
if api_key is None:
api_key = os.environ["BARD_API_KEY"]
# convert conv to conv_bard
conv_bard = []
for turn in conv:
if turn["role"] == "user":
conv_bard.append({"author": "0", "content": turn["content"]})
elif turn["role"] == "assistant":
conv_bard.append({"author": "1", "content": turn["content"]})
else:
raise ValueError(f"Unsupported role: {turn['role']}")
params = {
"model": model_name,
"prompt": conv_bard,
}
logger.info(f"==== request ====\n{params}")
try:
res = requests.post(
f"https://generativelanguage.googleapis.com/v1beta2/models/{model_name}:generateMessage?key={api_key}",
json={
"prompt": {
"messages": conv_bard,
},
},
timeout=30,
)
except Exception as e:
logger.error(f"==== error ====\n{e}")
yield {
"text": f"**API REQUEST ERROR** Reason: {e}.",
"error_code": 1,
}
if res.status_code != 200:
logger.error(f"==== error ==== ({res.status_code}): {res.text}")
yield {
"text": f"**API REQUEST ERROR** Reason: status code {res.status_code}.",
"error_code": 1,
}
response_json = res.json()
if "candidates" not in response_json:
logger.error(f"==== error ==== response blocked: {response_json}")
reason = response_json["filters"][0]["reason"]
yield {
"text": f"**API REQUEST ERROR** Reason: {reason}.",
"error_code": 1,
}
response = response_json["candidates"][0]["content"]
pos = 0
while pos < len(response):
# simulate token streaming
pos += random.randint(3, 6)
time.sleep(0.002)
data = {
"text": response[:pos],
"error_code": 0,
}
yield data
def ai2_api_stream_iter(
model_name,
model_id,
messages,
temperature,
top_p,
max_new_tokens,
api_key=None,
api_base=None,
):
# get keys and needed values
ai2_key = api_key or os.environ.get("AI2_API_KEY")
api_base = api_base or "https://inferd.allen.ai/api/v1/infer"
# Make requests
gen_params = {
"model": model_name,
"prompt": messages,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
# AI2 uses vLLM, which requires that `top_p` be 1.0 for greedy sampling:
# https://github.com/vllm-project/vllm/blob/v0.1.7/vllm/sampling_params.py#L156-L157
if temperature == 0.0 and top_p < 1.0:
raise ValueError("top_p must be 1 when temperature is 0.0")
res = requests.post(
api_base,
stream=True,
headers={"Authorization": f"Bearer {ai2_key}"},
json={
"model_id": model_id,
# This input format is specific to the Tulu2 model. Other models
# may require different input formats. See the model's schema
# documentation on InferD for more information.
"input": {
"messages": messages,
"opts": {
"max_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"logprobs": 1, # increase for more choices
},
},
},
timeout=5,
)
if res.status_code != 200:
logger.error(f"unexpected response ({res.status_code}): {res.text}")
raise ValueError("unexpected response from InferD", res)
text = ""
for line in res.iter_lines():
if line:
part = json.loads(line)
if "result" in part and "output" in part["result"]:
for t in part["result"]["output"]["text"]:
text += t
else:
logger.error(f"unexpected part: {part}")
raise ValueError("empty result in InferD response")
data = {
"text": text,
"error_code": 0,
}
yield data
def mistral_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens):
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
api_key = os.environ["MISTRAL_API_KEY"]
client = MistralClient(api_key=api_key, timeout=5)
# Make requests
gen_params = {
"model": model_name,
"prompt": messages,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
new_messages = [
ChatMessage(role=message["role"], content=message["content"])
for message in messages
]
res = client.chat_stream(
model=model_name,
temperature=temperature,
messages=new_messages,
max_tokens=max_new_tokens,
top_p=top_p,
)
text = ""
for chunk in res:
if chunk.choices[0].delta.content is not None:
text += chunk.choices[0].delta.content
data = {
"text": text,
"error_code": 0,
}
yield data
def nvidia_api_stream_iter(model_name, messages, temp, top_p, max_tokens, api_base):
api_key = os.environ["NVIDIA_API_KEY"]
headers = {
"Authorization": f"Bearer {api_key}",
"accept": "text/event-stream",
"content-type": "application/json",
}
# nvidia api does not accept 0 temperature
if temp == 0.0:
temp = 0.000001
payload = {
"messages": messages,
"temperature": temp,
"top_p": top_p,
"max_tokens": max_tokens,
"seed": 42,
"stream": True,
}
logger.info(f"==== request ====\n{payload}")
response = requests.post(
api_base, headers=headers, json=payload, stream=True, timeout=1
)
text = ""
for line in response.iter_lines():
if line:
data = line.decode("utf-8")
if data.endswith("[DONE]"):
break
data = json.loads(data[6:])["choices"][0]["delta"]["content"]
text += data
yield {"text": text, "error_code": 0}
def yandexgpt_api_stream_iter(
model_name, messages, temperature, max_tokens, api_base, api_key, folder_id
):
api_key = api_key or os.environ["YANDEXGPT_API_KEY"]
headers = {
"Authorization": f"Api-Key {api_key}",
"content-type": "application/json",
}
payload = {
"modelUri": f"gpt://{folder_id}/{model_name}",
"completionOptions": {
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True,
},
"messages": messages,
}
logger.info(f"==== request ====\n{payload}")
# https://llm.api.cloud.yandex.net/foundationModels/v1/completion
response = requests.post(
api_base, headers=headers, json=payload, stream=True, timeout=60
)
text = ""
for line in response.iter_lines():
if line:
data = json.loads(line.decode("utf-8"))
data = data["result"]
top_alternative = data["alternatives"][0]
text = top_alternative["message"]["text"]
yield {"text": text, "error_code": 0}
status = top_alternative["status"]
if status in (
"ALTERNATIVE_STATUS_FINAL",
"ALTERNATIVE_STATUS_TRUNCATED_FINAL",
):
break
def cohere_api_stream_iter(
client_name: str,
model_id: str,
messages: list,
temperature: Optional[
float
] = None, # The SDK or API handles None for all parameters following
top_p: Optional[float] = None,
max_new_tokens: Optional[int] = None,
api_key: Optional[str] = None, # default is env var CO_API_KEY
api_base: Optional[str] = None,
):
import cohere
OPENAI_TO_COHERE_ROLE_MAP = {
"user": "User",
"assistant": "Chatbot",
"system": "System",
}
client = cohere.Client(
api_key=api_key,
base_url=api_base,
client_name=client_name,
)
# prepare and log requests
chat_history = [
dict(
role=OPENAI_TO_COHERE_ROLE_MAP[message["role"]], message=message["content"]
)
for message in messages[:-1]
]
actual_prompt = messages[-1]["content"]
gen_params = {
"model": model_id,
"messages": messages,
"chat_history": chat_history,
"prompt": actual_prompt,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
# make request and stream response
res = client.chat_stream(
message=actual_prompt,
chat_history=chat_history,
model=model_id,
temperature=temperature,
max_tokens=max_new_tokens,
p=top_p,
)
try:
text = ""
for streaming_item in res:
if streaming_item.event_type == "text-generation":
text += streaming_item.text
yield {"text": text, "error_code": 0}
except cohere.core.ApiError as e:
logger.error(f"==== error from cohere api: {e} ====")
yield {
"text": f"**API REQUEST ERROR** Reason: {e}",
"error_code": 1,
}
def vertex_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens):
import vertexai
from vertexai import generative_models
from vertexai.generative_models import (
GenerationConfig,
GenerativeModel,
Image,
)
project_id = os.environ.get("GCP_PROJECT_ID", None)
location = os.environ.get("GCP_LOCATION", None)
vertexai.init(project=project_id, location=location)
text_messages = []
for message in messages:
if type(message) == str:
text_messages.append(message)
gen_params = {
"model": model_name,
"prompt": text_messages,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")
safety_settings = [
generative_models.SafetySetting(
category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
),
generative_models.SafetySetting(
category=generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
),
generative_models.SafetySetting(
category=generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
),
generative_models.SafetySetting(
category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=generative_models.HarmBlockThreshold.BLOCK_NONE,
),
]
generator = GenerativeModel(model_name).generate_content(
messages,
stream=True,
generation_config=GenerationConfig(
top_p=top_p, max_output_tokens=max_new_tokens, temperature=temperature
),
safety_settings=safety_settings,
)
ret = ""
for chunk in generator:
# NOTE(chris): This may be a vertex api error, below is HOTFIX: https://github.com/googleapis/python-aiplatform/issues/3129
ret += chunk.candidates[0].content.parts[0]._raw_part.text
# ret += chunk.text
data = {
"text": ret,
"error_code": 0,
}
yield data
def reka_api_stream_iter(
model_name: str,
messages: list,
temperature: Optional[
float
] = None, # The SDK or API handles None for all parameters following
top_p: Optional[float] = None,
max_new_tokens: Optional[int] = None,
api_key: Optional[str] = None, # default is env var CO_API_KEY
api_base: Optional[str] = None,
):
api_key = api_key or os.environ["REKA_API_KEY"]
use_search_engine = False
if "-online" in model_name:
model_name = model_name.replace("-online", "")
use_search_engine = True
request = {
"model_name": model_name,
"conversation_history": messages,
"temperature": temperature,
"request_output_len": max_new_tokens,
"runtime_top_p": top_p,
"stream": True,
"use_search_engine": use_search_engine,
}
# Make requests for logging
text_messages = []
for message in messages:
text_messages.append({"type": message["type"], "text": message["text"]})
logged_request = dict(request)
logged_request["conversation_history"] = text_messages
logger.info(f"==== request ====\n{logged_request}")
response = requests.post(
api_base,
stream=True,
json=request,
headers={
"X-Api-Key": api_key,
},
)
if response.status_code != 200:
error_message = response.text
logger.error(f"==== error from reka api: {error_message} ====")
yield {
"text": f"**API REQUEST ERROR** Reason: {error_message}",
"error_code": 1,
}
return
for line in response.iter_lines():
line = line.decode("utf8")
if not line.startswith("data: "):
continue
gen = json.loads(line[6:])
yield {"text": gen["text"], "error_code": 0}