Update main.py
Browse files
main.py
CHANGED
@@ -8,13 +8,17 @@ from transformers import (
|
|
8 |
AutoModelForCausalLM,
|
9 |
TrainingArguments,
|
10 |
Trainer,
|
|
|
|
|
11 |
)
|
|
|
12 |
from fastapi import FastAPI, HTTPException, Request
|
13 |
from fastapi.responses import HTMLResponse
|
14 |
import multiprocessing
|
15 |
import uuid
|
16 |
import torch
|
17 |
from torch.utils.data import Dataset
|
|
|
18 |
|
19 |
load_dotenv()
|
20 |
|
@@ -101,6 +105,10 @@ conversation_history = {}
|
|
101 |
tokenizer_name = "unified_tokenizer"
|
102 |
tokenizer = None
|
103 |
unified_model = None
|
|
|
|
|
|
|
|
|
104 |
|
105 |
@app.on_event("startup")
|
106 |
async def startup_event():
|
@@ -318,8 +326,85 @@ def train_unified_model():
|
|
318 |
trainer.train()
|
319 |
unified_model.save_pretrained(model_path)
|
320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
if __name__ == "__main__":
|
322 |
training_process = multiprocessing.Process(target=train_unified_model)
|
323 |
training_process.start()
|
|
|
|
|
|
|
|
|
324 |
import uvicorn
|
325 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
8 |
AutoModelForCausalLM,
|
9 |
TrainingArguments,
|
10 |
Trainer,
|
11 |
+
AutoModelForTextToWaveform,
|
12 |
+
pipeline,
|
13 |
)
|
14 |
+
from diffusers import FluxPipeline
|
15 |
from fastapi import FastAPI, HTTPException, Request
|
16 |
from fastapi.responses import HTMLResponse
|
17 |
import multiprocessing
|
18 |
import uuid
|
19 |
import torch
|
20 |
from torch.utils.data import Dataset
|
21 |
+
import numpy as np
|
22 |
|
23 |
load_dotenv()
|
24 |
|
|
|
105 |
tokenizer_name = "unified_tokenizer"
|
106 |
tokenizer = None
|
107 |
unified_model = None
|
108 |
+
musicgen_tokenizer = AutoTokenizer.from_pretrained("facebook/musicgen-small")
|
109 |
+
musicgen_model = AutoModelForTextToWaveform.from_pretrained("facebook/musicgen-small")
|
110 |
+
image_pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
111 |
+
image_pipeline.enable_model_cpu_offload()
|
112 |
|
113 |
@app.on_event("startup")
|
114 |
async def startup_event():
|
|
|
326 |
trainer.train()
|
327 |
unified_model.save_pretrained(model_path)
|
328 |
|
329 |
+
async def auto_learn():
|
330 |
+
global tokenizer, unified_model
|
331 |
+
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD)
|
332 |
+
while True:
|
333 |
+
training_data = redis_client.lpop("training_queue")
|
334 |
+
if training_data:
|
335 |
+
item_data = json.loads(training_data)
|
336 |
+
tokenizer_data = item_data["tokenizers"]
|
337 |
+
tokenizer_name = list(tokenizer_data.keys())[0]
|
338 |
+
if redis_client.exists(f"tokenizer:{tokenizer_name}"):
|
339 |
+
tokenizer.add_tokens(list(tokenizer_data[tokenizer_name].keys()))
|
340 |
+
data = item_data["data"]
|
341 |
+
dataset = SyntheticDataset(tokenizer, data)
|
342 |
+
|
343 |
+
model_name = "unified_model"
|
344 |
+
model_path = f"models/{model_name}"
|
345 |
+
|
346 |
+
training_args = TrainingArguments(
|
347 |
+
output_dir="./results",
|
348 |
+
per_device_train_batch_size=8,
|
349 |
+
num_train_epochs=3,
|
350 |
+
)
|
351 |
+
trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset)
|
352 |
+
trainer.train()
|
353 |
+
unified_model.save_pretrained(model_path)
|
354 |
+
|
355 |
+
async def auto_learn_music():
|
356 |
+
global musicgen_tokenizer, musicgen_model
|
357 |
+
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD)
|
358 |
+
while True:
|
359 |
+
music_training_data = redis_client.lpop("music_training_queue")
|
360 |
+
if music_training_data:
|
361 |
+
music_training_data = json.loads(music_training_data.decode("utf-8"))
|
362 |
+
inputs = musicgen_tokenizer(music_training_data, return_tensors="pt", padding=True)
|
363 |
+
musicgen_model.train()
|
364 |
+
optimizer = torch.optim.Adam(musicgen_model.parameters(), lr=5e-5)
|
365 |
+
loss_fn = torch.nn.CrossEntropyLoss()
|
366 |
+
|
367 |
+
for epoch in range(1):
|
368 |
+
outputs = musicgen_model(**inputs)
|
369 |
+
loss = loss_fn(outputs.logits, inputs['labels'])
|
370 |
+
optimizer.zero_grad()
|
371 |
+
loss.backward()
|
372 |
+
optimizer.step()
|
373 |
+
|
374 |
+
async def auto_learn_images():
|
375 |
+
global image_pipeline
|
376 |
+
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD)
|
377 |
+
while True:
|
378 |
+
image_training_data = redis_client.lpop("image_training_queue")
|
379 |
+
if image_training_data:
|
380 |
+
image_training_data = json.loads(image_training_data.decode("utf-8"))
|
381 |
+
for image_prompt in image_training_data:
|
382 |
+
image = image_pipeline(
|
383 |
+
image_prompt,
|
384 |
+
guidance_scale=0.0,
|
385 |
+
num_inference_steps=4,
|
386 |
+
max_sequence_length=256,
|
387 |
+
generator=torch.Generator("cpu").manual_seed(0)
|
388 |
+
).images[0]
|
389 |
+
image_tensor = torch.tensor(np.array(image)).unsqueeze(0)
|
390 |
+
image_pipeline.model.train()
|
391 |
+
optimizer = torch.optim.Adam(image_pipeline.model.parameters(), lr=1e-5)
|
392 |
+
loss_fn = torch.nn.MSELoss()
|
393 |
+
target_tensor = torch.zeros_like(image_tensor)
|
394 |
+
for epoch in range(1):
|
395 |
+
outputs = image_pipeline.model(image_tensor)
|
396 |
+
loss = loss_fn(outputs, target_tensor)
|
397 |
+
optimizer.zero_grad()
|
398 |
+
loss.backward()
|
399 |
+
optimizer.step()
|
400 |
+
|
401 |
+
|
402 |
if __name__ == "__main__":
|
403 |
training_process = multiprocessing.Process(target=train_unified_model)
|
404 |
training_process.start()
|
405 |
+
music_training_process = multiprocessing.Process(target=auto_learn_music)
|
406 |
+
music_training_process.start()
|
407 |
+
image_training_process = multiprocessing.Process(target=auto_learn_images)
|
408 |
+
image_training_process.start()
|
409 |
import uvicorn
|
410 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|