Update app.py
Browse files
app.py
CHANGED
@@ -700,36 +700,33 @@ async def generate(
|
|
700 |
):
|
701 |
"""
|
702 |
Generate a text response based on the provided context and chat history.
|
703 |
-
|
704 |
-
The generation process can be customized using various parameters in the config:
|
705 |
-
- temperature: Controls randomness (0.0 to 2.0)
|
706 |
-
- max_new_tokens: Maximum length of generated text
|
707 |
-
- top_p: Nucleus sampling parameter
|
708 |
-
- top_k: Top-k sampling parameter
|
709 |
-
- strategy: Generation strategy to use
|
710 |
-
- num_samples: Number of samples for applicable strategies
|
711 |
-
|
712 |
-
Generation Strategies:
|
713 |
-
- default: Standard generation
|
714 |
-
- majority_voting: Generates multiple responses and uses the most common one
|
715 |
-
- best_of_n: Generates multiple responses and picks the best
|
716 |
-
- beam_search: Uses beam search for coherent generation
|
717 |
-
- dvts: Dynamic vocabulary tree search
|
718 |
"""
|
719 |
try:
|
720 |
chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]]
|
721 |
user_input = request.messages[-1].content
|
722 |
-
|
|
|
723 |
config = request.config or GenerationConfig()
|
|
|
|
|
|
|
|
|
|
|
724 |
|
|
|
725 |
response = await asyncio.to_thread(
|
726 |
generator.generate_with_context,
|
727 |
context=request.context or "",
|
728 |
user_input=user_input,
|
729 |
chat_history=chat_history,
|
730 |
-
model_kwargs=
|
|
|
|
|
|
|
|
|
|
|
731 |
)
|
732 |
-
|
733 |
return GenerationResponse(
|
734 |
id=str(uuid.uuid4()),
|
735 |
content=response
|
|
|
700 |
):
|
701 |
"""
|
702 |
Generate a text response based on the provided context and chat history.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
703 |
"""
|
704 |
try:
|
705 |
chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]]
|
706 |
user_input = request.messages[-1].content
|
707 |
+
|
708 |
+
# Extract or set defaults for additional arguments
|
709 |
config = request.config or GenerationConfig()
|
710 |
+
model_kwargs = {
|
711 |
+
"temperature": config.temperature if hasattr(config, "temperature") else 0.7,
|
712 |
+
"max_new_tokens": config.max_new_tokens if hasattr(config, "max_new_tokens") else 100,
|
713 |
+
# Add other model kwargs as needed
|
714 |
+
}
|
715 |
|
716 |
+
# Explicitly pass additional required arguments
|
717 |
response = await asyncio.to_thread(
|
718 |
generator.generate_with_context,
|
719 |
context=request.context or "",
|
720 |
user_input=user_input,
|
721 |
chat_history=chat_history,
|
722 |
+
model_kwargs=model_kwargs,
|
723 |
+
max_history_turns=config.max_history_turns if hasattr(config, "max_history_turns") else 3,
|
724 |
+
strategy=config.strategy if hasattr(config, "strategy") else "default",
|
725 |
+
num_samples=config.num_samples if hasattr(config, "num_samples") else 5,
|
726 |
+
depth=config.depth if hasattr(config, "depth") else 3,
|
727 |
+
breadth=config.breadth if hasattr(config, "breadth") else 2,
|
728 |
)
|
729 |
+
|
730 |
return GenerationResponse(
|
731 |
id=str(uuid.uuid4()),
|
732 |
content=response
|