import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import gradio as gr import re # Check if GPU is available device = "cuda" if torch.cuda.is_available() else "cpu" # Load the model and tokenizer with GPU optimizations model = AutoModelForCausalLM.from_pretrained( "AlanYky/phi-3.5_tweets_instruct", torch_dtype=torch.float32, trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct") # Define the pipeline with device optimization pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1 # Set device for pipeline ) # Define generation arguments generation_args = { "max_new_tokens": 70, "return_full_text": False, "temperature": 0.7, "top_k": 20, "top_p": 0.8, "repetition_penalty": 1.2, "do_sample": True, } # Function for tweet generation def generate_tweet(tweet_idea): max_length = 50 # Set a maximum input length if len(tweet_idea) > max_length: return f"Error: Tweet idea exceeds the maximum allowed length of {max_length} characters." prompt = f"Could you generate a tweet about {tweet_idea}?" messages = [{"role": "user", "content": prompt}] output = pipe(messages, **generation_args) return clean_tweet(output[0]['generated_text']) def generate_tweet_with_rag(tweet_idea): prompt = f"Could you generate a tweet about {tweet_idea}?" messages = [{"role": "user", "content": prompt}] output = pipe(messages, **generation_args) return output[0]['generated_text'] def clean_tweet(tweet): # Remove trailing sequences of non-ASCII characters (like `����������������`) cleaned_tweet = tweet.replace("�", "") cleaned_tweet = re.sub(r"\(\)", "", cleaned_tweet) # change more than one dot to one dot cleaned_tweet = re.sub(r"\.{4,}", "...", cleaned_tweet) # change more than two question marks to two question marks cleaned_tweet = re.sub(r"\?{4,}", "???", cleaned_tweet) # change more than two ! to two ! cleaned_tweet = re.sub(r"!{4,}", "!!!", cleaned_tweet) # Remove repeated double quotes at the end cleaned_tweet = re.sub(r'"+$', '', cleaned_tweet) return cleaned_tweet.strip() custom_css = """ #header { text-align: center; margin-top: 20px; } #input-box, #output-box { margin: 0 auto; width: 80%; } #generate-button { margin: 10px auto; display: block; } """ with gr.Blocks(css=custom_css) as demo: # Add a title with the X.com logo gr.Markdown( """
Powered by AlanYky/phi-3.5_tweets_instruct