vishalkatheriya commited on
Commit
fad0c74
1 Parent(s): 81fe0f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -0
app.py CHANGED
@@ -1,6 +1,52 @@
1
  import streamlit as st
2
  from PIL import Image
3
  import inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  # Initialize session state to block re-running
5
  if 'has_run' not in st.session_state:
6
  st.session_state.has_run = False
 
1
  import streamlit as st
2
  from PIL import Image
3
  import inference
4
+ from transformers import AutoProcessor, AutoModelForCausalLM
5
+ from PIL import Image
6
+ import requests
7
+ import copy
8
+ import os
9
+ from unittest.mock import patch
10
+ from transformers.dynamic_module_utils import get_imports
11
+ import torch
12
+
13
+ #remove flash_attn for load model in cpu
14
+ def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
15
+ if not str(filename).endswith("modeling_florence2.py"):
16
+ return get_imports(filename)
17
+ imports = get_imports(filename)
18
+ imports.remove("flash_attn")
19
+ return imports
20
+
21
+ # Initialize session state for model loading and to block re-running
22
+ if 'model_loaded' not in st.session_state:
23
+ st.session_state.model_loaded = False
24
+
25
+ # Function to load the model (e.g., Florence-2 model)
26
+ def load_model():
27
+ # Simulate model loading process
28
+ model_id = "microsoft/Florence-2-large"
29
+ #processor loading
30
+ st.session_state.processor = AutoProcessor.from_pretrained(model_id, torch_dtype=torch.qint8, trust_remote_code=True)
31
+
32
+ # Load the model normally
33
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): # workaround for unnecessary flash_attn requirement
34
+ model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="sdpa", trust_remote_code=True)
35
+
36
+ # Apply dynamic quantization
37
+ Qmodel = torch.quantization.quantize_dynamic(
38
+ model, {torch.nn.Linear}, dtype=torch.qint8
39
+ )
40
+ del model
41
+ st.session_state.model = Qmodel
42
+ st.session_state.model_loaded = True
43
+ st.write("model loaded complete")
44
+ # Load the model only once
45
+ if not st.session_state.model_loaded:
46
+ with st.spinner('Loading model...'):
47
+ load_model()
48
+
49
+
50
  # Initialize session state to block re-running
51
  if 'has_run' not in st.session_state:
52
  st.session_state.has_run = False