jatingocodeo commited on
Commit
9ebfecb
·
verified ·
1 Parent(s): c5150ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -25
app.py CHANGED
@@ -34,36 +34,57 @@ def load_model(model_id):
34
  return model, tokenizer
35
 
36
  def generate_description(image, model, tokenizer, max_length=100, temperature=0.7, top_p=0.9):
37
- # Convert and resize image
38
- if image.mode != "RGB":
39
- image = image.convert("RGB")
40
- image = image.resize((32, 32))
41
-
42
- # Format the input text
43
- input_text = """Below is an image. Please describe it in detail.
 
44
 
45
  Image: [IMAGE]
46
  Description: """
47
-
48
- # Tokenize input
49
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
50
-
51
- # Generate response
52
- with torch.no_grad():
53
- outputs = model.generate(
54
- **inputs,
55
- max_length=max_length,
56
- temperature=temperature,
57
- top_p=top_p,
58
- do_sample=True,
59
- num_return_sequences=1,
60
- pad_token_id=tokenizer.pad_token_id,
61
- eos_token_id=tokenizer.eos_token_id
62
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- # Decode and return the response
65
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
- return generated_text.split("Description: ")[-1].strip()
67
 
68
  def create_demo(model_id):
69
  # Load model and tokenizer
 
34
  return model, tokenizer
35
 
36
  def generate_description(image, model, tokenizer, max_length=100, temperature=0.7, top_p=0.9):
37
+ try:
38
+ # Convert and resize image
39
+ if image.mode != "RGB":
40
+ image = image.convert("RGB")
41
+ image = image.resize((32, 32))
42
+
43
+ # Format the input text
44
+ input_text = """Below is an image. Please describe it in detail.
45
 
46
  Image: [IMAGE]
47
  Description: """
48
+
49
+ # Ensure we have valid token IDs
50
+ if tokenizer.pad_token_id is None:
51
+ tokenizer.pad_token_id = tokenizer.eos_token_id
52
+
53
+ # Tokenize input with explicit token IDs
54
+ inputs = tokenizer(
55
+ input_text,
56
+ return_tensors="pt",
57
+ padding=True,
58
+ truncation=True,
59
+ add_special_tokens=True
 
 
 
60
  )
61
+
62
+ # Calculate minimum length to ensure we generate new tokens
63
+ min_length = inputs['input_ids'].shape[1] + 20
64
+
65
+ # Generate response
66
+ with torch.no_grad():
67
+ outputs = model.generate(
68
+ input_ids=inputs['input_ids'],
69
+ attention_mask=inputs['attention_mask'],
70
+ max_length=max(min_length, max_length), # Ensure max_length is greater than input length
71
+ min_length=min_length,
72
+ temperature=temperature,
73
+ top_p=top_p,
74
+ do_sample=True,
75
+ num_return_sequences=1,
76
+ pad_token_id=tokenizer.pad_token_id,
77
+ eos_token_id=tokenizer.eos_token_id,
78
+ use_cache=True
79
+ )
80
+
81
+ # Decode and return the response
82
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
83
+ return generated_text.split("Description: ")[-1].strip()
84
 
85
+ except Exception as e:
86
+ import traceback
87
+ return f"Error generating description: {str(e)}\n{traceback.format_exc()}"
88
 
89
  def create_demo(model_id):
90
  # Load model and tokenizer