File size: 2,843 Bytes
3ec5e4c
 
 
213f021
3ec5e4c
213f021
3ec5e4c
213f021
3ec5e4c
 
 
213f021
3ec5e4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
from flask import Flask, request, jsonify
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login

app = Flask(__name__)

def init_model():
    global model, tokenizer
    hf_token = os.getenv("HF_TOKEN")  # Чтение токена из переменной окружения

    if hf_token is None:
        raise ValueError("Hugging Face token is not set. Please set the HF_TOKEN environment variable.")

    # Аутентификация с использованием токена
    login(hf_token, add_to_git_credential=True)

    # Загрузка модели и токенизатора без квантования и без распределения на CPU/диск
    tokenizer = AutoTokenizer.from_pretrained("IlyaGusev/saiga_gemma2_10b", token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(
        "IlyaGusev/saiga_gemma2_10b",
        token=hf_token,
        torch_dtype=torch.float16,  # Использование float16 для уменьшения потребления памяти
        device_map=None  # Не использовать автоматическое распределение на CPU/диск
    )

    # Явное перемещение модели на GPU, если доступно
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

@app.route('/generate', methods=['POST'])
def generate_response():
    try:
        data = request.get_json()
        print(f"Received data: {data}")

        prompt = data.get('prompt', '')
        max_length = data.get('max_length', 100)
        temperature = data.get('temperature', 0.7)
        top_p = data.get('top_p', 0.85)
        repetition_penalty = data.get('repetition_penalty', 1.1)

        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
        attention_mask = torch.ones_like(input_ids).to(model.device)

        output = model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_length=max_length,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id
        )
        print(f"Generated output: {output}")
        response_text = tokenizer.decode(output[0], skip_special_tokens=True)
        print(f"Generated response: {response_text}")
        return jsonify({"response": response_text})

    except Exception as e:
        print(f"Error: {str(e)}")
        return jsonify({"response": "Извините, произошла ошибка при генерации ответа."}), 500

if __name__ == "__main__":
    init_model()
    app.run(host='0.0.0.0', port=7860)