NourFakih commited on
Commit
6bc3970
·
verified ·
1 Parent(s): 11167d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -38,12 +38,24 @@ model.config.eos_token_id = tokenizer.eos_token_id
38
  model.config.decoder_start_token_id = tokenizer.bos_token_id
39
  model.config.pad_token_id = tokenizer.pad_token_id
40
 
 
 
 
 
 
 
 
 
 
 
41
  def generate_caption(image):
42
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
43
  output_ids = model.generate(pixel_values)
44
  caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
45
  return caption
46
 
 
47
  def get_synonyms(word):
48
  synonyms = set()
49
  for syn in wordnet.synsets(word):
@@ -51,14 +63,6 @@ def get_synonyms(word):
51
  synonyms.add(lemma.name())
52
  return synonyms
53
 
54
- def preprocess_query(query):
55
- doc = nlp(query)
56
- tokens = set()
57
- for token in doc:
58
- tokens.add(token.text)
59
- tokens.add(token.lemma_)
60
- tokens.update(get_synonyms(token.text))
61
- return tokens
62
 
63
  def search_captions(query, captions):
64
  query_tokens = preprocess_query(query)
 
38
  model.config.decoder_start_token_id = tokenizer.bos_token_id
39
  model.config.pad_token_id = tokenizer.pad_token_id
40
 
41
+ def preprocess_query(query):
42
+ doc = nlp(query)
43
+ tokens = set()
44
+ for token in doc:
45
+ tokens.add(token.text)
46
+ tokens.add(token.lemma_)
47
+ tokens.update(get_synonyms(token.text))
48
+ st.write(f"Query tokens: {tokens}") # Debugging line
49
+ return tokens
50
+
51
  def generate_caption(image):
52
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
53
  output_ids = model.generate(pixel_values)
54
  caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
55
+ st.write(f"Generated caption: {caption}") # Debugging line
56
  return caption
57
 
58
+
59
  def get_synonyms(word):
60
  synonyms = set()
61
  for syn in wordnet.synsets(word):
 
63
  synonyms.add(lemma.name())
64
  return synonyms
65
 
 
 
 
 
 
 
 
 
66
 
67
  def search_captions(query, captions):
68
  query_tokens = preprocess_query(query)