YALCINKAYA commited on
Commit
34139ad
·
verified ·
1 Parent(s): 1e04073

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import torch
3
  from flask import Flask, jsonify, request
4
  from flask_cors import CORS
5
- from transformers import GPTNeoForCausalLM, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
6
  # Set the HF_HOME environment variable to a writable directory
7
  os.environ["HF_HOME"] = "/workspace/huggingface_cache" # Change this to a writable path in your space
8
 
@@ -63,6 +63,8 @@ def generate_response(user_input, model_id):
63
 
64
 
65
  inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
 
 
66
  response = (tokenizer.decode(outputs[0], skip_special_tokens=True))
67
 
68
  cleaned_response = response.replace("User:", "").replace("Assistant:", "").strip()
 
2
  import torch
3
  from flask import Flask, jsonify, request
4
  from flask_cors import CORS
5
+ from transformers import GPTNeoForCausalLM, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, GenerationConfig
6
  # Set the HF_HOME environment variable to a writable directory
7
  os.environ["HF_HOME"] = "/workspace/huggingface_cache" # Change this to a writable path in your space
8
 
 
63
 
64
 
65
  inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
66
+
67
+ outputs = model.generate(**inputs, generation_config=generation_config)
68
  response = (tokenizer.decode(outputs[0], skip_special_tokens=True))
69
 
70
  cleaned_response = response.replace("User:", "").replace("Assistant:", "").strip()