daranaka commited on
Commit
78e178a
1 Parent(s): d164e2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -25
app.py CHANGED
@@ -11,40 +11,54 @@ if "memory" not in st.session_state:
11
 
12
  @st.cache_resource
13
  def load_model():
14
- model = AutoModel.from_pretrained("ragavsachdeva/magi", trust_remote_code=True)
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- model.to(device)
17
- return model
 
 
 
 
18
 
19
  @st.cache_data
20
  def read_image_as_np_array(image_path):
21
- if "http" in image_path:
22
- image = Image.open(urllib.request.urlopen(image_path)).convert("L").convert("RGB")
23
- else:
24
- image = Image.open(image_path).convert("L").convert("RGB")
25
- image = np.array(image)
26
- return image
 
 
 
 
27
 
28
  @st.cache_data
29
  def predict_detections_and_associations(
30
- image_path,
31
- character_detection_threshold,
32
- panel_detection_threshold,
33
- text_detection_threshold,
34
- character_character_matching_threshold,
35
- text_character_matching_threshold,
36
  ):
37
  image = read_image_as_np_array(image_path)
38
- with torch.no_grad():
39
- result = model.predict_detections_and_associations(
40
- [image],
41
- character_detection_threshold=character_detection_threshold,
42
- panel_detection_threshold=panel_detection_threshold,
43
- text_detection_threshold=text_detection_threshold,
44
- character_character_matching_threshold=character_character_matching_threshold,
45
- text_character_matching_threshold=text_character_matching_threshold,
 
 
 
46
  )[0]
47
- return result
 
 
 
48
 
49
  @st.cache_data
50
  def predict_ocr(
 
11
 
12
  @st.cache_resource
13
  def load_model():
14
+ try:
15
+ model = AutoModel.from_pretrained("ragavsachdeva/magi", trust_remote_code=True)
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model.to(device)
18
+ return model
19
+ except Exception as e:
20
+ st.error(f"Error loading model: {e}")
21
+ return None
22
 
23
  @st.cache_data
24
  def read_image_as_np_array(image_path):
25
+ try:
26
+ if "http" in image_path:
27
+ image = Image.open(urllib.request.urlopen(image_path)).convert("L").convert("RGB")
28
+ else:
29
+ image = Image.open(image_path).convert("L").convert("RGB")
30
+ image = np.array(image)
31
+ return image
32
+ except Exception as e:
33
+ st.error(f"Error reading image: {e}")
34
+ return None
35
 
36
  @st.cache_data
37
  def predict_detections_and_associations(
38
+ image_path,
39
+ char_detect_thresh,
40
+ panel_detect_thresh,
41
+ text_detect_thresh,
42
+ char_char_match_thresh,
43
+ text_char_match_thresh,
44
  ):
45
  image = read_image_as_np_array(image_path)
46
+ if image is None:
47
+ return None
48
+ try:
49
+ with torch.no_grad():
50
+ result = model.predict_detections_and_associations(
51
+ [image],
52
+ character_detection_threshold=char_detect_thresh,
53
+ panel_detection_threshold=panel_detect_thresh,
54
+ text_detection_threshold=text_detect_thresh,
55
+ character_character_matching_threshold=char_char_match_thresh,
56
+ text_character_matching_threshold=text_char_match_thresh,
57
  )[0]
58
+ return result
59
+ except Exception as e:
60
+ st.error(f"Error during prediction: {e}")
61
+ return None
62
 
63
  @st.cache_data
64
  def predict_ocr(