YALCINKAYA commited on
Commit
8c39757
·
1 Parent(s): 188010c

extract_relevant_text

Browse files
Files changed (1) hide show
  1. app.py +22 -2
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  from flask import Flask, jsonify, request
3
  from flask_cors import CORS
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
 
5
 
6
  # Set the HF_HOME environment variable to a writable directory
7
  os.environ["HF_HOME"] = "/workspace/huggingface_cache" # Change this to a writable path in your space
@@ -20,7 +21,7 @@ def get_model_and_tokenizer(model_id):
20
  try:
21
  print(f"Loading tokenizer for model_id: {model_id}")
22
  # Load the tokenizer
23
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
24
  tokenizer.pad_token = tokenizer.eos_token
25
 
26
  print(f"Loading model and for model_id: {model_id}")
@@ -31,6 +32,25 @@ def get_model_and_tokenizer(model_id):
31
  except Exception as e:
32
  print(f"Error loading model: {e}")
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def generate_response(user_input, model_id):
35
  prompt = formatted_prompt(user_input)
36
 
@@ -56,7 +76,7 @@ def generate_response(user_input, model_id):
56
  # Generate response
57
  outputs = model.generate(**inputs, generation_config=generation_config)
58
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
59
- return response
60
  except Exception as e:
61
  print(f"Error generating response: {e}")
62
  return "Error generating response."
 
2
  from flask import Flask, jsonify, request
3
  from flask_cors import CORS
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
5
+ import re
6
 
7
  # Set the HF_HOME environment variable to a writable directory
8
  os.environ["HF_HOME"] = "/workspace/huggingface_cache" # Change this to a writable path in your space
 
21
  try:
22
  print(f"Loading tokenizer for model_id: {model_id}")
23
  # Load the tokenizer
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
25
  tokenizer.pad_token = tokenizer.eos_token
26
 
27
  print(f"Loading model and for model_id: {model_id}")
 
32
  except Exception as e:
33
  print(f"Error loading model: {e}")
34
 
35
+ def extract_relevant_text(response):
36
+ """
37
+ This function extracts the first 'user' and 'assistant' blocks between
38
+ <|im_start|> and <|im_end|> in the generated response.
39
+ """
40
+ # Regex to match content between <|im_start|> and <|im_end|> tags
41
+ pattern = re.compile(r"<\|im_start\|>(.*?)<\|im_end\|>", re.DOTALL)
42
+ matches = pattern.findall(response)
43
+
44
+ if len(matches) < 2:
45
+ return "Unable to extract sufficient data from the response."
46
+
47
+ # Assuming the first match is user and the second match is assistant
48
+ user_message = matches[0].strip() # First <|im_start|> block
49
+ assistant_message = matches[1].strip() # Second <|im_start|> block
50
+
51
+ # Format the extracted result
52
+ return f"user: {user_message}\nassistant: {assistant_message}"
53
+
54
  def generate_response(user_input, model_id):
55
  prompt = formatted_prompt(user_input)
56
 
 
76
  # Generate response
77
  outputs = model.generate(**inputs, generation_config=generation_config)
78
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
79
+ return extract_relevant_text(response)
80
  except Exception as e:
81
  print(f"Error generating response: {e}")
82
  return "Error generating response."