mrbeliever commited on
Commit
17a46e3
·
verified ·
1 Parent(s): dea4c4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -8
app.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  from PIL import Image
6
  from transformers import AutoModelForCausalLM, LlamaTokenizer
7
 
 
8
  DEFAULT_PARAMS = {
9
  "do_sample": False,
10
  "max_new_tokens": 256,
@@ -23,6 +24,7 @@ DEFAULT_QUERY = (
23
  DTYPE = torch.bfloat16
24
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
 
26
  tokenizer = LlamaTokenizer.from_pretrained(
27
  pretrained_model_name_or_path="lmsys/vicuna-7b-v1.5",
28
  )
@@ -43,7 +45,7 @@ def generate_caption(
43
  ) -> str:
44
  inputs = model.build_conversation_input_ids(
45
  tokenizer=tokenizer,
46
- query=DEFAULT_QUERY, # Use the default query directly
47
  history=[],
48
  images=[image],
49
  )
@@ -61,13 +63,41 @@ def generate_caption(
61
  result = result.replace("This image showcases", "").strip().removesuffix("</s>").strip().capitalize()
62
  return result
63
 
64
- with gr.Blocks() as demo:
65
- with gr.Row():
66
- with gr.Column():
67
- input_image = gr.Image(type="pil") # Image input remains
68
- run_button = gr.Button(value="Generate Caption")
69
- with gr.Column():
70
- output_caption = gr.Textbox(label="Generated Caption", show_copy_button=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  run_button.click(
73
  fn=generate_caption,
 
5
  from PIL import Image
6
  from transformers import AutoModelForCausalLM, LlamaTokenizer
7
 
8
+ # Constants
9
  DEFAULT_PARAMS = {
10
  "do_sample": False,
11
  "max_new_tokens": 256,
 
24
  DTYPE = torch.bfloat16
25
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
 
27
+ # Load model and tokenizer
28
  tokenizer = LlamaTokenizer.from_pretrained(
29
  pretrained_model_name_or_path="lmsys/vicuna-7b-v1.5",
30
  )
 
45
  ) -> str:
46
  inputs = model.build_conversation_input_ids(
47
  tokenizer=tokenizer,
48
+ query=DEFAULT_QUERY,
49
  history=[],
50
  images=[image],
51
  )
 
63
  result = result.replace("This image showcases", "").strip().removesuffix("</s>").strip().capitalize()
64
  return result
65
 
66
+ # CSS for design enhancements with a dark button and white text
67
+ css = """
68
+ #container {
69
+ background-color: #f9f9f9;
70
+ padding: 20px;
71
+ border-radius: 15px;
72
+ border: 2px solid #333; /* Darker outline */
73
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); /* Enhanced shadow */
74
+ max-width: 400px;
75
+ margin: auto;
76
+ }
77
+ #input_image, #output_caption, #run_button {
78
+ margin-top: 15px;
79
+ border: 2px solid #333; /* Darker outline */
80
+ border-radius: 8px;
81
+ }
82
+ #run_button {
83
+ background-color: #000000; /* Dark button color */
84
+ color: white; /* White text */
85
+ border-radius: 10px;
86
+ padding: 10px;
87
+ cursor: pointer;
88
+ transition: background-color 0.3s ease;
89
+ }
90
+ #run_button:hover {
91
+ background-color: #333; /* Slightly lighter on hover */
92
+ }
93
+ """
94
+
95
+ # Gradio interface with vertical alignment
96
+ with gr.Blocks(css=css) as demo:
97
+ with gr.Column(elem_id="container"):
98
+ input_image = gr.Image(type="pil", elem_id="input_image")
99
+ run_button = gr.Button(value="Generate Caption", elem_id="run_button")
100
+ output_caption = gr.Textbox(label="Generated Caption", show_copy_button=True, elem_id="output_caption")
101
 
102
  run_button.click(
103
  fn=generate_caption,