plvictor commited on
Commit
fb6d39e
·
verified ·
1 Parent(s): a350083

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -82
app.py CHANGED
@@ -1,16 +1,21 @@
1
- import gradio as gr
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import os
 
 
5
 
6
- # Reduzir verbosidade dos warnings
7
  os.environ["TRANSFORMERS_VERBOSITY"] = "error"
 
8
 
9
- # TinyLlama - modelo leve e eficiente
10
  MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
11
 
12
- print("Carregando TinyLlama 1.1B...")
13
 
 
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_NAME,
@@ -19,95 +24,142 @@ model = AutoModelForCausalLM.from_pretrained(
19
  low_cpu_mem_usage=True
20
  )
21
 
22
- # Configurar pad token
23
  if tokenizer.pad_token is None:
24
  tokenizer.pad_token = tokenizer.eos_token
25
 
26
- print("✅ Modelo carregado! Interface iniciando...")
27
 
28
- def chat_response(message, max_tokens, temperature):
29
- """Função principal de chat"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  try:
31
- # Template do TinyLlama
32
- prompt = f"<|system|>\nVocê é um assistente útil. Responda de forma clara e concisa.<|user|>\n{message}<|assistant|>\n"
33
-
34
- # Tokenizar
35
- inputs = tokenizer(
36
- prompt,
37
- return_tensors="pt",
38
- truncation=True,
39
- max_length=1200,
40
- padding=False
41
- )
42
-
43
- # Gerar resposta (sem early_stopping para evitar warning)
44
- with torch.no_grad():
45
- outputs = model.generate(
46
- inputs.input_ids,
47
- attention_mask=inputs.attention_mask,
48
- max_new_tokens=max_tokens,
49
- temperature=temperature,
50
- do_sample=True,
51
- top_p=0.9,
52
- repetition_penalty=1.1,
53
- pad_token_id=tokenizer.eos_token_id,
54
- eos_token_id=tokenizer.eos_token_id
 
 
55
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Extrair resposta
58
- new_tokens = outputs[0][len(inputs.input_ids[0]):]
59
- response = tokenizer.decode(new_tokens, skip_special_tokens=True)
60
 
61
- # Limpar resposta
62
- response = response.split("<|user|>")[0]
63
- response = response.split("<|system|>")[0]
64
- response = response.strip()
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- return response if response else "Não consegui gerar uma resposta. Tente reformular sua pergunta."
67
 
68
  except Exception as e:
69
- return f"Erro: {str(e)}"
70
-
71
- # Interface Gradio simples e funcional
72
- interface = gr.Interface(
73
- fn=chat_response,
74
- inputs=[
75
- gr.Textbox(
76
- label="💬 Sua pergunta",
77
- placeholder="Digite sua pergunta aqui...",
78
- lines=2
79
- ),
80
- gr.Slider(
81
- minimum=50,
82
- maximum=400,
83
- value=200,
84
- step=10,
85
- label="🔢 Tokens máximos"
86
- ),
87
- gr.Slider(
88
- minimum=0.1,
89
- maximum=1.2,
90
- value=0.7,
91
- step=0.1,
92
- label="🌡️ Criatividade"
93
- )
94
- ],
95
- outputs=gr.Textbox(
96
- label="🤖 Resposta do TinyLlama",
97
- lines=5
98
- ),
99
- title="🦙 TinyLlama Chat API",
100
- description="Modelo de IA leve (2.2GB) otimizado para Hugging Face Spaces gratuito",
101
- theme="default",
102
- # Sem examples para evitar cache/erros
103
- allow_flagging="never"
104
- )
105
 
106
  if __name__ == "__main__":
107
- print("🚀 Iniciando servidor...")
108
- interface.launch(
109
- server_name="0.0.0.0",
110
- server_port=7860,
111
- share=False,
112
- show_error=False
 
 
 
 
 
 
 
113
  )
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
  import os
6
+ import uvicorn
7
+ import threading
8
 
9
+ # Configurações
10
  os.environ["TRANSFORMERS_VERBOSITY"] = "error"
11
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
12
 
13
+ # Modelo
14
  MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
15
 
16
+ print("🦙 Carregando TinyLlama para API...")
17
 
18
+ # Carregar modelo
19
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
  model = AutoModelForCausalLM.from_pretrained(
21
  MODEL_NAME,
 
24
  low_cpu_mem_usage=True
25
  )
26
 
 
27
  if tokenizer.pad_token is None:
28
  tokenizer.pad_token = tokenizer.eos_token
29
 
30
+ print("✅ Modelo carregado! API iniciando...")
31
 
32
+ # FastAPI app
33
+ app = FastAPI(
34
+ title="TinyLlama Chat API",
35
+ description="API REST para TinyLlama 1.1B",
36
+ version="1.0.0"
37
+ )
38
+
39
+ # Modelos Pydantic
40
+ class ChatRequest(BaseModel):
41
+ message: str
42
+ max_tokens: int = 200
43
+ temperature: float = 0.7
44
+
45
+ class ChatResponse(BaseModel):
46
+ response: str
47
+ status: str = "success"
48
+
49
+ # Lock para thread safety
50
+ model_lock = threading.Lock()
51
+
52
+ def generate_response(message: str, max_tokens: int = 200, temperature: float = 0.7) -> str:
53
+ """Gerar resposta com o modelo"""
54
  try:
55
+ with model_lock:
56
+ prompt = f"<|system|>\nVocê é um assistente útil. Responda de forma clara e concisa.<|user|>\n{message}<|assistant|>\n"
57
+
58
+ inputs = tokenizer(
59
+ prompt,
60
+ return_tensors="pt",
61
+ truncation=True,
62
+ max_length=1000,
63
+ padding=False
64
+ )
65
+
66
+ with torch.no_grad():
67
+ outputs = model.generate(
68
+ inputs.input_ids,
69
+ max_new_tokens=min(max_tokens, 300),
70
+ temperature=max(0.1, min(temperature, 1.0)),
71
+ do_sample=True,
72
+ top_p=0.9,
73
+ repetition_penalty=1.1,
74
+ pad_token_id=tokenizer.eos_token_id,
75
+ eos_token_id=tokenizer.eos_token_id
76
+ )
77
+
78
+ response = tokenizer.decode(
79
+ outputs[0][len(inputs.input_ids[0]):],
80
+ skip_special_tokens=True
81
  )
82
+
83
+ # Limpar resposta
84
+ response = response.split("<|user|>")[0].split("<|system|>")[0].strip()
85
+
86
+ return response if response else "Não consegui gerar uma resposta."
87
+
88
+ except Exception as e:
89
+ raise HTTPException(status_code=500, detail=f"Erro na geração: {str(e)}")
90
+
91
+ # Endpoints da API
92
+
93
+ @app.get("/")
94
+ async def root():
95
+ """Endpoint raiz - informações da API"""
96
+ return {
97
+ "message": "TinyLlama Chat API",
98
+ "model": MODEL_NAME,
99
+ "endpoints": {
100
+ "POST /chat": "Enviar mensagem para o modelo",
101
+ "GET /health": "Verificar status da API",
102
+ "GET /docs": "Documentação interativa"
103
+ }
104
+ }
105
+
106
+ @app.get("/health")
107
+ async def health_check():
108
+ """Verificar se a API está funcionando"""
109
+ return {
110
+ "status": "healthy",
111
+ "model_loaded": True,
112
+ "model_name": MODEL_NAME
113
+ }
114
+
115
+ @app.post("/chat", response_model=ChatResponse)
116
+ async def chat_endpoint(request: ChatRequest):
117
+ """Endpoint principal para chat"""
118
+ if not request.message or not request.message.strip():
119
+ raise HTTPException(status_code=400, detail="Mensagem não pode estar vazia")
120
+
121
+ try:
122
+ response = generate_response(
123
+ message=request.message,
124
+ max_tokens=request.max_tokens,
125
+ temperature=request.temperature
126
+ )
127
 
128
+ return ChatResponse(response=response)
 
 
129
 
130
+ except Exception as e:
131
+ raise HTTPException(status_code=500, detail=str(e))
132
+
133
+ @app.get("/chat")
134
+ async def chat_get(message: str, max_tokens: int = 200, temperature: float = 0.7):
135
+ """Endpoint GET para chat (mais simples de testar)"""
136
+ if not message or not message.strip():
137
+ raise HTTPException(status_code=400, detail="Parâmetro 'message' é obrigatório")
138
+
139
+ try:
140
+ response = generate_response(
141
+ message=message,
142
+ max_tokens=max_tokens,
143
+ temperature=temperature
144
+ )
145
 
146
+ return {"response": response, "status": "success"}
147
 
148
  except Exception as e:
149
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  if __name__ == "__main__":
152
+ print("🚀 Iniciando servidor FastAPI...")
153
+ print("📡 API estará disponível em:")
154
+ print(" - GET / (informações)")
155
+ print(" - GET /health (status)")
156
+ print(" - POST /chat (principal)")
157
+ print(" - GET /chat (teste simples)")
158
+ print(" - GET /docs (documentação)")
159
+
160
+ uvicorn.run(
161
+ app,
162
+ host="0.0.0.0",
163
+ port=7860,
164
+ log_level="error" # Reduzir logs
165
  )