PitterTMYT commited on
Commit
3ec5e4c
·
verified ·
1 Parent(s): cde668a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -9
app.py CHANGED
@@ -1,13 +1,69 @@
1
- import gradio as gr
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
 
4
- tokenizer = AutoTokenizer.from_pretrained("IlyaGusev/saiga_gemma2_10b")
5
- model = AutoModelForCausalLM.from_pretrained("IlyaGusev/saiga_gemma2_10b")
6
 
7
- def generate_text(prompt):
8
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
9
- output = model.generate(input_ids, max_length=50)
10
- return tokenizer.decode(output[0], skip_special_tokens=True)
11
 
12
- iface = gr.Interface(fn=generate_text, inputs="text", outputs="text")
13
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from flask import Flask, request, jsonify
3
+ import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from huggingface_hub import login
6
 
7
+ app = Flask(__name__)
 
8
 
9
+ def init_model():
10
+ global model, tokenizer
11
+ hf_token = os.getenv("HF_TOKEN") # Чтение токена из переменной окружения
 
12
 
13
+ if hf_token is None:
14
+ raise ValueError("Hugging Face token is not set. Please set the HF_TOKEN environment variable.")
15
+
16
+ # Аутентификация с использованием токена
17
+ login(hf_token, add_to_git_credential=True)
18
+
19
+ # Загрузка модели и токенизатора без квантования и без распределения на CPU/диск
20
+ tokenizer = AutoTokenizer.from_pretrained("IlyaGusev/saiga_gemma2_10b", token=hf_token)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ "IlyaGusev/saiga_gemma2_10b",
23
+ token=hf_token,
24
+ torch_dtype=torch.float16, # Использование float16 для уменьшения потребления памяти
25
+ device_map=None # Не использовать автоматическое распределение на CPU/диск
26
+ )
27
+
28
+ # Явное перемещение модели на GPU, если доступно
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ model.to(device)
31
+
32
+ @app.route('/generate', methods=['POST'])
33
+ def generate_response():
34
+ try:
35
+ data = request.get_json()
36
+ print(f"Received data: {data}")
37
+
38
+ prompt = data.get('prompt', '')
39
+ max_length = data.get('max_length', 100)
40
+ temperature = data.get('temperature', 0.7)
41
+ top_p = data.get('top_p', 0.85)
42
+ repetition_penalty = data.get('repetition_penalty', 1.1)
43
+
44
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
45
+ attention_mask = torch.ones_like(input_ids).to(model.device)
46
+
47
+ output = model.generate(
48
+ input_ids,
49
+ attention_mask=attention_mask,
50
+ max_length=max_length,
51
+ temperature=temperature,
52
+ top_p=top_p,
53
+ repetition_penalty=repetition_penalty,
54
+ do_sample=True,
55
+ num_return_sequences=1,
56
+ pad_token_id=tokenizer.eos_token_id
57
+ )
58
+ print(f"Generated output: {output}")
59
+ response_text = tokenizer.decode(output[0], skip_special_tokens=True)
60
+ print(f"Generated response: {response_text}")
61
+ return jsonify({"response": response_text})
62
+
63
+ except Exception as e:
64
+ print(f"Error: {str(e)}")
65
+ return jsonify({"response": "Извините, произошла ошибка при генерации ответа."}), 500
66
+
67
+ if __name__ == "__main__":
68
+ init_model()
69
+ app.run(host='0.0.0.0', port=7860)