daranaka commited on
Commit
44741f9
1 Parent(s): 78e178a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -39
app.py CHANGED
@@ -11,54 +11,40 @@ if "memory" not in st.session_state:
11
 
12
  @st.cache_resource
13
  def load_model():
14
- try:
15
- model = AutoModel.from_pretrained("ragavsachdeva/magi", trust_remote_code=True)
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
- model.to(device)
18
- return model
19
- except Exception as e:
20
- st.error(f"Error loading model: {e}")
21
- return None
22
 
23
  @st.cache_data
24
  def read_image_as_np_array(image_path):
25
- try:
26
- if "http" in image_path:
27
- image = Image.open(urllib.request.urlopen(image_path)).convert("L").convert("RGB")
28
- else:
29
- image = Image.open(image_path).convert("L").convert("RGB")
30
- image = np.array(image)
31
- return image
32
- except Exception as e:
33
- st.error(f"Error reading image: {e}")
34
- return None
35
 
36
  @st.cache_data
37
  def predict_detections_and_associations(
38
- image_path,
39
- char_detect_thresh,
40
- panel_detect_thresh,
41
- text_detect_thresh,
42
- char_char_match_thresh,
43
- text_char_match_thresh,
44
  ):
45
  image = read_image_as_np_array(image_path)
46
- if image is None:
47
- return None
48
- try:
49
- with torch.no_grad():
50
- result = model.predict_detections_and_associations(
51
- [image],
52
- character_detection_threshold=char_detect_thresh,
53
- panel_detection_threshold=panel_detect_thresh,
54
- text_detection_threshold=text_detect_thresh,
55
- character_character_matching_threshold=char_char_match_thresh,
56
- text_character_matching_threshold=text_char_match_thresh,
57
  )[0]
58
- return result
59
- except Exception as e:
60
- st.error(f"Error during prediction: {e}")
61
- return None
62
 
63
  @st.cache_data
64
  def predict_ocr(
 
11
 
12
  @st.cache_resource
13
  def load_model():
14
+ model = AutoModel.from_pretrained("ragavsachdeva/magi", trust_remote_code=True)
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model.to(device)
17
+ return model
 
 
 
 
18
 
19
  @st.cache_data
20
  def read_image_as_np_array(image_path):
21
+ if "http" in image_path:
22
+ image = Image.open(urllib.request.urlopen(image_path)).convert("L").convert("RGB")
23
+ else:
24
+ image = Image.open(image_path).convert("L").convert("RGB")
25
+ image = np.array(image)
26
+ return image
 
 
 
 
27
 
28
  @st.cache_data
29
  def predict_detections_and_associations(
30
+ image_path,
31
+ character_detection_threshold,
32
+ panel_detection_threshold,
33
+ text_detection_threshold,
34
+ character_character_matching_threshold,
35
+ text_character_matching_threshold,
36
  ):
37
  image = read_image_as_np_array(image_path)
38
+ with torch.no_grad():
39
+ result = model.predict_detections_and_associations(
40
+ [image],
41
+ character_detection_threshold=character_detection_threshold,
42
+ panel_detection_threshold=panel_detection_threshold,
43
+ text_detection_threshold=text_detection_threshold,
44
+ character_character_matching_threshold=character_character_matching_threshold,
45
+ text_character_matching_threshold=text_character_matching_threshold,
 
 
 
46
  )[0]
47
+ return result
 
 
 
48
 
49
  @st.cache_data
50
  def predict_ocr(