Yjhhh commited on
Commit
f52a035
·
verified ·
1 Parent(s): 9a1cd7f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +85 -0
main.py CHANGED
@@ -8,13 +8,17 @@ from transformers import (
8
  AutoModelForCausalLM,
9
  TrainingArguments,
10
  Trainer,
 
 
11
  )
 
12
  from fastapi import FastAPI, HTTPException, Request
13
  from fastapi.responses import HTMLResponse
14
  import multiprocessing
15
  import uuid
16
  import torch
17
  from torch.utils.data import Dataset
 
18
 
19
  load_dotenv()
20
 
@@ -101,6 +105,10 @@ conversation_history = {}
101
  tokenizer_name = "unified_tokenizer"
102
  tokenizer = None
103
  unified_model = None
 
 
 
 
104
 
105
  @app.on_event("startup")
106
  async def startup_event():
@@ -318,8 +326,85 @@ def train_unified_model():
318
  trainer.train()
319
  unified_model.save_pretrained(model_path)
320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  if __name__ == "__main__":
322
  training_process = multiprocessing.Process(target=train_unified_model)
323
  training_process.start()
 
 
 
 
324
  import uvicorn
325
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
8
  AutoModelForCausalLM,
9
  TrainingArguments,
10
  Trainer,
11
+ AutoModelForTextToWaveform,
12
+ pipeline,
13
  )
14
+ from diffusers import FluxPipeline
15
  from fastapi import FastAPI, HTTPException, Request
16
  from fastapi.responses import HTMLResponse
17
  import multiprocessing
18
  import uuid
19
  import torch
20
  from torch.utils.data import Dataset
21
+ import numpy as np
22
 
23
  load_dotenv()
24
 
 
105
  tokenizer_name = "unified_tokenizer"
106
  tokenizer = None
107
  unified_model = None
108
+ musicgen_tokenizer = AutoTokenizer.from_pretrained("facebook/musicgen-small")
109
+ musicgen_model = AutoModelForTextToWaveform.from_pretrained("facebook/musicgen-small")
110
+ image_pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
111
+ image_pipeline.enable_model_cpu_offload()
112
 
113
  @app.on_event("startup")
114
  async def startup_event():
 
326
  trainer.train()
327
  unified_model.save_pretrained(model_path)
328
 
329
+ async def auto_learn():
330
+ global tokenizer, unified_model
331
+ redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD)
332
+ while True:
333
+ training_data = redis_client.lpop("training_queue")
334
+ if training_data:
335
+ item_data = json.loads(training_data)
336
+ tokenizer_data = item_data["tokenizers"]
337
+ tokenizer_name = list(tokenizer_data.keys())[0]
338
+ if redis_client.exists(f"tokenizer:{tokenizer_name}"):
339
+ tokenizer.add_tokens(list(tokenizer_data[tokenizer_name].keys()))
340
+ data = item_data["data"]
341
+ dataset = SyntheticDataset(tokenizer, data)
342
+
343
+ model_name = "unified_model"
344
+ model_path = f"models/{model_name}"
345
+
346
+ training_args = TrainingArguments(
347
+ output_dir="./results",
348
+ per_device_train_batch_size=8,
349
+ num_train_epochs=3,
350
+ )
351
+ trainer = Trainer(model=unified_model, args=training_args, train_dataset=dataset)
352
+ trainer.train()
353
+ unified_model.save_pretrained(model_path)
354
+
355
+ async def auto_learn_music():
356
+ global musicgen_tokenizer, musicgen_model
357
+ redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD)
358
+ while True:
359
+ music_training_data = redis_client.lpop("music_training_queue")
360
+ if music_training_data:
361
+ music_training_data = json.loads(music_training_data.decode("utf-8"))
362
+ inputs = musicgen_tokenizer(music_training_data, return_tensors="pt", padding=True)
363
+ musicgen_model.train()
364
+ optimizer = torch.optim.Adam(musicgen_model.parameters(), lr=5e-5)
365
+ loss_fn = torch.nn.CrossEntropyLoss()
366
+
367
+ for epoch in range(1):
368
+ outputs = musicgen_model(**inputs)
369
+ loss = loss_fn(outputs.logits, inputs['labels'])
370
+ optimizer.zero_grad()
371
+ loss.backward()
372
+ optimizer.step()
373
+
374
+ async def auto_learn_images():
375
+ global image_pipeline
376
+ redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD)
377
+ while True:
378
+ image_training_data = redis_client.lpop("image_training_queue")
379
+ if image_training_data:
380
+ image_training_data = json.loads(image_training_data.decode("utf-8"))
381
+ for image_prompt in image_training_data:
382
+ image = image_pipeline(
383
+ image_prompt,
384
+ guidance_scale=0.0,
385
+ num_inference_steps=4,
386
+ max_sequence_length=256,
387
+ generator=torch.Generator("cpu").manual_seed(0)
388
+ ).images[0]
389
+ image_tensor = torch.tensor(np.array(image)).unsqueeze(0)
390
+ image_pipeline.model.train()
391
+ optimizer = torch.optim.Adam(image_pipeline.model.parameters(), lr=1e-5)
392
+ loss_fn = torch.nn.MSELoss()
393
+ target_tensor = torch.zeros_like(image_tensor)
394
+ for epoch in range(1):
395
+ outputs = image_pipeline.model(image_tensor)
396
+ loss = loss_fn(outputs, target_tensor)
397
+ optimizer.zero_grad()
398
+ loss.backward()
399
+ optimizer.step()
400
+
401
+
402
  if __name__ == "__main__":
403
  training_process = multiprocessing.Process(target=train_unified_model)
404
  training_process.start()
405
+ music_training_process = multiprocessing.Process(target=auto_learn_music)
406
+ music_training_process.start()
407
+ image_training_process = multiprocessing.Process(target=auto_learn_images)
408
+ image_training_process.start()
409
  import uvicorn
410
  uvicorn.run(app, host="0.0.0.0", port=7860)