Krish45 commited on
Commit
a9670a5
·
verified ·
1 Parent(s): 1489919

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -4
app.py CHANGED
@@ -1,9 +1,31 @@
1
  import gradio as gr
 
2
 
3
- def predict(text):
4
- return f"Echo: {text}"
5
 
6
- iface = gr.Interface(fn=predict, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Launch with API access
9
- iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
+ model_name = config["model_name"]
 
5
 
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForCausalLM.from_pretrained(
8
+ model_name, low_cpu_mem_usage=True, device_map="auto", torch_dtype="auto"
9
+ )
10
+
11
+ def predict(messages):
12
+ text = tokenizer.apply_chat_template(
13
+ messages, tokenize=False, add_generation_prompt=True
14
+ )
15
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
16
+
17
+ logger.info(f"Model generation process started at - {process_id}")
18
+ generated_ids = model.generate(**model_inputs, max_new_tokens=512)
19
+ generated_ids = [
20
+ output_ids[len(input_ids) :]
21
+ for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
22
+ ]
23
+ logger.info(f"Model generation process completed [{process_id}]")
24
+
25
+ reply = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
26
+ return reply
27
+
28
+ iface = gr.Interface(fn=predict, inputs="messages", outputs="reply")
29
 
30
  # Launch with API access
31
+ iface.launch(server_name="0.0.0.0", server_port=7860, share=False)