File size: 5,131 Bytes
c40c75a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import React from "react";
import { Form, Select as AntSelect } from "antd";
import { TextInput, Text } from "@tremor/react";
import { Row, Col } from "antd";
import { Providers } from "../provider_info_helpers";

interface LiteLLMModelNameFieldProps {
  selectedProvider: Providers;
  providerModels: string[];
  getPlaceholder: (provider: Providers) => string;
}

const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
  selectedProvider,
  providerModels, 
  getPlaceholder,
}) => {
  const form = Form.useFormInstance();

  const handleModelChange = (value: string | string[]) => {
    // Ensure value is always treated as an array
    const values = Array.isArray(value) ? value : [value];
    
    // If "all-wildcard" is selected, clear the model_name field
    if (values.includes("all-wildcard")) {
      form.setFieldsValue({ model_name: undefined, model_mappings: [] });
    } else {
      // Get current model value to check if we need to update
      const currentModel = form.getFieldValue('model');

      // Only update if the value has actually changed
      if (JSON.stringify(currentModel) !== JSON.stringify(values)) {

        // Create mappings first
        const mappings = values.map(model => ({
          public_name: model,
          litellm_model: model
        }));
        
        // Update both fields in one call to reduce re-renders
        form.setFieldsValue({ 
          model: values,
          model_mappings: mappings
        });
        
      }
    }
  };

  // Handle custom model name changes
  const handleCustomModelNameChange = (e: React.ChangeEvent<HTMLInputElement>) => {
    const customName = e.target.value;

    // Immediately update the model mappings
    const currentMappings = form.getFieldValue('model_mappings') || [];
    const updatedMappings = currentMappings.map((mapping: any) => {
      if (mapping.public_name === 'custom' || mapping.litellm_model === 'custom') {
        return {
          public_name: customName,
          litellm_model: customName
        };
      }
      return mapping;
    });
    
    form.setFieldsValue({ model_mappings: updatedMappings });
  };

  return (
    <>
      <Form.Item
        label="LiteLLM Model Name(s)"
        tooltip="Actual model name used for making litellm.completion() / litellm.embedding() call."
        className="mb-0"
      >
        <Form.Item
          name="model"
          rules={[{ required: true, message: "Please select at least one model." }]}
          noStyle
        >
          {(selectedProvider === Providers.Azure) || 
           (selectedProvider === Providers.OpenAI_Compatible) || 
           (selectedProvider === Providers.Ollama) ? (
            <>
              <TextInput 
                placeholder={getPlaceholder(selectedProvider)} 
              />
            </>
          ) : providerModels.length > 0 ? (
            <AntSelect
              mode="multiple"
              allowClear
              showSearch
              placeholder="Select models"
              onChange={handleModelChange}
              optionFilterProp="children"
              filterOption={(input, option) =>
                (option?.label ?? '').toLowerCase().includes(input.toLowerCase())
              }
              options={[
                {
                  label: 'Custom Model Name (Enter below)',
                  value: 'custom'
                },
                {
                  label: `All ${selectedProvider} Models (Wildcard)`,
                  value: 'all-wildcard'
                },
                ...providerModels.map(model => ({
                  label: model,
                  value: model
                }))
              ]}
              style={{ width: '100%' }}
            />
          ) : (
            <TextInput placeholder={getPlaceholder(selectedProvider)} />
          )}
        </Form.Item>

        {/* Custom Model Name field */}
        <Form.Item
          noStyle
          shouldUpdate={(prevValues, currentValues) => 
            prevValues.model !== currentValues.model
          }
        >
          {({ getFieldValue }) => {
            const selectedModels = getFieldValue('model') || [];
            const modelArray = Array.isArray(selectedModels) ? selectedModels : [selectedModels];
            return modelArray.includes('custom') && (
              <Form.Item
                name="custom_model_name"
                rules={[{ required: true, message: "Please enter a custom model name." }]}
                className="mt-2"
              >
                <TextInput 
                  placeholder="Enter custom model name" 
                  onChange={handleCustomModelNameChange}
                />
              </Form.Item>
            );
          }}
        </Form.Item>
      </Form.Item>
      <Row>
        <Col span={10}></Col>
        <Col span={10}>
          <Text className="mb-3 mt-1">
            Actual model name used for making litellm.completion() call. We loadbalance models with the same public name
          </Text>
        </Col>
      </Row>
    </>
  );
};

export default LiteLLMModelNameField;