xzerus commited on
Commit
2b71a80
·
verified ·
1 Parent(s): 9e7777f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -69
app.py CHANGED
@@ -1,81 +1,46 @@
1
  import torch
2
- import torchvision.transforms as T
3
  from PIL import Image
4
- from transformers import AutoModel, AutoTokenizer
5
  import gradio as gr
6
- import logging
7
 
8
- # Setup logging
9
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
10
-
11
- # Device Configuration
12
- device = torch.device("cpu") # Force CPU usage
13
-
14
- # ImageNet normalization values
15
- IMAGENET_MEAN = (0.485, 0.456, 0.406)
16
- IMAGENET_STD = (0.229, 0.224, 0.225)
17
-
18
- def build_transform(input_size):
19
- """Build preprocessing pipeline for images."""
20
- transform = T.Compose([
21
- T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
22
- T.Resize((input_size, input_size), interpolation=T.InterpolationMode.BICUBIC),
23
- T.ToTensor(),
24
- T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
25
- ])
26
- return transform
27
-
28
- def preprocess_image(image, input_size=448):
29
- """Preprocess the image to the required format."""
30
- transform = build_transform(input_size)
31
- tensor_image = transform(image).unsqueeze(0).to(torch.float32) # Use float32 for CPU
32
- return tensor_image
33
-
34
- # Load the model and tokenizer
35
- logging.info("Loading model from Hugging Face Hub...")
36
- model_path = "OpenGVLab/InternVL2_5-1B"
37
  model = AutoModel.from_pretrained(
38
- model_path,
39
- trust_remote_code=True,
40
- ).to(device).eval()
 
 
41
 
42
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
 
43
 
44
- # Add the `<image>` token if missing
45
- if "<image>" not in tokenizer.get_vocab():
46
- tokenizer.add_tokens(["<image>"])
47
- model.resize_token_embeddings(len(tokenizer)) # Resize model embeddings
48
-
49
- assert "<image>" in tokenizer.get_vocab(), "Error: `<image>` token is missing from tokenizer vocabulary."
50
-
51
- def describe_image(image):
52
- """Generate a description for the uploaded image."""
53
  try:
54
- pixel_values = preprocess_image(image, input_size=448)
55
- prompt = "<image>\nExtract text from the image, respond with only the extracted text."
56
-
57
- # Perform inference
58
- response = model.chat(
59
- tokenizer=tokenizer,
60
- pixel_values=pixel_values,
61
- question=prompt,
62
- history=None,
63
- return_history=False,
64
- generation_config=dict(max_new_tokens=512, do_sample=True),
65
- )
66
- return response
67
  except Exception as e:
68
- logging.error(f"Error during processing: {e}")
69
- return f"Error: {e}"
70
-
71
- # Gradio Interface
72
- interface = gr.Interface(
73
- fn=describe_image,
74
- inputs=gr.Image(type="pil"),
75
- outputs=gr.Textbox(label="Extracted Text", lines=10, interactive=False),
76
- title="Image to Text",
77
- description="Upload an image to extract text using the pretrained model.",
78
  )
79
 
 
80
  if __name__ == "__main__":
81
- interface.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import torch
 
2
  from PIL import Image
3
+ from transformers import AutoModel, CLIPImageProcessor
4
  import gradio as gr
 
5
 
6
+ # Load the model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  model = AutoModel.from_pretrained(
8
+ 'OpenGVLab/InternViT-6B-448px-V1-5',
9
+ torch_dtype=torch.bfloat16,
10
+ low_cpu_mem_usage=True,
11
+ trust_remote_code=True
12
+ ).cuda().eval()
13
 
14
+ # Load the image processor
15
+ image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternViT-6B-448px-V1-5')
16
 
17
+ # Define the function to process the image and generate outputs
18
+ def process_image(image):
 
 
 
 
 
 
 
19
  try:
20
+ # Convert uploaded image to RGB
21
+ image = image.convert('RGB')
22
+
23
+ # Preprocess the image
24
+ pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
25
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
26
+
27
+ # Run the model
28
+ outputs = model(pixel_values)
29
+
30
+ # Assuming the model returns embeddings or features
31
+ return f"Output Shape: {outputs.last_hidden_state.shape}"
 
32
  except Exception as e:
33
+ return f"Error: {str(e)}"
34
+
35
+ # Create the Gradio interface
36
+ demo = gr.Interface(
37
+ fn=process_image, # Function to process the input
38
+ inputs=gr.Image(type="pil"), # Accepts images as input
39
+ outputs=gr.Textbox(label="Model Output"), # Displays model output
40
+ title="InternViT Demo",
41
+ description="Upload an image to process it using the InternViT model from OpenGVLab."
 
42
  )
43
 
44
+ # Launch the demo
45
  if __name__ == "__main__":
46
+ demo.launch(server_name="0.0.0.0", server_port=7860)