xzerus commited on
Commit
c5e37aa
·
verified ·
1 Parent(s): 895c285

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -45
app.py CHANGED
@@ -1,22 +1,22 @@
1
  import torch
2
  import torchvision.transforms as T
3
  from PIL import Image
4
- from threading import Thread
5
- from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
6
  import gradio as gr
7
  import logging
8
 
9
  # Setup logging
10
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
11
 
 
 
 
12
  # ImageNet normalization values
13
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
14
  IMAGENET_STD = (0.229, 0.224, 0.225)
15
 
16
  def build_transform(input_size):
17
- """
18
- Build preprocessing pipeline for images.
19
- """
20
  transform = T.Compose([
21
  T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
22
  T.Resize((input_size, input_size), interpolation=T.InterpolationMode.BICUBIC),
@@ -26,79 +26,56 @@ def build_transform(input_size):
26
  return transform
27
 
28
  def preprocess_image(image, input_size=448):
29
- """
30
- Preprocess the image to the required format.
31
- """
32
- logging.info("Starting image preprocessing...")
33
  transform = build_transform(input_size)
34
- tensor_image = transform(image).unsqueeze(0) # Add batch dimension
35
- logging.info(f"Image preprocessed. Shape: {tensor_image.shape}")
36
  return tensor_image
37
 
38
  # Load the model and tokenizer
39
  logging.info("Loading model from Hugging Face Hub...")
40
- model_path = "OpenGVLab/InternVL2_5-1B" # Use Hugging Face model path
41
  model = AutoModel.from_pretrained(
42
  model_path,
43
- torch_dtype=torch.bfloat16,
44
  trust_remote_code=True,
45
- ).eval()
46
 
47
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
48
 
49
  # Add the `<image>` token if missing
50
  if "<image>" not in tokenizer.get_vocab():
51
  tokenizer.add_tokens(["<image>"])
52
- logging.info("Added `<image>` token to tokenizer vocabulary.")
53
  model.resize_token_embeddings(len(tokenizer)) # Resize model embeddings
54
 
55
  assert "<image>" in tokenizer.get_vocab(), "Error: `<image>` token is missing from tokenizer vocabulary."
56
 
57
  def describe_image(image):
58
- """
59
- Generate a description for the uploaded image with streamed output.
60
- """
61
  try:
62
- logging.info("Processing uploaded image...")
63
- pixel_values = preprocess_image(image, input_size=448).to(torch.bfloat16)
64
-
65
  prompt = "<image>\nExtract text from the image, respond with only the extracted text."
66
- logging.info(f"Prompt: {prompt}")
67
-
68
- # Streamer for live text output
69
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10)
70
- generation_config = dict(max_new_tokens=512, do_sample=True, streamer=streamer)
71
-
72
- logging.info("Starting model inference...")
73
- thread = Thread(target=model.chat, kwargs=dict(
74
- tokenizer=tokenizer, pixel_values=pixel_values, question=prompt,
75
- history=None, return_history=False, generation_config=generation_config,
76
- ))
77
- thread.start()
78
-
79
- generated_text = ''
80
- for new_text in streamer:
81
- if new_text == model.conv_template.sep:
82
- break
83
- generated_text += new_text
84
- yield new_text # Stream each chunk
85
 
86
- logging.info("Inference complete.")
 
 
 
 
 
 
 
 
87
  except Exception as e:
88
  logging.error(f"Error during processing: {e}")
89
- yield f"Error: {e}"
90
 
91
  # Gradio Interface
92
- logging.info("Setting up Gradio interface...")
93
  interface = gr.Interface(
94
  fn=describe_image,
95
  inputs=gr.Image(type="pil"),
96
  outputs=gr.Textbox(label="Extracted Text", lines=10, interactive=False),
97
  title="Image to Text",
98
  description="Upload an image to extract text using the pretrained model.",
99
- live=True, # Enables live streaming output
100
  )
101
 
102
  if __name__ == "__main__":
103
- logging.info("Launching Gradio interface...")
104
  interface.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import torch
2
  import torchvision.transforms as T
3
  from PIL import Image
4
+ from transformers import AutoModel, AutoTokenizer
 
5
  import gradio as gr
6
  import logging
7
 
8
  # Setup logging
9
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
10
 
11
+ # Device Configuration
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
  # ImageNet normalization values
15
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
16
  IMAGENET_STD = (0.229, 0.224, 0.225)
17
 
18
  def build_transform(input_size):
19
+ """Build preprocessing pipeline for images."""
 
 
20
  transform = T.Compose([
21
  T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
22
  T.Resize((input_size, input_size), interpolation=T.InterpolationMode.BICUBIC),
 
26
  return transform
27
 
28
  def preprocess_image(image, input_size=448):
29
+ """Preprocess the image to the required format."""
 
 
 
30
  transform = build_transform(input_size)
31
+ tensor_image = transform(image).unsqueeze(0).to(torch.float32 if device == "cpu" else torch.bfloat16).to(device)
 
32
  return tensor_image
33
 
34
  # Load the model and tokenizer
35
  logging.info("Loading model from Hugging Face Hub...")
36
+ model_path = "OpenGVLab/InternVL2_5-1B"
37
  model = AutoModel.from_pretrained(
38
  model_path,
39
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
40
  trust_remote_code=True,
41
+ ).to(device).eval()
42
 
43
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
44
 
45
  # Add the `<image>` token if missing
46
  if "<image>" not in tokenizer.get_vocab():
47
  tokenizer.add_tokens(["<image>"])
 
48
  model.resize_token_embeddings(len(tokenizer)) # Resize model embeddings
49
 
50
  assert "<image>" in tokenizer.get_vocab(), "Error: `<image>` token is missing from tokenizer vocabulary."
51
 
52
  def describe_image(image):
53
+ """Generate a description for the uploaded image."""
 
 
54
  try:
55
+ pixel_values = preprocess_image(image, input_size=448)
 
 
56
  prompt = "<image>\nExtract text from the image, respond with only the extracted text."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ response = model.chat(
59
+ tokenizer=tokenizer,
60
+ pixel_values=pixel_values,
61
+ question=prompt,
62
+ history=None,
63
+ return_history=False,
64
+ generation_config=dict(max_new_tokens=512, do_sample=True)
65
+ )
66
+ return response
67
  except Exception as e:
68
  logging.error(f"Error during processing: {e}")
69
+ return f"Error: {e}"
70
 
71
  # Gradio Interface
 
72
  interface = gr.Interface(
73
  fn=describe_image,
74
  inputs=gr.Image(type="pil"),
75
  outputs=gr.Textbox(label="Extracted Text", lines=10, interactive=False),
76
  title="Image to Text",
77
  description="Upload an image to extract text using the pretrained model.",
 
78
  )
79
 
80
  if __name__ == "__main__":
 
81
  interface.launch(server_name="0.0.0.0", server_port=7860)