MohamedRashad commited on
Commit
9e39b36
·
verified ·
1 Parent(s): 81e957f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -8,8 +8,9 @@ import os
8
  from threading import Thread
9
 
10
  # Load model directly
 
11
  tokenizer = AutoTokenizer.from_pretrained("Navid-AI/Mulhem-1-Mini", token=os.getenv("HF_TOKEN"))
12
- model = AutoModelForCausalLM.from_pretrained("Navid-AI/Mulhem-1-Mini", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2", token=os.getenv("HF_TOKEN"))
13
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
14
 
15
  def respond(
@@ -30,7 +31,7 @@ def respond(
30
  messages.append({"role": "assistant", "content": val[1]})
31
 
32
  messages.append({"role": "user", "content": message})
33
- inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True, enable_reasoning=enable_reasoning)
34
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p)
35
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
36
 
 
8
  from threading import Thread
9
 
10
  # Load model directly
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
  tokenizer = AutoTokenizer.from_pretrained("Navid-AI/Mulhem-1-Mini", token=os.getenv("HF_TOKEN"))
13
+ model = AutoModelForCausalLM.from_pretrained("Navid-AI/Mulhem-1-Mini", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", token=os.getenv("HF_TOKEN")).to(device)
14
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
15
 
16
  def respond(
 
31
  messages.append({"role": "assistant", "content": val[1]})
32
 
33
  messages.append({"role": "user", "content": message})
34
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True, enable_reasoning=enable_reasoning).to(device)
35
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p)
36
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
37