找回密码
 立即注册
首页 业界区 安全 大语言模型~Ollama本地模型和java一起体验LLM ...

大语言模型~Ollama本地模型和java一起体验LLM

裴涛 1 小时前
说明


  • 用户输入多个“信息”
  • 大语言模型将“信息”进行处理,转成数组;(一维张量,向量)
  • 通过余弦相似度等相关算法,计算两个向量是否相似
Ollama接口步骤


  • 安装 Ollama: https://ollama.ai/
  • 下载模型: ollama pull nomic-embed-text
  • Ollama 默认运行在 http://localhost:11434
推荐的嵌入模型:


  • nomic-embed-text: 768维,效果好,速度快
  • mxbai-embed-large: 1024维,效果更好
  • bge-m3: 多语言支持
1.png

springboot中调用本地模型
  1.     @Test
  2.         @Disabled("需要本地运行 Ollama 服务")
  3.         public void testOllamaEmbedding() {
  4.                 // Ollama API 地址
  5.                 String apiUrl = "http://localhost:11434/api/embeddings";
  6.                 String apiKey = ""; // Ollama 本地不需要 key
  7.                 String model = "nomic-embed-text"; // 或 mxbai-embed-large
  8.                 EmbeddingClient client = new EmbeddingClientImpl(apiUrl, apiKey);
  9.                 // 水果库
  10.                 List<Fruit> fruits = Arrays.asList(new Fruit("红富士苹果", "红色 甜 脆 苹果 新鲜"), new Fruit("青苹果", "绿色 酸 脆 苹果 清爽"),
  11.                                 new Fruit("金帅苹果", "黄色 甜 软 苹果"), new Fruit("香蕉", "黄色 甜 软 香蕉 热带水果"), new Fruit("草莓", "红色 甜 小 草莓 多汁 浆果"),
  12.                                 new Fruit("西瓜", "绿色外皮 红色果肉 甜 大 西瓜 多汁 夏天"), new Fruit("葡萄", "紫色 甜 小 葡萄 多汁 成串"));
  13.                 // 为每个水果生成嵌入向量
  14.                 for (Fruit fruit : fruits) {
  15.                         fruit.embedding = client.getEmbeddingVector(model, fruit.description);
  16.                 }
  17.                 // 用户搜索
  18.                 String query = "红色的甜水果";
  19.                 double[] queryVector = client.getEmbeddingVector(model, query);
  20.                 System.out.println("搜索: "" + query + """);
  21.                 System.out.println("向量维度: " + queryVector.length);
  22.                 System.out.println();
  23.                 // 按相似度排序
  24.                 fruits.sort(Comparator.comparingDouble(f -> -cosineSimilarity(queryVector, f.embedding)));
  25.                 // 输出结果
  26.                 System.out.println("搜索结果(按相似度排序):");
  27.                 for (Fruit f : fruits) {
  28.                         double sim = cosineSimilarity(queryVector, f.embedding);
  29.                         System.out.printf("  %s (%.4f): %s%n", f.name, sim, f.description);
  30.                 }
  31.         }
  32.     /**
  33.          * 计算两个向量的余弦相似度
  34.          */
  35.         public static double cosineSimilarity(double[] vectorA, double[] vectorB) {
  36.                 if (vectorA.length != vectorB.length) {
  37.                         throw new IllegalArgumentException("向量维度必须相同");
  38.                 }
  39.                 double dotProduct = 0;
  40.                 double normA = 0;
  41.                 double normB = 0;
  42.                 for (int i = 0; i < vectorA.length; i++) {
  43.                         dotProduct += vectorA[i] * vectorB[i];
  44.                         normA += vectorA[i] * vectorA[i];
  45.                         normB += vectorB[i] * vectorB[i];
  46.                 }
  47.                 if (normA == 0 || normB == 0) {
  48.                         return 0;
  49.                 }
  50.                 return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
  51.         }
