File size: 3,121 Bytes
73d70b7
78ca0a1
83afe32
 
78ca0a1
 
 
 
3b0149d
 
 
83afe32
 
 
 
 
78ca0a1
 
 
 
 
83afe32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78ca0a1
 
 
 
 
95c5b4a
 
78ca0a1
 
 
 
 
95c5b4a
 
78ca0a1
 
d8ba489
73d70b7
d8ba489
 
 
 
95c5b4a
 
 
 
 
 
 
 
 
78ca0a1
95c5b4a
 
78ca0a1
 
 
 
 
 
 
95c5b4a
 
78ca0a1
 
95c5b4a
 
 
 
d8ba489
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from fastapi import FastAPI, Request, Form , Body
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

app = FastAPI()

# Configura las plantillas Jinja2
templates = Jinja2Templates(directory="templates")

# Define el personaje
personaje = "rias"
user="user"
chat={
    personaje: f"hola soy {personaje} no esperaba verte por aqui",
    user:f"hola "
}
# Monta la carpeta 'static' para servir archivos estáticos
app.mount("/static", StaticFiles(directory="static"), name="static")

# Ruta para mostrar los personajes
@app.get("/", response_class=HTMLResponse)
async def read_html(request: Request):
    return templates.TemplateResponse("listapersonajes.html", {"request": request})

# Ruta dinámica para cada personaje
@app.get("/personajes/{personaje}", response_class=HTMLResponse)
async def personaje_detalle(request: Request, personaje: str):
    # El contexto es el nombre de la imagen que se usará
    context = {
        "character_image": f"{personaje}.jpg" , # Asume que el nombre de la imagen es igual al personaje
        "character_name": personaje.capitalize()  # Nombre del personaje con la primera letra en mayúscula

    }
    return templates.TemplateResponse("chat.html", {"request": request, **context})






# Cambia al nuevo modelo
model_name = "allura-org/MoE-Girl_400MA_1BT"

# Inicialización global
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",  # Utiliza automáticamente la GPU si está disponible
    torch_dtype=torch.float16  # Usa FP16 para eficiencia en GPUs
)

@app.post("/personajes/{personaje}/chat")
async def chat_with_character(request: Request, personaje: str, user_input: str = Body(...)):
    # Verificar que el user_input no esté vacío
    if not user_input:
        return JSONResponse(status_code=422, content={"message": "user_input is required"})
    
    # Crear el prompt dinámico con el formato esperado
    prompt = f"""<|im_start|>system
You are {personaje}, a sexy girl who has been dating the user for 2 months.<|im_end|>
<|im_start|>user
{user_input}<|im_end|>
<|im_start|>assistant
"""

    # Tokenizar el prompt
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")

    # Generar la respuesta
    outputs = model.generate(
        **inputs,
        max_new_tokens=500,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        temperature=0.7
    )

    # Decodificar la respuesta
    generated_response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Filtrar la respuesta para extraer solo el texto del asistente
    response_text = generated_response.split("<|im_start|>assistant")[1].strip().split("<|im_end|>")[0].strip()

    # Devolver la respuesta al usuario
    return JSONResponse(content={"response": response_text})