Tonic commited on
Commit
d3bf12d
·
verified ·
1 Parent(s): d09bf09

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer
3
+ from gemma.modeling_gemma import GemmaForCausalLM
4
+ import torch
5
+ import time
6
+
7
+ # Assuming the GemmaForCausalLM and the specific tokenizer are correctly installed and imported
8
+
9
+ def inference(input_text):
10
+ start_time = time.time()
11
+ input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
12
+ input_length = input_ids["input_ids"].shape[1]
13
+ outputs = model.generate(
14
+ input_ids=input_ids["input_ids"],
15
+ max_length=1024,
16
+ do_sample=False)
17
+ generated_sequence = outputs[:, input_length:].tolist()
18
+ res = tokenizer.decode(generated_sequence[0])
19
+ end_time = time.time()
20
+ return {"output": res, "latency": f"{end_time - start_time:.2f} seconds"}
21
+
22
+ # Initialize the tokenizer and model
23
+ model_id = "NexaAIDev/android_API_10k_data"
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
25
+ model = GemmaForCausalLM.from_pretrained(
26
+ model_id, torch_dtype=torch.bfloat16, device_map="auto"
27
+ )
28
+
29
+ def gradio_interface(input_text):
30
+ nexa_query = f"Below is the query from the users, please call the correct function and generate the parameters to call the function.\n\nQuery: {input_text} \n\nResponse:"
31
+ result = inference(nexa_query)
32
+ return result["output"], result["latency"]
33
+
34
+ iface = gr.Interface(
35
+ fn=gradio_interface,
36
+ inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your query here..."),
37
+ outputs=[gr.outputs.Textbox(label="Output"), gr.outputs.Textbox(label="Latency")],
38
+ title="Gemma Model Inference",
39
+ description="This application uses the Gemma model for generating responses based on the input query."
40
+ )
41
+
42
+ if __name__ == "__main__":
43
+ iface.launch()