说明
- 用户输入多个“信息”
- 大语言模型将“信息”进行处理,转成数组;(一维张量,向量)
- 通过余弦相似度等相关算法,计算两个向量是否相似
Ollama接口步骤
- 安装 Ollama: https://ollama.ai/
- 下载模型: ollama pull nomic-embed-text
- Ollama 默认运行在 http://localhost:11434
推荐的嵌入模型:
- nomic-embed-text: 768维,效果好,速度快
- mxbai-embed-large: 1024维,效果更好
- bge-m3: 多语言支持
springboot中调用本地模型
- @Test
- @Disabled("需要本地运行 Ollama 服务")
- public void testOllamaEmbedding() {
- // Ollama API 地址
- String apiUrl = "http://localhost:11434/api/embeddings";
- String apiKey = ""; // Ollama 本地不需要 key
- String model = "nomic-embed-text"; // 或 mxbai-embed-large
- EmbeddingClient client = new EmbeddingClientImpl(apiUrl, apiKey);
- // 水果库
- List<Fruit> fruits = Arrays.asList(new Fruit("红富士苹果", "红色 甜 脆 苹果 新鲜"), new Fruit("青苹果", "绿色 酸 脆 苹果 清爽"),
- new Fruit("金帅苹果", "黄色 甜 软 苹果"), new Fruit("香蕉", "黄色 甜 软 香蕉 热带水果"), new Fruit("草莓", "红色 甜 小 草莓 多汁 浆果"),
- new Fruit("西瓜", "绿色外皮 红色果肉 甜 大 西瓜 多汁 夏天"), new Fruit("葡萄", "紫色 甜 小 葡萄 多汁 成串"));
- // 为每个水果生成嵌入向量
- for (Fruit fruit : fruits) {
- fruit.embedding = client.getEmbeddingVector(model, fruit.description);
- }
- // 用户搜索
- String query = "红色的甜水果";
- double[] queryVector = client.getEmbeddingVector(model, query);
- System.out.println("搜索: "" + query + """);
- System.out.println("向量维度: " + queryVector.length);
- System.out.println();
- // 按相似度排序
- fruits.sort(Comparator.comparingDouble(f -> -cosineSimilarity(queryVector, f.embedding)));
- // 输出结果
- System.out.println("搜索结果(按相似度排序):");
- for (Fruit f : fruits) {
- double sim = cosineSimilarity(queryVector, f.embedding);
- System.out.printf(" %s (%.4f): %s%n", f.name, sim, f.description);
- }
- }
- /**
- * 计算两个向量的余弦相似度
- */
- public static double cosineSimilarity(double[] vectorA, double[] vectorB) {
- if (vectorA.length != vectorB.length) {
- throw new IllegalArgumentException("向量维度必须相同");
- }
- double dotProduct = 0;
- double normA = 0;
- double normB = 0;
- for (int i = 0; i < vectorA.length; i++) {
- dotProduct += vectorA[i] * vectorB[i];
- normA += vectorA[i] * vectorA[i];
- normB += vectorB[i] * vectorB[i];
- }
- if (normA == 0 || normB == 0) {
- return 0;
- }
- return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
- }
复制代码 核心方法
- @Slf4j
- public class EmbeddingClientImpl implements EmbeddingClient {
- private final RestTemplate restTemplate;
- private final String address;
- private final String key;
- public EmbeddingClientImpl(String address, String key) {
- PoolingHttpClientConnectionManager connectionManager = new PoolingHttpClientConnectionManager();
- connectionManager.setMaxTotal(100);
- connectionManager.setDefaultMaxPerRoute(20);
- // 设置请求配置
- RequestConfig requestConfig = RequestConfig.custom()
- .setConnectionRequestTimeout(Timeout.ofSeconds(30))
- .setResponseTimeout(Timeout.ofSeconds(300)) // 5分钟响应超时
- .build();
- // 使用 HttpClientBuilder 来构建 HttpClient
- HttpClient httpClient = HttpClientBuilder.create()
- .setConnectionManager(connectionManager)
- .setDefaultRequestConfig(requestConfig)
- .build();
- // 创建 HttpComponentsClientHttpRequestFactory
- HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
- requestFactory.setConnectTimeout(30000); // 30秒连接超时
- requestFactory.setConnectionRequestTimeout(30000);
- // 创建 RestTemplate,只使用 StringHttpMessageConverter 避免 Jackson 依赖问题
- this.restTemplate = new RestTemplate(requestFactory);
- // 清除默认的消息转换器,只保留字符串转换器
- this.restTemplate.setMessageConverters(
- Collections.singletonList(new StringHttpMessageConverter(StandardCharsets.UTF_8)));
- this.address = address;
- this.key = key;
- }
- @Override
- public String embedding(String model, String input) {
- long start = System.currentTimeMillis();
- String url = address;
- HttpHeaders headers = new HttpHeaders();
- headers.setContentType(MediaType.APPLICATION_JSON);
- headers.setAcceptCharset(Collections.singletonList(StandardCharsets.UTF_8));
- if (key != null && !key.isEmpty()) {
- headers.add("Authorization", "Bearer " + key);
- }
- // 将 request 转化为 body 字符串
- JSONObject jsonObject = new JSONObject();
- jsonObject.put("input", input);
- jsonObject.put("model", model);
- String body = jsonObject.toString();
- log.debug("Embedding Request Body: {}", body);
- // 请求
- HttpEntity<String> req = new HttpEntity<>(body, headers);
- ResponseEntity<String> result = restTemplate.postForEntity(url, req, String.class);
- if (!result.getStatusCode().equals(HttpStatus.OK)) {
- throw new RuntimeException("embeddings error, request: " + body + ", response: " + result.getBody());
- }
- log.info("embedding cost {} ms", System.currentTimeMillis() - start);
- return result.getBody();
- }
- /**
- * 获取文本嵌入向量
- * <p>
- * 解析 OpenAI 格式的响应,提取 embedding 向量
- *
- * 响应格式示例: <pre>
- * {
- * "object": "list",
- * "data": [{
- * "object": "embedding",
- * "index": 0,
- * "embedding": [0.0023064255, -0.009327292, ...]
- * }],
- * "model": "text-embedding-ada-002",
- * "usage": {"prompt_tokens": 8, "total_tokens": 8}
- * }
- * </pre>
- * @param model 模型名称
- * @param input 输入文本
- * @return 嵌入向量
- */
- @Override
- public double[] getEmbeddingVector(String model, String input) {
- String response = embedding(model, input);
- return parseEmbeddingVector(response);
- }
- /**
- * 解析嵌入向量响应
- * @param response JSON响应字符串
- * @return 向量数组
- */
- private double[] parseEmbeddingVector(String response) {
- try {
- JSONObject jsonResponse = JSONObject.parseObject(response);
- // OpenAI 格式
- if (jsonResponse.containsKey("data")) {
- JSONArray dataArray = jsonResponse.getJSONArray("data");
- if (dataArray != null && !dataArray.isEmpty()) {
- JSONObject firstData = dataArray.getJSONObject(0);
- JSONArray embeddingArray = firstData.getJSONArray("embedding");
- return jsonArrayToDoubleArray(embeddingArray);
- }
- }
- // Ollama 格式 (直接返回 embedding 数组)
- if (jsonResponse.containsKey("embedding")) {
- JSONArray embeddingArray = jsonResponse.getJSONArray("embedding");
- return jsonArrayToDoubleArray(embeddingArray);
- }
- // 阿里通义格式
- if (jsonResponse.containsKey("output")) {
- JSONObject output = jsonResponse.getJSONObject("output");
- if (output.containsKey("embeddings")) {
- JSONArray embeddings = output.getJSONArray("embeddings");
- if (!embeddings.isEmpty()) {
- JSONObject firstEmbedding = embeddings.getJSONObject(0);
- JSONArray embeddingArray = firstEmbedding.getJSONArray("embedding");
- return jsonArrayToDoubleArray(embeddingArray);
- }
- }
- }
- throw new RuntimeException("无法解析嵌入向量响应: " + response);
- }
- catch (Exception e) {
- log.error("解析嵌入向量失败: {}", response, e);
- throw new RuntimeException("解析嵌入向量失败", e);
- }
- }
- /**
- * 将 JSONArray 转换为 double 数组
- */
- private double[] jsonArrayToDoubleArray(JSONArray jsonArray) {
- double[] result = new double[jsonArray.size()];
- for (int i = 0; i < jsonArray.size(); i++) {
- result[i] = jsonArray.getDoubleValue(i);
- }
- return result;
- }
- }
复制代码
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作! |