demo_1 / app_api.py
Royrotem100's picture
Initial commit
b228d02
raw
history blame
1.95 kB
from flask import Flask, request, jsonify
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
app = Flask(__name__)
# Load the model and tokenizer
model_name = "dicta-il/dictalm2.0-instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Ensure the tokenizer has a pad token, if not, add one.
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))
# Set the device to load the model onto
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
@app.route('/chat', methods=['POST'])
def chat():
data = request.json
messages = data.get("messages", [])
if not messages:
return jsonify({"error": "No messages provided"}), 400
# Combine messages into a single input string with the correct template
conversation = "<s>"
for i, message in enumerate(messages):
role = message["role"]
content = message["content"]
if role == "user":
if i == 0:
conversation += f"[INST] {content} [/INST]"
else:
conversation += f" [INST] {content} [/INST]"
elif role == "assistant":
conversation += f" {content}"
conversation += "</s>"
# Tokenize the combined conversation
encoded = tokenizer(conversation, return_tensors="pt").to(device)
# Generate response using the model
generated_ids = model.generate(
input_ids=encoded['input_ids'],
attention_mask=encoded['attention_mask'],
max_new_tokens=50,
pad_token_id=tokenizer.pad_token_id,
do_sample=True
)
decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return jsonify({"response": decoded})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)