Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import time | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from threading import Thread | |
import time | |
import pytz | |
from datetime import datetime | |
print("Loading model and tokenizer...") | |
model_name = "large-traversaal/Phi-4-Hindi" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") | |
print("Model and tokenizer loaded successfully!") | |
def generate_response(message, temperature, max_new_tokens, top_p): | |
print(f"Input: {message}") | |
start_time = time.time() | |
inputs = tokenizer(message, return_tensors="pt").to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
gen_kwargs = { | |
"input_ids": inputs["input_ids"], | |
"streamer": streamer, | |
"temperature": temperature, | |
"max_new_tokens": max_new_tokens, | |
"top_p": top_p, | |
"do_sample": True if temperature > 0 else False, | |
} | |
thread = Thread(target=model.generate, kwargs=gen_kwargs) | |
thread.start() | |
result = [] | |
for text in streamer: | |
result.append(text) | |
current_output = "".join(result) | |
if current_output.startswith(message): | |
yield current_output[len(message):] | |
else: | |
yield current_output | |
end_time = time.time() | |
time_taken = end_time - start_time | |
output_text = "".join(result) | |
if output_text.startswith(message): | |
output_text = output_text[len(message):] | |
print(f"Output: {output_text}") | |
print(f"Time taken: {time_taken:.2f} seconds") | |
pst_timezone = pytz.timezone('America/Los_Angeles') | |
current_time_pst = datetime.now(pst_timezone).strftime("%Y-%m-%d %H:%M:%S %Z%z") | |
print(f"Current timestamp (PST): {current_time_pst}") | |
with gr.Blocks() as demo: | |
gr.Markdown("# Phi-4-Hindi Demo") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
label="Input", | |
placeholder="Enter your text here...", | |
lines=5 | |
) | |
with gr.Row(): | |
with gr.Column(): | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.1, | |
step=0.01, | |
label="Temperature" | |
) | |
with gr.Column(): | |
max_new_tokens = gr.Slider( | |
minimum=50, | |
maximum=1000, | |
value=400, | |
step=10, | |
label="Max New Tokens" | |
) | |
with gr.Column(): | |
top_p = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.1, | |
step=0.01, | |
label="Top P" | |
) | |
with gr.Row(): | |
clear_btn = gr.Button("Clear") | |
send_btn = gr.Button("Send", variant="primary") | |
with gr.Column(): | |
output_text = gr.Textbox( | |
label="Output", | |
lines=15 | |
) | |
send_btn.click( | |
fn=generate_response, | |
inputs=[input_text, temperature, max_new_tokens, top_p], | |
outputs=output_text | |
) | |
clear_btn.click( | |
fn=lambda: ("", "", "", ""), | |
inputs=None, | |
outputs=[input_text, output_text] | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() |