复制代码
核心方法
  1. @Slf4j
  2. public class EmbeddingClientImpl implements EmbeddingClient {
  3.         private final RestTemplate restTemplate;
  4.         private final String address;
  5.         private final String key;
  6.         public EmbeddingClientImpl(String address, String key) {
  7.                 PoolingHttpClientConnectionManager connectionManager = new PoolingHttpClientConnectionManager();
  8.                 connectionManager.setMaxTotal(100);
  9.                 connectionManager.setDefaultMaxPerRoute(20);
  10.                 // 设置请求配置
  11.                 RequestConfig requestConfig = RequestConfig.custom()
  12.                                 .setConnectionRequestTimeout(Timeout.ofSeconds(30))
  13.                                 .setResponseTimeout(Timeout.ofSeconds(300)) // 5分钟响应超时
  14.                                 .build();
  15.                 // 使用 HttpClientBuilder 来构建 HttpClient
  16.                 HttpClient httpClient = HttpClientBuilder.create()
  17.                                 .setConnectionManager(connectionManager)
  18.                                 .setDefaultRequestConfig(requestConfig)
  19.                                 .build();
  20.                 // 创建 HttpComponentsClientHttpRequestFactory
  21.                 HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
  22.                 requestFactory.setConnectTimeout(30000); // 30秒连接超时
  23.                 requestFactory.setConnectionRequestTimeout(30000);
  24.                 // 创建 RestTemplate,只使用 StringHttpMessageConverter 避免 Jackson 依赖问题
  25.                 this.restTemplate = new RestTemplate(requestFactory);
  26.                 // 清除默认的消息转换器,只保留字符串转换器
  27.                 this.restTemplate.setMessageConverters(
  28.                                 Collections.singletonList(new StringHttpMessageConverter(StandardCharsets.UTF_8)));
  29.                 this.address = address;
  30.                 this.key = key;
  31.         }
  32.         @Override
  33.         public String embedding(String model, String input) {
  34.                 long start = System.currentTimeMillis();
  35.                 String url = address;
  36.                 HttpHeaders headers = new HttpHeaders();
  37.                 headers.setContentType(MediaType.APPLICATION_JSON);
  38.                 headers.setAcceptCharset(Collections.singletonList(StandardCharsets.UTF_8));
  39.                 if (key != null && !key.isEmpty()) {
  40.                         headers.add("Authorization", "Bearer " + key);
  41.                 }
  42.                 // 将 request 转化为 body 字符串
  43.                 JSONObject jsonObject = new JSONObject();
  44.                 jsonObject.put("input", input);
  45.                 jsonObject.put("model", model);
  46.                 String body = jsonObject.toString();
  47.                 log.debug("Embedding Request Body: {}", body);
  48.                 // 请求
  49.                 HttpEntity<String> req = new HttpEntity<>(body, headers);
  50.                 ResponseEntity<String> result = restTemplate.postForEntity(url, req, String.class);
  51.                 if (!result.getStatusCode().equals(HttpStatus.OK)) {
  52.                         throw new RuntimeException("embeddings error, request: " + body + ", response: " + result.getBody());
  53.                 }
  54.                 log.info("embedding cost {} ms", System.currentTimeMillis() - start);
  55.                 return result.getBody();
  56.         }
  57.         /**
  58.          * 获取文本嵌入向量
  59.          * <p>
  60.          * 解析 OpenAI 格式的响应,提取 embedding 向量
  61.          *
  62.          * 响应格式示例: <pre>
  63.          * {
  64.          *   "object": "list",
  65.          *   "data": [{
  66.          *     "object": "embedding",
  67.          *     "index": 0,
  68.          *     "embedding": [0.0023064255, -0.009327292, ...]
  69.          *   }],
  70.          *   "model": "text-embedding-ada-002",
  71.          *   "usage": {"prompt_tokens": 8, "total_tokens": 8}
  72.          * }
  73.          * </pre>
  74.          * @param model 模型名称
  75.          * @param input 输入文本
  76.          * @return 嵌入向量
  77.          */
  78.         @Override
  79.         public double[] getEmbeddingVector(String model, String input) {
  80.                 String response = embedding(model, input);
  81.                 return parseEmbeddingVector(response);
  82.         }
  83.         /**
  84.          * 解析嵌入向量响应
  85.          * @param response JSON响应字符串
  86.          * @return 向量数组
  87.          */
  88.         private double[] parseEmbeddingVector(String response) {
  89.                 try {
  90.                         JSONObject jsonResponse = JSONObject.parseObject(response);
  91.                         // OpenAI 格式
  92.                         if (jsonResponse.containsKey("data")) {
  93.                                 JSONArray dataArray = jsonResponse.getJSONArray("data");
  94.                                 if (dataArray != null && !dataArray.isEmpty()) {
  95.                                         JSONObject firstData = dataArray.getJSONObject(0);
  96.                                         JSONArray embeddingArray = firstData.getJSONArray("embedding");
  97.                                         return jsonArrayToDoubleArray(embeddingArray);
  98.                                 }
  99.                         }
  100.                         // Ollama 格式 (直接返回 embedding 数组)
  101.                         if (jsonResponse.containsKey("embedding")) {
  102.                                 JSONArray embeddingArray = jsonResponse.getJSONArray("embedding");
  103.                                 return jsonArrayToDoubleArray(embeddingArray);
  104.                         }
  105.                         // 阿里通义格式
  106.                         if (jsonResponse.containsKey("output")) {
  107.                                 JSONObject output = jsonResponse.getJSONObject("output");
  108.                                 if (output.containsKey("embeddings")) {
  109.                                         JSONArray embeddings = output.getJSONArray("embeddings");
  110.                                         if (!embeddings.isEmpty()) {
  111.                                                 JSONObject firstEmbedding = embeddings.getJSONObject(0);
  112.                                                 JSONArray embeddingArray = firstEmbedding.getJSONArray("embedding");
  113.                                                 return jsonArrayToDoubleArray(embeddingArray);
  114.                                         }
  115.                                 }
  116.                         }
  117.                         throw new RuntimeException("无法解析嵌入向量响应: " + response);
  118.                 }
  119.                 catch (Exception e) {
  120.                         log.error("解析嵌入向量失败: {}", response, e);
  121.                         throw new RuntimeException("解析嵌入向量失败", e);
  122.                 }
  123.         }
  124.         /**
  125.          * 将 JSONArray 转换为 double 数组
  126.          */
  127.         private double[] jsonArrayToDoubleArray(JSONArray jsonArray) {
  128.                 double[] result = new double[jsonArray.size()];
  129.                 for (int i = 0; i < jsonArray.size(); i++) {
  130.                         result[i] = jsonArray.getDoubleValue(i);
  131.                 }
  132.                 return result;
  133.         }
  134. }
复制代码
2.png


来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册