Spaces:
Runtime error
Runtime error
# Copyright 2022 Tristan Behrens. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Lint as: python3 | |
from flask import Flask, render_template, request, send_file, jsonify, redirect, url_for | |
from PIL import Image | |
import os | |
import io | |
import random | |
import base64 | |
import torch | |
import wave | |
from source.logging import create_logger | |
from source.tokensequence import token_sequence_to_audio, token_sequence_to_image | |
from source import constants | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
logger = create_logger(__name__) | |
# Load the auth-token from authtoken.txt. | |
auth_token = os.getenv("authtoken") | |
# Loading the model and its tokenizer. | |
logger.info("Loading tokenizer and model...") | |
tokenizer = AutoTokenizer.from_pretrained("ai-guru/lakhclean_mmmtrack_4bars_d-2048", use_auth_token=auth_token) | |
model = AutoModelForCausalLM.from_pretrained("ai-guru/lakhclean_mmmtrack_4bars_d-2048", use_auth_token=auth_token) | |
logger.info("Done.") | |
# Create the app. | |
logger.info("Creating app...") | |
app = Flask(__name__) | |
logger.info("Done.") | |
# Route for the loading page. | |
def index(): | |
return render_template( | |
"composer.html", | |
compose_styles = constants.get_compose_styles_for_ui(), | |
densities=constants.get_densities_for_ui(), | |
temperatures=constants.get_temperatures_for_ui(), | |
) | |
def compose(): | |
# Get the parameters as JSON. | |
params = request.get_json() | |
music_style = params["music_style"] | |
density = params["density"] | |
temperature = params["temperature"] | |
instruments = constants.get_instruments(music_style) | |
density = constants.get_density(density) | |
temperature = constants.get_temperature(temperature) | |
print(f"instruments: {instruments} density: {density} temperature: {temperature}") | |
# Generate with the given parameters. | |
logger.info(f"Generating token sequence...") | |
generated_sequence = generate_sequence(instruments, density, temperature) | |
logger.info(f"Generated token sequence: {generated_sequence}") | |
# Get the audio data as a array of int16. | |
logger.info("Generating audio...") | |
sample_rate, audio_data = token_sequence_to_audio(generated_sequence) | |
logger.info(f"Done. Audio data: {len(audio_data)}") | |
# Encode the audio-data as wave file in memory. Use the wave module. | |
audio_data_bytes = io.BytesIO() | |
wave_file = wave.open(audio_data_bytes, "wb") | |
wave_file.setframerate(sample_rate) | |
wave_file.setnchannels(1) | |
wave_file.setsampwidth(2) | |
wave_file.writeframes(audio_data) | |
wave_file.close() | |
# Return the audio-data as a base64-encoded string. | |
audio_data_bytes.seek(0) | |
audio_data_base64 = base64.b64encode(audio_data_bytes.read()).decode("utf-8") | |
audio_data_bytes.close() | |
# Convert the audio data to an PIL image. | |
image = token_sequence_to_image(generated_sequence) | |
# Save PIL image to harddrive as PNG. | |
logger.debug(f"Saving image to harddrive... {type(image)}") | |
image_file_name = "compose.png" | |
image.save(image_file_name, "PNG") | |
# Save image to virtual file. | |
img_io = io.BytesIO() | |
image.save(img_io, 'PNG', quality=70) | |
img_io.seek(0) | |
# Return the image as a base64-encoded string. | |
image_data_base64 = base64.b64encode(img_io.read()).decode("utf-8") | |
img_io.close() | |
# Return. | |
return jsonify({ | |
"tokens": generated_sequence, | |
"audio": "data:audio/wav;base64," + audio_data_base64, | |
"image": "data:image/png;base64," + image_data_base64, | |
"status": "OK" | |
}) | |
def generate_sequence(instruments, density, temperature): | |
instruments = instruments[::] | |
random.shuffle(instruments) | |
generated_ids = tokenizer.encode("PIECE_START", return_tensors="pt")[0] | |
for instrument in instruments: | |
more_ids = tokenizer.encode(f"TRACK_START INST={instrument} DENSITY={density}", return_tensors="pt")[0] | |
generated_ids = torch.cat((generated_ids, more_ids)) | |
generated_ids = generated_ids.unsqueeze(0) | |
generated_ids = model.generate( | |
generated_ids, | |
max_length=2048, | |
do_sample=True, | |
temperature=temperature, | |
eos_token_id=tokenizer.encode("TRACK_END")[0] | |
)[0] | |
generated_sequence = tokenizer.decode(generated_ids) | |
return generated_sequence | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) |