WICKED4950 commited on
Commit
085d77c
·
verified ·
1 Parent(s): ca143a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -60
app.py CHANGED
@@ -1,74 +1,78 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
9
- # Simple rule-based response function
10
- def rule_based_response(history, message):
11
- if "hello" in message.lower():
12
- return "Hello! How can I help you today?"
13
- elif "how are you" in message.lower():
14
- return "I'm doing great, thanks for asking! How about you?"
15
- elif "bye" in message.lower():
16
- return "Goodbye! Have a nice day!"
17
- else:
18
- return "Sorry, I don't understand. Can you ask something else?"
19
 
20
- def respond(
21
- message,
22
- history: list[tuple[str, str]],
23
- system_message,
24
- max_tokens,
25
- temperature,
26
- top_p,
27
- ):
28
- # Rule-based response logic (for now)
29
- response = rule_based_response(history, message)
30
-
31
- # If the rule-based model cannot respond, fall back to the HuggingFace model
32
- if response:
33
- return response, history
34
 
35
- # Otherwise, use the HuggingFace model
36
- messages = [{"role": "system", "content": system_message}]
37
- for val in history:
38
- if val[0]:
39
- messages.append({"role": "user", "content": val[0]})
40
- if val[1]:
41
- messages.append({"role": "assistant", "content": val[1]})
42
 
43
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- response = ""
46
- for message in client.chat_completion(
47
- messages,
48
- max_tokens=max_tokens,
49
- stream=True,
50
- temperature=temperature,
51
- top_p=top_p,
52
- ):
53
- token = message.choices[0].delta.content
54
- response += token
55
- yield response
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # Gradio Chat Interface Setup
58
  demo = gr.ChatInterface(
59
- respond,
60
- additional_inputs=[
61
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
62
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
63
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
64
- gr.Slider(
65
- minimum=0.1,
66
- maximum=1.0,
67
- value=0.95,
68
- step=0.05,
69
- label="Top-p (nucleus sampling)",
70
- ),
71
- ],
72
  )
73
 
74
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ import tensorflow as tf
4
+ from huggingface_hub import login, create_repo, upload_file
5
+ from transformers import AutoTokenizer, TFAutoModelForCausalLM
6
+ policy = tf.keras.mixed_precision.Policy('mixed_bfloat16')
7
+ tf.keras.mixed_precision.set_global_policy(policy)
8
+ strategy = tf.distribute.MultiWorkerMirroredStrategy()
9
 
10
+ login(os.environ.get("hf_token"))
 
 
 
11
 
12
+ name = "WICKED4950/GPT2-InstEsther0.21eV3.1"
13
+ tokenizer = AutoTokenizer.from_pretrained(name)
14
+ tokenizer.pad_token = tokenizer.eos_token
15
+ with strategy.scope():
16
+ model = TFAutoModelForCausalLM.from_pretrained(name)
 
 
 
 
 
17
 
18
+ def raw_pred(input, model, tokenizer, max_length=50, temperature=0.2):
19
+ input_ids = tokenizer.encode(input, return_tensors='tf')
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Initialize variables
22
+ generated_ids = input_ids
23
+ stop_token_id = tokenizer.encode("<|SOH|>", add_special_tokens=False)[0] # ID for <|SOH|>
24
+ all_generated_tokens = [] # To store generated token IDs
25
+ tokens_yielded = [] # To store tokens as they are yielded
 
 
26
 
27
+ with strategy.scope():
28
+ for _ in range(max_length // 1): # Generate in chunks of 3 tokens
29
+ # Generate three tokens at a time
30
+ outputs = model.generate(
31
+ generated_ids,
32
+ max_length=generated_ids.shape[1] + 1, # Increment max length by 3
33
+ temperature=temperature,
34
+ pad_token_id=tokenizer.eos_token_id,
35
+ eos_token_id=stop_token_id, # Stop generation at <|SOH|>
36
+ do_sample=True,
37
+ num_return_sequences=1
38
+ )
39
 
40
+ # Get the newly generated tokens (last 3 tokens)
41
+ new_tokens = outputs[0, -1:]
42
+ generated_ids = outputs # Update the generated_ids with the new tokens
43
+
44
+ # Store the generated tokens as numbers (IDs)
45
+ all_generated_tokens.extend(new_tokens.numpy().tolist())
46
+
47
+ # Decode and yield the tokens as they are generated (as numbers)
48
+ tokens_text = tokenizer.decode(new_tokens, skip_special_tokens=False)
49
+ tokens_yielded.append(tokens_text)
50
+ yield tokens_text
51
 
52
+ # Stop if the generated tokens include <|SOH|>
53
+ if stop_token_id in new_tokens.numpy():
54
+ final_text = tokenizer.decode(all_generated_tokens, skip_special_tokens=False)
55
+ yield ("<|Clean|>" + final_text)
56
+ break
57
+
58
+ def respond(message, history):
59
+ # Prepare input for the model
60
+ give_mod = ""
61
+ for chunk in history:
62
+ give_mod = give_mod + "<|SOH|>" + chunk[0] + "<|SOB|>" + chunk[1]
63
+ give_mod = give_mod + "<|SOH|>" + message + "<|SOB|>"
64
+ print(give_mod)
65
+ response = ""
66
+ for token in raw_pred(give_mod, model, tokenizer):
67
+ if "<|Clean|>" in token:
68
+ response = token
69
+ else:
70
+ response += token
71
+ yield response.replace("<|SOH|>","").replace("<|Clean|>","")
72
+ print(response)
73
  # Gradio Chat Interface Setup
74
  demo = gr.ChatInterface(
75
+ respond
 
 
 
 
 
 
 
 
 
 
 
 
76
  )
77
 
78
  if __name__ == "__main__":