asv7j commited on
Commit
e2d20d4
·
verified ·
1 Parent(s): 9d541fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -2
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI
2
  import time
3
-
4
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  device = "cpu" # the device to load the model onto
@@ -75,4 +75,39 @@ async def read_droot():
75
  end_time = time.time()
76
  time_taken = end_time - starttime
77
  print(time_taken)
78
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
  import time
3
+ import torch
4
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  device = "cpu" # the device to load the model onto
 
75
  end_time = time.time()
76
  time_taken = end_time - starttime
77
  print(time_taken)
78
+ return {"Hello": "World!"}
79
+
80
+
81
+ @app.get("/text")
82
+ async def readdroot():
83
+ starttime = time.time()
84
+ messages = [
85
+ {"role": "system", "content": "You are a helpful assistant, Sia. You are developed by Sushma. You will response in polity and brief."},
86
+ {"role": "user", "content": "Who are you?"},
87
+ {"role": "assistant", "content": "I am Sia, a small language model created by Sushma. I am here to assist you."},
88
+ {"role": "user", "content": "Hi, How are you?"}
89
+ ]
90
+ text = tokenizer.apply_chat_template(
91
+ messages,
92
+ tokenize=False,
93
+ add_generation_prompt=True
94
+ )
95
+ model_inputs = tokenizer([text], return_tensors="pt").to(device)
96
+
97
+ with torch.no_grad(): # Disable gradient calculation
98
+ generated_ids = model.generate(
99
+ model_inputs.input_ids,
100
+ max_new_tokens=64, # Adjust this based on needs
101
+ use_cache=False # Use cached activations if applicable
102
+ )
103
+
104
+ generated_ids = [
105
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
106
+ ]
107
+
108
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
109
+ print(response)
110
+ end_time = time.time()
111
+ time_taken = end_time - starttime
112
+ print(time_taken)
113
+ return {"Hello": "World!"}