Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -34,36 +34,57 @@ def load_model(model_id):
|
|
34 |
return model, tokenizer
|
35 |
|
36 |
def generate_description(image, model, tokenizer, max_length=100, temperature=0.7, top_p=0.9):
|
37 |
-
|
38 |
-
|
39 |
-
image
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
44 |
|
45 |
Image: [IMAGE]
|
46 |
Description: """
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
num_return_sequences=1,
|
60 |
-
pad_token_id=tokenizer.pad_token_id,
|
61 |
-
eos_token_id=tokenizer.eos_token_id
|
62 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
|
68 |
def create_demo(model_id):
|
69 |
# Load model and tokenizer
|
|
|
34 |
return model, tokenizer
|
35 |
|
36 |
def generate_description(image, model, tokenizer, max_length=100, temperature=0.7, top_p=0.9):
|
37 |
+
try:
|
38 |
+
# Convert and resize image
|
39 |
+
if image.mode != "RGB":
|
40 |
+
image = image.convert("RGB")
|
41 |
+
image = image.resize((32, 32))
|
42 |
+
|
43 |
+
# Format the input text
|
44 |
+
input_text = """Below is an image. Please describe it in detail.
|
45 |
|
46 |
Image: [IMAGE]
|
47 |
Description: """
|
48 |
+
|
49 |
+
# Ensure we have valid token IDs
|
50 |
+
if tokenizer.pad_token_id is None:
|
51 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
52 |
+
|
53 |
+
# Tokenize input with explicit token IDs
|
54 |
+
inputs = tokenizer(
|
55 |
+
input_text,
|
56 |
+
return_tensors="pt",
|
57 |
+
padding=True,
|
58 |
+
truncation=True,
|
59 |
+
add_special_tokens=True
|
|
|
|
|
|
|
60 |
)
|
61 |
+
|
62 |
+
# Calculate minimum length to ensure we generate new tokens
|
63 |
+
min_length = inputs['input_ids'].shape[1] + 20
|
64 |
+
|
65 |
+
# Generate response
|
66 |
+
with torch.no_grad():
|
67 |
+
outputs = model.generate(
|
68 |
+
input_ids=inputs['input_ids'],
|
69 |
+
attention_mask=inputs['attention_mask'],
|
70 |
+
max_length=max(min_length, max_length), # Ensure max_length is greater than input length
|
71 |
+
min_length=min_length,
|
72 |
+
temperature=temperature,
|
73 |
+
top_p=top_p,
|
74 |
+
do_sample=True,
|
75 |
+
num_return_sequences=1,
|
76 |
+
pad_token_id=tokenizer.pad_token_id,
|
77 |
+
eos_token_id=tokenizer.eos_token_id,
|
78 |
+
use_cache=True
|
79 |
+
)
|
80 |
+
|
81 |
+
# Decode and return the response
|
82 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
83 |
+
return generated_text.split("Description: ")[-1].strip()
|
84 |
|
85 |
+
except Exception as e:
|
86 |
+
import traceback
|
87 |
+
return f"Error generating description: {str(e)}\n{traceback.format_exc()}"
|
88 |
|
89 |
def create_demo(model_id):
|
90 |
# Load model and tokenizer
|