asv7j commited on
Commit
8ce2931
·
verified ·
1 Parent(s): 2141949

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -8,15 +8,15 @@ access_token = os.getenv("read_access")
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
  device = "cpu" # the device to load the model onto
10
 
11
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
12
 
13
- model1 = AutoModelForCausalLM.from_pretrained(
14
- "Qwen/Qwen2-1.5B-Instruct",
15
- device_map="auto"
16
- )
17
 
18
- tokenizer2 = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token=access_token)
19
- model2 = AutoModelForCausalLM.from_pretrained(
20
  "google/gemma-2-2b-it",
21
  device_map="auto",
22
  token=access_token
@@ -106,9 +106,9 @@ async def read_droot():
106
  tokenize=False,
107
  add_generation_prompt=True
108
  )
109
- model_inputs = tokenizer2([text], return_tensors="pt").to(device)
110
 
111
- generated_ids = model2.generate(
112
  model_inputs.input_ids,
113
  max_new_tokens=64
114
  )
@@ -116,11 +116,10 @@ async def read_droot():
116
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
117
  ]
118
 
119
- response = tokenizer2.batch_decode(generated_ids, skip_special_tokens=True)[0]
120
- respons = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
121
  print(response)
122
  end_time = time.time()
123
  time_taken = end_time - starttime
124
  print(time_taken)
125
- return {"Hello": respons}
126
  #return {response: time}
 
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
  device = "cpu" # the device to load the model onto
10
 
11
+ #tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
12
 
13
+ #model1 = AutoModelForCausalLM.from_pretrained(
14
+ # "Qwen/Qwen2-1.5B-Instruct",
15
+ # device_map="auto"
16
+ #)
17
 
18
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token=access_token)
19
+ model = AutoModelForCausalLM.from_pretrained(
20
  "google/gemma-2-2b-it",
21
  device_map="auto",
22
  token=access_token
 
106
  tokenize=False,
107
  add_generation_prompt=True
108
  )
109
+ model_inputs = tokenizer([text], return_tensors="pt").to(device)
110
 
111
+ generated_ids = model.generate(
112
  model_inputs.input_ids,
113
  max_new_tokens=64
114
  )
 
116
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
117
  ]
118
 
119
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
120
  print(response)
121
  end_time = time.time()
122
  time_taken = end_time - starttime
123
  print(time_taken)
124
+ return {"Hello": "resps"}
125
  #return {response: time}