ai01firebird commited on
Commit
6eb3e9b
Β·
verified Β·
1 Parent(s): 4352744

change to pattern-based prompt for distilgpt2

Browse files
Files changed (1) hide show
  1. app.py +45 -6
app.py CHANGED
@@ -10,19 +10,19 @@ import torch
10
  #model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
11
 
12
  # distilgpt2 is only 80MB -> NOK, no emojis
13
- #tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
14
- #model = AutoModelForCausalLM.from_pretrained("distilgpt2")
15
 
16
  # tiny-gpt2 is only 20MB -> NOK, no emojis
17
  #tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
18
  #model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
19
 
20
  # TinyLlama
21
- tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
22
- model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
23
 
24
- # conversion method
25
- def text_to_emoji(input_text):
26
  # Eingabetext bereinigen (optional)
27
  cleaned_text = re.sub(r"[.,!?;:]", "", input_text)
28
 
@@ -47,6 +47,45 @@ def text_to_emoji(input_text):
47
 
48
  return emoji_part
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # Gradio UI
52
  iface = gr.Interface(
 
10
  #model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
11
 
12
  # distilgpt2 is only 80MB -> NOK, no emojis
13
+ tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
14
+ model = AutoModelForCausalLM.from_pretrained("distilgpt2")
15
 
16
  # tiny-gpt2 is only 20MB -> NOK, no emojis
17
  #tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
18
  #model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
19
 
20
  # TinyLlama
21
+ #tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
22
+ #model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
23
 
24
+ # OLD conversion method
25
+ def text_to_emoji_OLD(input_text):
26
  # Eingabetext bereinigen (optional)
27
  cleaned_text = re.sub(r"[.,!?;:]", "", input_text)
28
 
 
47
 
48
  return emoji_part
49
 
50
+ # conversion method
51
+ def text_to_emoji(input_text):
52
+ # Eingabetext bereinigen (optional)
53
+ cleaned_text = re.sub(r"[.,!?;:]", "", input_text)
54
+
55
+ # Pure pattern-based prompt
56
+ prompt = (
57
+ "Hi there β†’ πŸ‘‹πŸ™‚\n"
58
+ "Good night β†’ πŸŒ™πŸ˜΄\n"
59
+ "I love pizza β†’ β€οΈπŸ•\n"
60
+ "It's raining β†’ πŸŒ§οΈβ˜”\n"
61
+ "Happy birthday β†’ πŸŽ‰πŸŽ‚πŸ₯³\n"
62
+ "I am so tired β†’ πŸ˜΄πŸ’€\n"
63
+ "Let’s go to the beach β†’ πŸ–οΈπŸŒŠπŸ˜Ž\n"
64
+ "I’m feeling lucky β†’ πŸ€πŸ€ž\n"
65
+ "We’re getting married β†’ πŸ’πŸ‘°πŸ€΅\n"
66
+ "Merry Christmas β†’ πŸŽ„πŸŽπŸŽ…\n"
67
+ "Let’s party β†’ πŸŽ‰πŸ•ΊπŸ’ƒ\n"
68
+ f"{cleaned_text} β†’"
69
+ )
70
+
71
+ # Tokenisierung und Generation
72
+ inputs = tokenizer(prompt, return_tensors="pt")
73
+ outputs = model.generate(
74
+ **inputs,
75
+ max_new_tokens=10,
76
+ do_sample=True,
77
+ temperature=0.9,
78
+ top_k=50,
79
+ pad_token_id=tokenizer.eos_token_id # Prevents warning
80
+ )
81
+
82
+ # Decodieren
83
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
84
+
85
+ # Nur den generierten Teil nach dem letzten "β†’"
86
+ emoji_part = generated_text.split("β†’")[-1].strip().split("\n")[0]
87
+
88
+ return emoji_part
89
 
90
  # Gradio UI
91
  iface = gr.Interface(