# Optimized Devstral Inference Client # Connects to deployed model without restarting server for each request # Implements Modal best practices for lowest latency import modal import asyncio import time from typing import List, Dict, Any, Optional, AsyncGenerator import json # Connect to the deployed app app = modal.App("devstral-inference-client") # Image with OpenAI client for making requests client_image = modal.Image.debian_slim(python_version="3.12").pip_install( "openai>=1.76.0", "aiohttp>=3.9.0", "asyncio-throttle>=1.0.0" ) class DevstralClient: """Optimized client for Devstral model with persistent connections and caching""" def __init__(self, base_url: str, api_key: str): self.base_url = base_url self.api_key = api_key self._session = None self._response_cache = {} self._conversation_cache = {} async def __aenter__(self): """Async context manager entry - create persistent HTTP session""" import aiohttp connector = aiohttp.TCPConnector( limit=100, # Connection pool size keepalive_timeout=300, # Keep connections alive enable_cleanup_closed=True ) self._session = aiohttp.ClientSession( connector=connector, timeout=aiohttp.ClientTimeout(total=120) ) return self async def __aexit__(self, exc_type, exc_val, exc_tb): """Clean up session on exit""" if self._session: await self._session.close() def _get_cache_key(self, messages: List[Dict], **kwargs) -> str: """Generate cache key for deterministic requests""" key_data = { "messages": json.dumps(messages, sort_keys=True), "temperature": kwargs.get("temperature", 0.1), "max_tokens": kwargs.get("max_tokens", 500), "top_p": kwargs.get("top_p", 0.95) } return hash(json.dumps(key_data, sort_keys=True)) async def generate_response( self, prompt: str, system_prompt: Optional[str] = None, temperature: float = 0.1, max_tokens: int = 10000, stream: bool = False, use_cache: bool = True ) -> str: """Generate response from Devstral model with optimizations""" # Build messages messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) # Check cache for deterministic requests if use_cache and temperature == 0.0: cache_key = self._get_cache_key(messages, temperature=temperature, max_tokens=max_tokens) if cache_key in self._response_cache: print("šŸ“Ž Cache hit - returning cached response") return self._response_cache[cache_key] # Prepare request payload payload = { "model": "mistralai/Devstral-Small-2505", "messages": messages, "temperature": temperature, "max_tokens": max_tokens, # "top_p": 0.95, "stream": stream } headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } start_time = time.perf_counter() if stream: return await self._stream_response(payload, headers) else: return await self._complete_response(payload, headers, use_cache, start_time) async def _complete_response(self, payload: Dict, headers: Dict, use_cache: bool, start_time: float) -> str: """Handle complete (non-streaming) response""" async with self._session.post( f"{self.base_url}/v1/chat/completions", json=payload, headers=headers ) as response: if response.status != 200: error_text = await response.text() raise Exception(f"API Error {response.status}: {error_text}") result = await response.json() latency = (time.perf_counter() - start_time) * 1000 generated_text = result["choices"][0]["message"]["content"] # Cache deterministic responses if use_cache and payload["temperature"] == 0.0: cache_key = self._get_cache_key(payload["messages"], **payload) self._response_cache[cache_key] = generated_text print(f"⚔ Response generated in {latency:.2f}ms") return generated_text async def _stream_response(self, payload: Dict, headers: Dict) -> AsyncGenerator[str, None]: """Handle streaming response""" payload["stream"] = True async with self._session.post( f"{self.base_url}/v1/chat/completions", json=payload, headers=headers ) as response: if response.status != 200: error_text = await response.text() raise Exception(f"API Error {response.status}: {error_text}") buffer = "" async for chunk in response.content.iter_chunks(): if chunk[0]: buffer += chunk[0].decode() while "\n" in buffer: line, buffer = buffer.split("\n", 1) if line.startswith("data: "): data = line[6:] if data == "[DONE]": return try: json_data = json.loads(data) if "choices" in json_data and json_data["choices"]: delta = json_data["choices"][0].get("delta", {}) if "content" in delta: yield delta["content"] except json.JSONDecodeError: continue async def batch_generate( self, prompts: List[str], system_prompt: Optional[str] = None, temperature: float = 0.1, max_tokens: int = 500, max_concurrent: int = 5 ) -> List[str]: """Generate responses for multiple prompts with concurrency control""" from asyncio_throttle import Throttler # Throttle requests to avoid overwhelming the server throttler = Throttler(rate_limit=max_concurrent, period=1.0) async def generate_single(prompt: str) -> str: async with throttler: return await self.generate_response( prompt=prompt, system_prompt=system_prompt, temperature=temperature, max_tokens=max_tokens ) # Execute all requests concurrently tasks = [generate_single(prompt) for prompt in prompts] results = await asyncio.gather(*tasks, return_exceptions=True) # Handle any exceptions processed_results = [] for result in results: if isinstance(result, Exception): processed_results.append(f"Error: {str(result)}") else: processed_results.append(result) return processed_results @app.function( image=client_image, timeout=600, # 10 minutes ) async def run_devstral_inference( base_url: str, api_key: str, prompts: List[str], system_prompt: Optional[str] = None, mode: str = "single" # "single", "batch", "stream" ): """Main function to run optimized Devstral inference""" async with DevstralClient(base_url, api_key) as client: if mode == "single": # Single prompt inference if len(prompts) > 0: response = await client.generate_response( prompt=prompts[0], system_prompt=system_prompt, temperature=0.1, max_tokens=10000 ) return {"response": response} elif mode == "batch": # Batch inference for multiple prompts responses = await client.batch_generate( prompts=prompts, system_prompt=system_prompt, temperature=0.1, max_tokens=10000, max_concurrent=5 ) return {"responses": responses} elif mode == "stream": # Streaming inference if len(prompts) > 0: full_response = "" async for chunk in client.generate_response( prompt=prompts[0], system_prompt=system_prompt, temperature=0.1, stream=True ): full_response += chunk print(chunk, end="", flush=True) print() # New line after streaming return {"response": full_response} return {"error": "No prompts provided"} # Convenient wrapper functions for different use cases @app.function(image=client_image) async def code_generation(prompt: str, base_url: str, api_key: str) -> str: """Optimized for code generation tasks""" system_prompt = """You are an expert software engineer. Generate clean, efficient, and well-documented code. Focus on best practices, performance, and maintainability. Include brief explanations for complex logic.""" async with DevstralClient(base_url, api_key) as client: return await client.generate_response( prompt=prompt, system_prompt=system_prompt, temperature=0.0, # Deterministic for code max_tokens=10000, use_cache=True # Cache code responses ) @app.function(image=client_image) async def chat_response(prompt: str, base_url: str, api_key: str) -> str: """Optimized for conversational responses""" system_prompt = """You are a helpful, knowledgeable AI assistant. Provide clear, concise, and accurate responses. Be conversational but professional.""" async with DevstralClient(base_url, api_key) as client: return await client.generate_response( prompt=prompt, system_prompt=system_prompt, temperature=0.3, # Slightly creative max_tokens=10000 ) @app.function(image=client_image) async def document_analysis(prompt: str, base_url: str, api_key: str) -> str: """Optimized for document analysis and summarization""" system_prompt = """You are an expert document analyst. Provide thorough, structured analysis with key insights, summaries, and actionable recommendations. Use clear formatting and bullet points.""" async with DevstralClient(base_url, api_key) as client: return await client.generate_response( prompt=prompt, system_prompt=system_prompt, temperature=0.1, # Factual and consistent max_tokens=800 ) # Local test client for development @app.local_entrypoint() def main( base_url: str = "https://abhinav-bhatnagar--devstral-vllm-deployment-serve.modal.run", api_key: str = "ak-zMwhIPjqvBj30jbm1DmKqx", mode: str = "single" ): """Test the optimized Devstral inference client""" test_prompts = [ "Write a Python function to calculate the Fibonacci sequence using memoization.", "Explain the difference between REST and GraphQL APIs.", "What are the key benefits of using Docker containers?", "How does machine learning differ from traditional programming?", "Write a SQL query to find the top 5 customers by total order value." ] print(f"šŸš€ Testing Devstral inference in {mode} mode...") print(f"šŸ“” Connecting to: {base_url}") if mode == "single": # Test single inference result = run_devstral_inference.remote( base_url=base_url, api_key=api_key, prompts=[test_prompts[0]], system_prompt="You are a helpful coding assistant.", mode="single" ) print("āœ… Single inference result:") print(result["response"]) elif mode == "batch": # Test batch inference result = run_devstral_inference.remote( base_url=base_url, api_key=api_key, prompts=test_prompts[:3], # Test with 3 prompts system_prompt="You are a knowledgeable AI assistant.", mode="batch" ) print("āœ… Batch inference results:") for i, response in enumerate(result["responses"]): print(f"\nPrompt {i+1}: {test_prompts[i]}") print(f"Response: {response}") elif mode == "specialized": # Test specialized functions print("\nšŸ“ Testing Code Generation:") code_result = code_generation.remote( prompt="Create a Python class for a binary search tree with insert, search, and delete methods.", base_url=base_url, api_key=api_key ) print(code_result) print("\nšŸ’¬ Testing Chat Response:") chat_result = chat_response.remote( prompt="What's the best way to learn machine learning for beginners?", base_url=base_url, api_key=api_key ) print(chat_result) print("\nšŸŽ‰ Testing completed!") if __name__ == "__main__": # This allows running the client locally for testing import sys mode = sys.argv[1] if len(sys.argv) > 1 else "single" main(mode=mode)