Chris4K commited on
Commit
4279e53
·
verified ·
1 Parent(s): 9296210

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -18
app.py CHANGED
@@ -700,36 +700,33 @@ async def generate(
700
  ):
701
  """
702
  Generate a text response based on the provided context and chat history.
703
-
704
- The generation process can be customized using various parameters in the config:
705
- - temperature: Controls randomness (0.0 to 2.0)
706
- - max_new_tokens: Maximum length of generated text
707
- - top_p: Nucleus sampling parameter
708
- - top_k: Top-k sampling parameter
709
- - strategy: Generation strategy to use
710
- - num_samples: Number of samples for applicable strategies
711
-
712
- Generation Strategies:
713
- - default: Standard generation
714
- - majority_voting: Generates multiple responses and uses the most common one
715
- - best_of_n: Generates multiple responses and picks the best
716
- - beam_search: Uses beam search for coherent generation
717
- - dvts: Dynamic vocabulary tree search
718
  """
719
  try:
720
  chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]]
721
  user_input = request.messages[-1].content
722
-
 
723
  config = request.config or GenerationConfig()
 
 
 
 
 
724
 
 
725
  response = await asyncio.to_thread(
726
  generator.generate_with_context,
727
  context=request.context or "",
728
  user_input=user_input,
729
  chat_history=chat_history,
730
- model_kwargs=config
 
 
 
 
 
731
  )
732
-
733
  return GenerationResponse(
734
  id=str(uuid.uuid4()),
735
  content=response
 
700
  ):
701
  """
702
  Generate a text response based on the provided context and chat history.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
703
  """
704
  try:
705
  chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]]
706
  user_input = request.messages[-1].content
707
+
708
+ # Extract or set defaults for additional arguments
709
  config = request.config or GenerationConfig()
710
+ model_kwargs = {
711
+ "temperature": config.temperature if hasattr(config, "temperature") else 0.7,
712
+ "max_new_tokens": config.max_new_tokens if hasattr(config, "max_new_tokens") else 100,
713
+ # Add other model kwargs as needed
714
+ }
715
 
716
+ # Explicitly pass additional required arguments
717
  response = await asyncio.to_thread(
718
  generator.generate_with_context,
719
  context=request.context or "",
720
  user_input=user_input,
721
  chat_history=chat_history,
722
+ model_kwargs=model_kwargs,
723
+ max_history_turns=config.max_history_turns if hasattr(config, "max_history_turns") else 3,
724
+ strategy=config.strategy if hasattr(config, "strategy") else "default",
725
+ num_samples=config.num_samples if hasattr(config, "num_samples") else 5,
726
+ depth=config.depth if hasattr(config, "depth") else 3,
727
+ breadth=config.breadth if hasattr(config, "breadth") else 2,
728
  )
729
+
730
  return GenerationResponse(
731
  id=str(uuid.uuid4()),
732
  content=response