Spaces:
Sleeping
Sleeping
import os | |
from flask import Flask, request, jsonify | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from huggingface_hub import login | |
app = Flask(__name__) | |
def init_model(): | |
global model, tokenizer | |
hf_token = os.getenv("HF_TOKEN") # Чтение токена из переменной окружения | |
if hf_token is None: | |
raise ValueError("Hugging Face token is not set. Please set the HF_TOKEN environment variable.") | |
# Аутентификация с использованием токена | |
login(hf_token, add_to_git_credential=True) | |
# Загрузка модели и токенизатора без квантования и без распределения на CPU/диск | |
tokenizer = AutoTokenizer.from_pretrained("IlyaGusev/saiga_gemma2_10b", token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained( | |
"IlyaGusev/saiga_gemma2_10b", | |
token=hf_token, | |
torch_dtype=torch.float16, # Использование float16 для уменьшения потребления памяти | |
device_map=None # Не использовать автоматическое распределение на CPU/диск | |
) | |
# Явное перемещение модели на GPU, если доступно | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
def generate_response(): | |
try: | |
data = request.get_json() | |
print(f"Received data: {data}") | |
prompt = data.get('prompt', '') | |
max_length = data.get('max_length', 100) | |
temperature = data.get('temperature', 0.7) | |
top_p = data.get('top_p', 0.85) | |
repetition_penalty = data.get('repetition_penalty', 1.1) | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) | |
attention_mask = torch.ones_like(input_ids).to(model.device) | |
output = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_length=max_length, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
num_return_sequences=1, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
print(f"Generated output: {output}") | |
response_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
print(f"Generated response: {response_text}") | |
return jsonify({"response": response_text}) | |
except Exception as e: | |
print(f"Error: {str(e)}") | |
return jsonify({"response": "Извините, произошла ошибка при генерации ответа."}), 500 | |
if __name__ == "__main__": | |
init_model() | |
app.run(host='0.0.0.0', port=7860) |