Xianbao QIAN commited on
Commit
d76943e
·
1 Parent(s): 45eaf57

initial version

Browse files
Files changed (3) hide show
  1. app.py +76 -47
  2. requirements.txt +2 -1
  3. sambanova.py +89 -0
app.py CHANGED
@@ -1,63 +1,92 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
  additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
  gr.Slider(
 
 
 
 
 
 
 
 
52
  minimum=0.1,
 
 
 
 
 
 
 
53
  maximum=1.0,
54
- value=0.95,
55
  step=0.05,
56
- label="Top-p (nucleus sampling)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  ),
58
  ],
 
 
 
 
 
 
 
 
 
 
59
  )
 
 
60
 
61
-
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
  import gradio as gr
2
+ import os
3
+ from typing import Iterator
4
+ import sambanova
5
 
 
 
 
 
6
 
7
+ def generate(
8
+ message: str,
9
+ chat_history: list[tuple[str, str]],
10
+ max_new_tokens: int = 1024,
11
+ temperature: float = 0.6,
12
+ top_p: float = 0.9,
13
+ top_k: int = 50,
14
+ repetition_penalty: float = 1.2,
15
+ ) -> Iterator[str]:
16
+ conversation = []
17
+ for user, assistant in chat_history:
18
+ conversation.extend(
19
+ [
20
+ {"role": "user", "content": user},
21
+ {"role": "assistant", "content": assistant},
22
+ ]
23
+ )
24
+ conversation.append({"role": "user", "content": message})
25
 
26
+ outputs = []
27
+ for text in sambanova.Streamer(conversation, new_tokens=max_new_tokens,
28
+ temperature=temperature, top_k=top_k, top_p=top_p):
29
+ outputs.append(text)
30
+ yield "".join(outputs)
 
 
 
 
31
 
32
+ MAX_MAX_NEW_TOKENS = 2048
33
+ DEFAULT_MAX_NEW_TOKENS = 1024
34
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
 
35
 
36
+ chat_interface = gr.ChatInterface(
37
+ fn=generate,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  additional_inputs=[
 
 
 
39
  gr.Slider(
40
+ label="Max new tokens",
41
+ minimum=1,
42
+ maximum=MAX_MAX_NEW_TOKENS,
43
+ step=1,
44
+ value=DEFAULT_MAX_NEW_TOKENS,
45
+ ),
46
+ gr.Slider(
47
+ label="Temperature",
48
  minimum=0.1,
49
+ maximum=4.0,
50
+ step=0.1,
51
+ value=0.6,
52
+ ),
53
+ gr.Slider(
54
+ label="Top-p (nucleus sampling)",
55
+ minimum=0.05,
56
  maximum=1.0,
 
57
  step=0.05,
58
+ value=0.9,
59
+ ),
60
+ gr.Slider(
61
+ label="Top-k",
62
+ minimum=1,
63
+ maximum=1000,
64
+ step=1,
65
+ value=50,
66
+ ),
67
+ gr.Slider(
68
+ label="Repetition penalty",
69
+ minimum=1.0,
70
+ maximum=2.0,
71
+ step=0.05,
72
+ value=1.2,
73
  ),
74
  ],
75
+ stop_btn=None,
76
+ fill_height=True,
77
+ examples=[
78
+ ["Which one is bigger? 4.9 or 4.11"],
79
+ ["Can you explain briefly to me what is the Python programming language?"],
80
+ ["Explain the plot of Cinderella in a sentence."],
81
+ ["How many hours does it take a man to eat a Helicopter?"],
82
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
83
+ ],
84
+ cache_examples=False,
85
  )
86
+ with gr.Blocks() as demo:
87
+ gr.Markdown('# Sambanova model inference LLAMA 405B')
88
 
89
+ chat_interface.render()
90
+
91
  if __name__ == "__main__":
92
+ demo.queue(max_size=20).launch()
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- huggingface_hub==0.22.2
 
 
1
+ huggingface_hub==0.22.2
2
+ gradio
sambanova.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import os
4
+
5
+ def _stream_chat_response(url, headers, payload):
6
+ """
7
+ Streams the chat response from the given URL with the specified headers and payload.
8
+
9
+ Args:
10
+ url (str): The URL to send the POST request to.
11
+ headers (dict): The headers for the POST request.
12
+ payload (dict): The payload for the POST request.
13
+
14
+ Raises:
15
+ InvalidArgument: If the payload does not have the 'stream' key.
16
+ ConnectionError: If the request fails.
17
+
18
+ Yields:
19
+ str: The content of the streamed response.
20
+ """
21
+ if not payload.get('stream'):
22
+ raise ValueError('This method can only handle stream payload')
23
+
24
+ try:
25
+ # Make the POST request
26
+ response = requests.post(url, headers=headers, json=payload, stream=True)
27
+ response.raise_for_status() # Raise an error for bad status codes
28
+
29
+ # Process the streamed response
30
+ for line in response.iter_lines():
31
+ if line:
32
+ decoded_line = line.decode('utf-8')
33
+ DATA_PREFIX = "data: "
34
+ if decoded_line.startswith(DATA_PREFIX):
35
+ decoded_line = decoded_line[len(DATA_PREFIX):] # Remove the "data: " prefix
36
+ if decoded_line.strip() == "[DONE]":
37
+ break
38
+ try:
39
+ json_data = json.loads(decoded_line)
40
+ content = json_data.get('choices', [{}])[0].get('delta', {}).get('content', '')
41
+ if content:
42
+ yield content
43
+ except json.JSONDecodeError as e:
44
+ print(f"Warning: Error decoding JSON: {decoded_line}. Skipping this line.")
45
+ except requests.RequestException as e:
46
+ raise ConnectionError(f"Request failed: {e}") from e
47
+
48
+ def Streamer(history, **kwargs):
49
+ """
50
+ Streams the chat response based on the provided history and additional kwargs.
51
+
52
+ Args:
53
+ history (dict): The chat history.
54
+ **kwargs: Additional parameters to update the payload.
55
+
56
+ Yields:
57
+ str: The content of the streamed response.
58
+ """
59
+ url = os.getenv('URL')
60
+ token = os.getenv('TOKEN')
61
+
62
+ if not url or not token:
63
+ raise EnvironmentError("URL or TOKEN environment variable is not set.")
64
+
65
+ headers = {
66
+ "Authorization": f"Basic {token}",
67
+ "Content-Type": "application/json"
68
+ }
69
+ payload = {
70
+ "messages": history,
71
+ "max_tokens": 1000,
72
+ "stop": ["<|eot_id|>"],
73
+ "model": "llama3-405b",
74
+ "stream": True
75
+ }
76
+ payload.update(kwargs)
77
+
78
+ for update in _stream_chat_response(url, headers, payload):
79
+ yield update
80
+
81
+ # Example usage
82
+ if __name__ == "__main__":
83
+ try:
84
+ history = [{"role": "user", "content": "Tell me a joke"}]
85
+
86
+ for content in Streamer(history):
87
+ print(content, end='')
88
+ except Exception as e:
89
+ print(f"An error occurred: {e}")