Framormar commited on
Commit
0a650b5
verified
1 Parent(s): 6583844

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -47
app.py CHANGED
@@ -1,50 +1,95 @@
1
- import os
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
- import torch
5
-
6
- # 1. Autenticaci贸n
7
- HF_TOKEN = os.environ["HF_TOKEN"]
8
- os.environ["HUGGINGFACEHUB_API_TOKEN"] = HF_TOKEN
9
-
10
- # 2. Carga del modelo y tokenizer en GPU
11
- MODEL_ID = "arcee-ai/AFM-4.5B"
12
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
13
- model = AutoModelForCausalLM.from_pretrained(
14
- MODEL_ID,
15
- use_auth_token=HF_TOKEN,
16
- torch_dtype=torch.bfloat16,
17
- device_map="auto"
18
- )
19
- chat = pipeline(
20
- "text-generation",
21
- model=model,
22
- tokenizer=tokenizer,
23
- device_map="auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
25
 
26
- # 3. Funci贸n de inferencia
27
- def genera_respuesta(prompt: str, max_tokens: int = 256, temp: float = 0.5, top_p: float = 0.95):
28
- out = chat(
29
- prompt,
30
- max_new_tokens=max_tokens,
31
- temperature=temp,
32
- top_p=top_p,
33
- do_sample=True
34
- )
35
- return out[0]["generated_text"].strip()
36
-
37
- # 4. Interfaz Gradio
38
- with gr.Blocks() as demo:
39
- gr.Markdown("### AFM-4.5B en tu Space")
40
- with gr.Row():
41
- inp = gr.Textbox(label="Pregunta", lines=2)
42
- out = gr.Textbox(label="Respuesta")
43
- with gr.Row():
44
- max_toks = gr.Slider(50, 512, value=256, label="Max new tokens")
45
- temp = gr.Slider(0.1, 1.0, value=0.5, label="Temperatura")
46
- top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
47
- btn = gr.Button("Enviar")
48
- btn.click(fn=genera_respuesta, inputs=[inp, max_toks, temp, top_p], outputs=out)
49
-
50
- demo.launch()
 
 
1
  import gradio as gr
2
+ import requests
3
+ import json
4
+ import os
5
+
6
+ """
7
+ Using Together AI API for chat completions
8
+ """
9
+ TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
10
+ if not TOGETHER_API_KEY:
11
+ raise ValueError("TOGETHER_API_KEY environment variable is not set")
12
+
13
+ TOGETHER_API_URL = "https://api.together.xyz/v1/chat/completions"
14
+
15
+
16
+ def respond(
17
+ message,
18
+ history: list[tuple[str, str]],
19
+ system_message,
20
+ max_tokens,
21
+ temperature,
22
+ top_p,
23
+ ):
24
+ messages = [{"role": "system", "content": system_message}]
25
+
26
+ for val in history:
27
+ if val[0]:
28
+ messages.append({"role": "user", "content": val[0]})
29
+ if val[1]:
30
+ messages.append({"role": "assistant", "content": val[1]})
31
+
32
+ messages.append({"role": "user", "content": message})
33
+
34
+ headers = {
35
+ "Authorization": f"Bearer {TOGETHER_API_KEY}",
36
+ "Content-Type": "application/json"
37
+ }
38
+
39
+ data = {
40
+ "model": "arcee-ai/AFM-4.5B",
41
+ "messages": messages,
42
+ "max_tokens": max_tokens,
43
+ "temperature": temperature,
44
+ "top_p": top_p,
45
+ "stream": True
46
+ }
47
+
48
+ response = ""
49
+
50
+ try:
51
+ with requests.post(TOGETHER_API_URL, headers=headers, json=data, stream=True) as r:
52
+ r.raise_for_status()
53
+ for line in r.iter_lines():
54
+ if line:
55
+ line = line.decode('utf-8')
56
+ if line.startswith('data: '):
57
+ line = line[6:] # Remove 'data: ' prefix
58
+ if line.strip() == '[DONE]':
59
+ break
60
+ try:
61
+ chunk = json.loads(line)
62
+ if 'choices' in chunk and len(chunk['choices']) > 0:
63
+ delta = chunk['choices'][0].get('delta', {})
64
+ if 'content' in delta:
65
+ token = delta['content']
66
+ response += token
67
+ yield response
68
+ except json.JSONDecodeError:
69
+ continue
70
+ except requests.exceptions.RequestException as e:
71
+ yield f"Error: {str(e)}"
72
+
73
+
74
+ """
75
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
76
+ """
77
+ demo = gr.ChatInterface(
78
+ respond,
79
+ additional_inputs=[
80
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
81
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
82
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
83
+ gr.Slider(
84
+ minimum=0.1,
85
+ maximum=1.0,
86
+ value=0.95,
87
+ step=0.05,
88
+ label="Top-p (nucleus sampling)",
89
+ ),
90
+ ],
91
  )
92
 
93
+
94
+ if __name__ == "__main__":
95
+ demo.launch()