import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM # Load the model and tokenizer model, tokenizer = ( AutoModelForCausalLM.from_pretrained("Abhishekcr448/Tiny-Hinglish-Chat-21M"), AutoTokenizer.from_pretrained("Abhishekcr448/Tiny-Hinglish-Chat-21M"), ) # Function to generate text (suggestions) def generate_text(prompt, output_length, temperature, top_k, top_p): inputs = tokenizer(prompt, return_tensors='pt').to(model.device) generated_output = model.generate( inputs['input_ids'], max_length=inputs['input_ids'].shape[-1] + output_length, # Generate 10 more tokens no_repeat_ngram_size=2, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=True, ) output_text = tokenizer.decode(generated_output[0], skip_special_tokens=True) return output_text # Set up the Gradio interface with custom CSS with gr.Blocks(css=""" #response-text { background-color: #e1bee7; /* Light purple background */ border-radius: 8px; /* Rounded corners */ padding: 10px; /* Padding inside the textbox */ font-size: 16px; /* Font size */ color: #4a148c; /* Dark purple text color */ } """) as demo: # Add a title to the interface gr.Markdown("# Hinglish Chat Prediction") # Add a chat interface above the text boxes with reduced size with gr.Row(): chatbox = gr.Chatbot(label="Chat", type="messages", height=350, value=[{"role": "assistant", "content": "Kya kar rahe ho"}]) with gr.Row(): # Create a column for the two text boxes with gr.Column(scale=3): # Input text box for user input (first column) input_text = gr.Textbox(label="Start chatting", interactive=True) # Create a separate column for the buttons with gr.Column(scale=1): # Submit button placed above the replace button submit_button = gr.Button("Submit", variant="primary", elem_id="submit-btn") with gr.Row(): # Create a column for the two text boxes with gr.Column(scale=3): # Response text box (second column) response_text = gr.Textbox(label="Suggestion", interactive=False, elem_id="response-text") # Create a separate column for the buttons with gr.Column(scale=1): replace_button = gr.Button("Use Suggestion", variant="secondary", elem_id="replace-btn") regenerate_button = gr.Button("Regenerate", variant="secondary", elem_id="regenerate-btn") with gr.Row(): # Create a dropdown menu for text generation parameters with gr.Column(scale=1): with gr.Accordion("Change Parameters", open=False): output_length_slider = gr.Slider(1, 20, value=8, label="Output Length", step=1) temperature_slider = gr.Slider(0.1, 1.0, value=0.8, label="Temperature (Controls randomness)") top_k_slider = gr.Slider(1, 100, value=50, label="Top-k (Limits vocabulary size)", step=1) top_p_slider = gr.Slider(0.1, 1.0, value=0.9, label="Top-p (Nucleus sampling)") # Set up the interaction between input and output def validate_and_generate(prompt, output_length, temperature, top_k, top_p): if prompt.strip(): print(f"Prompt: {prompt}") return generate_text(prompt, output_length, temperature, top_k, top_p) input_text.input(validate_and_generate, inputs=[input_text, output_length_slider, temperature_slider, top_k_slider, top_p_slider], outputs=response_text) replace_button.click(lambda x: x, inputs=response_text, outputs=input_text) def chat_interaction(prompt, history, output_length, temperature, top_k, top_p): if prompt.strip(): response = generate_text(prompt, output_length, temperature, top_k, top_p) # Exclude the input prompt text from the response response = response[len(prompt):].strip() history.append({"role": "user", "content": prompt}) history.append({"role": "assistant", "content": response}) # Call validate_and_generate with the response response_text_value = validate_and_generate(response, output_length, temperature, top_k, top_p) return history, response_text_value[len(response):].strip(), "" return history, "", "" def regenerate_text(input_text, history, output_length, temperature, top_k, top_p): if input_text.strip(): return generate_text(input_text, output_length, temperature, top_k, top_p) else: last_message = history[-1]["content"] return generate_text(last_message, output_length, temperature, top_k, top_p)[len(last_message):].strip() input_text.submit(chat_interaction, inputs=[input_text, chatbox, output_length_slider, temperature_slider, top_k_slider, top_p_slider], outputs=[chatbox, response_text, input_text]) submit_button.click(chat_interaction, inputs=[input_text, chatbox, output_length_slider, temperature_slider, top_k_slider, top_p_slider], outputs=[chatbox, response_text, input_text]) regenerate_button.click(regenerate_text, inputs=[input_text, chatbox, output_length_slider, temperature_slider, top_k_slider, top_p_slider], outputs=response_text) # Launch the interface demo.launch()