Pectics commited on
Commit
aa819ab
·
verified ·
1 Parent(s): 1325e72

Fix ext threads invoking @GPU

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -1,7 +1,7 @@
1
- from threading import Thread
2
- from spaces import GPU
3
  from gradio import ChatInterface, Textbox, Slider
4
-
 
 
5
  from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor, TextIteratorStreamer, AutoProcessor, BatchFeature
6
  from qwen_vl_utils import process_vision_info
7
 
@@ -9,7 +9,7 @@ model_path = "Pectics/Softie-VL-7B-250123"
9
 
10
  model = Qwen2VLForConditionalGeneration.from_pretrained(
11
  model_path,
12
- torch_dtype="auto",
13
  attn_implementation="flash_attention_2",
14
  device_map="auto",
15
  )
@@ -18,9 +18,9 @@ max_pixels = 1280 * 28 * 28
18
  processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_path, min_pixels=min_pixels, max_pixels=max_pixels)
19
 
20
  @GPU
21
- def infer(inputs: BatchFeature, **kwargs) -> None:
22
  inputs = inputs.to("cuda")
23
- model.generate(**kwargs)
24
 
25
  def respond(
26
  message,
@@ -51,7 +51,7 @@ def respond(
51
  temperature=temperature,
52
  top_p=top_p,
53
  )
54
- Thread(target=infer, kwargs=kwargs).start()
55
  response = ""
56
  for token in streamer:
57
  response += token
 
 
 
1
  from gradio import ChatInterface, Textbox, Slider
2
+ from spaces import GPU
3
+ from threading import Thread
4
+ from torch import bfloat16
5
  from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor, TextIteratorStreamer, AutoProcessor, BatchFeature
6
  from qwen_vl_utils import process_vision_info
7
 
 
9
 
10
  model = Qwen2VLForConditionalGeneration.from_pretrained(
11
  model_path,
12
+ torch_dtype=bfloat16,
13
  attn_implementation="flash_attention_2",
14
  device_map="auto",
15
  )
 
18
  processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_path, min_pixels=min_pixels, max_pixels=max_pixels)
19
 
20
  @GPU
21
+ def infer(inputs: BatchFeature, **kwargs):
22
  inputs = inputs.to("cuda")
23
+ Thread(target=model.generate, kwargs=kwargs).start()
24
 
25
  def respond(
26
  message,
 
51
  temperature=temperature,
52
  top_p=top_p,
53
  )
54
+ infer(**kwargs)
55
  response = ""
56
  for token in streamer:
57
  response += token