English
naveensp commited on
Commit
5076284
·
verified ·
1 Parent(s): 0cc08b4

Upload llava_olmo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. llava_olmo.py +98 -0
llava_olmo.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from transformers import AutoTokenizer
3
+ import torch
4
+
5
+ import llava.model.language_model.llava_olmo1p58b as llava_olmo ##
6
+ import llava.model.language_model.llava_llama as llava_llama
7
+
8
+ from OLMo_Bitnet_1B.modeling_olmo import OLMoForCausalLM
9
+ from PIL import Image
10
+ import requests
11
+ from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
12
+ from llava.conversation import conv_templates
13
+
14
+
15
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
16
+ DEFAULT_IMAGE_TOKEN = "<image>"
17
+ IMAGE_TOKEN_INDEX = -200
18
+
19
+ # Define Image and Text inputs..
20
+ text = "What are the four major tournaments of the sport shown in the image?"
21
+ url = "https://farm3.staticflickr.com/2157/2439959136_d932f4e816_z.jpg"
22
+ image = Image.open(requests.get(url, stream=True).raw)
23
+
24
+
25
+ # LOAD MODEL FROM CHECKPOINT
26
+ with open('./checkpoints/llava-LlavaOLMoBitnet1B-Run3-finetune/config.json') as json_file:
27
+ data = json.load(json_file)
28
+
29
+ config_class = llava_olmo.LlavaOLMoBitnet1BConfig(**data)
30
+ model = llava_olmo.LlavaOLMoBitnet1BForCausalLM(config_class).to(device)
31
+ weight_checkpoint = torch.load('./checkpoints/llava-LlavaOLMoBitnet1B-Run3-finetune/pytorch_model.bin')
32
+ model.load_state_dict(weight_checkpoint)
33
+
34
+ # pre-process image; Apply chat template and tokenize text
35
+ image_processor = model.model.vision_tower.image_processor
36
+ tokenizer = AutoTokenizer.from_pretrained(
37
+ "NousResearch/OLMo-Bitnet-1B",
38
+ model_max_length=2048,
39
+ padding_side="right",
40
+ pad_token_id=1,
41
+ use_fast=True,
42
+ legacy=False,
43
+ unk_token='<|padding|>',
44
+ )
45
+
46
+
47
+ image_tensor = process_images([image], image_processor, model.config)[0]
48
+
49
+ text = DEFAULT_IMAGE_TOKEN + '\n' + text
50
+ conv = conv_templates['llava_v1'].copy()
51
+ conv.append_message(conv.roles[0], text)
52
+ conv.append_message(conv.roles[1], None)
53
+ prompt = conv.get_prompt()
54
+
55
+ text_tokens = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
56
+
57
+ # Generate response from the model
58
+ response = model.generate(images=image_tensor.unsqueeze(0).to(device), inputs=text_tokens, max_new_tokens=400)
59
+ decoded_text = tokenizer.batch_decode(response, skip_special_tokens=True)[0]
60
+ print("\n\n", "-"*100)
61
+ print(decoded_text[:decoded_text.find('</s>')].replace('|||IP_ADDRESS|||', '')) # The replace part is due to unwanted token introduction at start
62
+ print("-"*100)
63
+
64
+
65
+ #
66
+ ##
67
+ #
68
+ #
69
+ #
70
+ '''
71
+ # ORIGINAL CODE WITH ONLY OLMO:
72
+ with open('llava/config.json') as json_file:
73
+ data = json.load(json_file)
74
+
75
+ text = "Paris is a historic city with architectural marvels. It is also "
76
+ # text = ["Language modeling is "]
77
+
78
+ config_class = llava_olmo.LlavaOLMoBitnet1BConfig(**data)
79
+ lolmo = llava_olmo.LlavaOLMoBitnet1BForCausalLM(config_class).to(device)
80
+ lolmo.load_state_dict(torch.load('OLMo_Bitnet_1B/pytorch_model.bin'), strict=False)
81
+
82
+ olmo = OLMoForCausalLM(config_class).to(device)
83
+ olmo.load_state_dict(torch.load('OLMo_Bitnet_1B/pytorch_model.bin'))
84
+ actual_olmo = OLMoForCausalLM.from_pretrained("allenai/OLMo-1B").to(device)
85
+
86
+ actual_olmo_tokenizer = OLMoTokenizerFast.from_pretrained("allenai/OLMo-1B")
87
+ olmo_tokenizer = AutoTokenizer.from_pretrained("NousResearch/OLMo-Bitnet-1B")
88
+
89
+ olmo_tokens = olmo_tokenizer(text, return_tensors='pt', return_token_type_ids=False).to(device)
90
+ # olmo_tokens = actual_olmo_tokenizer(text, return_tensors='pt', return_token_type_ids=False).to(device)
91
+
92
+
93
+ response = lolmo.generate(inputs=olmo_tokens['input_ids'], attention_mask=olmo_tokens['attention_mask'], max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95)
94
+ # response = olmo.generate(inputs=olmo_tokens['input_ids'], attention_mask=olmo_tokens['attention_mask'], max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95)
95
+
96
+
97
+ print(olmo_tokenizer.batch_decode(response, skip_special_tokens=True)[0])
98
+ '''