Update app.py
Browse files
app.py
CHANGED
@@ -72,8 +72,19 @@ def generate_response(text: str, model_gemma, tokenizer_gemma, device) -> str: #
|
|
72 |
)
|
73 |
|
74 |
generated_output = model_gemma.generate(input, generation_config=generation_config)
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
#input_text = "Reapond to the users prompt: " + text
|
78 |
#input = tokenizer_gemma(input_text, return_tensors="pt").to(device)
|
79 |
#generated_output = model_gemma.generate(**input, max_length=MAX_GEMMA_LENGTH, early_stopping=True)
|
|
|
72 |
)
|
73 |
|
74 |
generated_output = model_gemma.generate(input, generation_config=generation_config)
|
75 |
+
decoded_output = tokenizer_gemma.decode(generated_output[0], skip_special_tokens=False)
|
76 |
+
|
77 |
+
# Extract the assistant's response (Gemma specific)
|
78 |
+
start_token = "<start_of_turn>model"
|
79 |
+
end_token = "<end_of_turn>"
|
80 |
+
|
81 |
+
start_index = decoded_output.find(start_token)
|
82 |
+
if start_index != -1:
|
83 |
+
start_index += len(start_token)
|
84 |
+
end_index = decoded_output.find(end_token, start_index)
|
85 |
+
assistant_response = decoded_output[start_index:].strip()
|
86 |
+
return assistant_response
|
87 |
+
return decoded_output
|
88 |
#input_text = "Reapond to the users prompt: " + text
|
89 |
#input = tokenizer_gemma(input_text, return_tensors="pt").to(device)
|
90 |
#generated_output = model_gemma.generate(**input, max_length=MAX_GEMMA_LENGTH, early_stopping=True)
|