asv7j commited on
Commit
c79fe72
·
verified ·
1 Parent(s): 19d0396

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -1
app.py CHANGED
@@ -38,7 +38,10 @@ model2 = AutoModelForCausalLM.from_pretrained(
38
  torch_dtype=torch.bfloat16,
39
  token=access_token
40
  )
41
-
 
 
 
42
 
43
  @app.get("/")
44
  async def read_root():
@@ -113,6 +116,37 @@ async def read_droot():
113
  print(time_taken)
114
  return {"Hello": "World!"}
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  @app.get("/text")
117
  async def read_droot():
118
  starttime = time.time()
 
38
  torch_dtype=torch.bfloat16,
39
  token=access_token
40
  )
41
+ model3 = AutoModelForCausalLM.from_pretrained(
42
+ "Qwen/Qwen2-0.5B-Instruct",
43
+ device_map="auto"
44
+ )
45
 
46
  @app.get("/")
47
  async def read_root():
 
116
  print(time_taken)
117
  return {"Hello": "World!"}
118
 
119
+ @app.get("/teat")
120
+ async def read_droot():
121
+ starttime = time.time()
122
+ messages = [
123
+ {"role": "system", "content": "You are a helpful assistant, Sia, developed by Sushma. You will response in polity and brief."},
124
+ {"role": "user", "content": "I'm Alok. Who are you?"},
125
+ {"role": "assistant", "content": "I am Sia, a small language model created by Sushma."},
126
+ {"role": "user", "content": "How are you?"}
127
+ ]
128
+ text = tokenizer.apply_chat_template(
129
+ messages,
130
+ tokenize=False,
131
+ add_generation_prompt=True
132
+ )
133
+ model_inputs = tokenizer([text], return_tensors="pt").to(device)
134
+
135
+ generated_ids = model3.generate(
136
+ model_inputs.input_ids,
137
+ max_new_tokens=64
138
+ )
139
+ generated_ids = [
140
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
141
+ ]
142
+
143
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
144
+ print(response)
145
+ end_time = time.time()
146
+ time_taken = end_time - starttime
147
+ print(time_taken)
148
+ return {"Hello": "World!"}
149
+
150
  @app.get("/text")
151
  async def read_droot():
152
  starttime = time.time()