CreativeWorks commited on
Commit
ab711c8
·
verified ·
1 Parent(s): 0ce19af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -10
app.py CHANGED
@@ -9,6 +9,14 @@ from threading import Thread
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
 
11
 
 
 
 
 
 
 
 
 
12
  DESCRIPTION = '''
13
  <div>
14
  <h1 style="text-align: center;">Meta Llama3 8B</h1>
@@ -59,13 +67,13 @@ terminators = [
59
  ]
60
 
61
  @spaces.GPU(duration=120)
62
- def chat_llama3_8b(message: str,
63
  history: list,
64
  temperature: float,
65
  max_new_tokens: int
66
  ) -> str:
67
  """
68
- Generate a streaming response using the llama3-8b model.
69
  Args:
70
  message (str): The input message.
71
  history (list): The conversation history used by ChatInterface.
@@ -78,13 +86,11 @@ def chat_llama3_8b(message: str,
78
  for user, assistant in history:
79
  conversation.extend([{"from": "human", "value": user}, {"from": "assistant", "value": assistant}])
80
  conversation.append({"from": "human", "value": message})
81
-
82
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
83
 
84
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
85
-
86
  generate_kwargs = dict(
87
- input_ids= input_ids,
88
  streamer=streamer,
89
  max_new_tokens=max_new_tokens,
90
  do_sample=True,
@@ -92,17 +98,18 @@ def chat_llama3_8b(message: str,
92
  eos_token_id=terminators,
93
  pad_token_id=tokenizer.eos_token_id
94
  )
95
- # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
 
96
  if temperature == 0:
97
  generate_kwargs['do_sample'] = False
98
 
99
  t = Thread(target=model.generate, kwargs=generate_kwargs)
100
  t.start()
101
-
102
  outputs = []
103
  for text in streamer:
 
 
104
  outputs.append(text)
105
- #print(outputs)
106
  yield "".join(outputs)
107
 
108
 
@@ -114,7 +121,7 @@ with gr.Blocks(fill_height=True, css=css) as demo:
114
  gr.Markdown(DESCRIPTION)
115
  #gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
116
  gr.ChatInterface(
117
- fn=chat_llama3_8b,
118
  chatbot=chatbot,
119
  fill_height=True,
120
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
@@ -145,5 +152,5 @@ with gr.Blocks(fill_height=True, css=css) as demo:
145
  gr.Markdown(LICENSE)
146
 
147
  if __name__ == "__main__":
148
- demo.launch(share=True)
149
 
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
 
11
 
12
+ # Lê as variáveis de ambiente para autenticação e compartilhamento
13
+ auth_users = os.getenv("GRADIO_AUTH_USERS")
14
+ auth_passwords = os.getenv("GRADIO_AUTH_PASSWORDS")
15
+ # Converte as strings de usuários e senhas em listas
16
+ auth_users = [user.strip() for user in auth_users.split(",")]
17
+ auth_passwords = [password.strip() for password in auth_passwords.split(",")]
18
+ # Cria um dicionário de autenticação
19
+ auth_credentials = dict(zip(auth_users, auth_passwords))
20
  DESCRIPTION = '''
21
  <div>
22
  <h1 style="text-align: center;">Meta Llama3 8B</h1>
 
67
  ]
68
 
69
  @spaces.GPU(duration=120)
70
+ def CreativeWorks_Mistral_7b_Chat_V1(message: str,
71
  history: list,
72
  temperature: float,
73
  max_new_tokens: int
74
  ) -> str:
75
  """
76
+ Generate a streaming response using the Mistral model.
77
  Args:
78
  message (str): The input message.
79
  history (list): The conversation history used by ChatInterface.
 
86
  for user, assistant in history:
87
  conversation.extend([{"from": "human", "value": user}, {"from": "assistant", "value": assistant}])
88
  conversation.append({"from": "human", "value": message})
 
89
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
90
 
91
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
 
92
  generate_kwargs = dict(
93
+ input_ids=input_ids,
94
  streamer=streamer,
95
  max_new_tokens=max_new_tokens,
96
  do_sample=True,
 
98
  eos_token_id=terminators,
99
  pad_token_id=tokenizer.eos_token_id
100
  )
101
+
102
+ # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
103
  if temperature == 0:
104
  generate_kwargs['do_sample'] = False
105
 
106
  t = Thread(target=model.generate, kwargs=generate_kwargs)
107
  t.start()
 
108
  outputs = []
109
  for text in streamer:
110
+ # Remove the unwanted prefix if present
111
+ text = text.replace("<|im_start|>assistant", " ")
112
  outputs.append(text)
 
113
  yield "".join(outputs)
114
 
115
 
 
121
  gr.Markdown(DESCRIPTION)
122
  #gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
123
  gr.ChatInterface(
124
+ fn=CreativeWorks_Mistral_7b_Chat_V1,
125
  chatbot=chatbot,
126
  fill_height=True,
127
  additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
 
152
  gr.Markdown(LICENSE)
153
 
154
  if __name__ == "__main__":
155
+ demo.launch(auth=auth_credentials, share=True)
156