Spaces:
Runtime error
Runtime error
File size: 1,889 Bytes
5ea344f 7360ef0 f013f59 7360ef0 f013f59 7360ef0 f013f59 7360ef0 5ef8568 f013f59 7360ef0 f013f59 7360ef0 f013f59 7360ef0 f013f59 7360ef0 59c5924 7360ef0 f013f59 7360ef0 f013f59 7360ef0 f013f59 7360ef0 f013f59 7360ef0 f013f59 7360ef0 f013f59 59c5924 f013f59 5d7ac94 f013f59 02ae823 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import gradio as gr
from gradio_client import Client
MODELS = {"OLMo-2-1124-13B-Instruct": "akhaliq/olmo-anychat", "Llama-3.1-Tulu-3-8B": "akhaliq/allen-test"}
def create_chat_fn(client):
def chat(message, history):
response = client.predict(
message=message,
system_prompt="You are a helpful AI assistant.",
temperature=0.7,
max_new_tokens=1024,
top_k=40,
repetition_penalty=1.1,
top_p=0.95,
api_name="/chat",
)
return response
return chat
def set_client_for_session(model_name, request: gr.Request):
headers = {}
if request and hasattr(request, "request") and hasattr(request.request, "headers"):
x_ip_token = request.request.headers.get("x-ip-token")
if x_ip_token:
headers["X-IP-Token"] = x_ip_token
return Client(MODELS[model_name], headers=headers)
def safe_chat_fn(message, history, client):
if client is None:
return "Error: Client not initialized. Please refresh the page."
return create_chat_fn(client)(message, history)
with gr.Blocks() as demo:
client = gr.State()
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()), value="OLMo-2-1124-13B-Instruct", label="Select Model", interactive=True
)
chat_interface = gr.ChatInterface(fn=safe_chat_fn, additional_inputs=[client])
# Update client when model changes
def update_model(model_name, request):
return set_client_for_session(model_name, request)
model_dropdown.change(
fn=update_model,
inputs=[model_dropdown],
outputs=[client],
)
# Initialize client on page load
demo.load(
fn=set_client_for_session,
inputs=gr.State("OLMo-2-1124-13B-Instruct"),
outputs=client,
)
if __name__ == "__main__":
demo.launch()
|