File size: 3,693 Bytes
4ffdbdc
 
 
770775f
6957169
 
 
 
 
 
 
e1361b1
6957169
 
 
 
 
ab82892
 
 
 
e1361b1
 
 
6957169
e1361b1
2eb2d02
6957169
 
 
 
 
 
 
 
2acb8d8
15b745f
6957169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2053e3b
6957169
 
2acb8d8
 
e1361b1
2acb8d8
e1361b1
2acb8d8
6957169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1361b1
6957169
 
 
 
 
 
 
 
 
 
 
 
 
 
770775f
6957169
 
0776050
 
6957169
3ce2b2b
6957169
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# A100 Zero GPU
import spaces

import time
import torch
import gradio as gr
from config import *
from PIL import Image
from utils.utils import *
from threading import Thread
import torch.nn.functional as F
from accelerate import Accelerator
from meteor.load_mmamba import load_mmamba
from meteor.load_meteor import load_meteor
from transformers import TextIteratorStreamer
from torchvision.transforms.functional import pil_to_tensor

# flash attention
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

# accel
accel = Accelerator()

# loading meteor model
mmamba = load_mmamba('BK-Lee/Meteor-Mamba')
meteor, tok_meteor = load_meteor('BK-Lee/Meteor-MLM', bits=4)

# freeze model
freeze_model(mmamba)
freeze_model(meteor)

# previous length
previous_length = 0

def threading_function(inputs, image_token_number, streamer, device):

    # Meteor Mamba
    mmamba_inputs = mmamba.eval_process(inputs=inputs, tokenizer=tok_meteor, device=device, img_token_number=image_token_number)
    if 'image' in mmamba_inputs.keys():
        clip_features = meteor.clip_features(mmamba_inputs['image'])
        mmamba_inputs.update({"image_features": clip_features})
    mmamba_outputs = mmamba(**mmamba_inputs)
    
    # Meteor
    meteor_inputs = meteor.eval_process(inputs=inputs, data='demo', tokenizer=tok_meteor, device=device, img_token_number=image_token_number)
    if 'image' in mmamba_inputs.keys():
        meteor_inputs.update({"image_features": clip_features})
    meteor_inputs.update({"tor_features": mmamba_outputs.tor_features})

    generation_kwargs = meteor_inputs
    generation_kwargs.update({'streamer': streamer})
    generation_kwargs.update({'do_sample': True})
    generation_kwargs.update({'max_new_tokens': 128})
    generation_kwargs.update({'top_p': 0.95})
    generation_kwargs.update({'temperature': 0.9})
    generation_kwargs.update({'use_cache': True})
    return meteor.generate(**generation_kwargs)

@spaces.GPU
def bot_streaming(message, history):

    # param
    for param in mmamba.parameters():
        param.data = param.to(accel.device)
    for param in meteor.parameters():
        param.data = param.to(accel.device)

    # prompt type -> input prompt
    image_token_number = int((490/14)**2)
    if len(message['files']) != 0:
        # Image Load
        image = F.interpolate(pil_to_tensor(Image.open(message['files'][0]).convert("RGB")).unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
        inputs = [{'image': image, 'question': message['text']}]
    else:
        inputs = [{'question': message['text']}]

    # [4] Meteor Generation
    with torch.inference_mode():
        # kwargs
        streamer = TextIteratorStreamer(tok_meteor, skip_special_tokens=True)

        # Threading generation
        thread = Thread(target=threading_function, kwargs=dict(inputs=inputs, image_token_number=image_token_number, streamer=streamer, device=accel.device))
        thread.start()

        # generated text
        generated_text = ""
        for new_text in streamer:
            generated_text += new_text
        generated_text

    # Text decoding
    response = generated_text.split('assistant\n')[-1].split('[U')[0].strip()

    buffer = ""
    for character in response:
        buffer += character
        time.sleep(0.02)
        yield buffer

demo = gr.ChatInterface(fn=bot_streaming, title="☄️ Meteor", 
                        description="Meteor is efficient 7B size Large Language and Vision Model built on the help of traversal of rationale",
                        stop_btn="Stop Generation", multimodal=True)
demo.launch()