import torch.cuda
import gradio as gr
import mdtex2html
import tempfile
from PIL import Image
import scipy
from llama.m2ugen import M2UGen
import llama
import numpy as np
import os
import torch
import torchaudio
import torchvision.transforms as transforms
import av
import subprocess
import librosa
args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
"mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
"music_decoder": "musicgen", "music_decoder_path": "facebook/musicgen-medium"}
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
args = dotdict(args)
generated_audio_files = []
llama_type = args.llama_type
llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
llama_tokenzier_path = args.llama_dir
model = M2UGen(llama_ckpt_dir, llama_tokenzier_path, args, knn=False, stage=None, load_llama=False)
print("Loading Model Checkpoint")
checkpoint = torch.load(args.model, map_location='cpu')
new_ckpt = {}
for key, value in checkpoint['model'].items():
if "generation_model" in key:
continue
key = key.replace("module.", "")
new_ckpt[key] = value
load_result = model.load_state_dict(new_ckpt, strict=False)
assert len(load_result.unexpected_keys) == 0, f"Unexpected keys: {load_result.unexpected_keys}"
model.eval()
model.to("cuda")
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x)])
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert((message)),
None if response is None else mdtex2html.convert(response),
)
return y
gr.Chatbot.postprocess = postprocess
def parse_text(text, image_path, video_path, audio_path):
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
outputs = text
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'
'
else:
lines[i] = f'
'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
lines[i] = " " + line
text = "".join(lines) + " "
if image_path is not None:
text += f' '
outputs = f'{image_path} ' + outputs
if video_path is not None:
text += f'