Spaces:
Sleeping
Sleeping
File size: 3,693 Bytes
4ffdbdc 770775f 6957169 e1361b1 6957169 ab82892 e1361b1 6957169 e1361b1 2eb2d02 6957169 2acb8d8 15b745f 6957169 2053e3b 6957169 2acb8d8 e1361b1 2acb8d8 e1361b1 2acb8d8 6957169 e1361b1 6957169 770775f 6957169 0776050 6957169 3ce2b2b 6957169 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
# A100 Zero GPU
import spaces
import time
import torch
import gradio as gr
from config import *
from PIL import Image
from utils.utils import *
from threading import Thread
import torch.nn.functional as F
from accelerate import Accelerator
from meteor.load_mmamba import load_mmamba
from meteor.load_meteor import load_meteor
from transformers import TextIteratorStreamer
from torchvision.transforms.functional import pil_to_tensor
# flash attention
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# accel
accel = Accelerator()
# loading meteor model
mmamba = load_mmamba('BK-Lee/Meteor-Mamba')
meteor, tok_meteor = load_meteor('BK-Lee/Meteor-MLM', bits=4)
# freeze model
freeze_model(mmamba)
freeze_model(meteor)
# previous length
previous_length = 0
def threading_function(inputs, image_token_number, streamer, device):
# Meteor Mamba
mmamba_inputs = mmamba.eval_process(inputs=inputs, tokenizer=tok_meteor, device=device, img_token_number=image_token_number)
if 'image' in mmamba_inputs.keys():
clip_features = meteor.clip_features(mmamba_inputs['image'])
mmamba_inputs.update({"image_features": clip_features})
mmamba_outputs = mmamba(**mmamba_inputs)
# Meteor
meteor_inputs = meteor.eval_process(inputs=inputs, data='demo', tokenizer=tok_meteor, device=device, img_token_number=image_token_number)
if 'image' in mmamba_inputs.keys():
meteor_inputs.update({"image_features": clip_features})
meteor_inputs.update({"tor_features": mmamba_outputs.tor_features})
generation_kwargs = meteor_inputs
generation_kwargs.update({'streamer': streamer})
generation_kwargs.update({'do_sample': True})
generation_kwargs.update({'max_new_tokens': 128})
generation_kwargs.update({'top_p': 0.95})
generation_kwargs.update({'temperature': 0.9})
generation_kwargs.update({'use_cache': True})
return meteor.generate(**generation_kwargs)
@spaces.GPU
def bot_streaming(message, history):
# param
for param in mmamba.parameters():
param.data = param.to(accel.device)
for param in meteor.parameters():
param.data = param.to(accel.device)
# prompt type -> input prompt
image_token_number = int((490/14)**2)
if len(message['files']) != 0:
# Image Load
image = F.interpolate(pil_to_tensor(Image.open(message['files'][0]).convert("RGB")).unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
inputs = [{'image': image, 'question': message['text']}]
else:
inputs = [{'question': message['text']}]
# [4] Meteor Generation
with torch.inference_mode():
# kwargs
streamer = TextIteratorStreamer(tok_meteor, skip_special_tokens=True)
# Threading generation
thread = Thread(target=threading_function, kwargs=dict(inputs=inputs, image_token_number=image_token_number, streamer=streamer, device=accel.device))
thread.start()
# generated text
generated_text = ""
for new_text in streamer:
generated_text += new_text
generated_text
# Text decoding
response = generated_text.split('assistant\n')[-1].split('[U')[0].strip()
buffer = ""
for character in response:
buffer += character
time.sleep(0.02)
yield buffer
demo = gr.ChatInterface(fn=bot_streaming, title="☄️ Meteor",
description="Meteor is efficient 7B size Large Language and Vision Model built on the help of traversal of rationale",
stop_btn="Stop Generation", multimodal=True)
demo.launch()
|