Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer | |
from gemma.modeling_gemma import GemmaForCausalLM | |
import torch | |
import time | |
# Assuming the GemmaForCausalLM and the specific tokenizer are correctly installed and imported | |
def inference(input_text): | |
start_time = time.time() | |
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device) | |
input_length = input_ids["input_ids"].shape[1] | |
outputs = model.generate( | |
input_ids=input_ids["input_ids"], | |
max_length=1024, | |
do_sample=False) | |
generated_sequence = outputs[:, input_length:].tolist() | |
res = tokenizer.decode(generated_sequence[0]) | |
end_time = time.time() | |
return {"output": res, "latency": f"{end_time - start_time:.2f} seconds"} | |
# Initialize the tokenizer and model | |
model_id = "NexaAIDev/Octopus-v2" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = GemmaForCausalLM.from_pretrained( | |
model_id, torch_dtype=torch.bfloat16, device_map="auto" | |
) | |
def gradio_interface(input_text): | |
nexa_query = f"Below is the query from the users, please call the correct function and generate the parameters to call the function.\n\nQuery: {input_text} \n\nResponse:" | |
result = inference(nexa_query) | |
return result["output"], result["latency"] | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your query here..."), | |
outputs=[gr.outputs.Textbox(label="Output"), gr.outputs.Textbox(label="Latency")], | |
title="Gemma Model Inference", | |
description="This application uses the Gemma model for generating responses based on the input query." | |
) | |
if __name__ == "__main__": | |
iface.launch() |