import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import gradio as gr import re # Check if GPU is available device = "cuda" if torch.cuda.is_available() else "cpu" # Load the model and tokenizer with GPU optimizations model = AutoModelForCausalLM.from_pretrained( "AlanYky/phi-3.5_tweets_instruct", torch_dtype=torch.float32, trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct") # Define the pipeline with device optimization pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1 # Set device for pipeline ) # Define generation arguments generation_args = { "max_new_tokens": 70, "return_full_text": False, "temperature": 0.7, "top_k": 20, "top_p": 0.8, "repetition_penalty": 1.2, "do_sample": True, } # Function for tweet generation def generate_tweet(tweet_idea): max_length = 50 # Set a maximum input length if len(tweet_idea) > max_length: return f"Error: Tweet idea exceeds the maximum allowed length of {max_length} characters." prompt = f"Could you generate a tweet about {tweet_idea}?" messages = [{"role": "user", "content": prompt}] output = pipe(messages, **generation_args) return clean_tweet(output[0]['generated_text']) def generate_tweet_with_rag(tweet_idea): prompt = f"Could you generate a tweet about {tweet_idea}?" messages = [{"role": "user", "content": prompt}] output = pipe(messages, **generation_args) return output[0]['generated_text'] def clean_tweet(tweet): # Remove trailing sequences of non-ASCII characters (like `����������������`) cleaned_tweet = tweet.replace("�", "") cleaned_tweet = re.sub(r"\(\)", "", cleaned_tweet) # change more than one dot to one dot cleaned_tweet = re.sub(r"\.{4,}", "...", cleaned_tweet) # change more than two question marks to two question marks cleaned_tweet = re.sub(r"\?{4,}", "???", cleaned_tweet) # change more than two ! to two ! cleaned_tweet = re.sub(r"!{4,}", "!!!", cleaned_tweet) # Remove repeated double quotes at the end cleaned_tweet = re.sub(r'"+$', '', cleaned_tweet) return cleaned_tweet.strip() custom_css = """ #header { text-align: center; margin-top: 20px; } #input-box, #output-box { margin: 0 auto; width: 80%; } #generate-button { margin: 10px auto; display: block; } """ with gr.Blocks(css=custom_css) as demo: # Add a title with the X.com logo gr.Markdown( """ """ ) # Center the input and output components instruction_input = gr.Textbox( label="Tweet Idea", placeholder="Enter your tweet idea (It can be a topic, hashtag, sentence, or any format)...", lines=2, elem_id="input-box" ) generate_button = gr.Button("Generate", elem_id="generate-button") output_box = gr.Textbox( label="Generated Tweet", placeholder="Your tweet will appear here.", lines=3, elem_id="output-box" ) # Connect the button to the generate function generate_button.click(generate_tweet, inputs=instruction_input, outputs=output_box) print("Model loaded on:", next(model.parameters()).device) demo.launch()