Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
import gradio as gr | |
import re | |
# Check if GPU is available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load the model and tokenizer with GPU optimizations | |
model = AutoModelForCausalLM.from_pretrained( | |
"AlanYky/phi-3.5_tweets_instruct", | |
torch_dtype=torch.float32, | |
trust_remote_code=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct") | |
# Define the pipeline with device optimization | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device=0 if device == "cuda" else -1 # Set device for pipeline | |
) | |
# Define generation arguments | |
generation_args = { | |
"max_new_tokens": 70, | |
"return_full_text": False, | |
"temperature": 0.7, | |
"top_k": 20, | |
"top_p": 0.8, | |
"repetition_penalty": 1.2, | |
"do_sample": True, | |
} | |
# Function for tweet generation | |
def generate_tweet(tweet_idea): | |
max_length = 50 # Set a maximum input length | |
if len(tweet_idea) > max_length: | |
return f"Error: Tweet idea exceeds the maximum allowed length of {max_length} characters." | |
prompt = f"Could you generate a tweet about {tweet_idea}?" | |
messages = [{"role": "user", "content": prompt}] | |
output = pipe(messages, **generation_args) | |
return clean_tweet(output[0]['generated_text']) | |
def generate_tweet_with_rag(tweet_idea): | |
prompt = f"Could you generate a tweet about {tweet_idea}?" | |
messages = [{"role": "user", "content": prompt}] | |
output = pipe(messages, **generation_args) | |
return output[0]['generated_text'] | |
def clean_tweet(tweet): | |
# Remove trailing sequences of non-ASCII characters (like `����������������`) | |
cleaned_tweet = tweet.replace("�", "") | |
cleaned_tweet = re.sub(r"\(\)", "", cleaned_tweet) | |
# change more than one dot to one dot | |
cleaned_tweet = re.sub(r"\.{4,}", "...", cleaned_tweet) | |
# change more than two question marks to two question marks | |
cleaned_tweet = re.sub(r"\?{4,}", "???", cleaned_tweet) | |
# change more than two ! to two ! | |
cleaned_tweet = re.sub(r"!{4,}", "!!!", cleaned_tweet) | |
# Remove repeated double quotes at the end | |
cleaned_tweet = re.sub(r'"+$', '', cleaned_tweet) | |
return cleaned_tweet.strip() | |
custom_css = """ | |
#header { | |
text-align: center; | |
margin-top: 20px; | |
} | |
#input-box, #output-box { | |
margin: 0 auto; | |
width: 80%; | |
} | |
#generate-button { | |
margin: 10px auto; | |
display: block; | |
} | |
""" | |
with gr.Blocks(css=custom_css) as demo: | |
# Add a title with the X.com logo | |
gr.Markdown( | |
""" | |
<div id="header"> | |
<img src="https://upload.wikimedia.org/wikipedia/commons/thumb/e/e3/X.com_logo.svg/1024px-X.com_logo.svg.png" | |
alt="X.com Logo" width="100"> | |
<h1 style="font-size: 2.5em; margin: 0;">Tweet Generator</h1> | |
<p style="font-size: 1.2em; color: gray;"> | |
Powered by <b>AlanYky/phi-3.5_tweets_instruct</b> | |
</p> | |
</div> | |
""" | |
) | |
# Center the input and output components | |
instruction_input = gr.Textbox( | |
label="Tweet Idea", | |
placeholder="Enter your tweet idea (It can be a topic, hashtag, sentence, or any format)...", | |
lines=2, | |
elem_id="input-box" | |
) | |
generate_button = gr.Button("Generate", elem_id="generate-button") | |
output_box = gr.Textbox( | |
label="Generated Tweet", | |
placeholder="Your tweet will appear here.", | |
lines=3, | |
elem_id="output-box" | |
) | |
# Connect the button to the generate function | |
generate_button.click(generate_tweet, inputs=instruction_input, outputs=output_box) | |
print("Model loaded on:", next(model.parameters()).device) | |
demo.launch() |