howard-hou commited on
Commit
6b7e620
·
1 Parent(s): 41c8853

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -6
app.py CHANGED
@@ -3,8 +3,6 @@ import os, gc
3
  from datetime import datetime
4
  from transformers import CLIPImageProcessor
5
  from huggingface_hub import hf_hub_download
6
- from typing import List, Dict
7
- from dataclasses import dataclass
8
  DEFAULT_IMAGE_TOKEN = "<image>"
9
 
10
 
@@ -56,7 +54,6 @@ def generate(
56
  out_str = ''
57
  occurrence = {}
58
  state = None
59
- print("in shape", model.w["emb.weight"].shape)
60
  for i in range(int(token_count)):
61
  if i == 0:
62
  input_ids = (image_ids + pipeline.encode(ctx))[-ctx_limit:]
@@ -105,10 +102,8 @@ def chatbot(image, question):
105
  image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
106
  image_features = visual_encoder.encode_images(image.unsqueeze(0))
107
  emb_mixer.set_image_embeddings(image_features.squeeze(0))
108
- print(emb_mixer.embedding.shape)
109
  model.w["emb.weight"] = emb_mixer.get_input_embeddings()
110
- print(emb_mixer.get_input_embeddings().shape)
111
- print("out shape", model.w["emb.weight"].shape)
112
  image_ids = [i for i in range(emb_mixer.image_start_index, emb_mixer.image_start_index + len(image_features))]
113
  input_text = generate_prompt(question)
114
  for output in generate(input_text, image_ids):
 
3
  from datetime import datetime
4
  from transformers import CLIPImageProcessor
5
  from huggingface_hub import hf_hub_download
 
 
6
  DEFAULT_IMAGE_TOKEN = "<image>"
7
 
8
 
 
54
  out_str = ''
55
  occurrence = {}
56
  state = None
 
57
  for i in range(int(token_count)):
58
  if i == 0:
59
  input_ids = (image_ids + pipeline.encode(ctx))[-ctx_limit:]
 
102
  image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
103
  image_features = visual_encoder.encode_images(image.unsqueeze(0))
104
  emb_mixer.set_image_embeddings(image_features.squeeze(0))
105
+ global model
106
  model.w["emb.weight"] = emb_mixer.get_input_embeddings()
 
 
107
  image_ids = [i for i in range(emb_mixer.image_start_index, emb_mixer.image_start_index + len(image_features))]
108
  input_text = generate_prompt(question)
109
  for output in generate(input_text, image_ids):