Do0rMaMu commited on
Commit
f3c7e66
·
verified ·
1 Parent(s): a46aed5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +27 -17
main.py CHANGED
@@ -1,28 +1,38 @@
1
  from fastapi import FastAPI
2
- from pydantic import BaseModel
 
3
  from llama_cpp import Llama
4
 
5
- # Model loading with specified path and configuration
6
- llm = Llama(
7
- model_path="Llama-3.2-3B-Instruct-Q8_0.gguf", # Update the path as necessary
8
- n_ctx=4096,
9
- n_threads=2,
 
10
  )
11
 
12
- # Pydantic object for validation
 
 
 
 
13
  class Validation(BaseModel):
14
- user_prompt: str # This will be the direct SQL query request or relevant prompt
15
- max_tokens: int = 1024
16
- temperature: float = 0.01
17
 
18
- # FastAPI application initialization
19
  app = FastAPI()
20
 
21
- # Endpoint for generating responses
22
  @app.post("/generate_response")
23
  async def generate_response(item: Validation):
24
- # Call the Llama model to generate a response directly based on the user's prompt
25
- output = llm(item.user_prompt, max_tokens=item.max_tokens, temperature=item.temperature, echo=False)
26
-
27
- # Extract and return the text from the response
28
- return output['choices'][0]['text']
 
 
 
 
 
1
  from fastapi import FastAPI
2
+ from pydantic import BaseModel, Field
3
+ from typing import List, Dict
4
  from llama_cpp import Llama
5
 
6
+ # Load the Llama model with the specified path and configuration
7
+ llm = Llama.from_pretrained(
8
+ repo_id="bartowski/Llama-3.2-3B-Instruct-GGUF", # Replace with the actual model repository ID
9
+ filename="Llama-3.2-3B-Instruct-Q8_0.gguf", # Replace with your actual model filename if necessary
10
+ n_ctx=4096,
11
+ n_threads=2,
12
  )
13
 
14
+ # Define a Pydantic model for request validation
15
+ class Message(BaseModel):
16
+ role: str # "user" or "assistant"
17
+ content: str # The actual message content
18
+
19
  class Validation(BaseModel):
20
+ messages: List[Message] = Field(default_factory=list) # List of previous messages in the conversation
21
+ max_tokens: int = 1024 # Maximum tokens for the response
22
+ temperature: float = 0.01 # Model response temperature for creativity
23
 
24
+ # Initialize the FastAPI application
25
  app = FastAPI()
26
 
27
+ # Define the endpoint for generating responses
28
  @app.post("/generate_response")
29
  async def generate_response(item: Validation):
30
+ # Generate a response using the Llama model with the chat history
31
+ response = llm.create_chat_completion(
32
+ messages=[{"role": msg.role, "content": msg.content} for msg in item.messages],
33
+ max_tokens=item.max_tokens,
34
+ temperature=item.temperature
35
+ )
36
+
37
+ # Extract and return the response text
38
+ return {"response": response['choices'][0]['message']['content']}