alan5543 commited on
Commit
547153c
·
1 Parent(s): 014edc2

improve the UI

Browse files
Files changed (1) hide show
  1. app.py +41 -5
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import gradio as gr
 
4
 
5
  # Check if GPU is available
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -8,7 +9,6 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
8
  # Load the model and tokenizer with GPU optimizations
9
  model = AutoModelForCausalLM.from_pretrained(
10
  "AlanYky/phi-3.5_tweets_instruct",
11
- device_map="cpu",
12
  torch_dtype=torch.float16, # Use FP16 for GPU
13
  trust_remote_code=True,
14
  load_in_8bit=True
@@ -37,17 +37,53 @@ generation_args = {
37
 
38
 
39
  # Function for tweet generation
40
- def generate_tweet(instruction):
41
- messages = [{"role": "user", "content": instruction}]
 
 
 
 
 
 
 
 
 
 
 
 
42
  output = pipe(messages, **generation_args)
43
  return output[0]['generated_text']
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # Gradio interface
47
  with gr.Blocks() as demo:
48
- gr.Markdown("# Tweet Generator (Optimized for GPU)")
49
  with gr.Row():
50
- instruction_input = gr.Textbox(label="Instruction", placeholder="Generate a tweet about...")
 
 
 
 
51
  generate_button = gr.Button("Generate")
52
  output_box = gr.Textbox(label="Generated Tweet", placeholder="Your tweet will appear here.")
53
 
 
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import gradio as gr
4
+ import re
5
 
6
  # Check if GPU is available
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
9
  # Load the model and tokenizer with GPU optimizations
10
  model = AutoModelForCausalLM.from_pretrained(
11
  "AlanYky/phi-3.5_tweets_instruct",
 
12
  torch_dtype=torch.float16, # Use FP16 for GPU
13
  trust_remote_code=True,
14
  load_in_8bit=True
 
37
 
38
 
39
  # Function for tweet generation
40
+ def generate_tweet(tweet_idea):
41
+ max_length = 50 # Set a maximum input length
42
+ if len(tweet_idea) > max_length:
43
+ return f"Error: Tweet idea exceeds the maximum allowed length of {max_length} characters."
44
+
45
+ prompt = f"Could you generate a tweet about {tweet_idea}?"
46
+ messages = [{"role": "user", "content": prompt}]
47
+ output = pipe(messages, **generation_args)
48
+ return clean_tweet(output[0]['generated_text'])
49
+
50
+
51
+ def generate_tweet_with_rag(tweet_idea):
52
+ prompt = f"Could you generate a tweet about {tweet_idea}?"
53
+ messages = [{"role": "user", "content": prompt}]
54
  output = pipe(messages, **generation_args)
55
  return output[0]['generated_text']
56
 
57
 
58
+ def clean_tweet(tweet):
59
+ # Remove trailing sequences of non-ASCII characters (like `����������������`)
60
+ cleaned_tweet = tweet.replace("�", "")
61
+
62
+ cleaned_tweet = re.sub(r"\(\)", "", cleaned_tweet)
63
+
64
+ # change more than one dot to one dot
65
+ cleaned_tweet = re.sub(r"\.{4,}", "...", cleaned_tweet)
66
+
67
+ # change more than two question marks to two question marks
68
+ cleaned_tweet = re.sub(r"\?{4,}", "???", cleaned_tweet)
69
+
70
+ # change more than two ! to two !
71
+ cleaned_tweet = re.sub(r"!{4,}", "!!!", cleaned_tweet)
72
+
73
+ # Remove repeated double quotes at the end
74
+ cleaned_tweet = re.sub(r'"+$', '', cleaned_tweet)
75
+
76
+ return cleaned_tweet.strip()
77
+
78
  # Gradio interface
79
  with gr.Blocks() as demo:
80
+ gr.Markdown("# Tweet Generator")
81
  with gr.Row():
82
+ instruction_input = gr.Textbox(
83
+ label="Instruction",
84
+ placeholder="Enter your tweet idea (It can be a topic, hashtag, sentence, or any format)..."
85
+ )
86
+
87
  generate_button = gr.Button("Generate")
88
  output_box = gr.Textbox(label="Generated Tweet", placeholder="Your tweet will appear here.")
89