zamalali commited on
Commit
efede83
·
1 Parent(s): a8b7cf1

Refactor model initialization to always load on GPU

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -46,14 +46,15 @@ ocr_model = ocr_predictor(
46
  )
47
 
48
 
49
- if torch.cuda.is_available():
50
- processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
51
- vision_model = LlavaNextForConditionalGeneration.from_pretrained(
52
- "llava-hf/llava-v1.6-mistral-7b-hf",
53
- torch_dtype=torch.float16,
54
- low_cpu_mem_usage=True,
55
- load_in_4bit=True,
56
- )
 
57
 
58
 
59
  @spaces.GPU
 
46
  )
47
 
48
 
49
+
50
+ processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
51
+ vision_model = LlavaNextForConditionalGeneration.from_pretrained(
52
+ "llava-hf/llava-v1.6-mistral-7b-hf",
53
+ torch_dtype=torch.float16,
54
+ low_cpu_mem_usage=True,
55
+ load_in_4bit=True,
56
+ )
57
+ vision_model.to("cuda:0")
58
 
59
 
60
  @spaces.GPU