Maximofn commited on
Commit
eae8d77
·
1 Parent(s): 8e2c98e

Replace HuggingFace Inference API with local Transformers model loading

Browse files

- Switch from HuggingFace Inference Client to local model loading
- Use SmolLM2-1.7B-Instruct model instead of Qwen/Qwen2.5-72B-Instruct
- Add device detection and model loading with torch.bfloat16
- Update model calling logic to use local model generation
- Improve token generation parameters
- Add print statements for model loading confirmation

Files changed (1) hide show
  1. app.py +37 -19
app.py CHANGED
@@ -1,6 +1,7 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from huggingface_hub import InferenceClient
 
4
 
5
  from langchain_core.messages import HumanMessage, AIMessage
6
  from langgraph.checkpoint.memory import MemorySaver
@@ -10,15 +11,21 @@ import os
10
  from dotenv import load_dotenv
11
  load_dotenv()
12
 
13
- # HuggingFace token
14
- HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", os.getenv("HUGGINGFACE_TOKEN"))
15
-
16
- # Initialize the HuggingFace model
17
- model = InferenceClient(
18
- model="Qwen/Qwen2.5-72B-Instruct",
19
- api_key=os.getenv("HUGGINGFACE_TOKEN")
 
 
 
 
20
  )
21
 
 
 
22
  # Define the function that calls the model
23
  def call_model(state: MessagesState):
24
  """
@@ -30,24 +37,35 @@ def call_model(state: MessagesState):
30
  Returns:
31
  dict: A dictionary containing the generated text and the thread ID
32
  """
33
- # Convert LangChain messages to HuggingFace format
34
- hf_messages = []
35
  for msg in state["messages"]:
36
  if isinstance(msg, HumanMessage):
37
- hf_messages.append({"role": "user", "content": msg.content})
38
  elif isinstance(msg, AIMessage):
39
- hf_messages.append({"role": "assistant", "content": msg.content})
40
 
41
- # Call the API
42
- response = model.chat_completion(
43
- messages=hf_messages,
44
- temperature=0.5,
45
- max_tokens=64,
46
- top_p=0.7
 
 
 
 
 
 
47
  )
48
 
 
 
 
 
 
49
  # Convert the response to LangChain format
50
- ai_message = AIMessage(content=response.choices[0].message.content)
51
  return {"messages": state["messages"] + [ai_message]}
52
 
53
  # Define the graph
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
 
6
  from langchain_core.messages import HumanMessage, AIMessage
7
  from langgraph.checkpoint.memory import MemorySaver
 
11
  from dotenv import load_dotenv
12
  load_dotenv()
13
 
14
+ # Initialize the model and tokenizer
15
+ print("Cargando modelo y tokenizer...")
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
18
+
19
+ # Load the model in BF16 format for better performance and lower memory usage
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ model_name,
23
+ torch_dtype=torch.bfloat16,
24
+ device_map="auto" # This will automatically distribute the model across available GPUs
25
  )
26
 
27
+ print(f"Modelo cargado en dispositivo: {device}")
28
+
29
  # Define the function that calls the model
30
  def call_model(state: MessagesState):
31
  """
 
37
  Returns:
38
  dict: A dictionary containing the generated text and the thread ID
39
  """
40
+ # Convert LangChain messages to chat format
41
+ messages = []
42
  for msg in state["messages"]:
43
  if isinstance(msg, HumanMessage):
44
+ messages.append({"role": "user", "content": msg.content})
45
  elif isinstance(msg, AIMessage):
46
+ messages.append({"role": "assistant", "content": msg.content})
47
 
48
+ # Prepare the input using the chat template
49
+ input_text = tokenizer.apply_chat_template(messages, tokenize=False)
50
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
51
+
52
+ # Generate response
53
+ outputs = model.generate(
54
+ inputs,
55
+ max_new_tokens=512, # Increase the number of tokens for longer responses
56
+ temperature=0.7,
57
+ top_p=0.9,
58
+ do_sample=True,
59
+ pad_token_id=tokenizer.eos_token_id
60
  )
61
 
62
+ # Decode and clean the response
63
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
+ # Extract only the assistant's response (after the last user message)
65
+ response = response.split("Assistant:")[-1].strip()
66
+
67
  # Convert the response to LangChain format
68
+ ai_message = AIMessage(content=response)
69
  return {"messages": state["messages"] + [ai_message]}
70
 
71
  # Define the graph