annabeth97c commited on
Commit
11872ca
·
verified ·
1 Parent(s): ee0998f

Initial commit app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ import logging
3
+
4
+ import sys
5
+ sys.path.append("/home/user/app/src/sonicverse")
6
+
7
+ from huggingface_hub import login
8
+ import os
9
+
10
+ hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
11
+ if not hf_token:
12
+ raise ValueError("Missing HUGGINGFACE_HUB_TOKEN. Set it as a secret in your Space.")
13
+
14
+ login(token=hf_token)
15
+
16
+ import gradio as gr
17
+ import torch
18
+ import transformers
19
+ import torchaudio
20
+
21
+ from multi_token.model_utils import MultiTaskType
22
+ from multi_token.training import ModelArguments
23
+ from multi_token.inference import load_trained_lora_model
24
+ from multi_token.data_tools import encode_chat
25
+
26
+
27
+ @dataclass
28
+ class ServeArguments(ModelArguments):
29
+ load_bits: int = field(default=16)
30
+ max_new_tokens: int = field(default=128)
31
+ temperature: float = field(default=0.01)
32
+
33
+
34
+ # Load arguments and model
35
+ logging.getLogger().setLevel(logging.INFO)
36
+
37
+ parser = transformers.HfArgumentParser((ServeArguments,))
38
+ serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
39
+
40
+ model, tokenizer = load_trained_lora_model(
41
+ model_name_or_path=serve_args.model_name_or_path,
42
+ model_lora_path=serve_args.model_lora_path,
43
+ load_bits=serve_args.load_bits,
44
+ use_multi_task=MultiTaskType(serve_args.use_multi_task),
45
+ tasks_config=serve_args.tasks_config
46
+ )
47
+
48
+
49
+ def generate_caption(audio_file):
50
+ # waveform, sample_rate = torchaudio.load(audio_file)
51
+
52
+ req_json = {
53
+ "messages": [
54
+ {"role": "user", "content": "Describe the music. <sound>"}
55
+ ],
56
+ "sounds": [audio_file]
57
+ }
58
+
59
+ encoded_dict = encode_chat(req_json, tokenizer, model.modalities)
60
+
61
+ with torch.inference_mode():
62
+ output_ids = model.generate(
63
+ input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
64
+ max_new_tokens=serve_args.max_new_tokens,
65
+ use_cache=True,
66
+ do_sample=True,
67
+ temperature=serve_args.temperature,
68
+ modality_inputs={
69
+ m.name: [encoded_dict[m.name]] for m in model.modalities
70
+ },
71
+ )
72
+
73
+ outputs = tokenizer.decode(
74
+ output_ids[0, encoded_dict["input_ids"].shape[0]:],
75
+ skip_special_tokens=True
76
+ ).strip()
77
+
78
+ return outputs
79
+
80
+
81
+ demo = gr.Interface(
82
+ fn=generate_caption,
83
+ inputs=gr.Audio(type="filepath", label="Upload an audio file"),
84
+ outputs=gr.Textbox(label="Generated Caption"),
85
+ title="SonicVerse",
86
+ description="Upload an audio file to generate a caption using SonicVerse"
87
+ )
88
+
89
+ if __name__ == "__main__":
90
+ demo.launch()