1024m commited on
Commit
cfc8af7
·
verified ·
1 Parent(s): aa82a83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -13
app.py CHANGED
@@ -6,15 +6,39 @@ from threading import Thread
6
  import time
7
  import pytz
8
  from datetime import datetime
 
 
 
 
 
 
 
9
  print("Loading model and tokenizer...")
10
  model_name = "large-traversaal/Phi-4-Hindi"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
 
 
 
 
13
  print("Model and tokenizer loaded successfully!")
14
- def generate_response(message, temperature, max_new_tokens, top_p):
15
- print(f"Input: {message}")
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  start_time = time.time()
17
- inputs = tokenizer(message, return_tensors="pt").to(model.device)
18
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
19
  gen_kwargs = {
20
  "input_ids": inputs["input_ids"],
@@ -29,16 +53,12 @@ def generate_response(message, temperature, max_new_tokens, top_p):
29
  result = []
30
  for text in streamer:
31
  result.append(text)
32
- current_output = "".join(result)
33
- if current_output.startswith(message):
34
- yield current_output[len(message):]
35
- else:
36
- yield current_output
37
  end_time = time.time()
38
  time_taken = end_time - start_time
39
  output_text = "".join(result)
40
- if output_text.startswith(message):
41
- output_text = output_text[len(message):]
42
  print(f"Output: {output_text}")
43
  print(f"Time taken: {time_taken:.2f} seconds")
44
  pst_timezone = pytz.timezone('America/Los_Angeles')
@@ -53,6 +73,11 @@ with gr.Blocks() as demo:
53
  placeholder="Enter your text here...",
54
  lines=5
55
  )
 
 
 
 
 
56
  with gr.Row():
57
  with gr.Column():
58
  temperature = gr.Slider(
@@ -88,11 +113,11 @@ with gr.Blocks() as demo:
88
  )
89
  send_btn.click(
90
  fn=generate_response,
91
- inputs=[input_text, temperature, max_new_tokens, top_p],
92
  outputs=output_text
93
  )
94
  clear_btn.click(
95
- fn=lambda: ("", "", "", ""),
96
  inputs=None,
97
  outputs=[input_text, output_text]
98
  )
 
6
  import time
7
  import pytz
8
  from datetime import datetime
9
+ import gradio as gr
10
+ import torch
11
+ import time
12
+ import pytz
13
+ from datetime import datetime
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
15
+ from threading import Thread
16
  print("Loading model and tokenizer...")
17
  model_name = "large-traversaal/Phi-4-Hindi"
18
  tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ model_name,
21
+ torch_dtype=torch.bfloat16,
22
+ device_map="auto"
23
+ )
24
  print("Model and tokenizer loaded successfully!")
25
+ option_mapping = {
26
+ "translation": "### TRANSLATION ###",
27
+ "mcq": "### MCQ ###",
28
+ "nli": "### NLI ###",
29
+ "summarization": "### SUMMARIZATION ###",
30
+ "long response": "### LONG RESPONSE ###",
31
+ "short response": "### SHORT RESPONSE ###",
32
+ "direct response": "### DIRECT RESPONSE ###",
33
+ "paraphrase": "### PARAPHRASE ###",
34
+ "code": "### CODE ###"
35
+ }
36
+ def generate_response(message, temperature, max_new_tokens, top_p, task):
37
+ append_text = option_mapping.get(task, "")
38
+ prompt = f"INPUT : {message} {append_text} RESPONSE : "
39
+ print(f"Prompt: {prompt}")
40
  start_time = time.time()
41
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
42
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
43
  gen_kwargs = {
44
  "input_ids": inputs["input_ids"],
 
53
  result = []
54
  for text in streamer:
55
  result.append(text)
56
+ yield "".join(result)
 
 
 
 
57
  end_time = time.time()
58
  time_taken = end_time - start_time
59
  output_text = "".join(result)
60
+ if "RESPONSE : " in output_text:
61
+ output_text = output_text.split("RESPONSE : ", 1)[1].strip()
62
  print(f"Output: {output_text}")
63
  print(f"Time taken: {time_taken:.2f} seconds")
64
  pst_timezone = pytz.timezone('America/Los_Angeles')
 
73
  placeholder="Enter your text here...",
74
  lines=5
75
  )
76
+ task_dropdown = gr.Dropdown(
77
+ choices=["translation", "mcq", "nli", "summarization", "long response", "short response", "direct response", "paraphrase", "code"],
78
+ value="long response",
79
+ label="Task"
80
+ )
81
  with gr.Row():
82
  with gr.Column():
83
  temperature = gr.Slider(
 
113
  )
114
  send_btn.click(
115
  fn=generate_response,
116
+ inputs=[input_text, temperature, max_new_tokens, top_p, task_dropdown],
117
  outputs=output_text
118
  )
119
  clear_btn.click(
120
+ fn=lambda: ("", ""),
121
  inputs=None,
122
  outputs=[input_text, output_text]
123
  )