File size: 2,262 Bytes
caceecd
 
149acbb
ea5bb32
149acbb
ea5bb32
 
6ea52a1
3bad752
 
ea5bb32
932ce4c
239bd4f
 
 
 
 
 
 
ea5bb32
239bd4f
caceecd
149acbb
239bd4f
 
 
 
 
 
 
ea5bb32
 
239bd4f
ea5bb32
 
149acbb
ea5bb32
149acbb
ea5bb32
149acbb
ea5bb32
 
 
 
 
149acbb
 
 
 
 
 
 
 
ea5bb32
239bd4f
ea5bb32
239bd4f
 
 
 
 
 
ea5bb32
 
239bd4f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
import os
from huggingface_hub import InferenceClient
import cohere

HF_API_KEY = os.getenv("HF_API_KEY")
COHERE_API_KEY = os.getenv("COHERE_API_KEY")  # Get Cohere API key

models = ["HuggingFaceH4/zephyr-7b-beta", "microsoft/Phi-4-mini-instruct", "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.1-8B-Instruct"]
client_hf = InferenceClient(model=models[3], token=HF_API_KEY)  # HF Client
client_cohere = cohere.Client(COHERE_API_KEY)  # Cohere Client

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    use_cohere,  # Checkbox value
):
    messages = [{"role": "system", "content": system_message}]
    
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})
    
    response = ""

    if use_cohere:  # If Cohere is selected
        cohere_response = client_cohere.chat(
            message=message,
            model="command-r",  # Or "command" depending on your plan
            temperature=temperature,
            max_tokens=max_tokens
        )
        response = cohere_response.text
        yield response  # Yield full response (Cohere doesn't stream)
    
    else:  # If HF is selected
        for message in client_hf.chat_completion(
            messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
        ):
            token = message.choices[0].delta.content
            response += token
            yield response

# Gradio UI
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
        gr.Checkbox(label="Use Cohere API"),  # Checkbox to switch API
    ],
)

if __name__ == "__main__":
    demo.launch()